Shizu0n commited on
Commit
ad5be9b
·
1 Parent(s): 47affa0

refactor: model scope for sql query only

Browse files
.gitignore CHANGED
@@ -37,6 +37,10 @@ Thumbs.db
37
  logs/
38
  *.log
39
 
 
 
 
 
40
  # Large local model artifacts should stay on the Hub, not in this Space repo
41
  *.safetensors
42
  *.bin
@@ -48,7 +52,10 @@ logs/
48
  # Local agent/workspace notes
49
  /AGENTS.md
50
  /CLAUDE.md
51
- /PROGRESS.md
 
 
 
52
  .claude
53
  .gstack/
54
  docs
 
37
  logs/
38
  *.log
39
 
40
+ # Graphify local artifacts
41
+ graphify-out/
42
+ .graphifyignore
43
+
44
  # Large local model artifacts should stay on the Hub, not in this Space repo
45
  *.safetensors
46
  *.bin
 
52
  # Local agent/workspace notes
53
  /AGENTS.md
54
  /CLAUDE.md
55
+ /HANDOFF.md
56
+ /FLOWS.md
57
+ /ERRORS.md
58
+ /ROADMAP.md
59
  .claude
60
  .gstack/
61
  docs
README.md CHANGED
@@ -17,16 +17,16 @@ Generates SQL queries from a table schema and a natural-language question using
17
 
18
  ## What the App Does
19
 
20
- Transforms simple table descriptions and questions into SQL using the fine-tuned Phi-3 Mini model. The base model is shown as offline evaluation evidence instead of a second live CPU-loaded model.
21
 
22
  ## How to Use
23
 
24
  1. Click **Load fine-tuned model**.
25
  - Loading is lazy: the model is only downloaded and loaded when you request it.
26
  - On CPU, the first load can take a few minutes.
27
- 2. Chat normally or enter/edit the **SQL table schema**.
28
  - You can use the presets: `employees`, `orders`, `students`, `products`, `sales`.
29
- - You can also write your own schema manually.
30
  3. Enter the question in the chat input.
31
  4. Click **Send**.
32
  5. Review the result in `gr.Code(language="sql")`.
@@ -50,8 +50,9 @@ Reported gain: **+71.5 percentage points** over the base model.
50
 
51
  ## Current Features
52
 
53
- - Gradio UI with a step-by-step flow: load the fine-tuned model, chat, and inspect SQL artifacts.
54
- - Intent routing that keeps normal conversation separate from SQL generation.
 
55
  - Lazy loading to reduce startup cost.
56
  - Preserved Phi-3 patches for local/Spaces compatibility.
57
  - Schema presets without blocking manual input.
@@ -66,7 +67,7 @@ The normal pytest suite does not load the 3.8B model. To manually verify the rea
66
  python scripts/model_probe.py
67
  ```
68
 
69
- The probe prints JSON with pass/fail checks for greeting, schema proposal, CREATE TABLE confirmation, schema edit, SQL query, and smalltalk while a schema is active.
70
 
71
  ## Run Locally
72
 
 
17
 
18
  ## What the App Does
19
 
20
+ Transforms explicit table schemas and schema-edit requests deterministically, then uses the fine-tuned Phi-3 Mini model only for SQL query generation. The base model is shown as offline evaluation evidence instead of a second live CPU-loaded model.
21
 
22
  ## How to Use
23
 
24
  1. Click **Load fine-tuned model**.
25
  - Loading is lazy: the model is only downloaded and loaded when you request it.
26
  - On CPU, the first load can take a few minutes.
27
+ 2. Select, create, or edit the **SQL table schema**.
28
  - You can use the presets: `employees`, `orders`, `students`, `products`, `sales`.
29
+ - You can also ask for explicit schema operations, such as `create table products with id name price` or `add stock`.
30
  3. Enter the question in the chat input.
31
  4. Click **Send**.
32
  5. Review the result in `gr.Code(language="sql")`.
 
50
 
51
  ## Current Features
52
 
53
+ - Gradio UI with a step-by-step flow: load the fine-tuned model, define schema context, and inspect SQL artifacts.
54
+ - Intent routing with 5 supported routes: `CREATE_TABLE`, `EDIT_TABLE`, `SQL_QUERY`, `SMALLTALK`, `UNKNOWN`.
55
+ - Model calls only for `SQL_QUERY`; smalltalk and unknown messages use a static fallback.
56
  - Lazy loading to reduce startup cost.
57
  - Preserved Phi-3 patches for local/Spaces compatibility.
58
  - Schema presets without blocking manual input.
 
67
  python scripts/model_probe.py
68
  ```
69
 
70
+ The probe prints JSON with pass/fail checks for static fallback, deterministic CREATE TABLE, deterministic schema edit, SQL query generation, and smalltalk while a schema is active.
71
 
72
  ## Run Locally
73
 
app.py CHANGED
@@ -18,29 +18,18 @@ import model_io as model_core
18
  import sql_tools as sql_core
19
 
20
 
21
- BASE_MODEL_ID = "microsoft/Phi-3-mini-4k-instruct"
22
  FINE_TUNED_MODEL_ID = "Shizu0n/phi3-mini-sql-generator-merged"
23
 
24
- BASE_MODEL_KEY = "base"
25
  FINE_TUNED_MODEL_KEY = "fine_tuned"
26
  DEFAULT_MODEL_KEY = FINE_TUNED_MODEL_KEY
 
 
 
 
 
 
27
 
28
  MODEL_CATALOG = {
29
- BASE_MODEL_KEY: {
30
- "label": "Base Phi-3 Mini",
31
- "short_label": "Base",
32
- "tag": "Base",
33
- "title": "Phi-3 Mini base",
34
- "model_id": BASE_MODEL_ID,
35
- "exact_match": "2.0%",
36
- "trust_remote_code": False,
37
- "ready_text": "Base model ready",
38
- "metadata": (
39
- "Model: microsoft/Phi-3-mini-4k-instruct\n"
40
- "Role: unfine-tuned baseline\n"
41
- "Metric: 2.0% exact match on the comparison setup"
42
- ),
43
- },
44
  FINE_TUNED_MODEL_KEY: {
45
  "label": "Fine-tuned QLoRA model",
46
  "short_label": "Fine-tuned",
@@ -400,12 +389,7 @@ def normalize_text(value):
400
 
401
 
402
  def safe_chat_fallback(_message=""):
403
- return (
404
- "Selecione um schema e faça uma pergunta SQL, "
405
- "ou peça para criar ou editar uma tabela. "
406
- "Exemplo: 'crie tabela produtos com id nome preco' "
407
- "ou 'qual o produto mais caro?'."
408
- )
409
 
410
 
411
  def clean_generation(text):
@@ -459,11 +443,11 @@ def render_header():
459
  <section class="top-panel">
460
  <div>
461
  <h1>Phi-3 Mini SQL Chatbot</h1>
462
- <p>Conversational SQL assistant powered by a fine-tuned Phi-3 Mini model</p>
463
  </div>
464
  <div class="top-badges">
465
- <span class="badge badge-green">Natural chat + SQL</span>
466
- <span class="badge badge-cream">Context-aware schema</span>
467
  <span class="badge badge-light">CPU lazy load</span>
468
  </div>
469
  </section>
@@ -534,10 +518,10 @@ def render_loading_overlay(model_key=None, visible=False):
534
  def model_metadata(model_key=None):
535
  return """
536
  <section class="stats-row">
537
- <div class="stat-card"><strong>Chat</strong><span>normal conversation</span></div>
538
- <div class="stat-card"><strong>Schema</strong><span>table proposals</span></div>
539
- <div class="stat-card"><strong>SQL</strong><span>query generation</span></div>
540
- <div class="stat-card"><strong>Probe</strong><span>manual model gate</span></div>
541
  </section>
542
  """
543
 
@@ -935,21 +919,6 @@ def render_message(message="", kind="error"):
935
  return f'<div class="message-box {class_name}">{html.escape(str(message))}</div>'
936
 
937
 
938
- def select_model(model_key, loaded_key):
939
- selected_key = model_key if model_key in MODEL_CATALOG else DEFAULT_MODEL_KEY
940
- query_is_active = loaded_key == selected_key
941
- return (
942
- selected_key,
943
- render_model_card(BASE_MODEL_KEY, selected_key),
944
- render_model_card(FINE_TUNED_MODEL_KEY, selected_key),
945
- render_status(selected_key, loaded_key),
946
- model_metadata(selected_key),
947
- *query_control_updates(query_is_active),
948
- gr.update(interactive=False),
949
- render_message(),
950
- )
951
-
952
-
953
  def load_selected_model(selected_key=FINE_TUNED_MODEL_KEY):
954
  selected_key = FINE_TUNED_MODEL_KEY
955
  model_def = model_by_key(selected_key)
@@ -966,7 +935,6 @@ def load_selected_model(selected_key=FINE_TUNED_MODEL_KEY):
966
  *query_control_updates(False),
967
  "",
968
  EMPTY_VALIDATOR,
969
- gr.update(value=None),
970
  render_message(),
971
  )
972
  started = time.time()
@@ -996,7 +964,6 @@ def load_selected_model(selected_key=FINE_TUNED_MODEL_KEY):
996
  *query_control_updates(False),
997
  "",
998
  EMPTY_VALIDATOR,
999
- gr.update(value=None),
1000
  render_message(error),
1001
  )
1002
  return
@@ -1011,7 +978,6 @@ def load_selected_model(selected_key=FINE_TUNED_MODEL_KEY):
1011
  *query_control_updates(True),
1012
  "",
1013
  EMPTY_VALIDATOR,
1014
- gr.update(value=None),
1015
  render_message(f"Loaded {model_def['model_id']} in {elapsed}s.", kind="ok"),
1016
  )
1017
 
@@ -1030,62 +996,6 @@ def trim_chat_history(chat_history, max_exchanges=10):
1030
  return history[-max_exchanges * 2 :]
1031
 
1032
 
1033
- def comparison_updates(saved_state, current_sql, loaded_key):
1034
- if not saved_state or not (current_sql or "").strip():
1035
- return gr.update(visible=False), "", "", "", ""
1036
-
1037
- loaded_def = model_by_key(loaded_key) if loaded_key else model_by_key(DEFAULT_MODEL_KEY)
1038
- return (
1039
- gr.update(visible=True),
1040
- render_compare_label("Saved", saved_state.get("model_label", "Unknown"), saved_state.get("match", "")),
1041
- saved_state.get("sql", ""),
1042
- render_compare_label("Current", loaded_def["short_label"], loaded_def["exact_match"]),
1043
- current_sql or "",
1044
- )
1045
-
1046
-
1047
- def render_compare_label(prefix, model_label, metric):
1048
- metric_html = f"<strong>{html.escape(metric)} match</strong>" if metric else ""
1049
- return (
1050
- f'<div class="compare-head"><span>{html.escape(prefix)} - '
1051
- f"{html.escape(model_label)}</span>{metric_html}</div>"
1052
- )
1053
-
1054
-
1055
- def save_for_comparison(sql_text, loaded_key, active_schema, last_message):
1056
- sql_text = (sql_text or "").strip()
1057
- if not sql_text or not loaded_key:
1058
- return (
1059
- None,
1060
- gr.update(visible=False),
1061
- "",
1062
- "",
1063
- "",
1064
- "",
1065
- gr.update(interactive=False, visible=False),
1066
- render_message("Generate SQL before saving a comparison."),
1067
- )
1068
-
1069
- model_def = model_by_key(loaded_key)
1070
- saved = {
1071
- "sql": sql_text,
1072
- "model_label": model_def["short_label"],
1073
- "match": model_def["exact_match"],
1074
- "schema_context": active_schema or "",
1075
- "user_message": last_message or "",
1076
- }
1077
- return (
1078
- saved,
1079
- gr.update(visible=True),
1080
- render_compare_label("Saved", model_def["short_label"], model_def["exact_match"]),
1081
- sql_text,
1082
- render_compare_label("Current", model_def["short_label"], model_def["exact_match"]),
1083
- sql_text,
1084
- gr.update(interactive=True),
1085
- render_message("Saved output for comparison.", kind="ok"),
1086
- )
1087
-
1088
-
1089
  def _append_chat_turn(chat_history, message, assistant_content):
1090
  return trim_chat_history(
1091
  [
@@ -1109,7 +1019,7 @@ def _response_tuple(
1109
  ):
1110
  state = chat_core.ConversationState.from_value(state)
1111
  if sql_text and "CREATE TABLE" in sql_text.upper():
1112
- state = state.with_active_schema(sql_text).clear_pending_schema()
1113
  new_history = _append_chat_turn(chat_history, message, assistant_content)
1114
  return (
1115
  new_history,
@@ -1118,7 +1028,6 @@ def _response_tuple(
1118
  message,
1119
  sql_text,
1120
  validator,
1121
- gr.update(value=None),
1122
  render_message(status_message, kind=status_kind),
1123
  state.to_dict(),
1124
  )
@@ -1129,7 +1038,6 @@ def deterministic_response(
1129
  message,
1130
  active_schema,
1131
  loaded_key,
1132
- saved_state,
1133
  assistant_content,
1134
  status_message,
1135
  *,
@@ -1153,7 +1061,7 @@ def deterministic_response(
1153
 
1154
  def _model_ready(loaded_key):
1155
  if not loaded_key or _model is None or _tokenizer is None:
1156
- return False, "Load the fine-tuned model before chatting or generating SQL."
1157
  model_def = model_by_key(loaded_key)
1158
  if _current_model_id != model_def["model_id"]:
1159
  return False, "Loaded model state is inconsistent. Reload the selected model."
@@ -1195,12 +1103,6 @@ def _generate_model_text(prompt, generation_kind=model_core.SQL_GENERATION):
1195
  return generated_text, int(time.time() - started)
1196
 
1197
 
1198
- def _schema_suggestion_message(suggestion):
1199
- columns = ", ".join(f"{name} {column_type}" for name, column_type in suggestion.columns)
1200
- rationale = f"\n\n{suggestion.rationale}" if suggestion.rationale else ""
1201
- return f"Posso montar a tabela `{suggestion.table_name}` com: {columns}.{rationale}\n\nSe quiser, diga `gera`."
1202
-
1203
-
1204
  def _empty_generation_response(chat_history, message, state, status_message, *, status_kind="error"):
1205
  return (
1206
  chat_history,
@@ -1209,13 +1111,12 @@ def _empty_generation_response(chat_history, message, state, status_message, *,
1209
  "",
1210
  "",
1211
  EMPTY_VALIDATOR,
1212
- gr.update(value=None),
1213
  render_message(status_message, kind=status_kind),
1214
  state.to_dict(),
1215
  )
1216
 
1217
 
1218
- def generate_response(message, chat_history, active_schema, loaded_key, saved_state=None, conversation_state=None):
1219
  message = (message or "").strip()
1220
  chat_history = list(chat_history or [])
1221
  state = chat_core.ConversationState.from_value(conversation_state, active_schema=(active_schema or ""))
@@ -1227,7 +1128,6 @@ def generate_response(message, chat_history, active_schema, loaded_key, saved_st
1227
  "",
1228
  "",
1229
  EMPTY_VALIDATOR,
1230
- gr.update(value=None),
1231
  render_message("Type a message before sending."),
1232
  state.to_dict(),
1233
  )
@@ -1241,20 +1141,6 @@ def generate_response(message, chat_history, active_schema, loaded_key, saved_st
1241
  )
1242
 
1243
  if intent_result.intent == intent_core.EDIT_TABLE:
1244
- if state.pending_schema_suggestion and not state.active_schema:
1245
- pending_sql = sql_core.create_table_from_suggestion(state.pending_schema_suggestion)
1246
- edited_table = sql_core.edit_create_table_from_message(message, chat_history, pending_sql)
1247
- table_name, columns = sql_core.parse_create_table_schema(edited_table)
1248
- if edited_table and table_name and columns:
1249
- suggestion = chat_core.SchemaSuggestion(table_name=table_name, columns=tuple(columns))
1250
- state = state.with_pending_schema(suggestion)
1251
- return _response_tuple(
1252
- chat_history,
1253
- message,
1254
- state,
1255
- _schema_suggestion_message(suggestion),
1256
- "Updated pending table proposal without calling the model.",
1257
- )
1258
  edited_table = sql_core.edit_create_table_from_message(message, chat_history, state.active_schema)
1259
  if edited_table:
1260
  display_response = f"```sql\n{edited_table}\n```"
@@ -1274,32 +1160,8 @@ def generate_response(message, chat_history, active_schema, loaded_key, saved_st
1274
  "I need an existing CREATE TABLE in the chat or an active schema before editing columns.",
1275
  )
1276
 
1277
- if intent_result.intent == intent_core.CREATE_TABLE_CONFIRM:
1278
- sql_text = sql_core.create_table_from_suggestion(state.pending_schema_suggestion)
1279
- if sql_text:
1280
- display_response = f"```sql\n{sql_text}\n```"
1281
- return _response_tuple(
1282
- chat_history,
1283
- message,
1284
- state.clear_pending_schema(),
1285
- display_response,
1286
- "Generated CREATE TABLE from the pending proposal.",
1287
- sql_text=sql_text,
1288
- validator=sql_core.validate_sql(sql_text),
1289
- )
1290
-
1291
  if intent_result.intent == intent_core.CREATE_TABLE:
1292
  sql_text = sql_core.create_table_from_message(message) or sql_core.create_table_from_schema(state.active_schema)
1293
- if not sql_text:
1294
- ready, _error = _model_ready(loaded_key)
1295
- if ready:
1296
- try:
1297
- prompt = model_core.build_schema_suggestion_prompt(message, state, chat_history)
1298
- generated_text, _elapsed = _generate_model_text(prompt, model_core.SCHEMA_GENERATION)
1299
- suggestion = model_core.parse_schema_suggestion(generated_text)
1300
- sql_text = sql_core.create_table_from_suggestion(suggestion)
1301
- except Exception:
1302
- sql_text = ""
1303
  if sql_text:
1304
  display_response = f"```sql\n{sql_text}\n```"
1305
  return _response_tuple(
@@ -1315,72 +1177,17 @@ def generate_response(message, chat_history, active_schema, loaded_key, saved_st
1315
  chat_history,
1316
  message,
1317
  state,
1318
- "CREATE TABLE needs a table name and columns, or a loaded model to propose them.",
1319
  )
1320
 
1321
- if intent_result.intent == intent_core.SCHEMA_SUGGESTION:
1322
- ready, error = _model_ready(loaded_key)
1323
- if not ready:
1324
- return _response_tuple(chat_history, message, state, error, error, status_kind="error")
1325
- try:
1326
- prompt = model_core.build_schema_suggestion_prompt(message, state, chat_history)
1327
- generated_text, elapsed = _generate_model_text(prompt, model_core.SCHEMA_GENERATION)
1328
- suggestion = model_core.parse_schema_suggestion(generated_text)
1329
- if not suggestion:
1330
- repair_prompt = (
1331
- "Return valid JSON only for this SQL table proposal. "
1332
- "Use table_name, columns, and rationale.\n\n"
1333
- f"Previous output:\n{generated_text}"
1334
- )
1335
- repaired_text, elapsed = _generate_model_text(repair_prompt, model_core.SCHEMA_GENERATION)
1336
- suggestion = model_core.parse_schema_suggestion(repaired_text)
1337
- if not suggestion:
1338
- return _response_tuple(
1339
- chat_history,
1340
- message,
1341
- state,
1342
- "Nao consegui estruturar essa proposta de tabela. Diga o nome da tabela e algumas colunas.",
1343
- "Schema proposal was not valid JSON.",
1344
- status_kind="error",
1345
- )
1346
- state = state.with_pending_schema(suggestion)
1347
- return _response_tuple(
1348
- chat_history,
1349
- message,
1350
- state,
1351
- _schema_suggestion_message(suggestion),
1352
- f"Generated schema proposal in {elapsed}s.",
1353
- )
1354
- except Exception as exc:
1355
- return _empty_generation_response(
1356
- chat_history,
1357
- message,
1358
- state,
1359
- f"Generation failed: {type(exc).__name__}: {exc}",
1360
- )
1361
-
1362
- if intent_result.intent in {intent_core.SMALLTALK, intent_core.CLARIFICATION, intent_core.UNKNOWN}:
1363
- ready, error = _model_ready(loaded_key)
1364
- if not ready:
1365
- return _response_tuple(chat_history, message, state, error, error, status_kind="error")
1366
- try:
1367
- prompt = model_core.build_chat_prompt(message, state, chat_history)
1368
- generated_text, elapsed = _generate_model_text(prompt, model_core.CHAT_GENERATION)
1369
- chat_text = model_core.clean_generation(generated_text)
1370
- return _response_tuple(
1371
- chat_history,
1372
- message,
1373
- state,
1374
- chat_text,
1375
- f"Generated chat response in {elapsed}s.",
1376
- )
1377
- except Exception as exc:
1378
- return _empty_generation_response(
1379
- chat_history,
1380
- message,
1381
- state,
1382
- f"Generation failed: {type(exc).__name__}: {exc}",
1383
- )
1384
 
1385
  ready, error = _model_ready(loaded_key)
1386
  if not ready:
@@ -1458,7 +1265,6 @@ def sync_on_load():
1458
  *query_control_updates(True),
1459
  "",
1460
  EMPTY_VALIDATOR,
1461
- gr.update(value=None),
1462
  render_message(f"Model already loaded: {_current_model_id}", kind="ok"),
1463
  )
1464
  return (
@@ -1470,7 +1276,6 @@ def sync_on_load():
1470
  *query_control_updates(False),
1471
  "",
1472
  EMPTY_VALIDATOR,
1473
- gr.update(value=None),
1474
  render_message(),
1475
  )
1476
 
@@ -1481,7 +1286,6 @@ CSS = """
1481
  /* Prevent Gradio dark theme from overriding text in light-bg components */
1482
  [class*="badge"],
1483
  [class*="validator-"],
1484
- [class*="compare-head"],
1485
  [class*="model-tag"],
1486
  [class*="stat-card"] {
1487
  color: inherit !important;
@@ -1608,7 +1412,6 @@ CSS = """
1608
  }
1609
 
1610
  .model-grid,
1611
- .compare-grid,
1612
  .stats-row {
1613
  display: grid;
1614
  gap: 12px;
@@ -1616,7 +1419,6 @@ CSS = """
1616
  }
1617
 
1618
  .model-grid > div,
1619
- .compare-grid > div,
1620
  .stats-row > div {
1621
  min-width: 0;
1622
  }
@@ -1756,8 +1558,7 @@ CSS = """
1756
  }
1757
 
1758
  #load-button,
1759
- #generate-button,
1760
- #save-button {
1761
  width: 100% !important;
1762
  }
1763
 
@@ -1790,19 +1591,6 @@ CSS = """
1790
  color: var(--bg-base) !important;
1791
  }
1792
 
1793
- #save-button button {
1794
- background: transparent !important;
1795
- border: 0.5px solid var(--border-hi) !important;
1796
- color: var(--text-primary) !important;
1797
- min-height: 38px !important;
1798
- width: 100% !important;
1799
- }
1800
-
1801
- #save-button button:hover {
1802
- border-color: var(--text-primary) !important;
1803
- }
1804
-
1805
- #save-button button:disabled,
1806
  #generate-button button:disabled {
1807
  opacity: 0.4 !important;
1808
  }
@@ -2079,10 +1867,7 @@ textarea {
2079
 
2080
  .output-shell .cm-editor,
2081
  .output-shell pre,
2082
- .output-shell code,
2083
- .compare-card .cm-editor,
2084
- .compare-card pre,
2085
- .compare-card code {
2086
  border: 0 !important;
2087
  font-size: 12px !important;
2088
  font-weight: 400 !important;
@@ -2104,45 +1889,6 @@ textarea {
2104
  color: var(--teal);
2105
  }
2106
 
2107
- .comparison-panel {
2108
- margin-top: 28px;
2109
- }
2110
-
2111
- .compare-card {
2112
- background: var(--bg-surface);
2113
- border: 0.5px solid var(--border);
2114
- border-radius: 6px;
2115
- overflow: hidden;
2116
- }
2117
-
2118
- .compare-card.current {
2119
- border-color: rgba(29, 158, 117, 0.45);
2120
- }
2121
-
2122
- .compare-head {
2123
- align-items: center;
2124
- background: var(--amber-soft);
2125
- color: var(--amber-text) !important;
2126
- display: flex;
2127
- font-size: 11px;
2128
- font-weight: 500;
2129
- gap: 16px;
2130
- justify-content: space-between;
2131
- min-height: 34px;
2132
- padding: 0 12px;
2133
- }
2134
-
2135
- .compare-card.current .compare-head,
2136
- .current-compare-head .compare-head {
2137
- background: var(--teal-soft);
2138
- color: var(--teal-text) !important;
2139
- }
2140
-
2141
- .compare-head strong {
2142
- color: inherit;
2143
- font-weight: 500;
2144
- }
2145
-
2146
  .loading-overlay {
2147
  align-items: center;
2148
  background: rgba(0, 0, 0, 0.6);
@@ -2210,7 +1956,6 @@ textarea {
2210
  @media (max-width: 860px) {
2211
  .top-panel,
2212
  .model-grid,
2213
- .compare-grid,
2214
  .evidence-grid {
2215
  grid-template-columns: 1fr;
2216
  }
@@ -2233,7 +1978,6 @@ with gr.Blocks(title="Phi-3 Mini SQL Chatbot") as demo:
2233
  loaded_key_state = gr.State(value=None)
2234
  active_schema = gr.State(value="")
2235
  conversation_state = gr.State(value=chat_core.default_state())
2236
- generation_meta_state = gr.State(value=None)
2237
  last_user_message = gr.State(value="")
2238
 
2239
  with gr.Column(elem_classes=["app-shell"]):
@@ -2335,7 +2079,6 @@ with gr.Blocks(title="Phi-3 Mini SQL Chatbot") as demo:
2335
  send_button,
2336
  sql_output,
2337
  validator_output,
2338
- generation_meta_state,
2339
  error_output,
2340
  ],
2341
  js=LOAD_SCROLL_JS,
@@ -2356,18 +2099,17 @@ with gr.Blocks(title="Phi-3 Mini SQL Chatbot") as demo:
2356
  last_user_message,
2357
  sql_output,
2358
  validator_output,
2359
- generation_meta_state,
2360
  error_output,
2361
  conversation_state,
2362
  ]
2363
  send_button.click(
2364
  generate_response,
2365
- inputs=[message_input, chatbot, active_schema, loaded_key_state, generation_meta_state, conversation_state],
2366
  outputs=chat_generation_outputs,
2367
  )
2368
  message_input.submit(
2369
  generate_response,
2370
- inputs=[message_input, chatbot, active_schema, loaded_key_state, generation_meta_state, conversation_state],
2371
  outputs=chat_generation_outputs,
2372
  )
2373
  demo.load(
@@ -2388,7 +2130,6 @@ with gr.Blocks(title="Phi-3 Mini SQL Chatbot") as demo:
2388
  send_button,
2389
  sql_output,
2390
  validator_output,
2391
- generation_meta_state,
2392
  error_output,
2393
  ],
2394
  )
 
18
  import sql_tools as sql_core
19
 
20
 
 
21
  FINE_TUNED_MODEL_ID = "Shizu0n/phi3-mini-sql-generator-merged"
22
 
 
23
  FINE_TUNED_MODEL_KEY = "fine_tuned"
24
  DEFAULT_MODEL_KEY = FINE_TUNED_MODEL_KEY
25
+ FALLBACK_RESPONSE = (
26
+ "Select a schema and ask a SQL question, "
27
+ "or ask to create or edit a table. "
28
+ "Example: 'what is the most expensive product?' or "
29
+ "'create table products with id name price'."
30
+ )
31
 
32
  MODEL_CATALOG = {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  FINE_TUNED_MODEL_KEY: {
34
  "label": "Fine-tuned QLoRA model",
35
  "short_label": "Fine-tuned",
 
389
 
390
 
391
  def safe_chat_fallback(_message=""):
392
+ return FALLBACK_RESPONSE
 
 
 
 
 
393
 
394
 
395
  def clean_generation(text):
 
443
  <section class="top-panel">
444
  <div>
445
  <h1>Phi-3 Mini SQL Chatbot</h1>
446
+ <p>SQL generation demo powered by a fine-tuned Phi-3 Mini model</p>
447
  </div>
448
  <div class="top-badges">
449
+ <span class="badge badge-green">SQL_QUERY only</span>
450
+ <span class="badge badge-cream">Deterministic schema edits</span>
451
  <span class="badge badge-light">CPU lazy load</span>
452
  </div>
453
  </section>
 
518
  def model_metadata(model_key=None):
519
  return """
520
  <section class="stats-row">
521
+ <div class="stat-card"><strong>SQL</strong><span>model-generated SELECT queries</span></div>
522
+ <div class="stat-card"><strong>Create</strong><span>deterministic CREATE TABLE parser</span></div>
523
+ <div class="stat-card"><strong>Edit</strong><span>deterministic schema updates</span></div>
524
+ <div class="stat-card"><strong>Fallback</strong><span>static non-SQL response</span></div>
525
  </section>
526
  """
527
 
 
919
  return f'<div class="message-box {class_name}">{html.escape(str(message))}</div>'
920
 
921
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
922
  def load_selected_model(selected_key=FINE_TUNED_MODEL_KEY):
923
  selected_key = FINE_TUNED_MODEL_KEY
924
  model_def = model_by_key(selected_key)
 
935
  *query_control_updates(False),
936
  "",
937
  EMPTY_VALIDATOR,
 
938
  render_message(),
939
  )
940
  started = time.time()
 
964
  *query_control_updates(False),
965
  "",
966
  EMPTY_VALIDATOR,
 
967
  render_message(error),
968
  )
969
  return
 
978
  *query_control_updates(True),
979
  "",
980
  EMPTY_VALIDATOR,
 
981
  render_message(f"Loaded {model_def['model_id']} in {elapsed}s.", kind="ok"),
982
  )
983
 
 
996
  return history[-max_exchanges * 2 :]
997
 
998
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
999
  def _append_chat_turn(chat_history, message, assistant_content):
1000
  return trim_chat_history(
1001
  [
 
1019
  ):
1020
  state = chat_core.ConversationState.from_value(state)
1021
  if sql_text and "CREATE TABLE" in sql_text.upper():
1022
+ state = state.with_active_schema(sql_text)
1023
  new_history = _append_chat_turn(chat_history, message, assistant_content)
1024
  return (
1025
  new_history,
 
1028
  message,
1029
  sql_text,
1030
  validator,
 
1031
  render_message(status_message, kind=status_kind),
1032
  state.to_dict(),
1033
  )
 
1038
  message,
1039
  active_schema,
1040
  loaded_key,
 
1041
  assistant_content,
1042
  status_message,
1043
  *,
 
1061
 
1062
  def _model_ready(loaded_key):
1063
  if not loaded_key or _model is None or _tokenizer is None:
1064
+ return False, "Load the fine-tuned model before generating SQL."
1065
  model_def = model_by_key(loaded_key)
1066
  if _current_model_id != model_def["model_id"]:
1067
  return False, "Loaded model state is inconsistent. Reload the selected model."
 
1103
  return generated_text, int(time.time() - started)
1104
 
1105
 
 
 
 
 
 
 
1106
  def _empty_generation_response(chat_history, message, state, status_message, *, status_kind="error"):
1107
  return (
1108
  chat_history,
 
1111
  "",
1112
  "",
1113
  EMPTY_VALIDATOR,
 
1114
  render_message(status_message, kind=status_kind),
1115
  state.to_dict(),
1116
  )
1117
 
1118
 
1119
+ def generate_response(message, chat_history, active_schema, loaded_key, conversation_state=None):
1120
  message = (message or "").strip()
1121
  chat_history = list(chat_history or [])
1122
  state = chat_core.ConversationState.from_value(conversation_state, active_schema=(active_schema or ""))
 
1128
  "",
1129
  "",
1130
  EMPTY_VALIDATOR,
 
1131
  render_message("Type a message before sending."),
1132
  state.to_dict(),
1133
  )
 
1141
  )
1142
 
1143
  if intent_result.intent == intent_core.EDIT_TABLE:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1144
  edited_table = sql_core.edit_create_table_from_message(message, chat_history, state.active_schema)
1145
  if edited_table:
1146
  display_response = f"```sql\n{edited_table}\n```"
 
1160
  "I need an existing CREATE TABLE in the chat or an active schema before editing columns.",
1161
  )
1162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1163
  if intent_result.intent == intent_core.CREATE_TABLE:
1164
  sql_text = sql_core.create_table_from_message(message) or sql_core.create_table_from_schema(state.active_schema)
 
 
 
 
 
 
 
 
 
 
1165
  if sql_text:
1166
  display_response = f"```sql\n{sql_text}\n```"
1167
  return _response_tuple(
 
1177
  chat_history,
1178
  message,
1179
  state,
1180
+ "CREATE TABLE needs a table name and columns.",
1181
  )
1182
 
1183
+ if intent_result.intent in {intent_core.SMALLTALK, intent_core.UNKNOWN}:
1184
+ return _response_tuple(
1185
+ chat_history,
1186
+ message,
1187
+ state,
1188
+ FALLBACK_RESPONSE,
1189
+ "Static fallback - no model call.",
1190
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1191
 
1192
  ready, error = _model_ready(loaded_key)
1193
  if not ready:
 
1265
  *query_control_updates(True),
1266
  "",
1267
  EMPTY_VALIDATOR,
 
1268
  render_message(f"Model already loaded: {_current_model_id}", kind="ok"),
1269
  )
1270
  return (
 
1276
  *query_control_updates(False),
1277
  "",
1278
  EMPTY_VALIDATOR,
 
1279
  render_message(),
1280
  )
1281
 
 
1286
  /* Prevent Gradio dark theme from overriding text in light-bg components */
1287
  [class*="badge"],
1288
  [class*="validator-"],
 
1289
  [class*="model-tag"],
1290
  [class*="stat-card"] {
1291
  color: inherit !important;
 
1412
  }
1413
 
1414
  .model-grid,
 
1415
  .stats-row {
1416
  display: grid;
1417
  gap: 12px;
 
1419
  }
1420
 
1421
  .model-grid > div,
 
1422
  .stats-row > div {
1423
  min-width: 0;
1424
  }
 
1558
  }
1559
 
1560
  #load-button,
1561
+ #generate-button {
 
1562
  width: 100% !important;
1563
  }
1564
 
 
1591
  color: var(--bg-base) !important;
1592
  }
1593
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1594
  #generate-button button:disabled {
1595
  opacity: 0.4 !important;
1596
  }
 
1867
 
1868
  .output-shell .cm-editor,
1869
  .output-shell pre,
1870
+ .output-shell code {
 
 
 
1871
  border: 0 !important;
1872
  font-size: 12px !important;
1873
  font-weight: 400 !important;
 
1889
  color: var(--teal);
1890
  }
1891
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1892
  .loading-overlay {
1893
  align-items: center;
1894
  background: rgba(0, 0, 0, 0.6);
 
1956
  @media (max-width: 860px) {
1957
  .top-panel,
1958
  .model-grid,
 
1959
  .evidence-grid {
1960
  grid-template-columns: 1fr;
1961
  }
 
1978
  loaded_key_state = gr.State(value=None)
1979
  active_schema = gr.State(value="")
1980
  conversation_state = gr.State(value=chat_core.default_state())
 
1981
  last_user_message = gr.State(value="")
1982
 
1983
  with gr.Column(elem_classes=["app-shell"]):
 
2079
  send_button,
2080
  sql_output,
2081
  validator_output,
 
2082
  error_output,
2083
  ],
2084
  js=LOAD_SCROLL_JS,
 
2099
  last_user_message,
2100
  sql_output,
2101
  validator_output,
 
2102
  error_output,
2103
  conversation_state,
2104
  ]
2105
  send_button.click(
2106
  generate_response,
2107
+ inputs=[message_input, chatbot, active_schema, loaded_key_state, conversation_state],
2108
  outputs=chat_generation_outputs,
2109
  )
2110
  message_input.submit(
2111
  generate_response,
2112
+ inputs=[message_input, chatbot, active_schema, loaded_key_state, conversation_state],
2113
  outputs=chat_generation_outputs,
2114
  )
2115
  demo.load(
 
2130
  send_button,
2131
  sql_output,
2132
  validator_output,
 
2133
  error_output,
2134
  ],
2135
  )
chat_state.py CHANGED
@@ -1,51 +1,10 @@
1
  from dataclasses import dataclass, field
2
 
3
 
4
- @dataclass(frozen=True)
5
- class SchemaSuggestion:
6
- table_name: str = ""
7
- columns: tuple[tuple[str, str], ...] = ()
8
- rationale: str = ""
9
-
10
- @classmethod
11
- def from_value(cls, value):
12
- if isinstance(value, cls):
13
- return value
14
- if not isinstance(value, dict):
15
- return None
16
- raw_columns = value.get("columns") or ()
17
- columns = []
18
- for column in raw_columns:
19
- if isinstance(column, dict):
20
- name = str(column.get("name") or "").strip()
21
- column_type = str(column.get("type") or "TEXT").strip().upper()
22
- elif isinstance(column, (list, tuple)) and len(column) >= 2:
23
- name = str(column[0] or "").strip()
24
- column_type = str(column[1] or "TEXT").strip().upper()
25
- else:
26
- continue
27
- if name:
28
- columns.append((name, column_type or "TEXT"))
29
- table_name = str(value.get("table_name") or "").strip()
30
- rationale = str(value.get("rationale") or "").strip()
31
- if not table_name or not columns:
32
- return None
33
- return cls(table_name=table_name, columns=tuple(columns), rationale=rationale)
34
-
35
- def to_dict(self):
36
- return {
37
- "table_name": self.table_name,
38
- "columns": [{"name": name, "type": column_type} for name, column_type in self.columns],
39
- "rationale": self.rationale,
40
- }
41
-
42
-
43
  @dataclass(frozen=True)
44
  class ConversationState:
45
  active_schema: str = ""
46
- pending_schema_suggestion: SchemaSuggestion | None = None
47
  last_intent: str | None = None
48
- last_table_topic: str | None = None
49
  debug: dict = field(default_factory=dict)
50
 
51
  @classmethod
@@ -56,52 +15,24 @@ class ConversationState:
56
  return value
57
  if not isinstance(value, dict):
58
  return cls(active_schema=(active_schema or "").strip())
59
- pending = SchemaSuggestion.from_value(value.get("pending_schema_suggestion"))
60
  state_active_schema = (value.get("active_schema") or active_schema or "").strip()
61
  return cls(
62
  active_schema=state_active_schema,
63
- pending_schema_suggestion=pending,
64
  last_intent=value.get("last_intent"),
65
- last_table_topic=value.get("last_table_topic"),
66
  debug=dict(value.get("debug") or {}),
67
  )
68
 
69
  def to_dict(self):
70
  return {
71
  "active_schema": self.active_schema,
72
- "pending_schema_suggestion": (
73
- self.pending_schema_suggestion.to_dict() if self.pending_schema_suggestion else None
74
- ),
75
  "last_intent": self.last_intent,
76
- "last_table_topic": self.last_table_topic,
77
  "debug": dict(self.debug or {}),
78
  }
79
 
80
  def with_active_schema(self, schema):
81
  return ConversationState(
82
  active_schema=(schema or "").strip(),
83
- pending_schema_suggestion=self.pending_schema_suggestion,
84
  last_intent=self.last_intent,
85
- last_table_topic=self.last_table_topic,
86
- debug=dict(self.debug or {}),
87
- )
88
-
89
- def with_pending_schema(self, suggestion):
90
- suggestion = SchemaSuggestion.from_value(suggestion)
91
- return ConversationState(
92
- active_schema=self.active_schema,
93
- pending_schema_suggestion=suggestion,
94
- last_intent=self.last_intent,
95
- last_table_topic=(suggestion.table_name if suggestion else self.last_table_topic),
96
- debug=dict(self.debug or {}),
97
- )
98
-
99
- def clear_pending_schema(self):
100
- return ConversationState(
101
- active_schema=self.active_schema,
102
- pending_schema_suggestion=None,
103
- last_intent=self.last_intent,
104
- last_table_topic=self.last_table_topic,
105
  debug=dict(self.debug or {}),
106
  )
107
 
@@ -112,13 +43,10 @@ class ConversationState:
112
  debug["reason"] = getattr(intent_result, "reason", None)
113
  return ConversationState(
114
  active_schema=self.active_schema,
115
- pending_schema_suggestion=self.pending_schema_suggestion,
116
  last_intent=getattr(intent_result, "intent", None),
117
- last_table_topic=self.last_table_topic,
118
  debug=debug,
119
  )
120
 
121
 
122
  def default_state(active_schema=""):
123
- return ConversationState(active_schema=(active_schema or "").strip()).to_dict()
124
-
 
1
  from dataclasses import dataclass, field
2
 
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  @dataclass(frozen=True)
5
  class ConversationState:
6
  active_schema: str = ""
 
7
  last_intent: str | None = None
 
8
  debug: dict = field(default_factory=dict)
9
 
10
  @classmethod
 
15
  return value
16
  if not isinstance(value, dict):
17
  return cls(active_schema=(active_schema or "").strip())
 
18
  state_active_schema = (value.get("active_schema") or active_schema or "").strip()
19
  return cls(
20
  active_schema=state_active_schema,
 
21
  last_intent=value.get("last_intent"),
 
22
  debug=dict(value.get("debug") or {}),
23
  )
24
 
25
  def to_dict(self):
26
  return {
27
  "active_schema": self.active_schema,
 
 
 
28
  "last_intent": self.last_intent,
 
29
  "debug": dict(self.debug or {}),
30
  }
31
 
32
  def with_active_schema(self, schema):
33
  return ConversationState(
34
  active_schema=(schema or "").strip(),
 
35
  last_intent=self.last_intent,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  debug=dict(self.debug or {}),
37
  )
38
 
 
43
  debug["reason"] = getattr(intent_result, "reason", None)
44
  return ConversationState(
45
  active_schema=self.active_schema,
 
46
  last_intent=getattr(intent_result, "intent", None),
 
47
  debug=debug,
48
  )
49
 
50
 
51
  def default_state(active_schema=""):
52
+ return ConversationState(active_schema=(active_schema or "").strip()).to_dict()
 
intent.py CHANGED
@@ -5,12 +5,9 @@ import sql_tools
5
 
6
 
7
  SMALLTALK = "smalltalk"
8
- SCHEMA_SUGGESTION = "schema_suggestion"
9
  CREATE_TABLE = "create_table"
10
- CREATE_TABLE_CONFIRM = "create_table_confirm"
11
  EDIT_TABLE = "edit_table"
12
  SQL_QUERY = "sql_query"
13
- CLARIFICATION = "clarification"
14
  UNKNOWN = "unknown"
15
 
16
 
@@ -21,29 +18,17 @@ class IntentResult:
21
  reason: str
22
 
23
 
24
- def _has_pending_schema(state):
25
- return bool(getattr(state, "pending_schema_suggestion", None))
26
-
27
-
28
  def _has_active_schema(state):
29
  return bool((getattr(state, "active_schema", "") or "").strip())
30
 
31
 
32
- def _is_confirmation(message):
33
- normalized = sql_tools.normalize_text(message)
34
- confirmations = {
35
- "sim", "yes", "ok", "claro", "pode", "pode gerar", "gera", "gerar",
36
- "gere", "faz", "faca", "cria", "crie", "manda", "confirmo", "isso",
37
- "isso mesmo", "perfeito",
38
- }
39
- return normalized in confirmations or normalized.startswith(("gera ", "pode gerar", "faz "))
40
-
41
-
42
  def _is_smalltalk(message):
43
  normalized = sql_tools.normalize_text(message)
44
  exact = {
45
  "oi", "ola", "hi", "hello", "hey", "bom dia", "boa tarde", "boa noite",
46
  "obrigado", "obrigada", "valeu", "thanks", "thank you",
 
 
47
  "como voce esta", "como voce esta hoje", "qual seu nome",
48
  "me conte uma piada", "conte uma piada", "vamos conversar",
49
  "o que voce faz", "como voce funciona", "como funciona",
@@ -56,60 +41,33 @@ def _is_smalltalk(message):
56
  "conte uma piada",
57
  "vamos conversar",
58
  "obrigado",
 
 
59
  )
60
  return any(fragment in normalized for fragment in smalltalk_fragments)
61
 
62
 
63
- def _is_schema_suggestion(message):
64
- normalized = sql_tools.normalize_text(message)
65
- patterns = (
66
- "preciso de uma tabela",
67
- "preciso de um schema",
68
- "quero uma tabela",
69
- "quero um schema",
70
- "sugira uma tabela",
71
- "sugerir uma tabela",
72
- "tabela sobre",
73
- "tabela de",
74
- "schema sobre",
75
- "schema de",
76
- "modelo de tabela",
77
- "modelar",
78
- )
79
- if any(pattern in normalized for pattern in patterns) and not sql_tools.is_create_table_intent(message):
80
- return True
81
- if "tabela" in normalized and any(term in normalized for term in ("sobre", "para", "de")):
82
- return not any(term in normalized for term in ("crie", "criar", "create", "generate", "gerar", "gere"))
83
- return False
84
-
85
-
86
  def classify_intent(message, state=None, chat_history=None):
87
  state = ConversationState.from_value(state)
88
  normalized = sql_tools.normalize_text(message)
89
  if not normalized:
90
  return IntentResult(UNKNOWN, 0.0, "empty_message")
91
 
92
- if _has_pending_schema(state) and _is_confirmation(message):
93
- return IntentResult(CREATE_TABLE_CONFIRM, 0.95, "confirmation_with_pending_schema")
94
-
95
  if _is_smalltalk(message):
96
  return IntentResult(SMALLTALK, 0.95, "smalltalk_phrase")
97
 
98
  edited_table = sql_tools.edit_create_table_from_message(message, chat_history, state.active_schema)
99
  if edited_table or sql_tools.is_table_edit_intent(message):
100
- return IntentResult(EDIT_TABLE, 0.9 if (edited_table or _has_active_schema(state) or _has_pending_schema(state)) else 0.7, "table_edit_terms")
 
 
 
 
101
 
102
  if sql_tools.is_create_table_intent(message):
103
  return IntentResult(CREATE_TABLE, 0.9, "explicit_create_table")
104
 
105
- if _is_schema_suggestion(message):
106
- return IntentResult(SCHEMA_SUGGESTION, 0.86, "schema_suggestion_phrase")
107
-
108
  if sql_tools.is_sql_intent(message, state.active_schema):
109
  return IntentResult(SQL_QUERY, 0.86, "sql_query_terms")
110
 
111
- if _has_pending_schema(state):
112
- return IntentResult(CLARIFICATION, 0.55, "pending_schema_context")
113
-
114
  return IntentResult(UNKNOWN, 0.25, "no_intent_match")
115
-
 
5
 
6
 
7
  SMALLTALK = "smalltalk"
 
8
  CREATE_TABLE = "create_table"
 
9
  EDIT_TABLE = "edit_table"
10
  SQL_QUERY = "sql_query"
 
11
  UNKNOWN = "unknown"
12
 
13
 
 
18
  reason: str
19
 
20
 
 
 
 
 
21
  def _has_active_schema(state):
22
  return bool((getattr(state, "active_schema", "") or "").strip())
23
 
24
 
 
 
 
 
 
 
 
 
 
 
25
  def _is_smalltalk(message):
26
  normalized = sql_tools.normalize_text(message)
27
  exact = {
28
  "oi", "ola", "hi", "hello", "hey", "bom dia", "boa tarde", "boa noite",
29
  "obrigado", "obrigada", "valeu", "thanks", "thank you",
30
+ "tudo bem", "tudo bom", "tudo", "tchau", "xau", "ate mais", "ate logo",
31
+ "de nada", "por nada", "imagina",
32
  "como voce esta", "como voce esta hoje", "qual seu nome",
33
  "me conte uma piada", "conte uma piada", "vamos conversar",
34
  "o que voce faz", "como voce funciona", "como funciona",
 
41
  "conte uma piada",
42
  "vamos conversar",
43
  "obrigado",
44
+ "tudo bem",
45
+ "tudo bom",
46
  )
47
  return any(fragment in normalized for fragment in smalltalk_fragments)
48
 
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  def classify_intent(message, state=None, chat_history=None):
51
  state = ConversationState.from_value(state)
52
  normalized = sql_tools.normalize_text(message)
53
  if not normalized:
54
  return IntentResult(UNKNOWN, 0.0, "empty_message")
55
 
 
 
 
56
  if _is_smalltalk(message):
57
  return IntentResult(SMALLTALK, 0.95, "smalltalk_phrase")
58
 
59
  edited_table = sql_tools.edit_create_table_from_message(message, chat_history, state.active_schema)
60
  if edited_table or sql_tools.is_table_edit_intent(message):
61
+ return IntentResult(
62
+ EDIT_TABLE,
63
+ 0.9 if (edited_table or _has_active_schema(state)) else 0.7,
64
+ "table_edit_terms",
65
+ )
66
 
67
  if sql_tools.is_create_table_intent(message):
68
  return IntentResult(CREATE_TABLE, 0.9, "explicit_create_table")
69
 
 
 
 
70
  if sql_tools.is_sql_intent(message, state.active_schema):
71
  return IntentResult(SQL_QUERY, 0.86, "sql_query_terms")
72
 
 
 
 
73
  return IntentResult(UNKNOWN, 0.25, "no_intent_match")
 
model_io.py CHANGED
@@ -1,41 +1,12 @@
1
- import json
2
- import re
3
-
4
- from chat_state import ConversationState, SchemaSuggestion
5
  import sql_tools
6
 
7
 
8
- CHAT_GENERATION = "chat"
9
- SCHEMA_GENERATION = "schema"
10
  SQL_GENERATION = "sql"
11
 
12
  GENERATION_BUDGETS = {
13
- CHAT_GENERATION: 120,
14
- SCHEMA_GENERATION: 180,
15
  SQL_GENERATION: 96,
16
  }
17
 
18
- CHAT_PROMPT_TEMPLATE = (
19
- "<|user|>\n"
20
- "You are a conversational SQL assistant. Reply naturally in Brazilian Portuguese unless the user writes in English.\n"
21
- "You can chat normally, discuss table ideas, and help generate SQL, but do not generate SQL unless the user asks for it.\n"
22
- "Current state:\n{state_summary}\n\n"
23
- "{history_context}"
24
- "User message: {message}<|end|>\n"
25
- "<|assistant|>"
26
- )
27
-
28
- SCHEMA_SUGGESTION_PROMPT_TEMPLATE = (
29
- "<|user|>\n"
30
- "Create a practical SQL table proposal for the user's domain request.\n"
31
- "Return JSON only with this shape: "
32
- '{{"table_name":"name","columns":[{{"name":"id","type":"INTEGER"}}],"rationale":"short reason"}}.\n'
33
- "Use simple SQL types: INTEGER, TEXT, NUMERIC, DATE, BOOLEAN.\n"
34
- "{history_context}"
35
- "Request: {message}<|end|>\n"
36
- "<|assistant|>"
37
- )
38
-
39
  SQL_PROMPT_TEMPLATE = (
40
  "<|user|>\n"
41
  "Given the following SQL table, write one SQL query. Output SQL only.\n\n"
@@ -63,28 +34,6 @@ def _history_context(chat_history, max_exchanges=3):
63
  return "Previous conversation:\n" + "\n".join(lines) + "\n\n"
64
 
65
 
66
- def _state_summary(state):
67
- state = ConversationState.from_value(state)
68
- pending = state.pending_schema_suggestion.table_name if state.pending_schema_suggestion else "none"
69
- active = "present" if state.active_schema else "none"
70
- return f"- active_schema: {active}\n- pending_schema_suggestion: {pending}"
71
-
72
-
73
- def build_chat_prompt(message, state=None, chat_history=None):
74
- return CHAT_PROMPT_TEMPLATE.format(
75
- message=(message or "").strip(),
76
- state_summary=_state_summary(state),
77
- history_context=_history_context(chat_history),
78
- )
79
-
80
-
81
- def build_schema_suggestion_prompt(message, state=None, chat_history=None):
82
- return SCHEMA_SUGGESTION_PROMPT_TEMPLATE.format(
83
- message=(message or "").strip(),
84
- history_context=_history_context(chat_history),
85
- )
86
-
87
-
88
  def build_sql_prompt(schema, message, chat_history=None):
89
  table_schema = (schema or "").strip() or "CREATE TABLE unknown (id INTEGER)"
90
  return SQL_PROMPT_TEMPLATE.format(
@@ -119,14 +68,3 @@ def format_generation_result(text):
119
  if is_sql_like(cleaned):
120
  return str(cleaned), "", sql_tools.validate_sql(cleaned)
121
  return "", str(cleaned), '<span class="validator-badge validator-empty">Chat response</span>'
122
-
123
-
124
- def parse_schema_suggestion(text):
125
- cleaned = clean_generation(text)
126
- match = re.search(r"\{.*\}", cleaned, flags=re.DOTALL)
127
- raw_json = match.group(0) if match else cleaned
128
- try:
129
- payload = json.loads(raw_json)
130
- except json.JSONDecodeError:
131
- return None
132
- return SchemaSuggestion.from_value(payload)
 
 
 
 
 
1
  import sql_tools
2
 
3
 
 
 
4
  SQL_GENERATION = "sql"
5
 
6
  GENERATION_BUDGETS = {
 
 
7
  SQL_GENERATION: 96,
8
  }
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  SQL_PROMPT_TEMPLATE = (
11
  "<|user|>\n"
12
  "Given the following SQL table, write one SQL query. Output SQL only.\n\n"
 
34
  return "Previous conversation:\n" + "\n".join(lines) + "\n\n"
35
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  def build_sql_prompt(schema, message, chat_history=None):
38
  table_schema = (schema or "").strip() or "CREATE TABLE unknown (id INTEGER)"
39
  return SQL_PROMPT_TEMPLATE.format(
 
68
  if is_sql_like(cleaned):
69
  return str(cleaned), "", sql_tools.validate_sql(cleaned)
70
  return "", str(cleaned), '<span class="validator-badge validator-empty">Chat response</span>'
 
 
 
 
 
 
 
 
 
 
 
scripts/model_probe.py CHANGED
@@ -20,7 +20,6 @@ def _scenario(name, message, history, active_schema, state):
20
  history,
21
  active_schema,
22
  app.FINE_TUNED_MODEL_KEY,
23
- None,
24
  state,
25
  )
26
  return {
@@ -28,9 +27,9 @@ def _scenario(name, message, history, active_schema, state):
28
  "message": message,
29
  "assistant": _assistant_text(result),
30
  "sql": result[4],
31
- "status": result[7],
32
  "active_schema": result[2],
33
- "state": result[8],
34
  "history": result[0],
35
  }
36
 
@@ -45,19 +44,14 @@ def _grade(records):
45
  by_name = {record["name"]: record for record in records}
46
 
47
  checks.append({
48
- "name": "smalltalk_is_conversational",
49
- "pass": bool(by_name["greeting"]["assistant"]) and not by_name["greeting"]["sql"],
50
- "detail": "Greeting should produce chat text and no SQL.",
51
  })
52
  checks.append({
53
- "name": "schema_suggestion_sets_pending",
54
- "pass": bool((by_name["schema_request"]["state"] or {}).get("pending_schema_suggestion")),
55
- "detail": "Domain table request should create a pending schema proposal.",
56
- })
57
- checks.append({
58
- "name": "confirmation_generates_create_table",
59
- "pass": "CREATE TABLE" in (by_name["confirm_generate"]["sql"] or "").upper(),
60
- "detail": "Confirmation should generate CREATE TABLE SQL.",
61
  })
62
  checks.append({
63
  "name": "edit_updates_schema",
@@ -71,8 +65,8 @@ def _grade(records):
71
  })
72
  checks.append({
73
  "name": "smalltalk_with_schema_stays_chat",
74
- "pass": bool(by_name["smalltalk_with_schema"]["assistant"]) and not by_name["smalltalk_with_schema"]["sql"],
75
- "detail": "Smalltalk with active schema should not become SQL.",
76
  })
77
  return checks
78
 
@@ -87,8 +81,7 @@ def main():
87
 
88
  for name, message in [
89
  ("greeting", "oi"),
90
- ("schema_request", "preciso de uma tabela sobre zoologico"),
91
- ("confirm_generate", "gera"),
92
  ("edit_schema", "troca capacidade por numero_animais"),
93
  ("query_schema", "liste zoologicos de Sao Paulo"),
94
  ("smalltalk_with_schema", "como voce esta hoje?"),
@@ -112,4 +105,3 @@ def main():
112
 
113
  if __name__ == "__main__":
114
  raise SystemExit(main())
115
-
 
20
  history,
21
  active_schema,
22
  app.FINE_TUNED_MODEL_KEY,
 
23
  state,
24
  )
25
  return {
 
27
  "message": message,
28
  "assistant": _assistant_text(result),
29
  "sql": result[4],
30
+ "status": result[6],
31
  "active_schema": result[2],
32
+ "state": result[7],
33
  "history": result[0],
34
  }
35
 
 
44
  by_name = {record["name"]: record for record in records}
45
 
46
  checks.append({
47
+ "name": "smalltalk_is_static_fallback",
48
+ "pass": app.FALLBACK_RESPONSE in by_name["greeting"]["assistant"] and not by_name["greeting"]["sql"],
49
+ "detail": "Greeting should use the static fallback and no SQL.",
50
  })
51
  checks.append({
52
+ "name": "create_table_is_deterministic",
53
+ "pass": "CREATE TABLE ZOOLOGICO" in (by_name["create_schema"]["sql"] or "").upper(),
54
+ "detail": "Explicit CREATE TABLE request should not need model generation.",
 
 
 
 
 
55
  })
56
  checks.append({
57
  "name": "edit_updates_schema",
 
65
  })
66
  checks.append({
67
  "name": "smalltalk_with_schema_stays_chat",
68
+ "pass": app.FALLBACK_RESPONSE in by_name["smalltalk_with_schema"]["assistant"] and not by_name["smalltalk_with_schema"]["sql"],
69
+ "detail": "Smalltalk with active schema should still use static fallback.",
70
  })
71
  return checks
72
 
 
81
 
82
  for name, message in [
83
  ("greeting", "oi"),
84
+ ("create_schema", "crie tabela zoologico com id nome cidade capacidade"),
 
85
  ("edit_schema", "troca capacidade por numero_animais"),
86
  ("query_schema", "liste zoologicos de Sao Paulo"),
87
  ("smalltalk_with_schema", "como voce esta hoje?"),
 
105
 
106
  if __name__ == "__main__":
107
  raise SystemExit(main())
 
sql_tools.py CHANGED
@@ -118,32 +118,37 @@ def validate_sql(sql_text):
118
  return '<span class="validator-badge validator-ok">Valid SQL</span>'
119
 
120
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  def is_create_table_intent(message):
122
  message = (message or "").strip().lower()
123
  return bool(
124
- re.search(r"\b(create|make|build|generate|criar|crie|cria|criando|gerar|gere|gera|gerando|faz|faça|fazendo|monta|montar|monte)\b", message)
125
  and re.search(r"\b(table|schema|tabela)\b", message)
126
  )
127
 
128
 
129
  def is_rename_intent(message):
130
- message = (message or "").strip().lower()
131
- return bool(
132
- re.search(
133
- r"\b(rename|edit|change|renomeie|renomear|renomeia|altere|mude|muda|troca|trocar)\s+\w+\s+(to|para|as|como|por)\s+\w+",
134
- message,
135
- flags=re.IGNORECASE,
136
- )
137
- )
138
 
139
 
140
  def is_table_edit_intent(message):
141
  message = (message or "").strip().lower()
142
- edit_terms = r"\b(edit|update|modify|alter|add|include|remove|delete|drop|edita|editar|altera|altere|alterar|mude|mudar|adicione|adicionar|inclua|incluir|acrescente|remova|remover|deletar|exclua|excluir|novo|nova|troca|trocar|coloque|colocar)\b"
143
- direct_add_terms = r"\b(add|include|adicione|adicionar|adicionando|inclua|incluir|acrescente|coloque|colocar)\b"
144
- direct_remove_terms = r"\b(remove|delete|drop|remova|remover|deletar|exclua|excluir)\b"
145
  target_terms = r"\b(column|field|element|coluna|campo|elemento|item)\b"
146
- sql_aggregation_terms = {"up", "sum", "total", "count", "average", "avg", "max", "min", "by"}
147
  add_match = re.search(direct_add_terms, message)
148
  if add_match:
149
  after_add = message[add_match.start() + len(add_match.group()) :].strip()
@@ -278,8 +283,8 @@ def format_create_table(table_name, columns):
278
  def create_table_from_message(message):
279
  message = (message or "").strip()
280
  patterns = (
281
- r"\b(?:table|tabela)\s+(?:called\s+|named\s+|chamada?\s+|nomeada?\s+)?([A-Za-z_][\w]*)\s+(?:with|containing|including|com)\s+(.+)$",
282
- r"\b(?:create|make|build|generate|criar|crie|gerar|gere)\b.*?\b(?:table|tabela)\b\s+([A-Za-z_][\w]*)\s+(?:with|containing|including|com)\s+(.+)$",
283
  )
284
  for pattern in patterns:
285
  match = re.search(pattern, message, flags=re.IGNORECASE)
@@ -361,7 +366,7 @@ def extract_added_columns(message):
361
  def extract_removed_columns(message):
362
  message = (message or "").strip()
363
  patterns = (
364
- r"\b(?:remove|delete|drop|remova|remover|deletar|exclua|excluir)\b\s+(?:a\s+|o\s+|the\s+)?(?:column|field|element|coluna|campo|elemento|item)?\s*(.+)$",
365
  )
366
  for pattern in patterns:
367
  match = re.search(pattern, message, flags=re.IGNORECASE)
@@ -376,15 +381,20 @@ def extract_removed_columns(message):
376
 
377
  def extract_renamed_columns(message):
378
  pattern = (
379
- r"\b(?:rename|edit|change|renomeie|renomear|renomeia|altere|mude)\s+"
380
  r"(\w+)\s+(?:to|para|as|como|por)\s+(\w+)"
381
  )
382
  matches = re.findall(pattern, message or "", flags=re.IGNORECASE)
383
  troca_matches = re.findall(r"\btroca\b\s+(\w+)\s+\bpor\b\s+(\w+)", message or "", flags=re.IGNORECASE)
 
 
384
  return [
385
  (normalize_identifier(old), normalize_identifier(new))
386
  for old, new in [*matches, *troca_matches]
387
- if normalize_identifier(old) and normalize_identifier(new)
 
 
 
388
  ]
389
 
390
 
@@ -392,8 +402,12 @@ def parse_compound_edit(message):
392
  segment_pattern = (
393
  r"\s+(?:and|e)\s+"
394
  r"(?=\b(?:add|include|remove|delete|drop|rename|edit|change|"
395
- r"adicione|adicionar|inclua|acrescente|remova|remover|deletar|"
396
- r"exclua|renomeie|renomear|renomeia|altere|mude|troca|trocar)\b)"
 
 
 
 
397
  )
398
  segments = re.split(segment_pattern, message or "", flags=re.IGNORECASE)
399
  added, removed, renamed = [], [], []
@@ -451,4 +465,3 @@ def create_table_from_suggestion(suggestion):
451
  if identifier:
452
  parsed.append((identifier, (column_type or "TEXT").upper()))
453
  return format_create_table(normalize_identifier(table_name), parsed)
454
-
 
118
  return '<span class="validator-badge validator-ok">Valid SQL</span>'
119
 
120
 
121
+ # Verb stems for create-table intent — keep in sync with create_table_from_message below.
122
+ _CREATE_VERBS = (
123
+ r"create|make|build|generate"
124
+ r"|criar|crie|cria|criando"
125
+ r"|gerar|gere|gera|gerando"
126
+ r"|faz|faca|fa\u00e7a|fazendo"
127
+ r"|monta|montar|monte"
128
+ r"|construa|construir|constroi"
129
+ r"|elabore|elaborar|elabora"
130
+ )
131
+
132
+
133
  def is_create_table_intent(message):
134
  message = (message or "").strip().lower()
135
  return bool(
136
+ re.search(rf"\b({_CREATE_VERBS})\b", message)
137
  and re.search(r"\b(table|schema|tabela)\b", message)
138
  )
139
 
140
 
141
  def is_rename_intent(message):
142
+ return bool(extract_renamed_columns(message))
 
 
 
 
 
 
 
143
 
144
 
145
  def is_table_edit_intent(message):
146
  message = (message or "").strip().lower()
147
+ edit_terms = r"\b(edit|update|modify|alter|add|include|remove|delete|drop|edita|editar|altera|altere|alterar|mude|mudar|adicione|adicionar|inclua|incluir|acrescente|remova|remover|deletar|exclua|excluir|exclui|novo|nova|troca|trocar|coloque|colocar|coloca|insira|insere|bota|tira|retire|retira|apaga|apague)\b"
148
+ direct_add_terms = r"\b(add|include|adicione|adicionar|adicionando|inclua|incluir|acrescente|coloque|colocar|coloca|acrescenta|insere|inserir|insira|bota|botar|bote)\b"
149
+ direct_remove_terms = r"\b(remove|delete|drop|remova|remover|deletar|exclua|excluir|exclui|tira|tirar|tire|retira|retirar|retire|apaga|apagar|apague)\b"
150
  target_terms = r"\b(column|field|element|coluna|campo|elemento|item)\b"
151
+ sql_aggregation_terms = {"up", "sum", "total", "count", "average", "avg", "max", "min", "by", "soma", "media", "contagem", "maximo", "minimo"}
152
  add_match = re.search(direct_add_terms, message)
153
  if add_match:
154
  after_add = message[add_match.start() + len(add_match.group()) :].strip()
 
283
  def create_table_from_message(message):
284
  message = (message or "").strip()
285
  patterns = (
286
+ r"\b(?:table|tabela)\s+(?:called\s+|named\s+|chamada?\s+|nomeada?\s+)?([A-Za-z_][\w]*)\s+(?:with|containing|including|com|tendo|contendo)\s+(.+)$",
287
+ rf"\b(?:{_CREATE_VERBS})\b.*?\b(?:table|tabela)\b\s+([A-Za-z_][\w]*)\s+(?:with|containing|including|com|tendo|contendo)\s+(.+)$",
288
  )
289
  for pattern in patterns:
290
  match = re.search(pattern, message, flags=re.IGNORECASE)
 
366
  def extract_removed_columns(message):
367
  message = (message or "").strip()
368
  patterns = (
369
+ r"\b(?:remove|delete|drop|remova|remover|deletar|exclua|excluir|exclui|tira|tirar|tire|retira|retirar|retire|apaga|apagar|apague)\b\s+(?:a\s+|o\s+|the\s+)?(?:column|field|element|coluna|campo|elemento|item)?\s*(.+)$",
370
  )
371
  for pattern in patterns:
372
  match = re.search(pattern, message, flags=re.IGNORECASE)
 
381
 
382
  def extract_renamed_columns(message):
383
  pattern = (
384
+ r"\b(?:rename|edit|change|renomeie|renomear|renomeia|altere|alterar|altera|mude|mudar|muda|edita|editar)\s+"
385
  r"(\w+)\s+(?:to|para|as|como|por)\s+(\w+)"
386
  )
387
  matches = re.findall(pattern, message or "", flags=re.IGNORECASE)
388
  troca_matches = re.findall(r"\btroca\b\s+(\w+)\s+\bpor\b\s+(\w+)", message or "", flags=re.IGNORECASE)
389
+ invalid_old_names = {"ela", "ele", "isso", "isto", "essa", "esse", "this", "it"}
390
+ invalid_new_names = {"ter", "have", "having", "tambem", "tambem"}
391
  return [
392
  (normalize_identifier(old), normalize_identifier(new))
393
  for old, new in [*matches, *troca_matches]
394
+ if normalize_identifier(old)
395
+ and normalize_identifier(new)
396
+ and normalize_identifier(old) not in invalid_old_names
397
+ and normalize_identifier(new) not in invalid_new_names
398
  ]
399
 
400
 
 
402
  segment_pattern = (
403
  r"\s+(?:and|e)\s+"
404
  r"(?=\b(?:add|include|remove|delete|drop|rename|edit|change|"
405
+ r"adicione|adicionar|adicionando|inclua|incluir|acrescente|"
406
+ r"coloca|coloque|bota|insira|insere|"
407
+ r"remova|remover|deletar|exclua|excluir|exclui|"
408
+ r"tira|tirar|tire|retira|retire|apaga|apague|"
409
+ r"renomeie|renomear|renomeia|altere|alterar|altera|"
410
+ r"mude|mudar|muda|edita|editar|troca|trocar)\b)"
411
  )
412
  segments = re.split(segment_pattern, message or "", flags=re.IGNORECASE)
413
  added, removed, renamed = [], [], []
 
465
  if identifier:
466
  parsed.append((identifier, (column_type or "TEXT").upper()))
467
  return format_create_table(normalize_identifier(table_name), parsed)
 
tests/e2e_flow_test.py CHANGED
@@ -17,7 +17,7 @@ def sql_out(result):
17
  return result[4]
18
 
19
  def status(result):
20
- return result[7]
21
 
22
  def reset_model_state():
23
  app._model = None
@@ -247,4 +247,4 @@ if __name__ == "__main__":
247
  print("Model not loaded. Call app.load_model(app.FINE_TUNED_MODEL_ID) then re-run.")
248
  print("From python: python -c \"import app; app.load_model(app.FINE_TUNED_MODEL_ID); exec(open('tests/e2e_flow_test.py').read())\"")
249
  else:
250
- run_all()
 
17
  return result[4]
18
 
19
  def status(result):
20
+ return result[6]
21
 
22
  def reset_model_state():
23
  app._model = None
 
247
  print("Model not loaded. Call app.load_model(app.FINE_TUNED_MODEL_ID) then re-run.")
248
  print("From python: python -c \"import app; app.load_model(app.FINE_TUNED_MODEL_ID); exec(open('tests/e2e_flow_test.py').read())\"")
249
  else:
250
+ run_all()
tests/test_chatbot_behavior.py CHANGED
@@ -24,7 +24,7 @@ def sql_output(result):
24
 
25
 
26
  def status_html(result):
27
- return result[7]
28
 
29
 
30
  @pytest.fixture(autouse=True)
@@ -436,7 +436,7 @@ def test_model_id_mismatch_returns_inconsistency_error():
436
  generation_config=types.SimpleNamespace(eos_token_id=0)
437
  )
438
  app._tokenizer = object()
439
- app._current_model_id = app.BASE_MODEL_ID
440
 
441
  try:
442
  result = app.generate_response(
@@ -493,7 +493,7 @@ def test_off_topic_message_returns_fallback(monkeypatch):
493
  result = app.generate_response("me conte uma piada", [], "", None, None)
494
 
495
  assert sql_output(result) == ""
496
- assert "load the fine-tuned model" in assistant_text(result).lower()
497
 
498
 
499
  def test_greeting_returns_fallback(monkeypatch):
@@ -502,6 +502,7 @@ def test_greeting_returns_fallback(monkeypatch):
502
  result = app.generate_response("oi", [], "", None, None)
503
 
504
  assert sql_output(result) == ""
 
505
 
506
 
507
  # ---------------------------------------------------------------------------
 
24
 
25
 
26
  def status_html(result):
27
+ return result[6]
28
 
29
 
30
  @pytest.fixture(autouse=True)
 
436
  generation_config=types.SimpleNamespace(eos_token_id=0)
437
  )
438
  app._tokenizer = object()
439
+ app._current_model_id = "microsoft/Phi-3-mini-4k-instruct"
440
 
441
  try:
442
  result = app.generate_response(
 
493
  result = app.generate_response("me conte uma piada", [], "", None, None)
494
 
495
  assert sql_output(result) == ""
496
+ assert app.FALLBACK_RESPONSE in assistant_text(result)
497
 
498
 
499
  def test_greeting_returns_fallback(monkeypatch):
 
502
  result = app.generate_response("oi", [], "", None, None)
503
 
504
  assert sql_output(result) == ""
505
+ assert app.FALLBACK_RESPONSE in assistant_text(result)
506
 
507
 
508
  # ---------------------------------------------------------------------------
tests/test_chatbot_core.py CHANGED
@@ -1,32 +1,22 @@
1
  import types
2
 
3
  import app
4
- from chat_state import ConversationState, SchemaSuggestion
5
- from intent import (
6
- CREATE_TABLE_CONFIRM,
7
- EDIT_TABLE,
8
- SCHEMA_SUGGESTION,
9
- SMALLTALK,
10
- SQL_QUERY,
11
- classify_intent,
12
- )
13
 
14
 
15
  def test_conversation_state_roundtrip_dict():
16
- suggestion = SchemaSuggestion(
17
- table_name="zoologico",
18
- columns=(("id", "INTEGER"), ("nome", "TEXT")),
19
- rationale="base",
20
  )
21
- state = ConversationState(active_schema="", pending_schema_suggestion=suggestion, last_intent=SCHEMA_SUGGESTION)
22
 
23
  restored = ConversationState.from_value(state.to_dict())
24
 
25
- pending = restored.pending_schema_suggestion
26
- assert pending is not None
27
- assert pending.table_name == "zoologico"
28
- assert pending.columns == (("id", "INTEGER"), ("nome", "TEXT"))
29
- assert restored.last_intent == SCHEMA_SUGGESTION
30
 
31
 
32
  def test_intent_smalltalk_with_active_schema_is_not_sql():
@@ -37,69 +27,66 @@ def test_intent_smalltalk_with_active_schema_is_not_sql():
37
  assert result.intent == SMALLTALK
38
 
39
 
40
- def test_intent_schema_suggestion_and_confirmation():
41
- state = ConversationState()
42
 
43
- suggestion = classify_intent("preciso de uma tabela sobre zoologico", state)
44
- pending = state.with_pending_schema(
45
- SchemaSuggestion(table_name="zoologico", columns=(("id", "INTEGER"), ("nome", "TEXT")))
46
- )
47
- confirmation = classify_intent("gera", pending)
48
-
49
- assert suggestion.intent == SCHEMA_SUGGESTION
50
- assert confirmation.intent == CREATE_TABLE_CONFIRM
51
 
52
 
53
- def test_intent_edit_and_sql_query():
54
- state = ConversationState(active_schema="CREATE TABLE zoologico (id INTEGER, cidade TEXT)")
 
55
 
56
- edit = classify_intent("troca cidade por municipio", state)
57
- query = classify_intent("liste zoologicos por municipio", state)
 
58
 
 
59
  assert edit.intent == EDIT_TABLE
60
  assert query.intent == SQL_QUERY
61
 
62
 
63
- def test_zoologico_transcript_with_mocked_model(monkeypatch):
64
  app._model = types.SimpleNamespace(generation_config=types.SimpleNamespace(eos_token_id=0))
65
  app._tokenizer = types.SimpleNamespace(eos_token_id=0, pad_token_id=0)
66
  app._current_model_id = app.FINE_TUNED_MODEL_ID
67
 
68
  def fake_generate(prompt, generation_kind):
69
- if generation_kind == app.model_core.CHAT_GENERATION:
70
- return "Oi, posso ajudar com conversa comum ou SQL.", 1
71
- if generation_kind == app.model_core.SCHEMA_GENERATION:
72
- return (
73
- '{"table_name":"zoologico","columns":['
74
- '{"name":"id","type":"INTEGER"},'
75
- '{"name":"nome","type":"TEXT"},'
76
- '{"name":"cidade","type":"TEXT"},'
77
- '{"name":"capacidade","type":"INTEGER"}],'
78
- '"rationale":"Tabela inicial para zoologicos."}',
79
- 1,
80
- )
81
  return "SELECT * FROM zoologico WHERE cidade = 'Sao Paulo';", 1
82
 
83
  monkeypatch.setattr(app, "_generate_model_text", fake_generate)
84
 
85
- r1 = app.generate_response("oi", [], "", app.FINE_TUNED_MODEL_KEY, None)
86
- assert app.EMPTY_CHAT_OUTPUT == ""
87
  assert r1[4] == ""
88
- assert "Oi" in r1[0][-1]["content"]
89
-
90
- r2 = app.generate_response("preciso de uma tabela sobre zoologico", r1[0], r1[2], app.FINE_TUNED_MODEL_KEY, None, r1[8])
91
- assert r2[4] == ""
92
- assert r2[8]["pending_schema_suggestion"] is not None
93
- assert r2[8]["pending_schema_suggestion"]["table_name"] == "zoologico"
94
-
95
- r3 = app.generate_response("gera", r2[0], r2[2], app.FINE_TUNED_MODEL_KEY, None, r2[8])
96
- assert "CREATE TABLE zoologico" in r3[4]
97
- assert "CREATE TABLE zoologico" in r3[2]
98
- assert r3[8]["pending_schema_suggestion"] is None
99
-
100
- r4 = app.generate_response("troca capacidade por numero_animais", r3[0], r3[2], app.FINE_TUNED_MODEL_KEY, None, r3[8])
101
- assert "numero_animais INTEGER" in r4[4]
102
- assert "capacidade" not in r4[4]
103
-
104
- r5 = app.generate_response("liste zoologicos de Sao Paulo", r4[0], r4[2], app.FINE_TUNED_MODEL_KEY, None, r4[8])
105
- assert "SELECT * FROM zoologico" in r5[4]
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import types
2
 
3
  import app
4
+ from chat_state import ConversationState
5
+ from intent import CREATE_TABLE, EDIT_TABLE, SMALLTALK, SQL_QUERY, UNKNOWN, classify_intent
 
 
 
 
 
 
 
6
 
7
 
8
  def test_conversation_state_roundtrip_dict():
9
+ state = ConversationState(
10
+ active_schema="CREATE TABLE zoologico (id INTEGER)",
11
+ last_intent=SQL_QUERY,
12
+ debug={"intent": SQL_QUERY, "confidence": 0.86, "reason": "sql_query_terms"},
13
  )
 
14
 
15
  restored = ConversationState.from_value(state.to_dict())
16
 
17
+ assert restored.active_schema == "CREATE TABLE zoologico (id INTEGER)"
18
+ assert restored.last_intent == SQL_QUERY
19
+ assert restored.debug["reason"] == "sql_query_terms"
 
 
20
 
21
 
22
  def test_intent_smalltalk_with_active_schema_is_not_sql():
 
27
  assert result.intent == SMALLTALK
28
 
29
 
30
+ def test_schema_request_without_columns_is_unknown_not_model_schema_task():
31
+ result = classify_intent("preciso de uma tabela sobre zoologico", ConversationState())
32
 
33
+ assert result.intent == UNKNOWN
 
 
 
 
 
 
 
34
 
35
 
36
+ def test_intent_create_edit_and_sql_query():
37
+ empty_state = ConversationState()
38
+ schema_state = ConversationState(active_schema="CREATE TABLE zoologico (id INTEGER, cidade TEXT)")
39
 
40
+ create = classify_intent("crie tabela zoologico com id nome cidade", empty_state)
41
+ edit = classify_intent("troca cidade por municipio", schema_state)
42
+ query = classify_intent("liste zoologicos por municipio", schema_state)
43
 
44
+ assert create.intent == CREATE_TABLE
45
  assert edit.intent == EDIT_TABLE
46
  assert query.intent == SQL_QUERY
47
 
48
 
49
+ def test_zoologico_transcript_with_mocked_sql_model(monkeypatch):
50
  app._model = types.SimpleNamespace(generation_config=types.SimpleNamespace(eos_token_id=0))
51
  app._tokenizer = types.SimpleNamespace(eos_token_id=0, pad_token_id=0)
52
  app._current_model_id = app.FINE_TUNED_MODEL_ID
53
 
54
  def fake_generate(prompt, generation_kind):
55
+ assert generation_kind == app.model_core.SQL_GENERATION
56
+ assert "CREATE TABLE zoologico" in prompt
 
 
 
 
 
 
 
 
 
 
57
  return "SELECT * FROM zoologico WHERE cidade = 'Sao Paulo';", 1
58
 
59
  monkeypatch.setattr(app, "_generate_model_text", fake_generate)
60
 
61
+ r1 = app.generate_response("oi", [], "", app.FINE_TUNED_MODEL_KEY)
 
62
  assert r1[4] == ""
63
+ assert app.FALLBACK_RESPONSE in r1[0][-1]["content"]
64
+
65
+ r2 = app.generate_response(
66
+ "crie tabela zoologico com id nome cidade capacidade",
67
+ r1[0],
68
+ r1[2],
69
+ app.FINE_TUNED_MODEL_KEY,
70
+ r1[7],
71
+ )
72
+ assert "CREATE TABLE zoologico" in r2[4]
73
+ assert "CREATE TABLE zoologico" in r2[2]
74
+
75
+ r3 = app.generate_response(
76
+ "troca capacidade por numero_animais",
77
+ r2[0],
78
+ r2[2],
79
+ app.FINE_TUNED_MODEL_KEY,
80
+ r2[7],
81
+ )
82
+ assert "numero_animais TEXT" in r3[4]
83
+ assert "capacidade" not in r3[4]
84
+
85
+ r4 = app.generate_response(
86
+ "liste zoologicos de Sao Paulo",
87
+ r3[0],
88
+ r3[2],
89
+ app.FINE_TUNED_MODEL_KEY,
90
+ r3[7],
91
+ )
92
+ assert "SELECT * FROM zoologico" in r4[4]