Shizu0n commited on
Commit
47affa0
·
1 Parent(s): bc39556

refactor: split chat flow from SQL routing

Browse files
Files changed (10) hide show
  1. .gitignore +3 -4
  2. README.md +14 -3
  3. app.py +365 -468
  4. chat_state.py +124 -0
  5. intent.py +115 -0
  6. model_io.py +132 -0
  7. scripts/model_probe.py +115 -0
  8. sql_tools.py +454 -0
  9. tests/test_chatbot_behavior.py +21 -1
  10. tests/test_chatbot_core.py +105 -0
.gitignore CHANGED
@@ -45,11 +45,10 @@ logs/
45
  *.ckpt
46
  *.gguf
47
 
48
- # AI-generated code artifacts
49
- *.gen.py
50
- .claude
51
-
52
  # Local agent/workspace notes
53
  /AGENTS.md
54
  /CLAUDE.md
55
  /PROGRESS.md
 
 
 
 
45
  *.ckpt
46
  *.gguf
47
 
 
 
 
 
48
  # Local agent/workspace notes
49
  /AGENTS.md
50
  /CLAUDE.md
51
  /PROGRESS.md
52
+ .claude
53
+ .gstack/
54
+ docs
README.md CHANGED
@@ -24,7 +24,7 @@ Transforms simple table descriptions and questions into SQL using the fine-tuned
24
  1. Click **Load fine-tuned model**.
25
  - Loading is lazy: the model is only downloaded and loaded when you request it.
26
  - On CPU, the first load can take a few minutes.
27
- 2. Enter or edit the **SQL table schema**.
28
  - You can use the presets: `employees`, `orders`, `students`, `products`, `sales`.
29
  - You can also write your own schema manually.
30
  3. Enter the question in the chat input.
@@ -38,6 +38,7 @@ Transforms simple table descriptions and questions into SQL using the fine-tuned
38
  - Fine-tuned merged model used in the app: [Shizu0n/phi3-mini-sql-generator-merged](https://huggingface.co/Shizu0n/phi3-mini-sql-generator-merged)
39
  - Offline baseline model used for evaluation: [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
40
 
 
41
  ## Metrics
42
 
43
  | Model | Exact match |
@@ -49,14 +50,24 @@ Reported gain: **+71.5 percentage points** over the base model.
49
 
50
  ## Current Features
51
 
52
- - Gradio UI with a step-by-step flow: load the fine-tuned model, enter schema/question, and generate SQL.
53
- - Offline baseline metrics shown in the UI without loading a second 3.8B model on the CPU Space.
54
  - Lazy loading to reduce startup cost.
55
  - Preserved Phi-3 patches for local/Spaces compatibility.
56
  - Schema presets without blocking manual input.
57
  - SQL output separated from errors/status so booleans, integers, and error messages do not appear inside the SQL block.
58
  - Centered loading overlay to make the loading state obvious.
59
 
 
 
 
 
 
 
 
 
 
 
60
  ## Run Locally
61
 
62
  ```bash
 
24
  1. Click **Load fine-tuned model**.
25
  - Loading is lazy: the model is only downloaded and loaded when you request it.
26
  - On CPU, the first load can take a few minutes.
27
+ 2. Chat normally or enter/edit the **SQL table schema**.
28
  - You can use the presets: `employees`, `orders`, `students`, `products`, `sales`.
29
  - You can also write your own schema manually.
30
  3. Enter the question in the chat input.
 
38
  - Fine-tuned merged model used in the app: [Shizu0n/phi3-mini-sql-generator-merged](https://huggingface.co/Shizu0n/phi3-mini-sql-generator-merged)
39
  - Offline baseline model used for evaluation: [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
40
 
41
+
42
  ## Metrics
43
 
44
  | Model | Exact match |
 
50
 
51
  ## Current Features
52
 
53
+ - Gradio UI with a step-by-step flow: load the fine-tuned model, chat, and inspect SQL artifacts.
54
+ - Intent routing that keeps normal conversation separate from SQL generation.
55
  - Lazy loading to reduce startup cost.
56
  - Preserved Phi-3 patches for local/Spaces compatibility.
57
  - Schema presets without blocking manual input.
58
  - SQL output separated from errors/status so booleans, integers, and error messages do not appear inside the SQL block.
59
  - Centered loading overlay to make the loading state obvious.
60
 
61
+ ## Model Probe
62
+
63
+ The normal pytest suite does not load the 3.8B model. To manually verify the real model behavior:
64
+
65
+ ```bash
66
+ python scripts/model_probe.py
67
+ ```
68
+
69
+ The probe prints JSON with pass/fail checks for greeting, schema proposal, CREATE TABLE confirmation, schema edit, SQL query, and smalltalk while a schema is active.
70
+
71
  ## Run Locally
72
 
73
  ```bash
app.py CHANGED
@@ -12,6 +12,11 @@ import unicodedata
12
  import gradio as gr
13
  import sqlparse
14
 
 
 
 
 
 
15
 
16
  BASE_MODEL_ID = "microsoft/Phi-3-mini-4k-instruct"
17
  FINE_TUNED_MODEL_ID = "Shizu0n/phi3-mini-sql-generator-merged"
@@ -447,166 +452,18 @@ def is_sql_like(text):
447
  }
448
 
449
 
450
- def is_sql_intent(message, schema):
451
- message = normalize_text(message)
452
- schema = (schema or "").strip()
453
- if not message:
454
- return False
455
- # P1 fix: if schema exists and message has substance, treat as SQL intent
456
- # (user is likely asking a question about the known schema)
457
- # Exclude short greetings/acknowledgments that could accompany a schema setup
458
- short_greetings = {
459
- "oi", "olá", "ola", "hi", "hello", "hey", "bom", "boa",
460
- "obrigado", "thanks", "ok", "sim", "claro", "de nada",
461
- }
462
- # Extended exclusions for FAQ/off-topic with schema active
463
- off_topic_patterns = {
464
- "obrigado", "thanks", "thank you", "muito obrigado", "obrigada",
465
- "como você funciona", "como voce funciona", "como funciona",
466
- "o que você faz", "o que voce faz", "o que faz",
467
- "como foi treinado", "como voce foi treinado", "treinado",
468
- "quais habilidades", "o que consegue", "o que pode fazer",
469
- "me ajude", "help me", "ajuda", "help",
470
- # Edit/table manipulation terms — prevent blanket-catch from routing to model
471
- "troca", "trocar", "renomeia", "renomear", "renomeie",
472
- "muda", "mudar", "altera", "alterar", "edita", "editar",
473
- "adiciona", "adicionar", "adicione", "remove", "remover",
474
- "apaga", "apagar", "delete column", "drop column",
475
- "coluna nova", "nova coluna", "novo campo", "campo novo",
476
- "trocando", "mudando", "alterando", "editando",
477
- }
478
- words = message.split()
479
- # Check if message is off-topic even with 2+ words
480
- if schema and len(words) >= 2:
481
- # Check exact matches and patterns
482
- if message in short_greetings or message in off_topic_patterns:
483
- return False
484
- # Check partial matches for common off-topic phrases
485
- for pattern in off_topic_patterns:
486
- if pattern in message:
487
- return False
488
- if schema and len(words) >= 2 and message not in short_greetings:
489
- return True
490
- sql_terms = {
491
- "all",
492
- "average",
493
- "count",
494
- "columns",
495
- "database",
496
- "find",
497
- "get",
498
- "group by",
499
- "join",
500
- "list",
501
- "order by",
502
- "query",
503
- "rows",
504
- "schema",
505
- "select",
506
- "show",
507
- "sql",
508
- "sum",
509
- "table",
510
- "where",
511
- "consulta",
512
- "consultar",
513
- "contar",
514
- "colunas",
515
- "linhas",
516
- "liste",
517
- "listar",
518
- "maior",
519
- "mais caro",
520
- "menor",
521
- "media",
522
- "média",
523
- "mostre",
524
- "mostrar",
525
- "ordene",
526
- "por departamento",
527
- "selecione",
528
- "sql",
529
- "some",
530
- "soma",
531
- "tabela",
532
- }
533
- return any(
534
- re.search(rf"(?<!\w){re.escape(normalize_text(term))}(?!\w)", message)
535
- for term in sql_terms
536
- )
537
-
538
-
539
- def build_generation_prompt(schema, message, chat_history=None):
540
- schema = (schema or "").strip()
541
- message = (message or "").strip()
542
- if is_sql_intent(message, schema):
543
- table_schema = schema or "CREATE TABLE unknown (id INTEGER)"
544
- # Inject last 3 conversation exchanges for multi-turn context
545
- history_context = ""
546
- if chat_history:
547
- trimmed = trim_chat_history(chat_history, max_exchanges=3)
548
- if trimmed:
549
- lines = []
550
- for i in range(0, len(trimmed), 2):
551
- entry1 = trimmed[i]
552
- entry2 = trimmed[i + 1] if i + 1 < len(trimmed) else None
553
- user_msg = entry1.get("content", "") if isinstance(entry1, dict) else (entry1[1] if isinstance(entry1, tuple) else str(entry1))
554
- asst_msg = entry2.get("content", "") if isinstance(entry2, dict) else (entry2[1] if isinstance(entry2, tuple) else str(entry2)) if entry2 else ""
555
- lines.append(f"User: {user_msg}")
556
- if asst_msg:
557
- lines.append(f"Assistant: {asst_msg}")
558
- if lines:
559
- history_context = "\n\nPrevious conversation:\n" + "\n".join(lines) + "\n"
560
- return PROMPT_TEMPLATE.format(schema=table_schema, question=message) + history_context
561
- return GENERAL_PROMPT_TEMPLATE.format(message=message)
562
-
563
-
564
- def format_generation_result(text):
565
- cleaned = extract_sql_candidate(text)
566
- if is_sql_like(cleaned):
567
- return str(cleaned), EMPTY_CHAT_OUTPUT, validate_sql(cleaned)
568
- return "", str(cleaned), CHAT_VALIDATOR
569
-
570
-
571
- def validate_sql(sql_text):
572
- sql_text = (sql_text or "").strip()
573
- if not sql_text:
574
- return EMPTY_VALIDATOR
575
- try:
576
- statements = [stmt for stmt in sqlparse.parse(sql_text) if str(stmt).strip()]
577
- except Exception as exc:
578
- error_type = html.escape(type(exc).__name__)
579
- return (
580
- '<span class="validator-badge validator-warn">Check syntax</span>'
581
- f'<span class="validator-detail">sqlparse error: {error_type}</span>'
582
- )
583
- if not statements:
584
- return (
585
- '<span class="validator-badge validator-warn">Check syntax</span>'
586
- '<span class="validator-detail">No parsed SQL statement.</span>'
587
- )
588
- first_token = statements[0].token_first(skip_cm=True)
589
- token_value = first_token.value.strip().upper() if first_token is not None else "UNKNOWN"
590
- allowed_starters = {"SELECT", "WITH", "INSERT", "UPDATE", "DELETE", "CREATE", "ALTER", "DROP"}
591
- if token_value not in allowed_starters:
592
- escaped_token = html.escape(token_value)
593
- return (
594
- '<span class="validator-badge validator-warn">Check syntax</span>'
595
- f'<span class="validator-detail">First token: {escaped_token}</span>'
596
- )
597
- return '<span class="validator-badge validator-ok">Valid SQL</span>'
598
 
599
 
600
  def render_header():
601
  return """
602
  <section class="top-panel">
603
  <div>
604
- <h1>Phi-3 Mini SQL Generator</h1>
605
- <p>QLoRA fine-tuned - b-mc2/sql-create-context</p>
606
  </div>
607
  <div class="top-badges">
608
- <span class="badge badge-green">73.5% exact match</span>
609
- <span class="badge badge-cream">+71.5pp vs base</span>
610
  <span class="badge badge-light">CPU lazy load</span>
611
  </div>
612
  </section>
@@ -677,10 +534,10 @@ def render_loading_overlay(model_key=None, visible=False):
677
  def model_metadata(model_key=None):
678
  return """
679
  <section class="stats-row">
680
- <div class="stat-card"><strong>73.5%</strong><span>exact match</span></div>
681
- <div class="stat-card"><strong>+71.5pp</strong><span>vs base</span></div>
682
- <div class="stat-card"><strong>1,000</strong><span>examples</span></div>
683
- <div class="stat-card"><strong>21 min</strong><span>T4 training</span></div>
684
  </section>
685
  """
686
 
@@ -731,7 +588,7 @@ def is_create_table_intent(message):
731
 
732
  def is_table_edit_intent(message):
733
  message = (message or "").strip().lower()
734
- edit_terms = r"\b(edit|update|modify|alter|add|include|remove|delete|drop|edita|editar|altera|altere|alterar|mude|mudar|adicione|adicionar|inclua|incluir|acrescente|remova|remover|delete|deletar|exclua|excluir|novo|nova)\b"
735
  direct_add_terms = r"\b(add|include|adicione|adicionar|adicionando|inclua|incluir|acrescente)\b"
736
  direct_remove_terms = r"\b(remove|delete|drop|remova|remover|deletar|exclua|excluir)\b"
737
  target_terms = r"\b(column|field|element|coluna|campo|elemento|item)\b"
@@ -757,7 +614,7 @@ def is_table_edit_intent(message):
757
  or re.search(direct_remove_terms, message)
758
  or is_rename_intent(message)
759
  or re.search(r"\b(?:altere|alterar|mude|mudar)\b.*\bter\b", message)
760
- or (re.search(edit_terms, message) and (re.search(target_terms, message) or ":" in message))
761
  )
762
 
763
 
@@ -915,26 +772,6 @@ def format_create_table(table_name, columns):
915
  return f"CREATE TABLE {table_name} (\n" + ",\n".join(column_lines) + "\n);"
916
 
917
 
918
- def create_table_from_message(message):
919
- message = (message or "").strip()
920
- patterns = (
921
- r"\b(?:table|tabela)\s+(?:called\s+|named\s+|chamada?\s+|nomeada?\s+)?([A-Za-z_][\w]*)\s+(?:with|containing|including|com)\s+(.+)$",
922
- r"\b(?:create|make|build|generate|criar|crie|gerar|gere)\b.*?\b(?:table|tabela)\b\s+([A-Za-z_][\w]*)\s+(?:with|containing|including|com)\s+(.+)$",
923
- )
924
- for pattern in patterns:
925
- match = re.search(pattern, message, flags=re.IGNORECASE)
926
- if not match:
927
- continue
928
- table_name = normalize_identifier(match.group(1))
929
- columns = [
930
- parsed
931
- for parsed in (parse_column_definition(column) for column in split_column_list(match.group(2)))
932
- if parsed
933
- ]
934
- return format_create_table(table_name, columns)
935
- return ""
936
-
937
-
938
  def parse_create_table_schema(schema):
939
  schema = (schema or "").strip()
940
  match = re.match(
@@ -953,11 +790,6 @@ def parse_create_table_schema(schema):
953
  return table_name, columns
954
 
955
 
956
- def create_table_from_schema(schema):
957
- table_name, columns = parse_create_table_schema(schema)
958
- return format_create_table(table_name, columns)
959
-
960
-
961
  def extract_create_table_statement(text):
962
  cleaned = extract_sql_candidate(text)
963
  match = re.search(
@@ -1018,7 +850,7 @@ def is_rename_intent(message):
1018
  message = (message or "").strip().lower()
1019
  return bool(
1020
  re.search(
1021
- r"\b(rename|edit|change|renomeie|renomear|altere|mude)\s+\w+\s+(to|para|as|como)\s+\w+",
1022
  message,
1023
  flags=re.IGNORECASE,
1024
  )
@@ -1031,9 +863,16 @@ def extract_renamed_columns(message):
1031
  r"(\w+)\s+(?:to|para|as|como)\s+(\w+)"
1032
  )
1033
  matches = re.findall(pattern, message or "", flags=re.IGNORECASE)
 
 
 
 
 
 
 
1034
  return [
1035
  (normalize_identifier(old), normalize_identifier(new))
1036
- for old, new in matches
1037
  if normalize_identifier(old) and normalize_identifier(new)
1038
  ]
1039
 
@@ -1044,7 +883,7 @@ def parse_compound_edit(message):
1044
  r"\s+(?:and|e)\s+"
1045
  r"(?=\b(?:add|include|remove|delete|drop|rename|edit|change|"
1046
  r"adicione|adicionar|inclua|acrescente|remova|remover|deletar|"
1047
- r"exclua|renomeie|renomear|altere|mude)\b)"
1048
  )
1049
  segments = re.split(segment_pattern, message or "", flags=re.IGNORECASE)
1050
 
@@ -1068,29 +907,6 @@ def parse_compound_edit(message):
1068
  return added, removed, renamed
1069
 
1070
 
1071
- def edit_create_table_from_message(message, chat_history, active_schema):
1072
- if not is_table_edit_intent(message) and not is_rename_intent(message):
1073
- return ""
1074
- base_sql = last_create_table_from_history(chat_history) or create_table_from_schema(active_schema)
1075
- table_name, existing_columns = parse_create_table_schema(base_sql)
1076
- if not table_name:
1077
- return ""
1078
-
1079
- added_columns, removed_columns_list, renamed_columns = parse_compound_edit(message)
1080
- removed_set = set(extract_removed_columns(message)) | {r for r in removed_columns_list}
1081
-
1082
- if not added_columns and not removed_set and not renamed_columns:
1083
- return ""
1084
-
1085
- rename_map = dict(renamed_columns)
1086
- kept_columns = [
1087
- (rename_map.get(col_name, col_name), col_type)
1088
- for col_name, col_type in existing_columns
1089
- if col_name not in removed_set
1090
- ]
1091
- return format_create_table(table_name, [*kept_columns, *added_columns])
1092
-
1093
-
1094
  def render_schema_context(schema=""):
1095
  schema = (schema or "").strip()
1096
  if not schema:
@@ -1150,9 +966,8 @@ def load_selected_model(selected_key=FINE_TUNED_MODEL_KEY):
1150
  *query_control_updates(False),
1151
  "",
1152
  EMPTY_VALIDATOR,
1153
- gr.update(interactive=False, visible=False),
1154
  render_message(),
1155
- gr.update(visible=False),
1156
  )
1157
  started = time.time()
1158
  try:
@@ -1181,9 +996,8 @@ def load_selected_model(selected_key=FINE_TUNED_MODEL_KEY):
1181
  *query_control_updates(False),
1182
  "",
1183
  EMPTY_VALIDATOR,
1184
- gr.update(interactive=False, visible=False),
1185
  render_message(error),
1186
- gr.update(visible=False),
1187
  )
1188
  return
1189
 
@@ -1197,19 +1011,18 @@ def load_selected_model(selected_key=FINE_TUNED_MODEL_KEY):
1197
  *query_control_updates(True),
1198
  "",
1199
  EMPTY_VALIDATOR,
1200
- gr.update(interactive=False, visible=False),
1201
  render_message(f"Loaded {model_def['model_id']} in {elapsed}s.", kind="ok"),
1202
- gr.update(visible=False),
1203
  )
1204
 
1205
 
1206
  def set_preset(name):
1207
  schema = PRESETS[name]
1208
- return schema, render_schema_context(schema), gr.update(visible=True)
1209
 
1210
 
1211
  def clear_schema_context():
1212
- return "", render_schema_context(""), gr.update(visible=False)
1213
 
1214
 
1215
  def trim_chat_history(chat_history, max_exchanges=10):
@@ -1239,12 +1052,54 @@ def render_compare_label(prefix, model_label, metric):
1239
  )
1240
 
1241
 
1242
- def deterministic_response(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1243
  chat_history,
1244
  message,
1245
- active_schema,
1246
- loaded_key,
1247
- saved_state,
1248
  assistant_content,
1249
  status_message,
1250
  *,
@@ -1252,262 +1107,342 @@ def deterministic_response(
1252
  validator=CHAT_VALIDATOR,
1253
  status_kind="ok",
1254
  ):
1255
- new_history = trim_chat_history(
1256
- [
1257
- *list(chat_history or []),
1258
- {"role": "user", "content": message},
1259
- {"role": "assistant", "content": assistant_content},
1260
- ]
1261
- )
1262
- # If sql_text is a CREATE TABLE, promote it to active_schema for subsequent queries
1263
- new_schema = active_schema
1264
  if sql_text and "CREATE TABLE" in sql_text.upper():
1265
- new_schema = sql_text
1266
- compare = comparison_updates(saved_state, sql_text, loaded_key)
1267
  return (
1268
  new_history,
1269
  "",
1270
- new_schema,
1271
  message,
1272
  sql_text,
1273
  validator,
1274
- gr.update(interactive=False, visible=False),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1275
  render_message(status_message, kind=status_kind),
1276
- *compare,
1277
  )
1278
 
1279
 
1280
- def generate_response(message, chat_history, active_schema, loaded_key, saved_state):
1281
  message = (message or "").strip()
1282
- active_schema = (active_schema or "").strip()
1283
  chat_history = list(chat_history or [])
 
1284
  if not message:
1285
- compare = comparison_updates(saved_state, "", loaded_key)
1286
  return (
1287
  chat_history,
1288
  "",
1289
- active_schema,
1290
  "",
1291
  "",
1292
  EMPTY_VALIDATOR,
1293
- gr.update(interactive=False, visible=False),
1294
  render_message("Type a message before sending."),
1295
- *compare,
1296
  )
1297
 
1298
- # Routing debug log — shows which intent matched
1299
- _routing = []
1300
- edited_table = edit_create_table_from_message(message, chat_history, active_schema)
1301
- if edited_table:
1302
- _routing.append("edit_create_table")
1303
- elif is_table_edit_intent(message):
1304
- _routing.append("is_table_edit_intent")
1305
- elif is_create_table_intent(message):
1306
- _routing.append("is_create_table_intent")
1307
- elif is_sql_intent(message, active_schema):
1308
- _routing.append("is_sql_intent")
1309
- else:
1310
- _routing.append("no_match")
1311
- print(f"[ROUTING] \"{message[:60]}\" → {_routing}")
1312
 
1313
- if edited_table:
1314
- display_response = f"```sql\n{edited_table}\n```"
1315
- return deterministic_response(
1316
- chat_history,
1317
- message,
1318
- active_schema,
1319
- loaded_key,
1320
- saved_state,
1321
- display_response,
1322
- "Edited CREATE TABLE without calling the model.",
1323
- sql_text=edited_table,
1324
- validator=validate_sql(edited_table),
1325
- )
1326
- if is_table_edit_intent(message):
1327
- compare = comparison_updates(saved_state, "", loaded_key)
1328
- return (
 
 
 
 
 
 
 
 
 
 
 
 
1329
  chat_history,
1330
  message,
1331
- active_schema,
1332
- "",
1333
- "",
1334
- EMPTY_VALIDATOR,
1335
- gr.update(interactive=False, visible=False),
1336
- render_message("I need an existing CREATE TABLE in the chat or an active schema before editing columns."),
1337
- *compare,
1338
  )
1339
 
1340
- if is_create_table_intent(message):
1341
- sql_text = create_table_from_message(message) or create_table_from_schema(active_schema)
1342
  if sql_text:
1343
  display_response = f"```sql\n{sql_text}\n```"
1344
- return deterministic_response(
1345
  chat_history,
1346
  message,
1347
- active_schema,
1348
- loaded_key,
1349
- saved_state,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1350
  display_response,
1351
  "Generated CREATE TABLE without calling the model.",
1352
  sql_text=sql_text,
1353
- validator=validate_sql(sql_text),
1354
  )
1355
- compare = comparison_updates(saved_state, "", loaded_key)
1356
- return (
1357
  chat_history,
1358
  message,
1359
- active_schema,
1360
- "",
1361
- "",
1362
- EMPTY_VALIDATOR,
1363
- gr.update(interactive=False, visible=False),
1364
- render_message("CREATE TABLE needs a table name and columns, or an active schema context."),
1365
- *compare,
1366
  )
1367
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1368
 
1369
- if not is_sql_intent(message, active_schema):
1370
- fallback = safe_chat_fallback()
1371
- return deterministic_response(
1372
- chat_history,
1373
- message,
1374
- active_schema,
1375
- loaded_key,
1376
- saved_state,
1377
- fallback,
1378
- "No SQL intent or active schema detected.",
1379
- )
1380
-
1381
- if not loaded_key or _model is None or _tokenizer is None:
1382
- compare = comparison_updates(saved_state, "", loaded_key)
1383
- return (
1384
- chat_history,
1385
- message,
1386
- active_schema,
1387
- "",
1388
- "",
1389
- EMPTY_VALIDATOR,
1390
- gr.update(interactive=False, visible=False),
1391
- render_message("Load a model before generating SQL."),
1392
- *compare,
1393
- )
1394
 
1395
- model_def = model_by_key(loaded_key)
1396
- if _current_model_id != model_def["model_id"]:
1397
- compare = comparison_updates(saved_state, "", loaded_key)
1398
- return (
1399
  chat_history,
1400
  message,
1401
- active_schema,
1402
- "",
1403
- "",
1404
- EMPTY_VALIDATOR,
1405
- gr.update(interactive=False, visible=False),
1406
- render_message("Loaded model state is inconsistent. Reload the selected model."),
1407
- *compare,
1408
  )
1409
 
1410
- started = time.time()
1411
  try:
1412
- import_model_runtime()
1413
- with _model_lock:
1414
- prompt = build_generation_prompt(active_schema, message, chat_history)
1415
- inputs = _tokenizer(prompt, return_tensors="pt")
1416
- input_length = inputs["input_ids"].shape[-1]
1417
- gen_kwargs = {
1418
- "max_new_tokens": 80,
1419
- "max_time": GENERATION_MAX_TIME_SECONDS,
1420
- "do_sample": False,
1421
- "use_cache": False,
1422
- "repetition_penalty": 1.1,
1423
- "eos_token_id": getattr(_model.generation_config, "eos_token_id", _tokenizer.eos_token_id),
1424
- "pad_token_id": _tokenizer.pad_token_id or _tokenizer.eos_token_id,
1425
- }
1426
- executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
1427
- future = executor.submit(_run_generation, _model, inputs, gen_kwargs)
1428
- try:
1429
- output_ids = future.result(timeout=GENERATION_TIMEOUT_SECONDS)
1430
- except concurrent.futures.TimeoutError:
1431
- # Timeout reached - do NOT call future.result() without timeout as it can block indefinitely.
1432
- # The thread may continue in background but we won't wait for it.
1433
- # Return error to user and release the slot.
1434
- executor.shutdown(wait=False, cancel_futures=False)
1435
- raise TimeoutError(f"Generation timed out after {GENERATION_TIMEOUT_SECONDS}s")
1436
- finally:
1437
- executor.shutdown(wait=False, cancel_futures=True)
1438
- generated_ids = output_ids[0][input_length:]
1439
- generated_text = _tokenizer.decode(generated_ids, skip_special_tokens=True)
1440
  except Exception as exc:
1441
- compare = comparison_updates(saved_state, "", loaded_key)
1442
- return (
1443
  chat_history,
1444
  message,
1445
- active_schema,
1446
- "",
1447
- "",
1448
- EMPTY_VALIDATOR,
1449
- gr.update(interactive=False, visible=False),
1450
- render_message(f"Generation failed: {type(exc).__name__}: {exc}"),
1451
- *compare,
1452
  )
1453
 
1454
- elapsed = int(time.time() - started)
1455
- sql_text, chat_text, validator = format_generation_result(generated_text)
1456
  display_response = f"```sql\n{sql_text}\n```" if sql_text else chat_text
1457
- new_history = trim_chat_history(
1458
- [
1459
- *chat_history,
1460
- {"role": "user", "content": message},
1461
- {"role": "assistant", "content": display_response},
1462
- ]
1463
- )
1464
- compare = comparison_updates(saved_state, sql_text, loaded_key)
1465
  response_kind = "SQL" if sql_text.strip() else "chat response"
1466
- return (
1467
- new_history,
1468
- "",
1469
- active_schema,
1470
  message,
1471
- str(sql_text),
1472
- validator,
1473
- gr.update(interactive=False, visible=False),
1474
- render_message(f"Generated {response_kind} with {model_def['model_id']} in {elapsed}s.", kind="ok"),
1475
- *compare,
1476
  )
1477
 
1478
 
1479
- def save_for_comparison(sql_text, loaded_key, active_schema, last_message):
1480
- sql_text = (sql_text or "").strip()
1481
- if not sql_text or not loaded_key:
1482
- return (
1483
- None,
1484
- gr.update(visible=False),
1485
- "",
1486
- "",
1487
- "",
1488
- "",
1489
- gr.update(interactive=False, visible=False),
1490
- render_message("Generate SQL before saving a comparison."),
1491
- )
1492
 
1493
- model_def = model_by_key(loaded_key)
1494
- saved = {
1495
- "sql": sql_text,
1496
- "model_label": model_def["short_label"],
1497
- "match": model_def["exact_match"],
1498
- "schema_context": active_schema or "",
1499
- "user_message": last_message or "",
1500
- }
1501
- return (
1502
- saved,
1503
- gr.update(visible=True),
1504
- render_compare_label("Saved", model_def["short_label"], model_def["exact_match"]),
1505
- sql_text,
1506
- render_compare_label("Current", model_def["short_label"], model_def["exact_match"]),
1507
- sql_text,
1508
- gr.update(interactive=True),
1509
- render_message("Saved output for comparison.", kind="ok"),
1510
- )
 
 
 
 
 
1511
 
1512
 
1513
  def sync_on_load():
@@ -1523,9 +1458,8 @@ def sync_on_load():
1523
  *query_control_updates(True),
1524
  "",
1525
  EMPTY_VALIDATOR,
1526
- gr.update(interactive=False, visible=False),
1527
  render_message(f"Model already loaded: {_current_model_id}", kind="ok"),
1528
- gr.update(visible=False),
1529
  )
1530
  return (
1531
  None,
@@ -1536,9 +1470,8 @@ def sync_on_load():
1536
  *query_control_updates(False),
1537
  "",
1538
  EMPTY_VALIDATOR,
1539
- gr.update(interactive=False, visible=False),
1540
  render_message(),
1541
- gr.update(visible=False),
1542
  )
1543
 
1544
 
@@ -2296,10 +2229,11 @@ textarea {
2296
  }
2297
  """
2298
 
2299
- with gr.Blocks(title="Phi-3 Mini SQL Generator") as demo:
2300
  loaded_key_state = gr.State(value=None)
2301
- saved_output = gr.State(value=None)
2302
  active_schema = gr.State(value="")
 
 
2303
  last_user_message = gr.State(value="")
2304
 
2305
  with gr.Column(elem_classes=["app-shell"]):
@@ -2312,7 +2246,6 @@ with gr.Blocks(title="Phi-3 Mini SQL Generator") as demo:
2312
  load_button = gr.Button("Load fine-tuned model", variant="primary", elem_id="load-button")
2313
  model_status = gr.HTML(render_status(DEFAULT_MODEL_KEY, None))
2314
  model_info = gr.HTML(model_metadata(DEFAULT_MODEL_KEY))
2315
- gr.HTML(render_baseline_evidence())
2316
 
2317
  with gr.Column(elem_id="query-section", elem_classes=["query-section"]):
2318
  gr.HTML(render_step("02", "Chat"))
@@ -2366,23 +2299,8 @@ with gr.Blocks(title="Phi-3 Mini SQL Generator") as demo:
2366
  interactive=False,
2367
  show_label=False,
2368
  )
2369
- save_button = gr.Button(
2370
- "Save output",
2371
- interactive=False,
2372
- visible=False,
2373
- elem_id="save-button",
2374
- )
2375
  error_output = gr.HTML(render_message())
2376
 
2377
- with gr.Column(visible=False, elem_classes=["comparison-panel"]) as comparison_panel:
2378
- with gr.Row(elem_classes=["compare-grid"]):
2379
- with gr.Column(elem_classes=["compare-card"]):
2380
- saved_model_label = gr.HTML("")
2381
- saved_sql = gr.Code(label="", language="sql", lines=6, show_label=False)
2382
- with gr.Column(elem_classes=["compare-card", "current"]):
2383
- current_model_label = gr.HTML("")
2384
- current_sql = gr.Code(label="", language="sql", lines=6, show_label=False)
2385
-
2386
  model_state_outputs = [
2387
  fine_tuned_model_card,
2388
  model_status,
@@ -2395,7 +2313,6 @@ with gr.Blocks(title="Phi-3 Mini SQL Generator") as demo:
2395
  clear_schema_button,
2396
  message_input,
2397
  send_button,
2398
- save_button,
2399
  error_output,
2400
  ]
2401
 
@@ -2418,14 +2335,13 @@ with gr.Blocks(title="Phi-3 Mini SQL Generator") as demo:
2418
  send_button,
2419
  sql_output,
2420
  validator_output,
2421
- save_button,
2422
  error_output,
2423
- comparison_panel,
2424
  ],
2425
  js=LOAD_SCROLL_JS,
2426
  )
2427
 
2428
- schema_context_outputs = [active_schema, active_schema_pill, clear_schema_button]
2429
  employees_preset.click(set_preset, inputs=gr.State("employees"), outputs=schema_context_outputs)
2430
  orders_preset.click(set_preset, inputs=gr.State("orders"), outputs=schema_context_outputs)
2431
  students_preset.click(set_preset, inputs=gr.State("students"), outputs=schema_context_outputs)
@@ -2440,38 +2356,20 @@ with gr.Blocks(title="Phi-3 Mini SQL Generator") as demo:
2440
  last_user_message,
2441
  sql_output,
2442
  validator_output,
2443
- save_button,
2444
  error_output,
2445
- comparison_panel,
2446
- saved_model_label,
2447
- saved_sql,
2448
- current_model_label,
2449
- current_sql,
2450
  ]
2451
  send_button.click(
2452
  generate_response,
2453
- inputs=[message_input, chatbot, active_schema, loaded_key_state, saved_output],
2454
  outputs=chat_generation_outputs,
2455
  )
2456
  message_input.submit(
2457
  generate_response,
2458
- inputs=[message_input, chatbot, active_schema, loaded_key_state, saved_output],
2459
  outputs=chat_generation_outputs,
2460
  )
2461
- save_button.click(
2462
- save_for_comparison,
2463
- inputs=[sql_output, loaded_key_state, active_schema, last_user_message],
2464
- outputs=[
2465
- saved_output,
2466
- comparison_panel,
2467
- saved_model_label,
2468
- saved_sql,
2469
- current_model_label,
2470
- current_sql,
2471
- save_button,
2472
- error_output,
2473
- ],
2474
- )
2475
  demo.load(
2476
  sync_on_load,
2477
  outputs=[
@@ -2490,9 +2388,8 @@ with gr.Blocks(title="Phi-3 Mini SQL Generator") as demo:
2490
  send_button,
2491
  sql_output,
2492
  validator_output,
2493
- save_button,
2494
  error_output,
2495
- comparison_panel,
2496
  ],
2497
  )
2498
 
 
12
  import gradio as gr
13
  import sqlparse
14
 
15
+ import chat_state as chat_core
16
+ import intent as intent_core
17
+ import model_io as model_core
18
+ import sql_tools as sql_core
19
+
20
 
21
  BASE_MODEL_ID = "microsoft/Phi-3-mini-4k-instruct"
22
  FINE_TUNED_MODEL_ID = "Shizu0n/phi3-mini-sql-generator-merged"
 
452
  }
453
 
454
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
455
 
456
 
457
  def render_header():
458
  return """
459
  <section class="top-panel">
460
  <div>
461
+ <h1>Phi-3 Mini SQL Chatbot</h1>
462
+ <p>Conversational SQL assistant powered by a fine-tuned Phi-3 Mini model</p>
463
  </div>
464
  <div class="top-badges">
465
+ <span class="badge badge-green">Natural chat + SQL</span>
466
+ <span class="badge badge-cream">Context-aware schema</span>
467
  <span class="badge badge-light">CPU lazy load</span>
468
  </div>
469
  </section>
 
534
  def model_metadata(model_key=None):
535
  return """
536
  <section class="stats-row">
537
+ <div class="stat-card"><strong>Chat</strong><span>normal conversation</span></div>
538
+ <div class="stat-card"><strong>Schema</strong><span>table proposals</span></div>
539
+ <div class="stat-card"><strong>SQL</strong><span>query generation</span></div>
540
+ <div class="stat-card"><strong>Probe</strong><span>manual model gate</span></div>
541
  </section>
542
  """
543
 
 
588
 
589
  def is_table_edit_intent(message):
590
  message = (message or "").strip().lower()
591
+ edit_terms = r"\b(edit|update|modify|alter|add|include|remove|delete|drop|edita|editar|altera|altere|alterar|mude|mudar|adicione|adicionar|inclua|incluir|acrescente|remova|remover|delete|deletar|exclua|excluir|novo|nova|troca|trocar|troquecoloque|colocar)\b"
592
  direct_add_terms = r"\b(add|include|adicione|adicionar|adicionando|inclua|incluir|acrescente)\b"
593
  direct_remove_terms = r"\b(remove|delete|drop|remova|remover|deletar|exclua|excluir)\b"
594
  target_terms = r"\b(column|field|element|coluna|campo|elemento|item)\b"
 
614
  or re.search(direct_remove_terms, message)
615
  or is_rename_intent(message)
616
  or re.search(r"\b(?:altere|alterar|mude|mudar)\b.*\bter\b", message)
617
+ or (re.search(edit_terms, message) and (re.search(target_terms, message) or ":" in message or re.search(r"\bpor\b", message)))
618
  )
619
 
620
 
 
772
  return f"CREATE TABLE {table_name} (\n" + ",\n".join(column_lines) + "\n);"
773
 
774
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
775
  def parse_create_table_schema(schema):
776
  schema = (schema or "").strip()
777
  match = re.match(
 
790
  return table_name, columns
791
 
792
 
 
 
 
 
 
793
  def extract_create_table_statement(text):
794
  cleaned = extract_sql_candidate(text)
795
  match = re.search(
 
850
  message = (message or "").strip().lower()
851
  return bool(
852
  re.search(
853
+ r"\b(rename|edit|change|renomeie|renomear|renomeia|renomeia|altere|mude|muda|troca|trocar)\s+\w+\s+(to|para|as|como|por)\s+\w+",
854
  message,
855
  flags=re.IGNORECASE,
856
  )
 
863
  r"(\w+)\s+(?:to|para|as|como)\s+(\w+)"
864
  )
865
  matches = re.findall(pattern, message or "", flags=re.IGNORECASE)
866
+ # Also handle "troca X por Y" pattern
867
+ troca_matches = re.findall(
868
+ r"\btroca\b\s+(\w+)\s+\bpor\b\s+(\w+)",
869
+ message or "",
870
+ flags=re.IGNORECASE,
871
+ )
872
+ all_matches = matches + troca_matches
873
  return [
874
  (normalize_identifier(old), normalize_identifier(new))
875
+ for old, new in all_matches
876
  if normalize_identifier(old) and normalize_identifier(new)
877
  ]
878
 
 
883
  r"\s+(?:and|e)\s+"
884
  r"(?=\b(?:add|include|remove|delete|drop|rename|edit|change|"
885
  r"adicione|adicionar|inclua|acrescente|remova|remover|deletar|"
886
+ r"exclua|renomeie|renomear|altere|mude|troca|trocar)\b)"
887
  )
888
  segments = re.split(segment_pattern, message or "", flags=re.IGNORECASE)
889
 
 
907
  return added, removed, renamed
908
 
909
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
910
  def render_schema_context(schema=""):
911
  schema = (schema or "").strip()
912
  if not schema:
 
966
  *query_control_updates(False),
967
  "",
968
  EMPTY_VALIDATOR,
969
+ gr.update(value=None),
970
  render_message(),
 
971
  )
972
  started = time.time()
973
  try:
 
996
  *query_control_updates(False),
997
  "",
998
  EMPTY_VALIDATOR,
999
+ gr.update(value=None),
1000
  render_message(error),
 
1001
  )
1002
  return
1003
 
 
1011
  *query_control_updates(True),
1012
  "",
1013
  EMPTY_VALIDATOR,
1014
+ gr.update(value=None),
1015
  render_message(f"Loaded {model_def['model_id']} in {elapsed}s.", kind="ok"),
 
1016
  )
1017
 
1018
 
1019
  def set_preset(name):
1020
  schema = PRESETS[name]
1021
+ return schema, render_schema_context(schema), gr.update(visible=True), chat_core.default_state(schema)
1022
 
1023
 
1024
  def clear_schema_context():
1025
+ return "", render_schema_context(""), gr.update(visible=False), chat_core.default_state("")
1026
 
1027
 
1028
  def trim_chat_history(chat_history, max_exchanges=10):
 
1052
  )
1053
 
1054
 
1055
+ def save_for_comparison(sql_text, loaded_key, active_schema, last_message):
1056
+ sql_text = (sql_text or "").strip()
1057
+ if not sql_text or not loaded_key:
1058
+ return (
1059
+ None,
1060
+ gr.update(visible=False),
1061
+ "",
1062
+ "",
1063
+ "",
1064
+ "",
1065
+ gr.update(interactive=False, visible=False),
1066
+ render_message("Generate SQL before saving a comparison."),
1067
+ )
1068
+
1069
+ model_def = model_by_key(loaded_key)
1070
+ saved = {
1071
+ "sql": sql_text,
1072
+ "model_label": model_def["short_label"],
1073
+ "match": model_def["exact_match"],
1074
+ "schema_context": active_schema or "",
1075
+ "user_message": last_message or "",
1076
+ }
1077
+ return (
1078
+ saved,
1079
+ gr.update(visible=True),
1080
+ render_compare_label("Saved", model_def["short_label"], model_def["exact_match"]),
1081
+ sql_text,
1082
+ render_compare_label("Current", model_def["short_label"], model_def["exact_match"]),
1083
+ sql_text,
1084
+ gr.update(interactive=True),
1085
+ render_message("Saved output for comparison.", kind="ok"),
1086
+ )
1087
+
1088
+
1089
+ def _append_chat_turn(chat_history, message, assistant_content):
1090
+ return trim_chat_history(
1091
+ [
1092
+ *list(chat_history or []),
1093
+ {"role": "user", "content": message},
1094
+ {"role": "assistant", "content": assistant_content},
1095
+ ]
1096
+ )
1097
+
1098
+
1099
+ def _response_tuple(
1100
  chat_history,
1101
  message,
1102
+ state,
 
 
1103
  assistant_content,
1104
  status_message,
1105
  *,
 
1107
  validator=CHAT_VALIDATOR,
1108
  status_kind="ok",
1109
  ):
1110
+ state = chat_core.ConversationState.from_value(state)
 
 
 
 
 
 
 
 
1111
  if sql_text and "CREATE TABLE" in sql_text.upper():
1112
+ state = state.with_active_schema(sql_text).clear_pending_schema()
1113
+ new_history = _append_chat_turn(chat_history, message, assistant_content)
1114
  return (
1115
  new_history,
1116
  "",
1117
+ state.active_schema,
1118
  message,
1119
  sql_text,
1120
  validator,
1121
+ gr.update(value=None),
1122
+ render_message(status_message, kind=status_kind),
1123
+ state.to_dict(),
1124
+ )
1125
+
1126
+
1127
+ def deterministic_response(
1128
+ chat_history,
1129
+ message,
1130
+ active_schema,
1131
+ loaded_key,
1132
+ saved_state,
1133
+ assistant_content,
1134
+ status_message,
1135
+ *,
1136
+ sql_text="",
1137
+ validator=CHAT_VALIDATOR,
1138
+ status_kind="ok",
1139
+ conversation_state=None,
1140
+ ):
1141
+ state = chat_core.ConversationState.from_value(conversation_state, active_schema=active_schema)
1142
+ return _response_tuple(
1143
+ chat_history,
1144
+ message,
1145
+ state,
1146
+ assistant_content,
1147
+ status_message,
1148
+ sql_text=sql_text,
1149
+ validator=validator,
1150
+ status_kind=status_kind,
1151
+ )
1152
+
1153
+
1154
+ def _model_ready(loaded_key):
1155
+ if not loaded_key or _model is None or _tokenizer is None:
1156
+ return False, "Load the fine-tuned model before chatting or generating SQL."
1157
+ model_def = model_by_key(loaded_key)
1158
+ if _current_model_id != model_def["model_id"]:
1159
+ return False, "Loaded model state is inconsistent. Reload the selected model."
1160
+ return True, ""
1161
+
1162
+
1163
+ def _generate_model_text(prompt, generation_kind=model_core.SQL_GENERATION):
1164
+ started = time.time()
1165
+ import_model_runtime()
1166
+ with _model_lock:
1167
+ model = _model
1168
+ tokenizer = _tokenizer
1169
+ if model is None or tokenizer is None:
1170
+ raise RuntimeError("Model runtime is not loaded.")
1171
+
1172
+ inputs = tokenizer(prompt, return_tensors="pt")
1173
+ input_length = inputs["input_ids"].shape[-1]
1174
+ generation_config = getattr(model, "generation_config", None)
1175
+ gen_kwargs = {
1176
+ "max_new_tokens": model_core.generation_budget(generation_kind),
1177
+ "max_time": GENERATION_MAX_TIME_SECONDS,
1178
+ "do_sample": False,
1179
+ "use_cache": False,
1180
+ "repetition_penalty": 1.1,
1181
+ "eos_token_id": getattr(generation_config, "eos_token_id", tokenizer.eos_token_id),
1182
+ "pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
1183
+ }
1184
+ executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
1185
+ future = executor.submit(_run_generation, model, inputs, gen_kwargs)
1186
+ try:
1187
+ output_ids = future.result(timeout=GENERATION_TIMEOUT_SECONDS)
1188
+ except concurrent.futures.TimeoutError:
1189
+ executor.shutdown(wait=False, cancel_futures=False)
1190
+ raise TimeoutError(f"Generation timed out after {GENERATION_TIMEOUT_SECONDS}s")
1191
+ finally:
1192
+ executor.shutdown(wait=False, cancel_futures=True)
1193
+ generated_ids = output_ids[0][input_length:]
1194
+ generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
1195
+ return generated_text, int(time.time() - started)
1196
+
1197
+
1198
+ def _schema_suggestion_message(suggestion):
1199
+ columns = ", ".join(f"{name} {column_type}" for name, column_type in suggestion.columns)
1200
+ rationale = f"\n\n{suggestion.rationale}" if suggestion.rationale else ""
1201
+ return f"Posso montar a tabela `{suggestion.table_name}` com: {columns}.{rationale}\n\nSe quiser, diga `gera`."
1202
+
1203
+
1204
+ def _empty_generation_response(chat_history, message, state, status_message, *, status_kind="error"):
1205
+ return (
1206
+ chat_history,
1207
+ message,
1208
+ state.active_schema,
1209
+ "",
1210
+ "",
1211
+ EMPTY_VALIDATOR,
1212
+ gr.update(value=None),
1213
  render_message(status_message, kind=status_kind),
1214
+ state.to_dict(),
1215
  )
1216
 
1217
 
1218
+ def generate_response(message, chat_history, active_schema, loaded_key, saved_state=None, conversation_state=None):
1219
  message = (message or "").strip()
 
1220
  chat_history = list(chat_history or [])
1221
+ state = chat_core.ConversationState.from_value(conversation_state, active_schema=(active_schema or ""))
1222
  if not message:
 
1223
  return (
1224
  chat_history,
1225
  "",
1226
+ state.active_schema,
1227
  "",
1228
  "",
1229
  EMPTY_VALIDATOR,
1230
+ gr.update(value=None),
1231
  render_message("Type a message before sending."),
1232
+ state.to_dict(),
1233
  )
1234
 
1235
+ intent_result = intent_core.classify_intent(message, state, chat_history)
1236
+ state = state.with_intent(intent_result)
1237
+ print(
1238
+ f"[ROUTING] \"{message[:60]}\" -> intent={intent_result.intent} "
1239
+ f"confidence={intent_result.confidence} reason={intent_result.reason}",
1240
+ flush=True,
1241
+ )
 
 
 
 
 
 
 
1242
 
1243
+ if intent_result.intent == intent_core.EDIT_TABLE:
1244
+ if state.pending_schema_suggestion and not state.active_schema:
1245
+ pending_sql = sql_core.create_table_from_suggestion(state.pending_schema_suggestion)
1246
+ edited_table = sql_core.edit_create_table_from_message(message, chat_history, pending_sql)
1247
+ table_name, columns = sql_core.parse_create_table_schema(edited_table)
1248
+ if edited_table and table_name and columns:
1249
+ suggestion = chat_core.SchemaSuggestion(table_name=table_name, columns=tuple(columns))
1250
+ state = state.with_pending_schema(suggestion)
1251
+ return _response_tuple(
1252
+ chat_history,
1253
+ message,
1254
+ state,
1255
+ _schema_suggestion_message(suggestion),
1256
+ "Updated pending table proposal without calling the model.",
1257
+ )
1258
+ edited_table = sql_core.edit_create_table_from_message(message, chat_history, state.active_schema)
1259
+ if edited_table:
1260
+ display_response = f"```sql\n{edited_table}\n```"
1261
+ return _response_tuple(
1262
+ chat_history,
1263
+ message,
1264
+ state,
1265
+ display_response,
1266
+ "Edited CREATE TABLE without calling the model.",
1267
+ sql_text=edited_table,
1268
+ validator=sql_core.validate_sql(edited_table),
1269
+ )
1270
+ return _empty_generation_response(
1271
  chat_history,
1272
  message,
1273
+ state,
1274
+ "I need an existing CREATE TABLE in the chat or an active schema before editing columns.",
 
 
 
 
 
1275
  )
1276
 
1277
+ if intent_result.intent == intent_core.CREATE_TABLE_CONFIRM:
1278
+ sql_text = sql_core.create_table_from_suggestion(state.pending_schema_suggestion)
1279
  if sql_text:
1280
  display_response = f"```sql\n{sql_text}\n```"
1281
+ return _response_tuple(
1282
  chat_history,
1283
  message,
1284
+ state.clear_pending_schema(),
1285
+ display_response,
1286
+ "Generated CREATE TABLE from the pending proposal.",
1287
+ sql_text=sql_text,
1288
+ validator=sql_core.validate_sql(sql_text),
1289
+ )
1290
+
1291
+ if intent_result.intent == intent_core.CREATE_TABLE:
1292
+ sql_text = sql_core.create_table_from_message(message) or sql_core.create_table_from_schema(state.active_schema)
1293
+ if not sql_text:
1294
+ ready, _error = _model_ready(loaded_key)
1295
+ if ready:
1296
+ try:
1297
+ prompt = model_core.build_schema_suggestion_prompt(message, state, chat_history)
1298
+ generated_text, _elapsed = _generate_model_text(prompt, model_core.SCHEMA_GENERATION)
1299
+ suggestion = model_core.parse_schema_suggestion(generated_text)
1300
+ sql_text = sql_core.create_table_from_suggestion(suggestion)
1301
+ except Exception:
1302
+ sql_text = ""
1303
+ if sql_text:
1304
+ display_response = f"```sql\n{sql_text}\n```"
1305
+ return _response_tuple(
1306
+ chat_history,
1307
+ message,
1308
+ state,
1309
  display_response,
1310
  "Generated CREATE TABLE without calling the model.",
1311
  sql_text=sql_text,
1312
+ validator=sql_core.validate_sql(sql_text),
1313
  )
1314
+ return _empty_generation_response(
 
1315
  chat_history,
1316
  message,
1317
+ state,
1318
+ "CREATE TABLE needs a table name and columns, or a loaded model to propose them.",
 
 
 
 
 
1319
  )
1320
 
1321
+ if intent_result.intent == intent_core.SCHEMA_SUGGESTION:
1322
+ ready, error = _model_ready(loaded_key)
1323
+ if not ready:
1324
+ return _response_tuple(chat_history, message, state, error, error, status_kind="error")
1325
+ try:
1326
+ prompt = model_core.build_schema_suggestion_prompt(message, state, chat_history)
1327
+ generated_text, elapsed = _generate_model_text(prompt, model_core.SCHEMA_GENERATION)
1328
+ suggestion = model_core.parse_schema_suggestion(generated_text)
1329
+ if not suggestion:
1330
+ repair_prompt = (
1331
+ "Return valid JSON only for this SQL table proposal. "
1332
+ "Use table_name, columns, and rationale.\n\n"
1333
+ f"Previous output:\n{generated_text}"
1334
+ )
1335
+ repaired_text, elapsed = _generate_model_text(repair_prompt, model_core.SCHEMA_GENERATION)
1336
+ suggestion = model_core.parse_schema_suggestion(repaired_text)
1337
+ if not suggestion:
1338
+ return _response_tuple(
1339
+ chat_history,
1340
+ message,
1341
+ state,
1342
+ "Nao consegui estruturar essa proposta de tabela. Diga o nome da tabela e algumas colunas.",
1343
+ "Schema proposal was not valid JSON.",
1344
+ status_kind="error",
1345
+ )
1346
+ state = state.with_pending_schema(suggestion)
1347
+ return _response_tuple(
1348
+ chat_history,
1349
+ message,
1350
+ state,
1351
+ _schema_suggestion_message(suggestion),
1352
+ f"Generated schema proposal in {elapsed}s.",
1353
+ )
1354
+ except Exception as exc:
1355
+ return _empty_generation_response(
1356
+ chat_history,
1357
+ message,
1358
+ state,
1359
+ f"Generation failed: {type(exc).__name__}: {exc}",
1360
+ )
1361
 
1362
+ if intent_result.intent in {intent_core.SMALLTALK, intent_core.CLARIFICATION, intent_core.UNKNOWN}:
1363
+ ready, error = _model_ready(loaded_key)
1364
+ if not ready:
1365
+ return _response_tuple(chat_history, message, state, error, error, status_kind="error")
1366
+ try:
1367
+ prompt = model_core.build_chat_prompt(message, state, chat_history)
1368
+ generated_text, elapsed = _generate_model_text(prompt, model_core.CHAT_GENERATION)
1369
+ chat_text = model_core.clean_generation(generated_text)
1370
+ return _response_tuple(
1371
+ chat_history,
1372
+ message,
1373
+ state,
1374
+ chat_text,
1375
+ f"Generated chat response in {elapsed}s.",
1376
+ )
1377
+ except Exception as exc:
1378
+ return _empty_generation_response(
1379
+ chat_history,
1380
+ message,
1381
+ state,
1382
+ f"Generation failed: {type(exc).__name__}: {exc}",
1383
+ )
 
 
 
1384
 
1385
+ ready, error = _model_ready(loaded_key)
1386
+ if not ready:
1387
+ return _empty_generation_response(
 
1388
  chat_history,
1389
  message,
1390
+ state,
1391
+ error if "inconsistent" in error else "Load a model before generating SQL.",
 
 
 
 
 
1392
  )
1393
 
 
1394
  try:
1395
+ prompt = model_core.build_sql_prompt(state.active_schema, message, chat_history)
1396
+ generated_text, elapsed = _generate_model_text(prompt, model_core.SQL_GENERATION)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1397
  except Exception as exc:
1398
+ return _empty_generation_response(
 
1399
  chat_history,
1400
  message,
1401
+ state,
1402
+ f"Generation failed: {type(exc).__name__}: {exc}",
 
 
 
 
 
1403
  )
1404
 
1405
+ sql_text, chat_text, validator = model_core.format_generation_result(generated_text)
 
1406
  display_response = f"```sql\n{sql_text}\n```" if sql_text else chat_text
 
 
 
 
 
 
 
 
1407
  response_kind = "SQL" if sql_text.strip() else "chat response"
1408
+ model_def = model_by_key(loaded_key)
1409
+ return _response_tuple(
1410
+ chat_history,
 
1411
  message,
1412
+ state,
1413
+ display_response,
1414
+ f"Generated {response_kind} with {model_def['model_id']} in {elapsed}s.",
1415
+ sql_text=str(sql_text),
1416
+ validator=validator,
1417
  )
1418
 
1419
 
1420
+ def is_sql_intent(message, schema):
1421
+ return sql_core.is_sql_intent(message, schema)
 
 
 
 
 
 
 
 
 
 
 
1422
 
1423
+
1424
+ def build_generation_prompt(schema, message, chat_history=None):
1425
+ return model_core.build_sql_prompt(schema, message, chat_history)
1426
+
1427
+
1428
+ def format_generation_result(text):
1429
+ return model_core.format_generation_result(text)
1430
+
1431
+
1432
+ def validate_sql(sql_text):
1433
+ return sql_core.validate_sql(sql_text)
1434
+
1435
+
1436
+ def create_table_from_message(message):
1437
+ return sql_core.create_table_from_message(message)
1438
+
1439
+
1440
+ def create_table_from_schema(schema):
1441
+ return sql_core.create_table_from_schema(schema)
1442
+
1443
+
1444
+ def edit_create_table_from_message(message, chat_history, active_schema):
1445
+ return sql_core.edit_create_table_from_message(message, chat_history, active_schema)
1446
 
1447
 
1448
  def sync_on_load():
 
1458
  *query_control_updates(True),
1459
  "",
1460
  EMPTY_VALIDATOR,
1461
+ gr.update(value=None),
1462
  render_message(f"Model already loaded: {_current_model_id}", kind="ok"),
 
1463
  )
1464
  return (
1465
  None,
 
1470
  *query_control_updates(False),
1471
  "",
1472
  EMPTY_VALIDATOR,
1473
+ gr.update(value=None),
1474
  render_message(),
 
1475
  )
1476
 
1477
 
 
2229
  }
2230
  """
2231
 
2232
+ with gr.Blocks(title="Phi-3 Mini SQL Chatbot") as demo:
2233
  loaded_key_state = gr.State(value=None)
 
2234
  active_schema = gr.State(value="")
2235
+ conversation_state = gr.State(value=chat_core.default_state())
2236
+ generation_meta_state = gr.State(value=None)
2237
  last_user_message = gr.State(value="")
2238
 
2239
  with gr.Column(elem_classes=["app-shell"]):
 
2246
  load_button = gr.Button("Load fine-tuned model", variant="primary", elem_id="load-button")
2247
  model_status = gr.HTML(render_status(DEFAULT_MODEL_KEY, None))
2248
  model_info = gr.HTML(model_metadata(DEFAULT_MODEL_KEY))
 
2249
 
2250
  with gr.Column(elem_id="query-section", elem_classes=["query-section"]):
2251
  gr.HTML(render_step("02", "Chat"))
 
2299
  interactive=False,
2300
  show_label=False,
2301
  )
 
 
 
 
 
 
2302
  error_output = gr.HTML(render_message())
2303
 
 
 
 
 
 
 
 
 
 
2304
  model_state_outputs = [
2305
  fine_tuned_model_card,
2306
  model_status,
 
2313
  clear_schema_button,
2314
  message_input,
2315
  send_button,
 
2316
  error_output,
2317
  ]
2318
 
 
2335
  send_button,
2336
  sql_output,
2337
  validator_output,
2338
+ generation_meta_state,
2339
  error_output,
 
2340
  ],
2341
  js=LOAD_SCROLL_JS,
2342
  )
2343
 
2344
+ schema_context_outputs = [active_schema, active_schema_pill, clear_schema_button, conversation_state]
2345
  employees_preset.click(set_preset, inputs=gr.State("employees"), outputs=schema_context_outputs)
2346
  orders_preset.click(set_preset, inputs=gr.State("orders"), outputs=schema_context_outputs)
2347
  students_preset.click(set_preset, inputs=gr.State("students"), outputs=schema_context_outputs)
 
2356
  last_user_message,
2357
  sql_output,
2358
  validator_output,
2359
+ generation_meta_state,
2360
  error_output,
2361
+ conversation_state,
 
 
 
 
2362
  ]
2363
  send_button.click(
2364
  generate_response,
2365
+ inputs=[message_input, chatbot, active_schema, loaded_key_state, generation_meta_state, conversation_state],
2366
  outputs=chat_generation_outputs,
2367
  )
2368
  message_input.submit(
2369
  generate_response,
2370
+ inputs=[message_input, chatbot, active_schema, loaded_key_state, generation_meta_state, conversation_state],
2371
  outputs=chat_generation_outputs,
2372
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2373
  demo.load(
2374
  sync_on_load,
2375
  outputs=[
 
2388
  send_button,
2389
  sql_output,
2390
  validator_output,
2391
+ generation_meta_state,
2392
  error_output,
 
2393
  ],
2394
  )
2395
 
chat_state.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+
3
+
4
+ @dataclass(frozen=True)
5
+ class SchemaSuggestion:
6
+ table_name: str = ""
7
+ columns: tuple[tuple[str, str], ...] = ()
8
+ rationale: str = ""
9
+
10
+ @classmethod
11
+ def from_value(cls, value):
12
+ if isinstance(value, cls):
13
+ return value
14
+ if not isinstance(value, dict):
15
+ return None
16
+ raw_columns = value.get("columns") or ()
17
+ columns = []
18
+ for column in raw_columns:
19
+ if isinstance(column, dict):
20
+ name = str(column.get("name") or "").strip()
21
+ column_type = str(column.get("type") or "TEXT").strip().upper()
22
+ elif isinstance(column, (list, tuple)) and len(column) >= 2:
23
+ name = str(column[0] or "").strip()
24
+ column_type = str(column[1] or "TEXT").strip().upper()
25
+ else:
26
+ continue
27
+ if name:
28
+ columns.append((name, column_type or "TEXT"))
29
+ table_name = str(value.get("table_name") or "").strip()
30
+ rationale = str(value.get("rationale") or "").strip()
31
+ if not table_name or not columns:
32
+ return None
33
+ return cls(table_name=table_name, columns=tuple(columns), rationale=rationale)
34
+
35
+ def to_dict(self):
36
+ return {
37
+ "table_name": self.table_name,
38
+ "columns": [{"name": name, "type": column_type} for name, column_type in self.columns],
39
+ "rationale": self.rationale,
40
+ }
41
+
42
+
43
+ @dataclass(frozen=True)
44
+ class ConversationState:
45
+ active_schema: str = ""
46
+ pending_schema_suggestion: SchemaSuggestion | None = None
47
+ last_intent: str | None = None
48
+ last_table_topic: str | None = None
49
+ debug: dict = field(default_factory=dict)
50
+
51
+ @classmethod
52
+ def from_value(cls, value=None, *, active_schema=""):
53
+ if isinstance(value, cls):
54
+ if active_schema and active_schema != value.active_schema:
55
+ return value.with_active_schema(active_schema)
56
+ return value
57
+ if not isinstance(value, dict):
58
+ return cls(active_schema=(active_schema or "").strip())
59
+ pending = SchemaSuggestion.from_value(value.get("pending_schema_suggestion"))
60
+ state_active_schema = (value.get("active_schema") or active_schema or "").strip()
61
+ return cls(
62
+ active_schema=state_active_schema,
63
+ pending_schema_suggestion=pending,
64
+ last_intent=value.get("last_intent"),
65
+ last_table_topic=value.get("last_table_topic"),
66
+ debug=dict(value.get("debug") or {}),
67
+ )
68
+
69
+ def to_dict(self):
70
+ return {
71
+ "active_schema": self.active_schema,
72
+ "pending_schema_suggestion": (
73
+ self.pending_schema_suggestion.to_dict() if self.pending_schema_suggestion else None
74
+ ),
75
+ "last_intent": self.last_intent,
76
+ "last_table_topic": self.last_table_topic,
77
+ "debug": dict(self.debug or {}),
78
+ }
79
+
80
+ def with_active_schema(self, schema):
81
+ return ConversationState(
82
+ active_schema=(schema or "").strip(),
83
+ pending_schema_suggestion=self.pending_schema_suggestion,
84
+ last_intent=self.last_intent,
85
+ last_table_topic=self.last_table_topic,
86
+ debug=dict(self.debug or {}),
87
+ )
88
+
89
+ def with_pending_schema(self, suggestion):
90
+ suggestion = SchemaSuggestion.from_value(suggestion)
91
+ return ConversationState(
92
+ active_schema=self.active_schema,
93
+ pending_schema_suggestion=suggestion,
94
+ last_intent=self.last_intent,
95
+ last_table_topic=(suggestion.table_name if suggestion else self.last_table_topic),
96
+ debug=dict(self.debug or {}),
97
+ )
98
+
99
+ def clear_pending_schema(self):
100
+ return ConversationState(
101
+ active_schema=self.active_schema,
102
+ pending_schema_suggestion=None,
103
+ last_intent=self.last_intent,
104
+ last_table_topic=self.last_table_topic,
105
+ debug=dict(self.debug or {}),
106
+ )
107
+
108
+ def with_intent(self, intent_result):
109
+ debug = dict(self.debug or {})
110
+ debug["intent"] = getattr(intent_result, "intent", None)
111
+ debug["confidence"] = getattr(intent_result, "confidence", None)
112
+ debug["reason"] = getattr(intent_result, "reason", None)
113
+ return ConversationState(
114
+ active_schema=self.active_schema,
115
+ pending_schema_suggestion=self.pending_schema_suggestion,
116
+ last_intent=getattr(intent_result, "intent", None),
117
+ last_table_topic=self.last_table_topic,
118
+ debug=debug,
119
+ )
120
+
121
+
122
+ def default_state(active_schema=""):
123
+ return ConversationState(active_schema=(active_schema or "").strip()).to_dict()
124
+
intent.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ from chat_state import ConversationState
4
+ import sql_tools
5
+
6
+
7
+ SMALLTALK = "smalltalk"
8
+ SCHEMA_SUGGESTION = "schema_suggestion"
9
+ CREATE_TABLE = "create_table"
10
+ CREATE_TABLE_CONFIRM = "create_table_confirm"
11
+ EDIT_TABLE = "edit_table"
12
+ SQL_QUERY = "sql_query"
13
+ CLARIFICATION = "clarification"
14
+ UNKNOWN = "unknown"
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class IntentResult:
19
+ intent: str
20
+ confidence: float
21
+ reason: str
22
+
23
+
24
+ def _has_pending_schema(state):
25
+ return bool(getattr(state, "pending_schema_suggestion", None))
26
+
27
+
28
+ def _has_active_schema(state):
29
+ return bool((getattr(state, "active_schema", "") or "").strip())
30
+
31
+
32
+ def _is_confirmation(message):
33
+ normalized = sql_tools.normalize_text(message)
34
+ confirmations = {
35
+ "sim", "yes", "ok", "claro", "pode", "pode gerar", "gera", "gerar",
36
+ "gere", "faz", "faca", "cria", "crie", "manda", "confirmo", "isso",
37
+ "isso mesmo", "perfeito",
38
+ }
39
+ return normalized in confirmations or normalized.startswith(("gera ", "pode gerar", "faz "))
40
+
41
+
42
+ def _is_smalltalk(message):
43
+ normalized = sql_tools.normalize_text(message)
44
+ exact = {
45
+ "oi", "ola", "hi", "hello", "hey", "bom dia", "boa tarde", "boa noite",
46
+ "obrigado", "obrigada", "valeu", "thanks", "thank you",
47
+ "como voce esta", "como voce esta hoje", "qual seu nome",
48
+ "me conte uma piada", "conte uma piada", "vamos conversar",
49
+ "o que voce faz", "como voce funciona", "como funciona",
50
+ }
51
+ if normalized in exact:
52
+ return True
53
+ smalltalk_fragments = (
54
+ "como voce esta",
55
+ "qual seu nome",
56
+ "conte uma piada",
57
+ "vamos conversar",
58
+ "obrigado",
59
+ )
60
+ return any(fragment in normalized for fragment in smalltalk_fragments)
61
+
62
+
63
+ def _is_schema_suggestion(message):
64
+ normalized = sql_tools.normalize_text(message)
65
+ patterns = (
66
+ "preciso de uma tabela",
67
+ "preciso de um schema",
68
+ "quero uma tabela",
69
+ "quero um schema",
70
+ "sugira uma tabela",
71
+ "sugerir uma tabela",
72
+ "tabela sobre",
73
+ "tabela de",
74
+ "schema sobre",
75
+ "schema de",
76
+ "modelo de tabela",
77
+ "modelar",
78
+ )
79
+ if any(pattern in normalized for pattern in patterns) and not sql_tools.is_create_table_intent(message):
80
+ return True
81
+ if "tabela" in normalized and any(term in normalized for term in ("sobre", "para", "de")):
82
+ return not any(term in normalized for term in ("crie", "criar", "create", "generate", "gerar", "gere"))
83
+ return False
84
+
85
+
86
+ def classify_intent(message, state=None, chat_history=None):
87
+ state = ConversationState.from_value(state)
88
+ normalized = sql_tools.normalize_text(message)
89
+ if not normalized:
90
+ return IntentResult(UNKNOWN, 0.0, "empty_message")
91
+
92
+ if _has_pending_schema(state) and _is_confirmation(message):
93
+ return IntentResult(CREATE_TABLE_CONFIRM, 0.95, "confirmation_with_pending_schema")
94
+
95
+ if _is_smalltalk(message):
96
+ return IntentResult(SMALLTALK, 0.95, "smalltalk_phrase")
97
+
98
+ edited_table = sql_tools.edit_create_table_from_message(message, chat_history, state.active_schema)
99
+ if edited_table or sql_tools.is_table_edit_intent(message):
100
+ return IntentResult(EDIT_TABLE, 0.9 if (edited_table or _has_active_schema(state) or _has_pending_schema(state)) else 0.7, "table_edit_terms")
101
+
102
+ if sql_tools.is_create_table_intent(message):
103
+ return IntentResult(CREATE_TABLE, 0.9, "explicit_create_table")
104
+
105
+ if _is_schema_suggestion(message):
106
+ return IntentResult(SCHEMA_SUGGESTION, 0.86, "schema_suggestion_phrase")
107
+
108
+ if sql_tools.is_sql_intent(message, state.active_schema):
109
+ return IntentResult(SQL_QUERY, 0.86, "sql_query_terms")
110
+
111
+ if _has_pending_schema(state):
112
+ return IntentResult(CLARIFICATION, 0.55, "pending_schema_context")
113
+
114
+ return IntentResult(UNKNOWN, 0.25, "no_intent_match")
115
+
model_io.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+
4
+ from chat_state import ConversationState, SchemaSuggestion
5
+ import sql_tools
6
+
7
+
8
+ CHAT_GENERATION = "chat"
9
+ SCHEMA_GENERATION = "schema"
10
+ SQL_GENERATION = "sql"
11
+
12
+ GENERATION_BUDGETS = {
13
+ CHAT_GENERATION: 120,
14
+ SCHEMA_GENERATION: 180,
15
+ SQL_GENERATION: 96,
16
+ }
17
+
18
+ CHAT_PROMPT_TEMPLATE = (
19
+ "<|user|>\n"
20
+ "You are a conversational SQL assistant. Reply naturally in Brazilian Portuguese unless the user writes in English.\n"
21
+ "You can chat normally, discuss table ideas, and help generate SQL, but do not generate SQL unless the user asks for it.\n"
22
+ "Current state:\n{state_summary}\n\n"
23
+ "{history_context}"
24
+ "User message: {message}<|end|>\n"
25
+ "<|assistant|>"
26
+ )
27
+
28
+ SCHEMA_SUGGESTION_PROMPT_TEMPLATE = (
29
+ "<|user|>\n"
30
+ "Create a practical SQL table proposal for the user's domain request.\n"
31
+ "Return JSON only with this shape: "
32
+ '{{"table_name":"name","columns":[{{"name":"id","type":"INTEGER"}}],"rationale":"short reason"}}.\n'
33
+ "Use simple SQL types: INTEGER, TEXT, NUMERIC, DATE, BOOLEAN.\n"
34
+ "{history_context}"
35
+ "Request: {message}<|end|>\n"
36
+ "<|assistant|>"
37
+ )
38
+
39
+ SQL_PROMPT_TEMPLATE = (
40
+ "<|user|>\n"
41
+ "Given the following SQL table, write one SQL query. Output SQL only.\n\n"
42
+ "Table: {schema}\n\n"
43
+ "{history_context}"
44
+ "Question: {question}<|end|>\n"
45
+ "<|assistant|>"
46
+ )
47
+
48
+
49
+ def _history_context(chat_history, max_exchanges=3):
50
+ history = list(chat_history or [])[-max_exchanges * 2 :]
51
+ if not history:
52
+ return ""
53
+ lines = []
54
+ for item in history:
55
+ if not isinstance(item, dict):
56
+ continue
57
+ role = item.get("role", "user")
58
+ content = sql_tools.content_to_text(item.get("content", "")).strip()
59
+ if content:
60
+ lines.append(f"{role.title()}: {content}")
61
+ if not lines:
62
+ return ""
63
+ return "Previous conversation:\n" + "\n".join(lines) + "\n\n"
64
+
65
+
66
+ def _state_summary(state):
67
+ state = ConversationState.from_value(state)
68
+ pending = state.pending_schema_suggestion.table_name if state.pending_schema_suggestion else "none"
69
+ active = "present" if state.active_schema else "none"
70
+ return f"- active_schema: {active}\n- pending_schema_suggestion: {pending}"
71
+
72
+
73
+ def build_chat_prompt(message, state=None, chat_history=None):
74
+ return CHAT_PROMPT_TEMPLATE.format(
75
+ message=(message or "").strip(),
76
+ state_summary=_state_summary(state),
77
+ history_context=_history_context(chat_history),
78
+ )
79
+
80
+
81
+ def build_schema_suggestion_prompt(message, state=None, chat_history=None):
82
+ return SCHEMA_SUGGESTION_PROMPT_TEMPLATE.format(
83
+ message=(message or "").strip(),
84
+ history_context=_history_context(chat_history),
85
+ )
86
+
87
+
88
+ def build_sql_prompt(schema, message, chat_history=None):
89
+ table_schema = (schema or "").strip() or "CREATE TABLE unknown (id INTEGER)"
90
+ return SQL_PROMPT_TEMPLATE.format(
91
+ schema=table_schema,
92
+ question=(message or "").strip(),
93
+ history_context=_history_context(chat_history),
94
+ )
95
+
96
+
97
+ def build_generation_prompt(schema, message, chat_history=None):
98
+ return build_sql_prompt(schema, message, chat_history)
99
+
100
+
101
+ def generation_budget(kind):
102
+ return GENERATION_BUDGETS.get(kind, GENERATION_BUDGETS[SQL_GENERATION])
103
+
104
+
105
+ def clean_generation(text):
106
+ return sql_tools.clean_generation(text)
107
+
108
+
109
+ def extract_sql_candidate(text):
110
+ return sql_tools.extract_sql_candidate(text)
111
+
112
+
113
+ def is_sql_like(text):
114
+ return sql_tools.is_sql_like(text)
115
+
116
+
117
+ def format_generation_result(text):
118
+ cleaned = extract_sql_candidate(text)
119
+ if is_sql_like(cleaned):
120
+ return str(cleaned), "", sql_tools.validate_sql(cleaned)
121
+ return "", str(cleaned), '<span class="validator-badge validator-empty">Chat response</span>'
122
+
123
+
124
+ def parse_schema_suggestion(text):
125
+ cleaned = clean_generation(text)
126
+ match = re.search(r"\{.*\}", cleaned, flags=re.DOTALL)
127
+ raw_json = match.group(0) if match else cleaned
128
+ try:
129
+ payload = json.loads(raw_json)
130
+ except json.JSONDecodeError:
131
+ return None
132
+ return SchemaSuggestion.from_value(payload)
scripts/model_probe.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import sys
3
+ from pathlib import Path
4
+
5
+ ROOT = Path(__file__).resolve().parents[1]
6
+ if str(ROOT) not in sys.path:
7
+ sys.path.insert(0, str(ROOT))
8
+
9
+ import app # noqa: E402
10
+
11
+
12
+ def _assistant_text(result):
13
+ history = result[0] or []
14
+ return history[-1]["content"] if history else ""
15
+
16
+
17
+ def _scenario(name, message, history, active_schema, state):
18
+ result = app.generate_response(
19
+ message,
20
+ history,
21
+ active_schema,
22
+ app.FINE_TUNED_MODEL_KEY,
23
+ None,
24
+ state,
25
+ )
26
+ return {
27
+ "name": name,
28
+ "message": message,
29
+ "assistant": _assistant_text(result),
30
+ "sql": result[4],
31
+ "status": result[7],
32
+ "active_schema": result[2],
33
+ "state": result[8],
34
+ "history": result[0],
35
+ }
36
+
37
+
38
+ def _contains_any(text, needles):
39
+ text = (text or "").lower()
40
+ return any(needle.lower() in text for needle in needles)
41
+
42
+
43
+ def _grade(records):
44
+ checks = []
45
+ by_name = {record["name"]: record for record in records}
46
+
47
+ checks.append({
48
+ "name": "smalltalk_is_conversational",
49
+ "pass": bool(by_name["greeting"]["assistant"]) and not by_name["greeting"]["sql"],
50
+ "detail": "Greeting should produce chat text and no SQL.",
51
+ })
52
+ checks.append({
53
+ "name": "schema_suggestion_sets_pending",
54
+ "pass": bool((by_name["schema_request"]["state"] or {}).get("pending_schema_suggestion")),
55
+ "detail": "Domain table request should create a pending schema proposal.",
56
+ })
57
+ checks.append({
58
+ "name": "confirmation_generates_create_table",
59
+ "pass": "CREATE TABLE" in (by_name["confirm_generate"]["sql"] or "").upper(),
60
+ "detail": "Confirmation should generate CREATE TABLE SQL.",
61
+ })
62
+ checks.append({
63
+ "name": "edit_updates_schema",
64
+ "pass": _contains_any(by_name["edit_schema"]["sql"], ["numero_animais", "num_animais"]),
65
+ "detail": "Edit should replace capacidade with an animal-count column.",
66
+ })
67
+ checks.append({
68
+ "name": "query_generates_select",
69
+ "pass": "SELECT" in (by_name["query_schema"]["sql"] or "").upper(),
70
+ "detail": "Natural query should generate SELECT SQL.",
71
+ })
72
+ checks.append({
73
+ "name": "smalltalk_with_schema_stays_chat",
74
+ "pass": bool(by_name["smalltalk_with_schema"]["assistant"]) and not by_name["smalltalk_with_schema"]["sql"],
75
+ "detail": "Smalltalk with active schema should not become SQL.",
76
+ })
77
+ return checks
78
+
79
+
80
+ def main():
81
+ app.load_model(app.FINE_TUNED_MODEL_ID)
82
+
83
+ history = []
84
+ active_schema = ""
85
+ state = app.chat_core.default_state()
86
+ records = []
87
+
88
+ for name, message in [
89
+ ("greeting", "oi"),
90
+ ("schema_request", "preciso de uma tabela sobre zoologico"),
91
+ ("confirm_generate", "gera"),
92
+ ("edit_schema", "troca capacidade por numero_animais"),
93
+ ("query_schema", "liste zoologicos de Sao Paulo"),
94
+ ("smalltalk_with_schema", "como voce esta hoje?"),
95
+ ]:
96
+ record = _scenario(name, message, history, active_schema, state)
97
+ records.append({key: value for key, value in record.items() if key != "history"})
98
+ history = record["history"]
99
+ active_schema = record["active_schema"]
100
+ state = record["state"]
101
+
102
+ checks = _grade(records)
103
+ report = {
104
+ "model": app.FINE_TUNED_MODEL_ID,
105
+ "passed": all(check["pass"] for check in checks),
106
+ "checks": checks,
107
+ "records": records,
108
+ }
109
+ print(json.dumps(report, ensure_ascii=False, indent=2))
110
+ return 0 if report["passed"] else 1
111
+
112
+
113
+ if __name__ == "__main__":
114
+ raise SystemExit(main())
115
+
sql_tools.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import html
2
+ import re
3
+ import unicodedata
4
+
5
+ import sqlparse
6
+
7
+
8
+ SQL_STARTERS = {"SELECT", "WITH", "INSERT", "UPDATE", "DELETE", "CREATE", "ALTER", "DROP"}
9
+
10
+
11
+ def content_to_text(value):
12
+ if value is None:
13
+ return ""
14
+ if isinstance(value, str):
15
+ return value
16
+ if isinstance(value, dict):
17
+ for key in ("text", "content", "value"):
18
+ if key in value:
19
+ return content_to_text(value[key])
20
+ return " ".join(content_to_text(item) for item in value.values())
21
+ if isinstance(value, (list, tuple)):
22
+ return "\n".join(content_to_text(item) for item in value)
23
+ return str(value)
24
+
25
+
26
+ def normalize_text(value):
27
+ text = content_to_text(value).lower()
28
+ text = unicodedata.normalize("NFKD", text)
29
+ text = "".join(char for char in text if not unicodedata.combining(char))
30
+ return re.sub(r"\s+", " ", text).strip()
31
+
32
+
33
+ def clean_generation(text):
34
+ cleaned = content_to_text(text).strip()
35
+ if cleaned.startswith("```"):
36
+ lines = cleaned.splitlines()
37
+ if lines and lines[0].strip().lower() in {"```", "```sql"}:
38
+ lines = lines[1:]
39
+ if lines and lines[-1].strip() == "```":
40
+ lines = lines[:-1]
41
+ cleaned = "\n".join(lines).strip()
42
+ for marker in ("<|end|>", "<|user|>", "<|assistant|>", "</s>"):
43
+ if marker in cleaned:
44
+ cleaned = cleaned.split(marker, 1)[0].strip()
45
+ if cleaned.upper().startswith("SQL:"):
46
+ cleaned = cleaned[4:].strip()
47
+ return cleaned
48
+
49
+
50
+ def extract_sql_candidate(text):
51
+ cleaned = clean_generation(text)
52
+ match = re.search(r"\b(SELECT|WITH|INSERT|UPDATE|DELETE|CREATE|ALTER|DROP)\b", cleaned, flags=re.IGNORECASE)
53
+ if not match:
54
+ return cleaned
55
+ return cleaned[match.start() :].strip()
56
+
57
+
58
+ def is_sql_like(text):
59
+ text = (text or "").strip()
60
+ if not text:
61
+ return False
62
+ first_word = re.match(r"^\s*([A-Za-z]+)", text)
63
+ if not first_word:
64
+ return False
65
+ return first_word.group(1).upper() in SQL_STARTERS
66
+
67
+
68
+ def is_sql_intent(message, schema=""):
69
+ message = normalize_text(message)
70
+ if not message:
71
+ return False
72
+ smalltalk_patterns = {
73
+ "oi", "ola", "olá", "hi", "hello", "hey", "obrigado", "obrigada", "thanks",
74
+ "thank you", "como voce esta", "como você esta", "qual seu nome", "me conte uma piada",
75
+ "vamos conversar", "como voce funciona", "como funciona", "o que voce faz", "o que faz",
76
+ }
77
+ if message in {normalize_text(item) for item in smalltalk_patterns}:
78
+ return False
79
+ if any(pattern in message for pattern in ("como voce esta", "qual seu nome", "conte uma piada")):
80
+ return False
81
+ sql_terms = {
82
+ "all", "average", "count", "columns", "database", "find", "get", "group by",
83
+ "join", "list", "order by", "query", "rows", "select", "show", "sum", "where",
84
+ "consulta", "consultar", "contar", "colunas", "linhas", "liste", "listar",
85
+ "maior", "mais caro", "menor", "media", "mostre", "mostrar", "ordene",
86
+ "selecione", "some", "soma", "quantos", "filtre", "filtrar",
87
+ }
88
+ if any(re.search(rf"(?<!\w){re.escape(normalize_text(term))}(?!\w)", message) for term in sql_terms):
89
+ return True
90
+ return bool(schema and is_sql_like(message))
91
+
92
+
93
+ def validate_sql(sql_text):
94
+ sql_text = (sql_text or "").strip()
95
+ if not sql_text:
96
+ return '<span class="validator-badge validator-empty">No SQL yet</span>'
97
+ try:
98
+ statements = [stmt for stmt in sqlparse.parse(sql_text) if str(stmt).strip()]
99
+ except Exception as exc:
100
+ error_type = html.escape(type(exc).__name__)
101
+ return (
102
+ '<span class="validator-badge validator-warn">Check syntax</span>'
103
+ f'<span class="validator-detail">sqlparse error: {error_type}</span>'
104
+ )
105
+ if not statements:
106
+ return (
107
+ '<span class="validator-badge validator-warn">Check syntax</span>'
108
+ '<span class="validator-detail">No parsed SQL statement.</span>'
109
+ )
110
+ first_token = statements[0].token_first(skip_cm=True)
111
+ token_value = first_token.value.strip().upper() if first_token is not None else "UNKNOWN"
112
+ if token_value not in SQL_STARTERS:
113
+ escaped_token = html.escape(token_value)
114
+ return (
115
+ '<span class="validator-badge validator-warn">Check syntax</span>'
116
+ f'<span class="validator-detail">First token: {escaped_token}</span>'
117
+ )
118
+ return '<span class="validator-badge validator-ok">Valid SQL</span>'
119
+
120
+
121
+ def is_create_table_intent(message):
122
+ message = (message or "").strip().lower()
123
+ return bool(
124
+ re.search(r"\b(create|make|build|generate|criar|crie|cria|criando|gerar|gere|gera|gerando|faz|faça|fazendo|monta|montar|monte)\b", message)
125
+ and re.search(r"\b(table|schema|tabela)\b", message)
126
+ )
127
+
128
+
129
+ def is_rename_intent(message):
130
+ message = (message or "").strip().lower()
131
+ return bool(
132
+ re.search(
133
+ r"\b(rename|edit|change|renomeie|renomear|renomeia|altere|mude|muda|troca|trocar)\s+\w+\s+(to|para|as|como|por)\s+\w+",
134
+ message,
135
+ flags=re.IGNORECASE,
136
+ )
137
+ )
138
+
139
+
140
+ def is_table_edit_intent(message):
141
+ message = (message or "").strip().lower()
142
+ edit_terms = r"\b(edit|update|modify|alter|add|include|remove|delete|drop|edita|editar|altera|altere|alterar|mude|mudar|adicione|adicionar|inclua|incluir|acrescente|remova|remover|deletar|exclua|excluir|novo|nova|troca|trocar|coloque|colocar)\b"
143
+ direct_add_terms = r"\b(add|include|adicione|adicionar|adicionando|inclua|incluir|acrescente|coloque|colocar)\b"
144
+ direct_remove_terms = r"\b(remove|delete|drop|remova|remover|deletar|exclua|excluir)\b"
145
+ target_terms = r"\b(column|field|element|coluna|campo|elemento|item)\b"
146
+ sql_aggregation_terms = {"up", "sum", "total", "count", "average", "avg", "max", "min", "by"}
147
+ add_match = re.search(direct_add_terms, message)
148
+ if add_match:
149
+ after_add = message[add_match.start() + len(add_match.group()) :].strip()
150
+ first_word_after = after_add.split()[0] if after_add.split() else ""
151
+ is_add_intent = first_word_after not in sql_aggregation_terms
152
+ else:
153
+ is_add_intent = False
154
+ return bool(
155
+ is_add_intent
156
+ or re.search(direct_remove_terms, message)
157
+ or is_rename_intent(message)
158
+ or re.search(r"\b(?:altere|alterar|mude|mudar)\b.*\bter\b", message)
159
+ or (re.search(edit_terms, message) and (re.search(target_terms, message) or ":" in message or re.search(r"\bpor\b", message)))
160
+ )
161
+
162
+
163
+ def infer_column_type(column_name):
164
+ name = column_name.strip().lower()
165
+ if name == "id" or name.endswith("_id") or name in {"quantity", "quantidade", "stock", "estoque", "year"}:
166
+ return "INTEGER"
167
+ if name in {
168
+ "salary", "price", "preco", "amount", "total", "grade", "peso", "weight",
169
+ "idade", "age", "altura", "height", "largura", "width", "comprimento",
170
+ "length", "desconto", "discount",
171
+ }:
172
+ return "NUMERIC"
173
+ if name in {"date", "created_at", "updated_at"} or name.endswith("_date"):
174
+ return "DATE"
175
+ return "TEXT"
176
+
177
+
178
+ def normalize_identifier(value):
179
+ identifier = re.sub(r"\W+", "_", normalize_text(value)).strip("_")
180
+ if not identifier:
181
+ return ""
182
+ if identifier[0].isdigit():
183
+ identifier = f"col_{identifier}"
184
+ return identifier
185
+
186
+
187
+ def parse_column_definition(raw_column):
188
+ raw_column = re.sub(r"\b(for me|please|por favor)\b", "", raw_column or "", flags=re.IGNORECASE)
189
+ raw_column = raw_column.strip(" .;:")
190
+ if not raw_column:
191
+ return None
192
+ type_matches = list(
193
+ re.finditer(
194
+ r"\b(integer|int|numeric|decimal|real|float|double|text|varchar|char|date|datetime|timestamp|boolean|bool)\b",
195
+ raw_column,
196
+ flags=re.IGNORECASE,
197
+ )
198
+ )
199
+ explicit_type = type_matches[-1] if type_matches else None
200
+ if explicit_type:
201
+ name_part = raw_column[: explicit_type.start()].strip()
202
+ column_type = explicit_type.group(1).upper()
203
+ if column_type == "INT":
204
+ column_type = "INTEGER"
205
+ elif column_type == "BOOL":
206
+ column_type = "BOOLEAN"
207
+ elif column_type == "DECIMAL":
208
+ column_type = "NUMERIC"
209
+ elif column_type in {"FLOAT", "DOUBLE"}:
210
+ column_type = "REAL"
211
+ if not name_part.strip():
212
+ column_type = None
213
+ name_part = raw_column
214
+ else:
215
+ name_part = raw_column
216
+ column_type = None
217
+ name_part = re.sub(r"\b(column|field|coluna|campo)\b", "", name_part, flags=re.IGNORECASE)
218
+ column_name = normalize_identifier(name_part)
219
+ if not column_name:
220
+ return None
221
+ return column_name, column_type or infer_column_type(column_name)
222
+
223
+
224
+ def split_column_list(columns_text):
225
+ columns_text = re.sub(r"\s+(and|e)\s+", ",", columns_text or "", flags=re.IGNORECASE)
226
+ parts = []
227
+ type_pattern = r"\b(integer|int|numeric|decimal|real|float|double|text|varchar|char|date|datetime|timestamp|boolean|bool)\b"
228
+ type_tokens = {
229
+ "integer", "int", "numeric", "decimal", "real", "float", "double",
230
+ "text", "varchar", "char", "date", "datetime", "timestamp", "boolean", "bool",
231
+ }
232
+ stopwords = {"to", "from", "into", "as", "for", "o", "a", "os", "de", "do", "da", "dos", "das"}
233
+ for part in (item.strip() for item in columns_text.split(",") if item.strip()):
234
+ tokens = [token.strip() for token in re.split(r"\s+", part) if token.strip()]
235
+ tokens = [token for token in tokens if token.lower() not in stopwords]
236
+ if not tokens:
237
+ continue
238
+ if re.search(type_pattern, part, flags=re.IGNORECASE) and len(tokens) > 2:
239
+ index = 0
240
+ inferrable_names = {"total", "date", "time", "timestamp", "int", "text", "real", "char"}
241
+ while index < len(tokens):
242
+ current = tokens[index]
243
+ next_token = tokens[index + 1].lower() if index + 1 < len(tokens) else ""
244
+ if next_token in type_tokens and not (
245
+ current.lower() in inferrable_names and next_token in {"date", "datetime", "timestamp"}
246
+ ):
247
+ parts.append(f"{current} {tokens[index + 1]}")
248
+ index += 2
249
+ else:
250
+ parts.append(current)
251
+ index += 1
252
+ continue
253
+ if re.search(type_pattern, part, flags=re.IGNORECASE):
254
+ parts.append(part)
255
+ continue
256
+ if len(tokens) > 1 and all(re.match(r"^[A-Za-z_][\wÀ-ÿ]*$", token) for token in tokens):
257
+ parts.extend(tokens)
258
+ else:
259
+ parts.append(part)
260
+ return parts
261
+
262
+
263
+ def format_create_table(table_name, columns):
264
+ if not table_name or not columns:
265
+ return ""
266
+ seen = set()
267
+ column_lines = []
268
+ for column_name, column_type in columns:
269
+ if column_name in seen:
270
+ continue
271
+ seen.add(column_name)
272
+ column_lines.append(f" {column_name} {column_type}")
273
+ if not column_lines:
274
+ return ""
275
+ return f"CREATE TABLE {table_name} (\n" + ",\n".join(column_lines) + "\n);"
276
+
277
+
278
+ def create_table_from_message(message):
279
+ message = (message or "").strip()
280
+ patterns = (
281
+ r"\b(?:table|tabela)\s+(?:called\s+|named\s+|chamada?\s+|nomeada?\s+)?([A-Za-z_][\w]*)\s+(?:with|containing|including|com)\s+(.+)$",
282
+ r"\b(?:create|make|build|generate|criar|crie|gerar|gere)\b.*?\b(?:table|tabela)\b\s+([A-Za-z_][\w]*)\s+(?:with|containing|including|com)\s+(.+)$",
283
+ )
284
+ for pattern in patterns:
285
+ match = re.search(pattern, message, flags=re.IGNORECASE)
286
+ if not match:
287
+ continue
288
+ table_name = normalize_identifier(match.group(1))
289
+ columns = [
290
+ parsed
291
+ for parsed in (parse_column_definition(column) for column in split_column_list(match.group(2)))
292
+ if parsed
293
+ ]
294
+ return format_create_table(table_name, columns)
295
+ return ""
296
+
297
+
298
+ def parse_create_table_schema(schema):
299
+ schema = (schema or "").strip()
300
+ match = re.match(
301
+ r"^\s*(?:CREATE\s+TABLE\s+)?([A-Za-z_][\w]*)\s*\((.*?)\)\s*;?\s*$",
302
+ schema,
303
+ flags=re.IGNORECASE | re.DOTALL,
304
+ )
305
+ if not match:
306
+ return "", []
307
+ table_name = normalize_identifier(match.group(1))
308
+ columns = [
309
+ parsed
310
+ for parsed in (parse_column_definition(column) for column in split_column_list(match.group(2)))
311
+ if parsed
312
+ ]
313
+ return table_name, columns
314
+
315
+
316
+ def create_table_from_schema(schema):
317
+ table_name, columns = parse_create_table_schema(schema)
318
+ return format_create_table(table_name, columns)
319
+
320
+
321
+ def extract_create_table_statement(text):
322
+ cleaned = extract_sql_candidate(text)
323
+ match = re.search(
324
+ r"\bCREATE\s+TABLE\s+[A-Za-z_][\w]*\s*\(.*?\)\s*;?",
325
+ cleaned,
326
+ flags=re.IGNORECASE | re.DOTALL,
327
+ )
328
+ return clean_generation(match.group(0)) if match else ""
329
+
330
+
331
+ def last_create_table_from_history(chat_history):
332
+ for item in reversed(list(chat_history or [])):
333
+ if not isinstance(item, dict) or item.get("role") != "assistant":
334
+ continue
335
+ statement = extract_create_table_statement(item.get("content", ""))
336
+ if statement:
337
+ return statement
338
+ return ""
339
+
340
+
341
+ def extract_added_columns(message):
342
+ message = (message or "").strip()
343
+ patterns = (
344
+ r":\s*(.+)$",
345
+ r"\b(?:add|include|with|adicionar|adicione|adicionando|inclua|incluir|acrescente|ter|coloque|colocar)\b\s+(?:um\s+|uma\s+|a\s+|an\s+)?(?:novo\s+|nova\s+|new\s+)?(?:column|field|element|coluna|campo|elemento|item)?\s*(.+)$",
346
+ )
347
+ for pattern in patterns:
348
+ match = re.search(pattern, message, flags=re.IGNORECASE)
349
+ if not match:
350
+ continue
351
+ columns = [
352
+ parsed
353
+ for parsed in (parse_column_definition(column) for column in split_column_list(match.group(1)))
354
+ if parsed
355
+ ]
356
+ if columns:
357
+ return columns
358
+ return []
359
+
360
+
361
+ def extract_removed_columns(message):
362
+ message = (message or "").strip()
363
+ patterns = (
364
+ r"\b(?:remove|delete|drop|remova|remover|deletar|exclua|excluir)\b\s+(?:a\s+|o\s+|the\s+)?(?:column|field|element|coluna|campo|elemento|item)?\s*(.+)$",
365
+ )
366
+ for pattern in patterns:
367
+ match = re.search(pattern, message, flags=re.IGNORECASE)
368
+ if not match:
369
+ continue
370
+ columns = [normalize_identifier(column) for column in split_column_list(match.group(1))]
371
+ columns = [column for column in columns if column]
372
+ if columns:
373
+ return columns
374
+ return []
375
+
376
+
377
+ def extract_renamed_columns(message):
378
+ pattern = (
379
+ r"\b(?:rename|edit|change|renomeie|renomear|renomeia|altere|mude)\s+"
380
+ r"(\w+)\s+(?:to|para|as|como|por)\s+(\w+)"
381
+ )
382
+ matches = re.findall(pattern, message or "", flags=re.IGNORECASE)
383
+ troca_matches = re.findall(r"\btroca\b\s+(\w+)\s+\bpor\b\s+(\w+)", message or "", flags=re.IGNORECASE)
384
+ return [
385
+ (normalize_identifier(old), normalize_identifier(new))
386
+ for old, new in [*matches, *troca_matches]
387
+ if normalize_identifier(old) and normalize_identifier(new)
388
+ ]
389
+
390
+
391
+ def parse_compound_edit(message):
392
+ segment_pattern = (
393
+ r"\s+(?:and|e)\s+"
394
+ r"(?=\b(?:add|include|remove|delete|drop|rename|edit|change|"
395
+ r"adicione|adicionar|inclua|acrescente|remova|remover|deletar|"
396
+ r"exclua|renomeie|renomear|renomeia|altere|mude|troca|trocar)\b)"
397
+ )
398
+ segments = re.split(segment_pattern, message or "", flags=re.IGNORECASE)
399
+ added, removed, renamed = [], [], []
400
+ for seg in segments:
401
+ seg = seg.strip()
402
+ if not seg:
403
+ continue
404
+ if is_rename_intent(seg):
405
+ renamed.extend(extract_renamed_columns(seg))
406
+ elif re.search(r"\b(remove|delete|drop|remova|remover|deletar|exclua|excluir)\b", seg, flags=re.IGNORECASE):
407
+ removed.extend(extract_removed_columns(seg))
408
+ else:
409
+ cols = extract_added_columns(seg)
410
+ if cols:
411
+ added.extend(cols)
412
+ return added, removed, renamed
413
+
414
+
415
+ def edit_create_table_from_message(message, chat_history, active_schema):
416
+ if not is_table_edit_intent(message) and not is_rename_intent(message):
417
+ return ""
418
+ base_sql = last_create_table_from_history(chat_history) or create_table_from_schema(active_schema)
419
+ table_name, existing_columns = parse_create_table_schema(base_sql)
420
+ if not table_name:
421
+ return ""
422
+ added_columns, removed_columns_list, renamed_columns = parse_compound_edit(message)
423
+ removed_set = set(extract_removed_columns(message)) | {r for r in removed_columns_list}
424
+ if not added_columns and not removed_set and not renamed_columns:
425
+ return ""
426
+ rename_map = dict(renamed_columns)
427
+ kept_columns = [
428
+ (rename_map.get(col_name, col_name), col_type)
429
+ for col_name, col_type in existing_columns
430
+ if col_name not in removed_set
431
+ ]
432
+ return format_create_table(table_name, [*kept_columns, *added_columns])
433
+
434
+
435
+ def create_table_from_suggestion(suggestion):
436
+ if not suggestion:
437
+ return ""
438
+ if isinstance(suggestion, dict):
439
+ table_name = suggestion.get("table_name")
440
+ columns = [
441
+ (column.get("name"), column.get("type", "TEXT"))
442
+ for column in suggestion.get("columns", [])
443
+ if isinstance(column, dict)
444
+ ]
445
+ else:
446
+ table_name = getattr(suggestion, "table_name", "")
447
+ columns = getattr(suggestion, "columns", ())
448
+ parsed = []
449
+ for name, column_type in columns:
450
+ identifier = normalize_identifier(name)
451
+ if identifier:
452
+ parsed.append((identifier, (column_type or "TEXT").upper()))
453
+ return format_create_table(normalize_identifier(table_name), parsed)
454
+
tests/test_chatbot_behavior.py CHANGED
@@ -493,7 +493,7 @@ def test_off_topic_message_returns_fallback(monkeypatch):
493
  result = app.generate_response("me conte uma piada", [], "", None, None)
494
 
495
  assert sql_output(result) == ""
496
- assert "schema" in assistant_text(result).lower() or "tabela" in assistant_text(result).lower()
497
 
498
 
499
  def test_greeting_returns_fallback(monkeypatch):
@@ -670,3 +670,23 @@ def test_build_generation_prompt_no_history_no_context():
670
  assert "comida" in prompt
671
  assert "liste todos" in prompt or "liste" in prompt
672
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493
  result = app.generate_response("me conte uma piada", [], "", None, None)
494
 
495
  assert sql_output(result) == ""
496
+ assert "load the fine-tuned model" in assistant_text(result).lower()
497
 
498
 
499
  def test_greeting_returns_fallback(monkeypatch):
 
670
  assert "comida" in prompt
671
  assert "liste todos" in prompt or "liste" in prompt
672
 
673
+
674
+ # ---------------------------------------------------------------------------
675
+ # Regression: ISSUE-001 — "troca X por Y" rename pattern
676
+ # ---------------------------------------------------------------------------
677
+
678
+ def test_troca_x_por_y_rename(monkeypatch):
679
+ monkeypatch.setattr(app, "_run_generation", lambda *a, **k: pytest.fail("model should not run"))
680
+ base = app.generate_response(
681
+ "crie tabela comida com id nome sabor peso tipo", [], "", None, None
682
+ )
683
+ result = app.generate_response(
684
+ "troca tipo por medida", base[0], base[2], None, None
685
+ )
686
+ schema = sql_output(result)
687
+ assert "medida TEXT" in schema
688
+ assert "tipo TEXT" not in schema
689
+ assert "id INTEGER" in schema
690
+ assert "nome TEXT" in schema
691
+ assert "sabor TEXT" in schema
692
+ assert "peso NUMERIC" in schema
tests/test_chatbot_core.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import types
2
+
3
+ import app
4
+ from chat_state import ConversationState, SchemaSuggestion
5
+ from intent import (
6
+ CREATE_TABLE_CONFIRM,
7
+ EDIT_TABLE,
8
+ SCHEMA_SUGGESTION,
9
+ SMALLTALK,
10
+ SQL_QUERY,
11
+ classify_intent,
12
+ )
13
+
14
+
15
+ def test_conversation_state_roundtrip_dict():
16
+ suggestion = SchemaSuggestion(
17
+ table_name="zoologico",
18
+ columns=(("id", "INTEGER"), ("nome", "TEXT")),
19
+ rationale="base",
20
+ )
21
+ state = ConversationState(active_schema="", pending_schema_suggestion=suggestion, last_intent=SCHEMA_SUGGESTION)
22
+
23
+ restored = ConversationState.from_value(state.to_dict())
24
+
25
+ pending = restored.pending_schema_suggestion
26
+ assert pending is not None
27
+ assert pending.table_name == "zoologico"
28
+ assert pending.columns == (("id", "INTEGER"), ("nome", "TEXT"))
29
+ assert restored.last_intent == SCHEMA_SUGGESTION
30
+
31
+
32
+ def test_intent_smalltalk_with_active_schema_is_not_sql():
33
+ state = ConversationState(active_schema="CREATE TABLE employees (id INTEGER)")
34
+
35
+ result = classify_intent("como voce esta hoje?", state)
36
+
37
+ assert result.intent == SMALLTALK
38
+
39
+
40
+ def test_intent_schema_suggestion_and_confirmation():
41
+ state = ConversationState()
42
+
43
+ suggestion = classify_intent("preciso de uma tabela sobre zoologico", state)
44
+ pending = state.with_pending_schema(
45
+ SchemaSuggestion(table_name="zoologico", columns=(("id", "INTEGER"), ("nome", "TEXT")))
46
+ )
47
+ confirmation = classify_intent("gera", pending)
48
+
49
+ assert suggestion.intent == SCHEMA_SUGGESTION
50
+ assert confirmation.intent == CREATE_TABLE_CONFIRM
51
+
52
+
53
+ def test_intent_edit_and_sql_query():
54
+ state = ConversationState(active_schema="CREATE TABLE zoologico (id INTEGER, cidade TEXT)")
55
+
56
+ edit = classify_intent("troca cidade por municipio", state)
57
+ query = classify_intent("liste zoologicos por municipio", state)
58
+
59
+ assert edit.intent == EDIT_TABLE
60
+ assert query.intent == SQL_QUERY
61
+
62
+
63
+ def test_zoologico_transcript_with_mocked_model(monkeypatch):
64
+ app._model = types.SimpleNamespace(generation_config=types.SimpleNamespace(eos_token_id=0))
65
+ app._tokenizer = types.SimpleNamespace(eos_token_id=0, pad_token_id=0)
66
+ app._current_model_id = app.FINE_TUNED_MODEL_ID
67
+
68
+ def fake_generate(prompt, generation_kind):
69
+ if generation_kind == app.model_core.CHAT_GENERATION:
70
+ return "Oi, posso ajudar com conversa comum ou SQL.", 1
71
+ if generation_kind == app.model_core.SCHEMA_GENERATION:
72
+ return (
73
+ '{"table_name":"zoologico","columns":['
74
+ '{"name":"id","type":"INTEGER"},'
75
+ '{"name":"nome","type":"TEXT"},'
76
+ '{"name":"cidade","type":"TEXT"},'
77
+ '{"name":"capacidade","type":"INTEGER"}],'
78
+ '"rationale":"Tabela inicial para zoologicos."}',
79
+ 1,
80
+ )
81
+ return "SELECT * FROM zoologico WHERE cidade = 'Sao Paulo';", 1
82
+
83
+ monkeypatch.setattr(app, "_generate_model_text", fake_generate)
84
+
85
+ r1 = app.generate_response("oi", [], "", app.FINE_TUNED_MODEL_KEY, None)
86
+ assert app.EMPTY_CHAT_OUTPUT == ""
87
+ assert r1[4] == ""
88
+ assert "Oi" in r1[0][-1]["content"]
89
+
90
+ r2 = app.generate_response("preciso de uma tabela sobre zoologico", r1[0], r1[2], app.FINE_TUNED_MODEL_KEY, None, r1[8])
91
+ assert r2[4] == ""
92
+ assert r2[8]["pending_schema_suggestion"] is not None
93
+ assert r2[8]["pending_schema_suggestion"]["table_name"] == "zoologico"
94
+
95
+ r3 = app.generate_response("gera", r2[0], r2[2], app.FINE_TUNED_MODEL_KEY, None, r2[8])
96
+ assert "CREATE TABLE zoologico" in r3[4]
97
+ assert "CREATE TABLE zoologico" in r3[2]
98
+ assert r3[8]["pending_schema_suggestion"] is None
99
+
100
+ r4 = app.generate_response("troca capacidade por numero_animais", r3[0], r3[2], app.FINE_TUNED_MODEL_KEY, None, r3[8])
101
+ assert "numero_animais INTEGER" in r4[4]
102
+ assert "capacidade" not in r4[4]
103
+
104
+ r5 = app.generate_response("liste zoologicos de Sao Paulo", r4[0], r4[2], app.FINE_TUNED_MODEL_KEY, None, r4[8])
105
+ assert "SELECT * FROM zoologico" in r5[4]