Shizu0n's picture
refactor: split chat flow from SQL routing
47affa0
import html
import re
import unicodedata
import sqlparse
SQL_STARTERS = {"SELECT", "WITH", "INSERT", "UPDATE", "DELETE", "CREATE", "ALTER", "DROP"}
def content_to_text(value):
if value is None:
return ""
if isinstance(value, str):
return value
if isinstance(value, dict):
for key in ("text", "content", "value"):
if key in value:
return content_to_text(value[key])
return " ".join(content_to_text(item) for item in value.values())
if isinstance(value, (list, tuple)):
return "\n".join(content_to_text(item) for item in value)
return str(value)
def normalize_text(value):
text = content_to_text(value).lower()
text = unicodedata.normalize("NFKD", text)
text = "".join(char for char in text if not unicodedata.combining(char))
return re.sub(r"\s+", " ", text).strip()
def clean_generation(text):
cleaned = content_to_text(text).strip()
if cleaned.startswith("```"):
lines = cleaned.splitlines()
if lines and lines[0].strip().lower() in {"```", "```sql"}:
lines = lines[1:]
if lines and lines[-1].strip() == "```":
lines = lines[:-1]
cleaned = "\n".join(lines).strip()
for marker in ("<|end|>", "<|user|>", "<|assistant|>", "</s>"):
if marker in cleaned:
cleaned = cleaned.split(marker, 1)[0].strip()
if cleaned.upper().startswith("SQL:"):
cleaned = cleaned[4:].strip()
return cleaned
def extract_sql_candidate(text):
cleaned = clean_generation(text)
match = re.search(r"\b(SELECT|WITH|INSERT|UPDATE|DELETE|CREATE|ALTER|DROP)\b", cleaned, flags=re.IGNORECASE)
if not match:
return cleaned
return cleaned[match.start() :].strip()
def is_sql_like(text):
text = (text or "").strip()
if not text:
return False
first_word = re.match(r"^\s*([A-Za-z]+)", text)
if not first_word:
return False
return first_word.group(1).upper() in SQL_STARTERS
def is_sql_intent(message, schema=""):
message = normalize_text(message)
if not message:
return False
smalltalk_patterns = {
"oi", "ola", "olá", "hi", "hello", "hey", "obrigado", "obrigada", "thanks",
"thank you", "como voce esta", "como você esta", "qual seu nome", "me conte uma piada",
"vamos conversar", "como voce funciona", "como funciona", "o que voce faz", "o que faz",
}
if message in {normalize_text(item) for item in smalltalk_patterns}:
return False
if any(pattern in message for pattern in ("como voce esta", "qual seu nome", "conte uma piada")):
return False
sql_terms = {
"all", "average", "count", "columns", "database", "find", "get", "group by",
"join", "list", "order by", "query", "rows", "select", "show", "sum", "where",
"consulta", "consultar", "contar", "colunas", "linhas", "liste", "listar",
"maior", "mais caro", "menor", "media", "mostre", "mostrar", "ordene",
"selecione", "some", "soma", "quantos", "filtre", "filtrar",
}
if any(re.search(rf"(?<!\w){re.escape(normalize_text(term))}(?!\w)", message) for term in sql_terms):
return True
return bool(schema and is_sql_like(message))
def validate_sql(sql_text):
sql_text = (sql_text or "").strip()
if not sql_text:
return '<span class="validator-badge validator-empty">No SQL yet</span>'
try:
statements = [stmt for stmt in sqlparse.parse(sql_text) if str(stmt).strip()]
except Exception as exc:
error_type = html.escape(type(exc).__name__)
return (
'<span class="validator-badge validator-warn">Check syntax</span>'
f'<span class="validator-detail">sqlparse error: {error_type}</span>'
)
if not statements:
return (
'<span class="validator-badge validator-warn">Check syntax</span>'
'<span class="validator-detail">No parsed SQL statement.</span>'
)
first_token = statements[0].token_first(skip_cm=True)
token_value = first_token.value.strip().upper() if first_token is not None else "UNKNOWN"
if token_value not in SQL_STARTERS:
escaped_token = html.escape(token_value)
return (
'<span class="validator-badge validator-warn">Check syntax</span>'
f'<span class="validator-detail">First token: {escaped_token}</span>'
)
return '<span class="validator-badge validator-ok">Valid SQL</span>'
def is_create_table_intent(message):
message = (message or "").strip().lower()
return bool(
re.search(r"\b(create|make|build|generate|criar|crie|cria|criando|gerar|gere|gera|gerando|faz|faça|fazendo|monta|montar|monte)\b", message)
and re.search(r"\b(table|schema|tabela)\b", message)
)
def is_rename_intent(message):
message = (message or "").strip().lower()
return bool(
re.search(
r"\b(rename|edit|change|renomeie|renomear|renomeia|altere|mude|muda|troca|trocar)\s+\w+\s+(to|para|as|como|por)\s+\w+",
message,
flags=re.IGNORECASE,
)
)
def is_table_edit_intent(message):
message = (message or "").strip().lower()
edit_terms = r"\b(edit|update|modify|alter|add|include|remove|delete|drop|edita|editar|altera|altere|alterar|mude|mudar|adicione|adicionar|inclua|incluir|acrescente|remova|remover|deletar|exclua|excluir|novo|nova|troca|trocar|coloque|colocar)\b"
direct_add_terms = r"\b(add|include|adicione|adicionar|adicionando|inclua|incluir|acrescente|coloque|colocar)\b"
direct_remove_terms = r"\b(remove|delete|drop|remova|remover|deletar|exclua|excluir)\b"
target_terms = r"\b(column|field|element|coluna|campo|elemento|item)\b"
sql_aggregation_terms = {"up", "sum", "total", "count", "average", "avg", "max", "min", "by"}
add_match = re.search(direct_add_terms, message)
if add_match:
after_add = message[add_match.start() + len(add_match.group()) :].strip()
first_word_after = after_add.split()[0] if after_add.split() else ""
is_add_intent = first_word_after not in sql_aggregation_terms
else:
is_add_intent = False
return bool(
is_add_intent
or re.search(direct_remove_terms, message)
or is_rename_intent(message)
or re.search(r"\b(?:altere|alterar|mude|mudar)\b.*\bter\b", message)
or (re.search(edit_terms, message) and (re.search(target_terms, message) or ":" in message or re.search(r"\bpor\b", message)))
)
def infer_column_type(column_name):
name = column_name.strip().lower()
if name == "id" or name.endswith("_id") or name in {"quantity", "quantidade", "stock", "estoque", "year"}:
return "INTEGER"
if name in {
"salary", "price", "preco", "amount", "total", "grade", "peso", "weight",
"idade", "age", "altura", "height", "largura", "width", "comprimento",
"length", "desconto", "discount",
}:
return "NUMERIC"
if name in {"date", "created_at", "updated_at"} or name.endswith("_date"):
return "DATE"
return "TEXT"
def normalize_identifier(value):
identifier = re.sub(r"\W+", "_", normalize_text(value)).strip("_")
if not identifier:
return ""
if identifier[0].isdigit():
identifier = f"col_{identifier}"
return identifier
def parse_column_definition(raw_column):
raw_column = re.sub(r"\b(for me|please|por favor)\b", "", raw_column or "", flags=re.IGNORECASE)
raw_column = raw_column.strip(" .;:")
if not raw_column:
return None
type_matches = list(
re.finditer(
r"\b(integer|int|numeric|decimal|real|float|double|text|varchar|char|date|datetime|timestamp|boolean|bool)\b",
raw_column,
flags=re.IGNORECASE,
)
)
explicit_type = type_matches[-1] if type_matches else None
if explicit_type:
name_part = raw_column[: explicit_type.start()].strip()
column_type = explicit_type.group(1).upper()
if column_type == "INT":
column_type = "INTEGER"
elif column_type == "BOOL":
column_type = "BOOLEAN"
elif column_type == "DECIMAL":
column_type = "NUMERIC"
elif column_type in {"FLOAT", "DOUBLE"}:
column_type = "REAL"
if not name_part.strip():
column_type = None
name_part = raw_column
else:
name_part = raw_column
column_type = None
name_part = re.sub(r"\b(column|field|coluna|campo)\b", "", name_part, flags=re.IGNORECASE)
column_name = normalize_identifier(name_part)
if not column_name:
return None
return column_name, column_type or infer_column_type(column_name)
def split_column_list(columns_text):
columns_text = re.sub(r"\s+(and|e)\s+", ",", columns_text or "", flags=re.IGNORECASE)
parts = []
type_pattern = r"\b(integer|int|numeric|decimal|real|float|double|text|varchar|char|date|datetime|timestamp|boolean|bool)\b"
type_tokens = {
"integer", "int", "numeric", "decimal", "real", "float", "double",
"text", "varchar", "char", "date", "datetime", "timestamp", "boolean", "bool",
}
stopwords = {"to", "from", "into", "as", "for", "o", "a", "os", "de", "do", "da", "dos", "das"}
for part in (item.strip() for item in columns_text.split(",") if item.strip()):
tokens = [token.strip() for token in re.split(r"\s+", part) if token.strip()]
tokens = [token for token in tokens if token.lower() not in stopwords]
if not tokens:
continue
if re.search(type_pattern, part, flags=re.IGNORECASE) and len(tokens) > 2:
index = 0
inferrable_names = {"total", "date", "time", "timestamp", "int", "text", "real", "char"}
while index < len(tokens):
current = tokens[index]
next_token = tokens[index + 1].lower() if index + 1 < len(tokens) else ""
if next_token in type_tokens and not (
current.lower() in inferrable_names and next_token in {"date", "datetime", "timestamp"}
):
parts.append(f"{current} {tokens[index + 1]}")
index += 2
else:
parts.append(current)
index += 1
continue
if re.search(type_pattern, part, flags=re.IGNORECASE):
parts.append(part)
continue
if len(tokens) > 1 and all(re.match(r"^[A-Za-z_][\wÀ-ÿ]*$", token) for token in tokens):
parts.extend(tokens)
else:
parts.append(part)
return parts
def format_create_table(table_name, columns):
if not table_name or not columns:
return ""
seen = set()
column_lines = []
for column_name, column_type in columns:
if column_name in seen:
continue
seen.add(column_name)
column_lines.append(f" {column_name} {column_type}")
if not column_lines:
return ""
return f"CREATE TABLE {table_name} (\n" + ",\n".join(column_lines) + "\n);"
def create_table_from_message(message):
message = (message or "").strip()
patterns = (
r"\b(?:table|tabela)\s+(?:called\s+|named\s+|chamada?\s+|nomeada?\s+)?([A-Za-z_][\w]*)\s+(?:with|containing|including|com)\s+(.+)$",
r"\b(?:create|make|build|generate|criar|crie|gerar|gere)\b.*?\b(?:table|tabela)\b\s+([A-Za-z_][\w]*)\s+(?:with|containing|including|com)\s+(.+)$",
)
for pattern in patterns:
match = re.search(pattern, message, flags=re.IGNORECASE)
if not match:
continue
table_name = normalize_identifier(match.group(1))
columns = [
parsed
for parsed in (parse_column_definition(column) for column in split_column_list(match.group(2)))
if parsed
]
return format_create_table(table_name, columns)
return ""
def parse_create_table_schema(schema):
schema = (schema or "").strip()
match = re.match(
r"^\s*(?:CREATE\s+TABLE\s+)?([A-Za-z_][\w]*)\s*\((.*?)\)\s*;?\s*$",
schema,
flags=re.IGNORECASE | re.DOTALL,
)
if not match:
return "", []
table_name = normalize_identifier(match.group(1))
columns = [
parsed
for parsed in (parse_column_definition(column) for column in split_column_list(match.group(2)))
if parsed
]
return table_name, columns
def create_table_from_schema(schema):
table_name, columns = parse_create_table_schema(schema)
return format_create_table(table_name, columns)
def extract_create_table_statement(text):
cleaned = extract_sql_candidate(text)
match = re.search(
r"\bCREATE\s+TABLE\s+[A-Za-z_][\w]*\s*\(.*?\)\s*;?",
cleaned,
flags=re.IGNORECASE | re.DOTALL,
)
return clean_generation(match.group(0)) if match else ""
def last_create_table_from_history(chat_history):
for item in reversed(list(chat_history or [])):
if not isinstance(item, dict) or item.get("role") != "assistant":
continue
statement = extract_create_table_statement(item.get("content", ""))
if statement:
return statement
return ""
def extract_added_columns(message):
message = (message or "").strip()
patterns = (
r":\s*(.+)$",
r"\b(?:add|include|with|adicionar|adicione|adicionando|inclua|incluir|acrescente|ter|coloque|colocar)\b\s+(?:um\s+|uma\s+|a\s+|an\s+)?(?:novo\s+|nova\s+|new\s+)?(?:column|field|element|coluna|campo|elemento|item)?\s*(.+)$",
)
for pattern in patterns:
match = re.search(pattern, message, flags=re.IGNORECASE)
if not match:
continue
columns = [
parsed
for parsed in (parse_column_definition(column) for column in split_column_list(match.group(1)))
if parsed
]
if columns:
return columns
return []
def extract_removed_columns(message):
message = (message or "").strip()
patterns = (
r"\b(?:remove|delete|drop|remova|remover|deletar|exclua|excluir)\b\s+(?:a\s+|o\s+|the\s+)?(?:column|field|element|coluna|campo|elemento|item)?\s*(.+)$",
)
for pattern in patterns:
match = re.search(pattern, message, flags=re.IGNORECASE)
if not match:
continue
columns = [normalize_identifier(column) for column in split_column_list(match.group(1))]
columns = [column for column in columns if column]
if columns:
return columns
return []
def extract_renamed_columns(message):
pattern = (
r"\b(?:rename|edit|change|renomeie|renomear|renomeia|altere|mude)\s+"
r"(\w+)\s+(?:to|para|as|como|por)\s+(\w+)"
)
matches = re.findall(pattern, message or "", flags=re.IGNORECASE)
troca_matches = re.findall(r"\btroca\b\s+(\w+)\s+\bpor\b\s+(\w+)", message or "", flags=re.IGNORECASE)
return [
(normalize_identifier(old), normalize_identifier(new))
for old, new in [*matches, *troca_matches]
if normalize_identifier(old) and normalize_identifier(new)
]
def parse_compound_edit(message):
segment_pattern = (
r"\s+(?:and|e)\s+"
r"(?=\b(?:add|include|remove|delete|drop|rename|edit|change|"
r"adicione|adicionar|inclua|acrescente|remova|remover|deletar|"
r"exclua|renomeie|renomear|renomeia|altere|mude|troca|trocar)\b)"
)
segments = re.split(segment_pattern, message or "", flags=re.IGNORECASE)
added, removed, renamed = [], [], []
for seg in segments:
seg = seg.strip()
if not seg:
continue
if is_rename_intent(seg):
renamed.extend(extract_renamed_columns(seg))
elif re.search(r"\b(remove|delete|drop|remova|remover|deletar|exclua|excluir)\b", seg, flags=re.IGNORECASE):
removed.extend(extract_removed_columns(seg))
else:
cols = extract_added_columns(seg)
if cols:
added.extend(cols)
return added, removed, renamed
def edit_create_table_from_message(message, chat_history, active_schema):
if not is_table_edit_intent(message) and not is_rename_intent(message):
return ""
base_sql = last_create_table_from_history(chat_history) or create_table_from_schema(active_schema)
table_name, existing_columns = parse_create_table_schema(base_sql)
if not table_name:
return ""
added_columns, removed_columns_list, renamed_columns = parse_compound_edit(message)
removed_set = set(extract_removed_columns(message)) | {r for r in removed_columns_list}
if not added_columns and not removed_set and not renamed_columns:
return ""
rename_map = dict(renamed_columns)
kept_columns = [
(rename_map.get(col_name, col_name), col_type)
for col_name, col_type in existing_columns
if col_name not in removed_set
]
return format_create_table(table_name, [*kept_columns, *added_columns])
def create_table_from_suggestion(suggestion):
if not suggestion:
return ""
if isinstance(suggestion, dict):
table_name = suggestion.get("table_name")
columns = [
(column.get("name"), column.get("type", "TEXT"))
for column in suggestion.get("columns", [])
if isinstance(column, dict)
]
else:
table_name = getattr(suggestion, "table_name", "")
columns = getattr(suggestion, "columns", ())
parsed = []
for name, column_type in columns:
identifier = normalize_identifier(name)
if identifier:
parsed.append((identifier, (column_type or "TEXT").upper()))
return format_create_table(normalize_identifier(table_name), parsed)