| import html |
| import re |
| import unicodedata |
|
|
| import sqlparse |
|
|
|
|
| SQL_STARTERS = {"SELECT", "WITH", "INSERT", "UPDATE", "DELETE", "CREATE", "ALTER", "DROP"} |
|
|
|
|
| def content_to_text(value): |
| if value is None: |
| return "" |
| if isinstance(value, str): |
| return value |
| if isinstance(value, dict): |
| for key in ("text", "content", "value"): |
| if key in value: |
| return content_to_text(value[key]) |
| return " ".join(content_to_text(item) for item in value.values()) |
| if isinstance(value, (list, tuple)): |
| return "\n".join(content_to_text(item) for item in value) |
| return str(value) |
|
|
|
|
| def normalize_text(value): |
| text = content_to_text(value).lower() |
| text = unicodedata.normalize("NFKD", text) |
| text = "".join(char for char in text if not unicodedata.combining(char)) |
| return re.sub(r"\s+", " ", text).strip() |
|
|
|
|
| def clean_generation(text): |
| cleaned = content_to_text(text).strip() |
| if cleaned.startswith("```"): |
| lines = cleaned.splitlines() |
| if lines and lines[0].strip().lower() in {"```", "```sql"}: |
| lines = lines[1:] |
| if lines and lines[-1].strip() == "```": |
| lines = lines[:-1] |
| cleaned = "\n".join(lines).strip() |
| for marker in ("<|end|>", "<|user|>", "<|assistant|>", "</s>"): |
| if marker in cleaned: |
| cleaned = cleaned.split(marker, 1)[0].strip() |
| if cleaned.upper().startswith("SQL:"): |
| cleaned = cleaned[4:].strip() |
| return cleaned |
|
|
|
|
| def extract_sql_candidate(text): |
| cleaned = clean_generation(text) |
| match = re.search(r"\b(SELECT|WITH|INSERT|UPDATE|DELETE|CREATE|ALTER|DROP)\b", cleaned, flags=re.IGNORECASE) |
| if not match: |
| return cleaned |
| return cleaned[match.start() :].strip() |
|
|
|
|
| def is_sql_like(text): |
| text = (text or "").strip() |
| if not text: |
| return False |
| first_word = re.match(r"^\s*([A-Za-z]+)", text) |
| if not first_word: |
| return False |
| return first_word.group(1).upper() in SQL_STARTERS |
|
|
|
|
| def is_sql_intent(message, schema=""): |
| message = normalize_text(message) |
| if not message: |
| return False |
| smalltalk_patterns = { |
| "oi", "ola", "olá", "hi", "hello", "hey", "obrigado", "obrigada", "thanks", |
| "thank you", "como voce esta", "como você esta", "qual seu nome", "me conte uma piada", |
| "vamos conversar", "como voce funciona", "como funciona", "o que voce faz", "o que faz", |
| } |
| if message in {normalize_text(item) for item in smalltalk_patterns}: |
| return False |
| if any(pattern in message for pattern in ("como voce esta", "qual seu nome", "conte uma piada")): |
| return False |
| sql_terms = { |
| "all", "average", "count", "columns", "database", "find", "get", "group by", |
| "join", "list", "order by", "query", "rows", "select", "show", "sum", "where", |
| "consulta", "consultar", "contar", "colunas", "linhas", "liste", "listar", |
| "maior", "mais caro", "menor", "media", "mostre", "mostrar", "ordene", |
| "selecione", "some", "soma", "quantos", "filtre", "filtrar", |
| } |
| if any(re.search(rf"(?<!\w){re.escape(normalize_text(term))}(?!\w)", message) for term in sql_terms): |
| return True |
| return bool(schema and is_sql_like(message)) |
|
|
|
|
| def validate_sql(sql_text): |
| sql_text = (sql_text or "").strip() |
| if not sql_text: |
| return '<span class="validator-badge validator-empty">No SQL yet</span>' |
| try: |
| statements = [stmt for stmt in sqlparse.parse(sql_text) if str(stmt).strip()] |
| except Exception as exc: |
| error_type = html.escape(type(exc).__name__) |
| return ( |
| '<span class="validator-badge validator-warn">Check syntax</span>' |
| f'<span class="validator-detail">sqlparse error: {error_type}</span>' |
| ) |
| if not statements: |
| return ( |
| '<span class="validator-badge validator-warn">Check syntax</span>' |
| '<span class="validator-detail">No parsed SQL statement.</span>' |
| ) |
| first_token = statements[0].token_first(skip_cm=True) |
| token_value = first_token.value.strip().upper() if first_token is not None else "UNKNOWN" |
| if token_value not in SQL_STARTERS: |
| escaped_token = html.escape(token_value) |
| return ( |
| '<span class="validator-badge validator-warn">Check syntax</span>' |
| f'<span class="validator-detail">First token: {escaped_token}</span>' |
| ) |
| return '<span class="validator-badge validator-ok">Valid SQL</span>' |
|
|
|
|
| def is_create_table_intent(message): |
| message = (message or "").strip().lower() |
| return bool( |
| re.search(r"\b(create|make|build|generate|criar|crie|cria|criando|gerar|gere|gera|gerando|faz|faça|fazendo|monta|montar|monte)\b", message) |
| and re.search(r"\b(table|schema|tabela)\b", message) |
| ) |
|
|
|
|
| def is_rename_intent(message): |
| message = (message or "").strip().lower() |
| return bool( |
| re.search( |
| r"\b(rename|edit|change|renomeie|renomear|renomeia|altere|mude|muda|troca|trocar)\s+\w+\s+(to|para|as|como|por)\s+\w+", |
| message, |
| flags=re.IGNORECASE, |
| ) |
| ) |
|
|
|
|
| def is_table_edit_intent(message): |
| message = (message or "").strip().lower() |
| edit_terms = r"\b(edit|update|modify|alter|add|include|remove|delete|drop|edita|editar|altera|altere|alterar|mude|mudar|adicione|adicionar|inclua|incluir|acrescente|remova|remover|deletar|exclua|excluir|novo|nova|troca|trocar|coloque|colocar)\b" |
| direct_add_terms = r"\b(add|include|adicione|adicionar|adicionando|inclua|incluir|acrescente|coloque|colocar)\b" |
| direct_remove_terms = r"\b(remove|delete|drop|remova|remover|deletar|exclua|excluir)\b" |
| target_terms = r"\b(column|field|element|coluna|campo|elemento|item)\b" |
| sql_aggregation_terms = {"up", "sum", "total", "count", "average", "avg", "max", "min", "by"} |
| add_match = re.search(direct_add_terms, message) |
| if add_match: |
| after_add = message[add_match.start() + len(add_match.group()) :].strip() |
| first_word_after = after_add.split()[0] if after_add.split() else "" |
| is_add_intent = first_word_after not in sql_aggregation_terms |
| else: |
| is_add_intent = False |
| return bool( |
| is_add_intent |
| or re.search(direct_remove_terms, message) |
| or is_rename_intent(message) |
| or re.search(r"\b(?:altere|alterar|mude|mudar)\b.*\bter\b", message) |
| or (re.search(edit_terms, message) and (re.search(target_terms, message) or ":" in message or re.search(r"\bpor\b", message))) |
| ) |
|
|
|
|
| def infer_column_type(column_name): |
| name = column_name.strip().lower() |
| if name == "id" or name.endswith("_id") or name in {"quantity", "quantidade", "stock", "estoque", "year"}: |
| return "INTEGER" |
| if name in { |
| "salary", "price", "preco", "amount", "total", "grade", "peso", "weight", |
| "idade", "age", "altura", "height", "largura", "width", "comprimento", |
| "length", "desconto", "discount", |
| }: |
| return "NUMERIC" |
| if name in {"date", "created_at", "updated_at"} or name.endswith("_date"): |
| return "DATE" |
| return "TEXT" |
|
|
|
|
| def normalize_identifier(value): |
| identifier = re.sub(r"\W+", "_", normalize_text(value)).strip("_") |
| if not identifier: |
| return "" |
| if identifier[0].isdigit(): |
| identifier = f"col_{identifier}" |
| return identifier |
|
|
|
|
| def parse_column_definition(raw_column): |
| raw_column = re.sub(r"\b(for me|please|por favor)\b", "", raw_column or "", flags=re.IGNORECASE) |
| raw_column = raw_column.strip(" .;:") |
| if not raw_column: |
| return None |
| type_matches = list( |
| re.finditer( |
| r"\b(integer|int|numeric|decimal|real|float|double|text|varchar|char|date|datetime|timestamp|boolean|bool)\b", |
| raw_column, |
| flags=re.IGNORECASE, |
| ) |
| ) |
| explicit_type = type_matches[-1] if type_matches else None |
| if explicit_type: |
| name_part = raw_column[: explicit_type.start()].strip() |
| column_type = explicit_type.group(1).upper() |
| if column_type == "INT": |
| column_type = "INTEGER" |
| elif column_type == "BOOL": |
| column_type = "BOOLEAN" |
| elif column_type == "DECIMAL": |
| column_type = "NUMERIC" |
| elif column_type in {"FLOAT", "DOUBLE"}: |
| column_type = "REAL" |
| if not name_part.strip(): |
| column_type = None |
| name_part = raw_column |
| else: |
| name_part = raw_column |
| column_type = None |
| name_part = re.sub(r"\b(column|field|coluna|campo)\b", "", name_part, flags=re.IGNORECASE) |
| column_name = normalize_identifier(name_part) |
| if not column_name: |
| return None |
| return column_name, column_type or infer_column_type(column_name) |
|
|
|
|
| def split_column_list(columns_text): |
| columns_text = re.sub(r"\s+(and|e)\s+", ",", columns_text or "", flags=re.IGNORECASE) |
| parts = [] |
| type_pattern = r"\b(integer|int|numeric|decimal|real|float|double|text|varchar|char|date|datetime|timestamp|boolean|bool)\b" |
| type_tokens = { |
| "integer", "int", "numeric", "decimal", "real", "float", "double", |
| "text", "varchar", "char", "date", "datetime", "timestamp", "boolean", "bool", |
| } |
| stopwords = {"to", "from", "into", "as", "for", "o", "a", "os", "de", "do", "da", "dos", "das"} |
| for part in (item.strip() for item in columns_text.split(",") if item.strip()): |
| tokens = [token.strip() for token in re.split(r"\s+", part) if token.strip()] |
| tokens = [token for token in tokens if token.lower() not in stopwords] |
| if not tokens: |
| continue |
| if re.search(type_pattern, part, flags=re.IGNORECASE) and len(tokens) > 2: |
| index = 0 |
| inferrable_names = {"total", "date", "time", "timestamp", "int", "text", "real", "char"} |
| while index < len(tokens): |
| current = tokens[index] |
| next_token = tokens[index + 1].lower() if index + 1 < len(tokens) else "" |
| if next_token in type_tokens and not ( |
| current.lower() in inferrable_names and next_token in {"date", "datetime", "timestamp"} |
| ): |
| parts.append(f"{current} {tokens[index + 1]}") |
| index += 2 |
| else: |
| parts.append(current) |
| index += 1 |
| continue |
| if re.search(type_pattern, part, flags=re.IGNORECASE): |
| parts.append(part) |
| continue |
| if len(tokens) > 1 and all(re.match(r"^[A-Za-z_][\wÀ-ÿ]*$", token) for token in tokens): |
| parts.extend(tokens) |
| else: |
| parts.append(part) |
| return parts |
|
|
|
|
| def format_create_table(table_name, columns): |
| if not table_name or not columns: |
| return "" |
| seen = set() |
| column_lines = [] |
| for column_name, column_type in columns: |
| if column_name in seen: |
| continue |
| seen.add(column_name) |
| column_lines.append(f" {column_name} {column_type}") |
| if not column_lines: |
| return "" |
| return f"CREATE TABLE {table_name} (\n" + ",\n".join(column_lines) + "\n);" |
|
|
|
|
| def create_table_from_message(message): |
| message = (message or "").strip() |
| patterns = ( |
| r"\b(?:table|tabela)\s+(?:called\s+|named\s+|chamada?\s+|nomeada?\s+)?([A-Za-z_][\w]*)\s+(?:with|containing|including|com)\s+(.+)$", |
| r"\b(?:create|make|build|generate|criar|crie|gerar|gere)\b.*?\b(?:table|tabela)\b\s+([A-Za-z_][\w]*)\s+(?:with|containing|including|com)\s+(.+)$", |
| ) |
| for pattern in patterns: |
| match = re.search(pattern, message, flags=re.IGNORECASE) |
| if not match: |
| continue |
| table_name = normalize_identifier(match.group(1)) |
| columns = [ |
| parsed |
| for parsed in (parse_column_definition(column) for column in split_column_list(match.group(2))) |
| if parsed |
| ] |
| return format_create_table(table_name, columns) |
| return "" |
|
|
|
|
| def parse_create_table_schema(schema): |
| schema = (schema or "").strip() |
| match = re.match( |
| r"^\s*(?:CREATE\s+TABLE\s+)?([A-Za-z_][\w]*)\s*\((.*?)\)\s*;?\s*$", |
| schema, |
| flags=re.IGNORECASE | re.DOTALL, |
| ) |
| if not match: |
| return "", [] |
| table_name = normalize_identifier(match.group(1)) |
| columns = [ |
| parsed |
| for parsed in (parse_column_definition(column) for column in split_column_list(match.group(2))) |
| if parsed |
| ] |
| return table_name, columns |
|
|
|
|
| def create_table_from_schema(schema): |
| table_name, columns = parse_create_table_schema(schema) |
| return format_create_table(table_name, columns) |
|
|
|
|
| def extract_create_table_statement(text): |
| cleaned = extract_sql_candidate(text) |
| match = re.search( |
| r"\bCREATE\s+TABLE\s+[A-Za-z_][\w]*\s*\(.*?\)\s*;?", |
| cleaned, |
| flags=re.IGNORECASE | re.DOTALL, |
| ) |
| return clean_generation(match.group(0)) if match else "" |
|
|
|
|
| def last_create_table_from_history(chat_history): |
| for item in reversed(list(chat_history or [])): |
| if not isinstance(item, dict) or item.get("role") != "assistant": |
| continue |
| statement = extract_create_table_statement(item.get("content", "")) |
| if statement: |
| return statement |
| return "" |
|
|
|
|
| def extract_added_columns(message): |
| message = (message or "").strip() |
| patterns = ( |
| r":\s*(.+)$", |
| r"\b(?:add|include|with|adicionar|adicione|adicionando|inclua|incluir|acrescente|ter|coloque|colocar)\b\s+(?:um\s+|uma\s+|a\s+|an\s+)?(?:novo\s+|nova\s+|new\s+)?(?:column|field|element|coluna|campo|elemento|item)?\s*(.+)$", |
| ) |
| for pattern in patterns: |
| match = re.search(pattern, message, flags=re.IGNORECASE) |
| if not match: |
| continue |
| columns = [ |
| parsed |
| for parsed in (parse_column_definition(column) for column in split_column_list(match.group(1))) |
| if parsed |
| ] |
| if columns: |
| return columns |
| return [] |
|
|
|
|
| def extract_removed_columns(message): |
| message = (message or "").strip() |
| patterns = ( |
| r"\b(?:remove|delete|drop|remova|remover|deletar|exclua|excluir)\b\s+(?:a\s+|o\s+|the\s+)?(?:column|field|element|coluna|campo|elemento|item)?\s*(.+)$", |
| ) |
| for pattern in patterns: |
| match = re.search(pattern, message, flags=re.IGNORECASE) |
| if not match: |
| continue |
| columns = [normalize_identifier(column) for column in split_column_list(match.group(1))] |
| columns = [column for column in columns if column] |
| if columns: |
| return columns |
| return [] |
|
|
|
|
| def extract_renamed_columns(message): |
| pattern = ( |
| r"\b(?:rename|edit|change|renomeie|renomear|renomeia|altere|mude)\s+" |
| r"(\w+)\s+(?:to|para|as|como|por)\s+(\w+)" |
| ) |
| matches = re.findall(pattern, message or "", flags=re.IGNORECASE) |
| troca_matches = re.findall(r"\btroca\b\s+(\w+)\s+\bpor\b\s+(\w+)", message or "", flags=re.IGNORECASE) |
| return [ |
| (normalize_identifier(old), normalize_identifier(new)) |
| for old, new in [*matches, *troca_matches] |
| if normalize_identifier(old) and normalize_identifier(new) |
| ] |
|
|
|
|
| def parse_compound_edit(message): |
| segment_pattern = ( |
| r"\s+(?:and|e)\s+" |
| r"(?=\b(?:add|include|remove|delete|drop|rename|edit|change|" |
| r"adicione|adicionar|inclua|acrescente|remova|remover|deletar|" |
| r"exclua|renomeie|renomear|renomeia|altere|mude|troca|trocar)\b)" |
| ) |
| segments = re.split(segment_pattern, message or "", flags=re.IGNORECASE) |
| added, removed, renamed = [], [], [] |
| for seg in segments: |
| seg = seg.strip() |
| if not seg: |
| continue |
| if is_rename_intent(seg): |
| renamed.extend(extract_renamed_columns(seg)) |
| elif re.search(r"\b(remove|delete|drop|remova|remover|deletar|exclua|excluir)\b", seg, flags=re.IGNORECASE): |
| removed.extend(extract_removed_columns(seg)) |
| else: |
| cols = extract_added_columns(seg) |
| if cols: |
| added.extend(cols) |
| return added, removed, renamed |
|
|
|
|
| def edit_create_table_from_message(message, chat_history, active_schema): |
| if not is_table_edit_intent(message) and not is_rename_intent(message): |
| return "" |
| base_sql = last_create_table_from_history(chat_history) or create_table_from_schema(active_schema) |
| table_name, existing_columns = parse_create_table_schema(base_sql) |
| if not table_name: |
| return "" |
| added_columns, removed_columns_list, renamed_columns = parse_compound_edit(message) |
| removed_set = set(extract_removed_columns(message)) | {r for r in removed_columns_list} |
| if not added_columns and not removed_set and not renamed_columns: |
| return "" |
| rename_map = dict(renamed_columns) |
| kept_columns = [ |
| (rename_map.get(col_name, col_name), col_type) |
| for col_name, col_type in existing_columns |
| if col_name not in removed_set |
| ] |
| return format_create_table(table_name, [*kept_columns, *added_columns]) |
|
|
|
|
| def create_table_from_suggestion(suggestion): |
| if not suggestion: |
| return "" |
| if isinstance(suggestion, dict): |
| table_name = suggestion.get("table_name") |
| columns = [ |
| (column.get("name"), column.get("type", "TEXT")) |
| for column in suggestion.get("columns", []) |
| if isinstance(column, dict) |
| ] |
| else: |
| table_name = getattr(suggestion, "table_name", "") |
| columns = getattr(suggestion, "columns", ()) |
| parsed = [] |
| for name, column_type in columns: |
| identifier = normalize_identifier(name) |
| if identifier: |
| parsed.append((identifier, (column_type or "TEXT").upper())) |
| return format_create_table(normalize_identifier(table_name), parsed) |
|
|
|
|