Shizu0n commited on
Commit
bc39556
·
1 Parent(s): 03cc0b0

chore: double call fix and build_generation_prompt history injection bug fix

Browse files
Files changed (6) hide show
  1. .gitignore +5 -0
  2. README.md +12 -16
  3. app.py +1094 -122
  4. requirements.txt +1 -1
  5. tests/e2e_flow_test.py +250 -0
  6. tests/test_chatbot_behavior.py +672 -0
.gitignore CHANGED
@@ -48,3 +48,8 @@ logs/
48
  # AI-generated code artifacts
49
  *.gen.py
50
  .claude
 
 
 
 
 
 
48
  # AI-generated code artifacts
49
  *.gen.py
50
  .claude
51
+
52
+ # Local agent/workspace notes
53
+ /AGENTS.md
54
+ /CLAUDE.md
55
+ /PROGRESS.md
README.md CHANGED
@@ -13,34 +13,30 @@ short_description: "SQL generator powered by Phi-3 Mini fine-tuning"
13
 
14
  # Phi-3 Mini SQL Generator
15
 
16
- Generates SQL queries from a table schema and a natural-language question, comparing the base Phi-3 Mini model with a fine-tuned text-to-SQL version.
17
 
18
  ## What the App Does
19
 
20
- Transforms simple table descriptions and questions into SQL using Phi-3 Mini, with a choice between the base model and a QLoRA fine-tuned model.
21
 
22
  ## How to Use
23
 
24
- 1. Select a model by clicking the card or the selection button:
25
- - **Base Phi-3 Mini**: the non-fine-tuned baseline.
26
- - **Fine-tuned QLoRA model**: the main model, selected by default.
27
- 2. Click **Load selected model**.
28
  - Loading is lazy: the model is only downloaded and loaded when you request it.
29
  - On CPU, the first load can take a few minutes.
30
- 3. Enter or edit the **SQL table schema**.
31
  - You can use the presets: `employees`, `orders`, `students`, `products`, `sales`.
32
  - You can also write your own schema manually.
33
- 4. Enter the question in the **Question** field.
34
- 5. Click **Generate SQL**.
35
- 6. Review the result in `gr.Code(language="sql")`.
36
  - The app shows a validation badge powered by `sqlparse`.
37
- 7. Optional: click **Save for comparison** to compare the saved query with the current query.
38
 
39
  ## Models
40
 
41
  - Fine-tuned adapter: [Shizu0n/phi3-mini-sql-generator](https://huggingface.co/Shizu0n/phi3-mini-sql-generator)
42
  - Fine-tuned merged model used in the app: [Shizu0n/phi3-mini-sql-generator-merged](https://huggingface.co/Shizu0n/phi3-mini-sql-generator-merged)
43
- - Base comparison model: [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
44
 
45
  ## Metrics
46
 
@@ -53,10 +49,10 @@ Reported gain: **+71.5 percentage points** over the base model.
53
 
54
  ## Current Features
55
 
56
- - Gradio UI with a step-by-step flow: select model, load, enter schema/question, generate SQL, and compare outputs.
57
- - Clickable model cards in addition to the selection buttons.
58
- - Lazy loading with unloading of the previous model to reduce memory use.
59
- - Preserved Phi-3 patches: `rope_scaling`, `use_cache=False`, and `trust_remote_code`.
60
  - Schema presets without blocking manual input.
61
  - SQL output separated from errors/status so booleans, integers, and error messages do not appear inside the SQL block.
62
  - Centered loading overlay to make the loading state obvious.
 
13
 
14
  # Phi-3 Mini SQL Generator
15
 
16
+ Generates SQL queries from a table schema and a natural-language question using a QLoRA fine-tuned Phi-3 Mini model.
17
 
18
  ## What the App Does
19
 
20
+ Transforms simple table descriptions and questions into SQL using the fine-tuned Phi-3 Mini model. The base model is shown as offline evaluation evidence instead of a second live CPU-loaded model.
21
 
22
  ## How to Use
23
 
24
+ 1. Click **Load fine-tuned model**.
 
 
 
25
  - Loading is lazy: the model is only downloaded and loaded when you request it.
26
  - On CPU, the first load can take a few minutes.
27
+ 2. Enter or edit the **SQL table schema**.
28
  - You can use the presets: `employees`, `orders`, `students`, `products`, `sales`.
29
  - You can also write your own schema manually.
30
+ 3. Enter the question in the chat input.
31
+ 4. Click **Send**.
32
+ 5. Review the result in `gr.Code(language="sql")`.
33
  - The app shows a validation badge powered by `sqlparse`.
 
34
 
35
  ## Models
36
 
37
  - Fine-tuned adapter: [Shizu0n/phi3-mini-sql-generator](https://huggingface.co/Shizu0n/phi3-mini-sql-generator)
38
  - Fine-tuned merged model used in the app: [Shizu0n/phi3-mini-sql-generator-merged](https://huggingface.co/Shizu0n/phi3-mini-sql-generator-merged)
39
+ - Offline baseline model used for evaluation: [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
40
 
41
  ## Metrics
42
 
 
49
 
50
  ## Current Features
51
 
52
+ - Gradio UI with a step-by-step flow: load the fine-tuned model, enter schema/question, and generate SQL.
53
+ - Offline baseline metrics shown in the UI without loading a second 3.8B model on the CPU Space.
54
+ - Lazy loading to reduce startup cost.
55
+ - Preserved Phi-3 patches for local/Spaces compatibility.
56
  - Schema presets without blocking manual input.
57
  - SQL output separated from errors/status so booleans, integers, and error messages do not appear inside the SQL block.
58
  - Centered loading overlay to make the loading state obvious.
app.py CHANGED
@@ -1,9 +1,13 @@
 
1
  import gc
2
  import html
3
  import inspect
 
4
  import re
5
  import threading
6
  import time
 
 
7
 
8
  import gradio as gr
9
  import sqlparse
@@ -24,7 +28,7 @@ MODEL_CATALOG = {
24
  "title": "Phi-3 Mini base",
25
  "model_id": BASE_MODEL_ID,
26
  "exact_match": "2.0%",
27
- "trust_remote_code": True,
28
  "ready_text": "Base model ready",
29
  "metadata": (
30
  "Model: microsoft/Phi-3-mini-4k-instruct\n"
@@ -50,17 +54,12 @@ MODEL_CATALOG = {
50
  },
51
  }
52
 
53
- MODEL_OPTIONS = {
54
- MODEL_CATALOG[FINE_TUNED_MODEL_KEY]["label"]: FINE_TUNED_MODEL_ID,
55
- MODEL_CATALOG[BASE_MODEL_KEY]["label"]: BASE_MODEL_ID,
56
- }
57
-
58
  PRESETS = {
59
- "employees": "employees (id, name, department, salary)",
60
- "orders": "orders (id, customer_id, product, amount, date)",
61
- "students": "students (id, name, course, grade, year)",
62
- "products": "products (id, name, category, price, stock)",
63
- "sales": "sales (id, product_id, quantity, total, date)",
64
  }
65
 
66
  PROMPT_TEMPLATE = (
@@ -73,15 +72,18 @@ PROMPT_TEMPLATE = (
73
 
74
  GENERAL_PROMPT_TEMPLATE = (
75
  "<|user|>\n"
76
- "You are Phi-3 Mini in a SQL generator demo. Reply naturally and briefly. "
77
- "If the user asks for SQL, provide only the SQL query.\n\n"
78
- "User: {message}<|end|>\n"
79
  "<|assistant|>"
80
  )
81
 
82
  EMPTY_VALIDATOR = '<span class="validator-badge validator-empty">No SQL yet</span>'
83
  CHAT_VALIDATOR = '<span class="validator-badge validator-empty">Chat response</span>'
84
  EMPTY_CHAT_OUTPUT = ""
 
 
 
 
85
  LOAD_SCROLL_JS = """
86
  (selectedKey) => {
87
  setTimeout(() => {
@@ -98,6 +100,7 @@ _current_model_id = None
98
  _model = None
99
  _tokenizer = None
100
  _model_lock = threading.RLock()
 
101
 
102
 
103
  def import_model_runtime():
@@ -115,13 +118,144 @@ def import_model_runtime():
115
  return torch, AutoConfig, AutoModelForCausalLM, AutoTokenizer
116
 
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  def patch_phi3_config(config):
119
  if hasattr(config, "rope_scaling") and config.rope_scaling:
120
  rope_type = config.rope_scaling.get("rope_type", "longrope")
121
- if rope_type == "default":
122
- config.rope_scaling = None
123
- elif "type" not in config.rope_scaling:
124
  config.rope_scaling["type"] = rope_type
 
 
125
  return config
126
 
127
 
@@ -147,36 +281,77 @@ def unload_model():
147
 
148
  def load_model(model_id):
149
  global _current_model_id, _model, _tokenizer
150
- with _model_lock:
 
 
 
 
151
  if _current_model_id == model_id and _model is not None and _tokenizer is not None:
 
152
  return _model, _tokenizer
153
 
154
- _, AutoConfig, AutoModelForCausalLM, AutoTokenizer = import_model_runtime()
155
- unload_model()
156
- model_def = model_by_id(model_id)
157
- config = AutoConfig.from_pretrained(
158
- model_id,
159
- trust_remote_code=model_def["trust_remote_code"],
160
- )
161
- config = patch_phi3_config(config)
162
- tokenizer = AutoTokenizer.from_pretrained(
163
- model_id,
164
- trust_remote_code=model_def["trust_remote_code"],
165
- )
166
- model = AutoModelForCausalLM.from_pretrained(
167
- model_id,
168
- config=config,
169
- trust_remote_code=model_def["trust_remote_code"],
170
- device_map={"": "cpu"},
171
- torch_dtype="auto",
172
- low_cpu_mem_usage=True,
173
- )
174
- model.eval()
175
-
176
- _model = model
177
- _tokenizer = tokenizer
178
- _current_model_id = model_id
179
- return model, tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
 
182
  def model_by_key(model_key):
@@ -197,8 +372,39 @@ def model_key_by_id(model_id):
197
  return None
198
 
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  def clean_generation(text):
201
- cleaned = (text or "").strip()
202
  if cleaned.startswith("```"):
203
  lines = cleaned.splitlines()
204
  if lines and lines[0].strip().lower() in {"```", "```sql"}:
@@ -209,9 +415,19 @@ def clean_generation(text):
209
  for marker in ("<|end|>", "<|user|>", "<|assistant|>", "</s>"):
210
  if marker in cleaned:
211
  cleaned = cleaned.split(marker, 1)[0].strip()
 
 
212
  return cleaned
213
 
214
 
 
 
 
 
 
 
 
 
215
  def is_sql_like(text):
216
  text = (text or "").strip()
217
  if not text:
@@ -232,43 +448,121 @@ def is_sql_like(text):
232
 
233
 
234
  def is_sql_intent(message, schema):
235
- message = (message or "").strip().lower()
236
  schema = (schema or "").strip()
237
- if schema:
238
- return True
239
  if not message:
240
  return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  sql_terms = {
242
- "sql",
243
- "query",
244
- "select",
245
- "table",
246
- "schema",
247
  "database",
248
- "join",
 
249
  "group by",
 
 
250
  "order by",
251
- "where",
252
- "average",
253
- "count",
254
- "sum",
255
  "rows",
256
- "columns",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  }
258
- return any(term in message for term in sql_terms)
 
 
 
259
 
260
 
261
- def build_generation_prompt(schema, message):
262
  schema = (schema or "").strip()
263
  message = (message or "").strip()
264
  if is_sql_intent(message, schema):
265
- table_schema = schema or "No explicit schema provided. Infer the table and columns only if the request includes them."
266
- return PROMPT_TEMPLATE.format(schema=table_schema, question=message)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  return GENERAL_PROMPT_TEMPLATE.format(message=message)
268
 
269
 
270
  def format_generation_result(text):
271
- cleaned = clean_generation(text)
272
  if is_sql_like(cleaned):
273
  return str(cleaned), EMPTY_CHAT_OUTPUT, validate_sql(cleaned)
274
  return "", str(cleaned), CHAT_VALIDATOR
@@ -332,7 +626,7 @@ def render_model_card(model_key, selected_key):
332
  selected = model_key == selected_key
333
  state_class = " selected" if selected else ""
334
  return f"""
335
- <article class="model-card{state_class}" role="button" tabindex="0">
336
  <div class="model-tag">{model_def["tag"]}</div>
337
  <h3>{model_def["title"]}</h3>
338
  <code>{model_def["model_id"]}</code>
@@ -391,6 +685,34 @@ def model_metadata(model_key=None):
391
  """
392
 
393
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
  def schema_name_by_value(schema):
395
  schema = (schema or "").strip()
396
  for name, value in PRESETS.items():
@@ -399,6 +721,376 @@ def schema_name_by_value(schema):
399
  return "custom"
400
 
401
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
402
  def render_schema_context(schema=""):
403
  schema = (schema or "").strip()
404
  if not schema:
@@ -416,7 +1108,8 @@ def render_schema_context(schema=""):
416
 
417
  def query_control_updates(can_generate):
418
  context_updates = [gr.update(interactive=True) for _ in range(6)]
419
- return [*context_updates, gr.update(interactive=True), gr.update(interactive=can_generate)]
 
420
 
421
 
422
  def render_message(message="", kind="error"):
@@ -441,9 +1134,13 @@ def select_model(model_key, loaded_key):
441
  )
442
 
443
 
444
- def load_selected_model(selected_key):
445
- selected_key = selected_key if selected_key in MODEL_CATALOG else DEFAULT_MODEL_KEY
446
  model_def = model_by_key(selected_key)
 
 
 
 
447
  yield (
448
  None,
449
  render_status(selected_key, None, state="loading"),
@@ -459,15 +1156,28 @@ def load_selected_model(selected_key):
459
  )
460
  started = time.time()
461
  try:
462
- load_model(model_def["model_id"])
 
 
 
 
 
 
 
 
 
 
 
463
  except Exception as exc:
464
  error = f"Load failed for {model_def['model_id']}: {type(exc).__name__}: {exc}"
 
 
465
  yield (
466
  None,
467
  render_status(selected_key, None),
468
  render_loading_overlay(visible=False),
469
  model_metadata(selected_key),
470
- gr.update(interactive=True),
471
  *query_control_updates(False),
472
  "",
473
  EMPTY_VALIDATOR,
@@ -483,7 +1193,7 @@ def load_selected_model(selected_key):
483
  render_status(selected_key, selected_key),
484
  render_loading_overlay(visible=False),
485
  model_metadata(selected_key),
486
- gr.update(interactive=True),
487
  *query_control_updates(True),
488
  "",
489
  EMPTY_VALIDATOR,
@@ -529,11 +1239,91 @@ def render_compare_label(prefix, model_label, metric):
529
  )
530
 
531
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
532
  def generate_response(message, chat_history, active_schema, loaded_key, saved_state):
533
  message = (message or "").strip()
534
  active_schema = (active_schema or "").strip()
535
  chat_history = list(chat_history or [])
536
- if not loaded_key or _model is None or _tokenizer is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
537
  compare = comparison_updates(saved_state, "", loaded_key)
538
  return (
539
  chat_history,
@@ -543,20 +1333,62 @@ def generate_response(message, chat_history, active_schema, loaded_key, saved_st
543
  "",
544
  EMPTY_VALIDATOR,
545
  gr.update(interactive=False, visible=False),
546
- render_message("Load a model before generating SQL."),
547
  *compare,
548
  )
549
- if not message:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
550
  compare = comparison_updates(saved_state, "", loaded_key)
551
  return (
552
  chat_history,
 
 
 
553
  "",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
554
  active_schema,
555
  "",
556
  "",
557
  EMPTY_VALIDATOR,
558
  gr.update(interactive=False, visible=False),
559
- render_message("Type a message before sending."),
560
  *compare,
561
  )
562
 
@@ -577,20 +1409,34 @@ def generate_response(message, chat_history, active_schema, loaded_key, saved_st
577
 
578
  started = time.time()
579
  try:
580
- torch, _, _, _ = import_model_runtime()
581
  with _model_lock:
582
- prompt = build_generation_prompt(active_schema, message)
583
  inputs = _tokenizer(prompt, return_tensors="pt")
584
  input_length = inputs["input_ids"].shape[-1]
585
- with torch.no_grad():
586
- output_ids = _model.generate(
587
- **inputs,
588
- max_new_tokens=80,
589
- do_sample=False,
590
- use_cache=False,
591
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
592
  generated_ids = output_ids[0][input_length:]
593
- generated_text = _tokenizer.decode(generated_ids, skip_special_tokens=False)
594
  except Exception as exc:
595
  compare = comparison_updates(saved_state, "", loaded_key)
596
  return (
@@ -624,7 +1470,7 @@ def generate_response(message, chat_history, active_schema, loaded_key, saved_st
624
  message,
625
  str(sql_text),
626
  validator,
627
- gr.update(interactive=bool(sql_text.strip()), visible=bool(sql_text.strip())),
628
  render_message(f"Generated {response_kind} with {model_def['model_id']} in {elapsed}s.", kind="ok"),
629
  *compare,
630
  )
@@ -664,9 +1510,50 @@ def save_for_comparison(sql_text, loaded_key, active_schema, last_message):
664
  )
665
 
666
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
667
  CSS = """
668
  @import url('https://fonts.googleapis.com/css2?family=Space+Mono:wght@400;500;700&display=swap');
669
 
 
 
 
 
 
 
 
 
 
670
  :root {
671
  --bg-base: #0c0c0b;
672
  --bg-surface: #1a1a18;
@@ -757,19 +1644,19 @@ CSS = """
757
  .badge-green,
758
  .validator-ok {
759
  background: var(--teal-soft);
760
- color: var(--teal-text);
761
  }
762
 
763
  .badge-cream,
764
  .validator-warn {
765
  background: var(--amber-soft);
766
- color: var(--amber-text);
767
  }
768
 
769
  .badge-light,
770
  .validator-empty {
771
  background: var(--bg-raised);
772
- color: var(--text-secondary);
773
  border: 0.5px solid var(--border);
774
  }
775
 
@@ -805,29 +1692,24 @@ CSS = """
805
  background: var(--bg-surface);
806
  border: 0.5px solid var(--border);
807
  border-radius: 6px;
808
- cursor: pointer;
809
  min-height: 176px;
810
  padding: 16px;
811
  transition: border-color 160ms ease, background 160ms ease;
812
  }
813
 
814
- .model-card:hover {
815
- border-color: var(--border-hi);
816
- }
817
-
818
  .model-card.selected {
819
  border: 1.5px solid var(--teal);
820
  }
821
 
822
  .model-tag {
823
  background: var(--amber-soft);
824
- color: var(--amber-text);
825
  margin-bottom: 18px;
826
  }
827
 
828
  .model-card.selected .model-tag {
829
  background: var(--teal-soft);
830
- color: var(--teal-text);
831
  }
832
 
833
  .model-card h3 {
@@ -882,6 +1764,64 @@ CSS = """
882
  display: flex;
883
  }
884
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
885
  #load-button,
886
  #generate-button,
887
  #save-button {
@@ -906,6 +1846,11 @@ CSS = """
906
  width: 100% !important;
907
  }
908
 
 
 
 
 
 
909
  #load-button button:hover,
910
  #generate-button button:hover {
911
  background: var(--text-primary) !important;
@@ -969,7 +1914,7 @@ CSS = """
969
  }
970
 
971
  .stat-card strong {
972
- color: var(--text-primary);
973
  display: block;
974
  font-size: 15px;
975
  font-weight: 500;
@@ -978,7 +1923,7 @@ CSS = """
978
  }
979
 
980
  .stat-card span {
981
- color: var(--text-secondary);
982
  display: block;
983
  font-size: 11px;
984
  font-weight: 400;
@@ -1063,16 +2008,32 @@ CSS = """
1063
  }
1064
 
1065
  .composer-row {
1066
- align-items: stretch;
 
1067
  gap: 8px !important;
1068
  }
1069
 
 
 
 
 
 
 
1070
  #message-input {
1071
  flex: 1 1 auto;
1072
  }
1073
 
1074
  #message-input textarea {
1075
  min-height: 42px !important;
 
 
 
 
 
 
 
 
 
1076
  }
1077
 
1078
  #clear-schema-button button {
@@ -1178,7 +2139,7 @@ textarea {
1178
  }
1179
 
1180
  .validator-detail {
1181
- color: var(--text-secondary);
1182
  font-size: 11px;
1183
  margin-left: 8px;
1184
  }
@@ -1228,7 +2189,7 @@ textarea {
1228
  .compare-head {
1229
  align-items: center;
1230
  background: var(--amber-soft);
1231
- color: var(--amber-text);
1232
  display: flex;
1233
  font-size: 11px;
1234
  font-weight: 500;
@@ -1241,7 +2202,7 @@ textarea {
1241
  .compare-card.current .compare-head,
1242
  .current-compare-head .compare-head {
1243
  background: var(--teal-soft);
1244
- color: var(--teal-text);
1245
  }
1246
 
1247
  .compare-head strong {
@@ -1316,7 +2277,8 @@ textarea {
1316
  @media (max-width: 860px) {
1317
  .top-panel,
1318
  .model-grid,
1319
- .compare-grid {
 
1320
  grid-template-columns: 1fr;
1321
  }
1322
 
@@ -1334,8 +2296,7 @@ textarea {
1334
  }
1335
  """
1336
 
1337
- with gr.Blocks(css=CSS, title="Phi-3 Mini SQL Generator") as demo:
1338
- selected_model_key = gr.State(value=DEFAULT_MODEL_KEY)
1339
  loaded_key_state = gr.State(value=None)
1340
  saved_output = gr.State(value=None)
1341
  active_schema = gr.State(value="")
@@ -1347,11 +2308,11 @@ with gr.Blocks(css=CSS, title="Phi-3 Mini SQL Generator") as demo:
1347
 
1348
  gr.HTML(render_step("01", "Model"))
1349
  with gr.Row(elem_classes=["model-grid"]):
1350
- base_model_card = gr.HTML(render_model_card(BASE_MODEL_KEY, DEFAULT_MODEL_KEY))
1351
  fine_tuned_model_card = gr.HTML(render_model_card(FINE_TUNED_MODEL_KEY, DEFAULT_MODEL_KEY))
1352
- load_button = gr.Button("Load selected model", variant="primary", elem_id="load-button")
1353
  model_status = gr.HTML(render_status(DEFAULT_MODEL_KEY, None))
1354
  model_info = gr.HTML(model_metadata(DEFAULT_MODEL_KEY))
 
1355
 
1356
  with gr.Column(elem_id="query-section", elem_classes=["query-section"]):
1357
  gr.HTML(render_step("02", "Chat"))
@@ -1406,7 +2367,7 @@ with gr.Blocks(css=CSS, title="Phi-3 Mini SQL Generator") as demo:
1406
  show_label=False,
1407
  )
1408
  save_button = gr.Button(
1409
- "Save for comparison",
1410
  interactive=False,
1411
  visible=False,
1412
  elem_id="save-button",
@@ -1423,8 +2384,6 @@ with gr.Blocks(css=CSS, title="Phi-3 Mini SQL Generator") as demo:
1423
  current_sql = gr.Code(label="", language="sql", lines=6, show_label=False)
1424
 
1425
  model_state_outputs = [
1426
- selected_model_key,
1427
- base_model_card,
1428
  fine_tuned_model_card,
1429
  model_status,
1430
  model_info,
@@ -1439,20 +2398,10 @@ with gr.Blocks(css=CSS, title="Phi-3 Mini SQL Generator") as demo:
1439
  save_button,
1440
  error_output,
1441
  ]
1442
- base_model_card.click(
1443
- select_model,
1444
- inputs=[gr.State(BASE_MODEL_KEY), loaded_key_state],
1445
- outputs=model_state_outputs,
1446
- )
1447
- fine_tuned_model_card.click(
1448
- select_model,
1449
- inputs=[gr.State(FINE_TUNED_MODEL_KEY), loaded_key_state],
1450
- outputs=model_state_outputs,
1451
- )
1452
 
1453
  load_button.click(
1454
  load_selected_model,
1455
- inputs=selected_model_key,
1456
  outputs=[
1457
  loaded_key_state,
1458
  model_status,
@@ -1523,6 +2472,29 @@ with gr.Blocks(css=CSS, title="Phi-3 Mini SQL Generator") as demo:
1523
  error_output,
1524
  ],
1525
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1526
 
1527
  queue_kwargs = {}
1528
  if "default_concurrency_limit" in inspect.signature(demo.queue).parameters:
@@ -1531,4 +2503,4 @@ demo.queue(**queue_kwargs)
1531
 
1532
 
1533
  if __name__ == "__main__":
1534
- demo.launch()
 
1
+ import concurrent.futures
2
  import gc
3
  import html
4
  import inspect
5
+ import os
6
  import re
7
  import threading
8
  import time
9
+ import traceback
10
+ import unicodedata
11
 
12
  import gradio as gr
13
  import sqlparse
 
28
  "title": "Phi-3 Mini base",
29
  "model_id": BASE_MODEL_ID,
30
  "exact_match": "2.0%",
31
+ "trust_remote_code": False,
32
  "ready_text": "Base model ready",
33
  "metadata": (
34
  "Model: microsoft/Phi-3-mini-4k-instruct\n"
 
54
  },
55
  }
56
 
 
 
 
 
 
57
  PRESETS = {
58
+ "employees": "CREATE TABLE employees (id INTEGER, name TEXT, department TEXT, salary NUMERIC)",
59
+ "orders": "CREATE TABLE orders (id INTEGER, customer_id INTEGER, product TEXT, amount NUMERIC, date DATE)",
60
+ "students": "CREATE TABLE students (id INTEGER, name TEXT, course TEXT, grade NUMERIC, year INTEGER)",
61
+ "products": "CREATE TABLE products (id INTEGER, name TEXT, category TEXT, price NUMERIC, stock INTEGER)",
62
+ "sales": "CREATE TABLE sales (id INTEGER, product_id INTEGER, quantity INTEGER, total NUMERIC, date DATE)",
63
  }
64
 
65
  PROMPT_TEMPLATE = (
 
72
 
73
  GENERAL_PROMPT_TEMPLATE = (
74
  "<|user|>\n"
75
+ "You are a SQL assistant. Answer the user's question.\n\n"
76
+ "Question: {message}<|end|>\n"
 
77
  "<|assistant|>"
78
  )
79
 
80
  EMPTY_VALIDATOR = '<span class="validator-badge validator-empty">No SQL yet</span>'
81
  CHAT_VALIDATOR = '<span class="validator-badge validator-empty">Chat response</span>'
82
  EMPTY_CHAT_OUTPUT = ""
83
+ LOAD_TIMEOUT_SECONDS = 900
84
+ GENERATION_MAX_TIME_SECONDS = 285
85
+ GENERATION_TIMEOUT_SECONDS = 320
86
+ LOCAL_FILES_ONLY_ENV = "PHI3_SQL_LOCAL_FILES_ONLY"
87
  LOAD_SCROLL_JS = """
88
  (selectedKey) => {
89
  setTimeout(() => {
 
100
  _model = None
101
  _tokenizer = None
102
  _model_lock = threading.RLock()
103
+ _model_activity_lock = threading.Lock()
104
 
105
 
106
  def import_model_runtime():
 
118
  return torch, AutoConfig, AutoModelForCausalLM, AutoTokenizer
119
 
120
 
121
+ def log_load_step(model_id, step, started=None):
122
+ elapsed = "" if started is None else f" elapsed={time.time() - started:.1f}s"
123
+ print(f"[LOAD_STEP] model={model_id} step={step}{elapsed}", flush=True)
124
+
125
+
126
+ def cached_model_weights_available(model_id):
127
+ try:
128
+ from huggingface_hub import try_to_load_from_cache
129
+ except ModuleNotFoundError:
130
+ return False
131
+
132
+ weight_files = (
133
+ "model.safetensors",
134
+ "model.safetensors.index.json",
135
+ "pytorch_model.bin",
136
+ "pytorch_model.bin.index.json",
137
+ )
138
+ for filename in weight_files:
139
+ try:
140
+ cached_path = try_to_load_from_cache(model_id, filename)
141
+ except Exception:
142
+ cached_path = None
143
+ if isinstance(cached_path, str) and os.path.exists(cached_path):
144
+ return True
145
+ return False
146
+
147
+
148
+ def cached_file_path(model_id, filename):
149
+ try:
150
+ from huggingface_hub import try_to_load_from_cache
151
+ except ModuleNotFoundError:
152
+ return None
153
+
154
+ try:
155
+ cached_path = try_to_load_from_cache(model_id, filename)
156
+ except Exception:
157
+ return None
158
+ if isinstance(cached_path, str) and os.path.exists(cached_path):
159
+ return cached_path
160
+ return None
161
+
162
+
163
+ def cached_snapshot_path(model_id):
164
+ config_path = cached_file_path(model_id, "config.json")
165
+ if not config_path or not cached_model_weights_available(model_id):
166
+ return None
167
+ return os.path.dirname(config_path)
168
+
169
+
170
+ def local_files_only_for(model_id):
171
+ explicit_local = os.getenv(LOCAL_FILES_ONLY_ENV, "").strip().lower() in {"1", "true", "yes", "on"}
172
+ offline_mode = bool(os.getenv("HF_HUB_OFFLINE") or os.getenv("TRANSFORMERS_OFFLINE"))
173
+ return explicit_local or offline_mode
174
+
175
+
176
+ def running_on_spaces():
177
+ return bool(os.getenv("SPACE_ID"))
178
+
179
+
180
+ def resolve_model_source(model_id):
181
+ if local_files_only_for(model_id):
182
+ return cached_snapshot_path(model_id) or model_id
183
+ return model_id
184
+
185
+
186
+ def dtype_from_name(torch, dtype_name):
187
+ if not dtype_name:
188
+ return None
189
+ normalized = str(dtype_name).replace("torch.", "")
190
+ return {
191
+ "float16": torch.float16,
192
+ "bfloat16": torch.bfloat16,
193
+ "float32": torch.float32,
194
+ }.get(normalized)
195
+
196
+
197
+ def dtype_from_safetensors(torch, source):
198
+ safetensors_path = os.path.join(source, "model.safetensors")
199
+ if not os.path.exists(safetensors_path):
200
+ return None
201
+ try:
202
+ from safetensors import safe_open
203
+
204
+ with safe_open(safetensors_path, framework="pt", device="cpu") as handle:
205
+ keys = list(handle.keys())
206
+ if not keys:
207
+ return None
208
+ return handle.get_tensor(keys[0]).dtype
209
+ except Exception:
210
+ return None
211
+
212
+
213
+ def cpu_model_dtype(torch):
214
+ return torch.bfloat16
215
+
216
+
217
+ def model_load_kwargs(torch, config, source):
218
+ return {
219
+ "attn_implementation": "eager",
220
+ "device_map": {"": "cpu"},
221
+ "low_cpu_mem_usage": True,
222
+ "torch_dtype": "auto",
223
+ }
224
+
225
+
226
+ def force_eager_attention(config):
227
+ for attr in ("attn_implementation", "_attn_implementation"):
228
+ try:
229
+ setattr(config, attr, "eager")
230
+ except Exception:
231
+ pass
232
+ return config
233
+
234
+
235
+ def _run_generation(model, inputs, kwargs):
236
+ if not _model_activity_lock.acquire(blocking=False):
237
+ raise RuntimeError(
238
+ "Another model operation is still running. Wait for it to finish before starting another request."
239
+ )
240
+ torch, _, _, _ = import_model_runtime()
241
+ try:
242
+ with torch.no_grad():
243
+ return model.generate(**inputs, **kwargs)
244
+ finally:
245
+ _model_activity_lock.release()
246
+
247
+
248
+ def _run_model_load(model_id):
249
+ return load_model(model_id)
250
+
251
+
252
  def patch_phi3_config(config):
253
  if hasattr(config, "rope_scaling") and config.rope_scaling:
254
  rope_type = config.rope_scaling.get("rope_type", "longrope")
255
+ if "type" not in config.rope_scaling:
 
 
256
  config.rope_scaling["type"] = rope_type
257
+ if hasattr(config, "rope_parameters") and config.rope_parameters is None:
258
+ config.rope_parameters = dict(config.rope_scaling)
259
  return config
260
 
261
 
 
281
 
282
  def load_model(model_id):
283
  global _current_model_id, _model, _tokenizer
284
+ started = time.time()
285
+ log_load_step(model_id, "requested", started)
286
+ if not _model_lock.acquire(blocking=False):
287
+ raise RuntimeError("Another model load is still running. Wait for it to finish before retrying.")
288
+ try:
289
  if _current_model_id == model_id and _model is not None and _tokenizer is not None:
290
+ log_load_step(model_id, "already_loaded", started)
291
  return _model, _tokenizer
292
 
293
+ if not _model_activity_lock.acquire(blocking=False):
294
+ raise RuntimeError(
295
+ "Another model operation is still running. Wait for it to finish before switching models."
296
+ )
297
+ try:
298
+ log_load_step(model_id, "runtime_import_start", started)
299
+ torch, AutoConfig, AutoModelForCausalLM, AutoTokenizer = import_model_runtime()
300
+ log_load_step(model_id, "runtime_import_done", started)
301
+ local_files_only = local_files_only_for(model_id)
302
+ model_source = resolve_model_source(model_id)
303
+ log_load_step(model_id, f"cache_mode local_files_only={local_files_only}", started)
304
+ log_load_step(model_id, f"model_source {model_source}", started)
305
+ log_load_step(model_id, "unload_previous_start", started)
306
+ unload_model()
307
+ log_load_step(model_id, "unload_previous_done", started)
308
+ model_def = model_by_id(model_id)
309
+ common_kwargs = {
310
+ "trust_remote_code": model_def["trust_remote_code"],
311
+ "local_files_only": local_files_only,
312
+ }
313
+ log_load_step(model_id, "config_start", started)
314
+ config = AutoConfig.from_pretrained(
315
+ model_source,
316
+ **common_kwargs,
317
+ )
318
+ if model_def["trust_remote_code"]:
319
+ config = patch_phi3_config(config)
320
+ config = force_eager_attention(config)
321
+ log_load_step(model_id, "config_done", started)
322
+ load_kwargs = model_load_kwargs(torch, config, model_source)
323
+ log_load_step(model_id, f"model_kwargs {load_kwargs}", started)
324
+ log_load_step(model_id, "tokenizer_start", started)
325
+ tokenizer = AutoTokenizer.from_pretrained(
326
+ model_source,
327
+ **common_kwargs,
328
+ )
329
+ if tokenizer.pad_token_id is None and tokenizer.eos_token is not None:
330
+ tokenizer.pad_token = tokenizer.eos_token
331
+ log_load_step(model_id, "tokenizer_done", started)
332
+ log_load_step(model_id, "weights_start", started)
333
+ model = AutoModelForCausalLM.from_pretrained(
334
+ model_source,
335
+ config=config,
336
+ **common_kwargs,
337
+ **load_kwargs,
338
+ )
339
+ log_load_step(model_id, "weights_done", started)
340
+ log_load_step(model_id, f"loaded_dtype {getattr(model, 'dtype', 'unknown')}", started)
341
+ log_load_step(model_id, "eval_start", started)
342
+ model.config.use_cache = False
343
+ model.eval()
344
+ log_load_step(model_id, "eval_done", started)
345
+
346
+ _model = model
347
+ _tokenizer = tokenizer
348
+ _current_model_id = model_id
349
+ log_load_step(model_id, "state_set_done", started)
350
+ return model, tokenizer
351
+ finally:
352
+ _model_activity_lock.release()
353
+ finally:
354
+ _model_lock.release()
355
 
356
 
357
  def model_by_key(model_key):
 
372
  return None
373
 
374
 
375
+ def content_to_text(value):
376
+ if value is None:
377
+ return ""
378
+ if isinstance(value, str):
379
+ return value
380
+ if isinstance(value, dict):
381
+ for key in ("text", "content", "value"):
382
+ if key in value:
383
+ return content_to_text(value[key])
384
+ return " ".join(content_to_text(item) for item in value.values())
385
+ if isinstance(value, (list, tuple)):
386
+ return "\n".join(content_to_text(item) for item in value)
387
+ return str(value)
388
+
389
+
390
+ def normalize_text(value):
391
+ text = content_to_text(value).lower()
392
+ text = unicodedata.normalize("NFKD", text)
393
+ text = "".join(char for char in text if not unicodedata.combining(char))
394
+ return re.sub(r"\s+", " ", text).strip()
395
+
396
+
397
+ def safe_chat_fallback(_message=""):
398
+ return (
399
+ "Selecione um schema e faça uma pergunta SQL, "
400
+ "ou peça para criar ou editar uma tabela. "
401
+ "Exemplo: 'crie tabela produtos com id nome preco' "
402
+ "ou 'qual o produto mais caro?'."
403
+ )
404
+
405
+
406
  def clean_generation(text):
407
+ cleaned = content_to_text(text).strip()
408
  if cleaned.startswith("```"):
409
  lines = cleaned.splitlines()
410
  if lines and lines[0].strip().lower() in {"```", "```sql"}:
 
415
  for marker in ("<|end|>", "<|user|>", "<|assistant|>", "</s>"):
416
  if marker in cleaned:
417
  cleaned = cleaned.split(marker, 1)[0].strip()
418
+ if cleaned.upper().startswith("SQL:"):
419
+ cleaned = cleaned[4:].strip()
420
  return cleaned
421
 
422
 
423
+ def extract_sql_candidate(text):
424
+ cleaned = clean_generation(text)
425
+ match = re.search(r"\b(SELECT|WITH|INSERT|UPDATE|DELETE|CREATE|ALTER|DROP)\b", cleaned, flags=re.IGNORECASE)
426
+ if not match:
427
+ return cleaned
428
+ return cleaned[match.start() :].strip()
429
+
430
+
431
  def is_sql_like(text):
432
  text = (text or "").strip()
433
  if not text:
 
448
 
449
 
450
  def is_sql_intent(message, schema):
451
+ message = normalize_text(message)
452
  schema = (schema or "").strip()
 
 
453
  if not message:
454
  return False
455
+ # P1 fix: if schema exists and message has substance, treat as SQL intent
456
+ # (user is likely asking a question about the known schema)
457
+ # Exclude short greetings/acknowledgments that could accompany a schema setup
458
+ short_greetings = {
459
+ "oi", "olá", "ola", "hi", "hello", "hey", "bom", "boa",
460
+ "obrigado", "thanks", "ok", "sim", "claro", "de nada",
461
+ }
462
+ # Extended exclusions for FAQ/off-topic with schema active
463
+ off_topic_patterns = {
464
+ "obrigado", "thanks", "thank you", "muito obrigado", "obrigada",
465
+ "como você funciona", "como voce funciona", "como funciona",
466
+ "o que você faz", "o que voce faz", "o que faz",
467
+ "como foi treinado", "como voce foi treinado", "treinado",
468
+ "quais habilidades", "o que consegue", "o que pode fazer",
469
+ "me ajude", "help me", "ajuda", "help",
470
+ # Edit/table manipulation terms — prevent blanket-catch from routing to model
471
+ "troca", "trocar", "renomeia", "renomear", "renomeie",
472
+ "muda", "mudar", "altera", "alterar", "edita", "editar",
473
+ "adiciona", "adicionar", "adicione", "remove", "remover",
474
+ "apaga", "apagar", "delete column", "drop column",
475
+ "coluna nova", "nova coluna", "novo campo", "campo novo",
476
+ "trocando", "mudando", "alterando", "editando",
477
+ }
478
+ words = message.split()
479
+ # Check if message is off-topic even with 2+ words
480
+ if schema and len(words) >= 2:
481
+ # Check exact matches and patterns
482
+ if message in short_greetings or message in off_topic_patterns:
483
+ return False
484
+ # Check partial matches for common off-topic phrases
485
+ for pattern in off_topic_patterns:
486
+ if pattern in message:
487
+ return False
488
+ if schema and len(words) >= 2 and message not in short_greetings:
489
+ return True
490
  sql_terms = {
491
+ "all",
492
+ "average",
493
+ "count",
494
+ "columns",
 
495
  "database",
496
+ "find",
497
+ "get",
498
  "group by",
499
+ "join",
500
+ "list",
501
  "order by",
502
+ "query",
 
 
 
503
  "rows",
504
+ "schema",
505
+ "select",
506
+ "show",
507
+ "sql",
508
+ "sum",
509
+ "table",
510
+ "where",
511
+ "consulta",
512
+ "consultar",
513
+ "contar",
514
+ "colunas",
515
+ "linhas",
516
+ "liste",
517
+ "listar",
518
+ "maior",
519
+ "mais caro",
520
+ "menor",
521
+ "media",
522
+ "média",
523
+ "mostre",
524
+ "mostrar",
525
+ "ordene",
526
+ "por departamento",
527
+ "selecione",
528
+ "sql",
529
+ "some",
530
+ "soma",
531
+ "tabela",
532
  }
533
+ return any(
534
+ re.search(rf"(?<!\w){re.escape(normalize_text(term))}(?!\w)", message)
535
+ for term in sql_terms
536
+ )
537
 
538
 
539
+ def build_generation_prompt(schema, message, chat_history=None):
540
  schema = (schema or "").strip()
541
  message = (message or "").strip()
542
  if is_sql_intent(message, schema):
543
+ table_schema = schema or "CREATE TABLE unknown (id INTEGER)"
544
+ # Inject last 3 conversation exchanges for multi-turn context
545
+ history_context = ""
546
+ if chat_history:
547
+ trimmed = trim_chat_history(chat_history, max_exchanges=3)
548
+ if trimmed:
549
+ lines = []
550
+ for i in range(0, len(trimmed), 2):
551
+ entry1 = trimmed[i]
552
+ entry2 = trimmed[i + 1] if i + 1 < len(trimmed) else None
553
+ user_msg = entry1.get("content", "") if isinstance(entry1, dict) else (entry1[1] if isinstance(entry1, tuple) else str(entry1))
554
+ asst_msg = entry2.get("content", "") if isinstance(entry2, dict) else (entry2[1] if isinstance(entry2, tuple) else str(entry2)) if entry2 else ""
555
+ lines.append(f"User: {user_msg}")
556
+ if asst_msg:
557
+ lines.append(f"Assistant: {asst_msg}")
558
+ if lines:
559
+ history_context = "\n\nPrevious conversation:\n" + "\n".join(lines) + "\n"
560
+ return PROMPT_TEMPLATE.format(schema=table_schema, question=message) + history_context
561
  return GENERAL_PROMPT_TEMPLATE.format(message=message)
562
 
563
 
564
  def format_generation_result(text):
565
+ cleaned = extract_sql_candidate(text)
566
  if is_sql_like(cleaned):
567
  return str(cleaned), EMPTY_CHAT_OUTPUT, validate_sql(cleaned)
568
  return "", str(cleaned), CHAT_VALIDATOR
 
626
  selected = model_key == selected_key
627
  state_class = " selected" if selected else ""
628
  return f"""
629
+ <article class="model-card{state_class}">
630
  <div class="model-tag">{model_def["tag"]}</div>
631
  <h3>{model_def["title"]}</h3>
632
  <code>{model_def["model_id"]}</code>
 
685
  """
686
 
687
 
688
+ def render_baseline_evidence():
689
+ return """
690
+ <section class="evidence-panel">
691
+ <div class="evidence-copy">
692
+ <h2>Offline baseline comparison</h2>
693
+ <p>The live Space loads only the fine-tuned model to keep the CPU demo testable. The base model comparison is kept as evaluation evidence instead of a second live 3.8B CPU load.</p>
694
+ </div>
695
+ <div class="evidence-grid">
696
+ <div class="evidence-card">
697
+ <span>Base Phi-3 Mini</span>
698
+ <strong>2.0%</strong>
699
+ <small>exact match</small>
700
+ </div>
701
+ <div class="evidence-card highlighted">
702
+ <span>Fine-tuned QLoRA</span>
703
+ <strong>73.5%</strong>
704
+ <small>exact match</small>
705
+ </div>
706
+ <div class="evidence-card">
707
+ <span>Gain</span>
708
+ <strong>+71.5pp</strong>
709
+ <small>same comparison setup</small>
710
+ </div>
711
+ </div>
712
+ </section>
713
+ """
714
+
715
+
716
  def schema_name_by_value(schema):
717
  schema = (schema or "").strip()
718
  for name, value in PRESETS.items():
 
721
  return "custom"
722
 
723
 
724
+ def is_create_table_intent(message):
725
+ message = (message or "").strip().lower()
726
+ return bool(
727
+ re.search(r"\b(create|make|build|generate|criar|crie|cria|gerar|gere|faz|faça)\b", message)
728
+ and re.search(r"\b(table|schema|tabela)\b", message)
729
+ )
730
+
731
+
732
+ def is_table_edit_intent(message):
733
+ message = (message or "").strip().lower()
734
+ edit_terms = r"\b(edit|update|modify|alter|add|include|remove|delete|drop|edita|editar|altera|altere|alterar|mude|mudar|adicione|adicionar|inclua|incluir|acrescente|remova|remover|delete|deletar|exclua|excluir|novo|nova)\b"
735
+ direct_add_terms = r"\b(add|include|adicione|adicionar|adicionando|inclua|incluir|acrescente)\b"
736
+ direct_remove_terms = r"\b(remove|delete|drop|remova|remover|deletar|exclua|excluir)\b"
737
+ target_terms = r"\b(column|field|element|coluna|campo|elemento|item)\b"
738
+ # SQL aggregation keywords that indicate query, not table edit
739
+ sql_aggregation_terms = {"up", "sum", "total", "count", "average", "avg", "max", "min", "by"}
740
+ words = message.split()
741
+ # For add: require target term OR check if it's clearly a column name list
742
+ # "add up the total" is SQL query; "add email and phone" is table edit
743
+ add_match = re.search(direct_add_terms, message)
744
+ has_target = re.search(target_terms, message)
745
+ if add_match:
746
+ # Find position after "add" keyword
747
+ match_pos = add_match.start()
748
+ after_add = message[match_pos + len(add_match.group()):].strip()
749
+ first_word_after = after_add.split()[0] if after_add.split() else ""
750
+ # If first word after "add" is aggregation term, it's SQL query, not edit
751
+ is_sql_query = first_word_after in sql_aggregation_terms
752
+ is_add_intent = not is_sql_query
753
+ else:
754
+ is_add_intent = False
755
+ return bool(
756
+ is_add_intent
757
+ or re.search(direct_remove_terms, message)
758
+ or is_rename_intent(message)
759
+ or re.search(r"\b(?:altere|alterar|mude|mudar)\b.*\bter\b", message)
760
+ or (re.search(edit_terms, message) and (re.search(target_terms, message) or ":" in message))
761
+ )
762
+
763
+
764
+ def infer_column_type(column_name):
765
+ name = column_name.strip().lower()
766
+ if name == "id" or name.endswith("_id") or name in {"quantity", "quantidade", "stock", "estoque", "year"}:
767
+ return "INTEGER"
768
+ if name in {
769
+ "salary",
770
+ "price",
771
+ "preco",
772
+ "amount",
773
+ "total",
774
+ "grade",
775
+ "peso",
776
+ "weight",
777
+ "idade",
778
+ "age",
779
+ "altura",
780
+ "height",
781
+ "largura",
782
+ "width",
783
+ "comprimento",
784
+ "length",
785
+ "desconto",
786
+ "discount",
787
+ }:
788
+ return "NUMERIC"
789
+ if name in {"date", "created_at", "updated_at"} or name.endswith("_date"):
790
+ return "DATE"
791
+ return "TEXT"
792
+
793
+
794
+ def normalize_identifier(value):
795
+ identifier = re.sub(r"\W+", "_", normalize_text(value)).strip("_")
796
+ if not identifier:
797
+ return ""
798
+ if identifier[0].isdigit():
799
+ identifier = f"col_{identifier}"
800
+ return identifier
801
+
802
+
803
+ def parse_column_definition(raw_column):
804
+ raw_column = re.sub(r"\b(for me|please|por favor)\b", "", raw_column or "", flags=re.IGNORECASE)
805
+ raw_column = raw_column.strip(" .;:")
806
+ if not raw_column:
807
+ return None
808
+
809
+ # P2 fix: procurar o tipo como token FINAL, não o primeiro match
810
+ # "date DATE" deve ser interpretado como nome="date", tipo="DATE", não nome="" tipo="date"
811
+ type_matches = list(
812
+ re.finditer(
813
+ r"\b(integer|int|numeric|decimal|real|float|double|text|varchar|char|date|datetime|timestamp|boolean|bool)\b",
814
+ raw_column,
815
+ flags=re.IGNORECASE,
816
+ )
817
+ )
818
+ explicit_type = type_matches[-1] if type_matches else None
819
+ if explicit_type:
820
+ name_part = raw_column[: explicit_type.start()].strip()
821
+ column_type = explicit_type.group(1).upper()
822
+ if column_type == "INT":
823
+ column_type = "INTEGER"
824
+ elif column_type == "BOOL":
825
+ column_type = "BOOLEAN"
826
+ elif column_type == "DECIMAL":
827
+ column_type = "NUMERIC"
828
+ elif column_type in {"FLOAT", "DOUBLE"}:
829
+ column_type = "REAL"
830
+ if not name_part.strip():
831
+ column_type = None
832
+ name_part = raw_column
833
+ else:
834
+ name_part = raw_column
835
+ column_type = None
836
+
837
+ name_part = re.sub(r"\b(column|field|coluna|campo)\b", "", name_part, flags=re.IGNORECASE)
838
+ column_name = normalize_identifier(name_part)
839
+ if not column_name:
840
+ return None
841
+ return column_name, column_type or infer_column_type(column_name)
842
+
843
+
844
+ def split_column_list(columns_text):
845
+ columns_text = re.sub(r"\s+(and|e)\s+", ",", columns_text or "", flags=re.IGNORECASE)
846
+ parts = []
847
+ type_pattern = (
848
+ r"\b(integer|int|numeric|decimal|real|float|double|text|varchar|char|date|datetime|timestamp|boolean|bool)\b"
849
+ )
850
+ type_tokens = {
851
+ "integer",
852
+ "int",
853
+ "numeric",
854
+ "decimal",
855
+ "real",
856
+ "float",
857
+ "double",
858
+ "text",
859
+ "varchar",
860
+ "char",
861
+ "date",
862
+ "datetime",
863
+ "timestamp",
864
+ "boolean",
865
+ "bool",
866
+ }
867
+ STOPWORDS = {
868
+ "to", "from", "into", "as", "for",
869
+ "o", "a", "os", "de", "do", "da", "dos", "das",
870
+ }
871
+ for part in (item.strip() for item in columns_text.split(",") if item.strip()):
872
+ tokens = [token.strip() for token in re.split(r"\s+", part) if token.strip()]
873
+ tokens = [t for t in tokens if t.lower() not in STOPWORDS]
874
+ if not tokens:
875
+ continue
876
+ if re.search(type_pattern, part, flags=re.IGNORECASE) and len(tokens) > 2:
877
+ index = 0
878
+ # Column names that could be confused with SQL types when followed by date/datetime/timestamp
879
+ # These should be treated as column names, not as part of type specification
880
+ inferrable_names = {"total", "date", "time", "timestamp", "int", "text", "real", "char"}
881
+ while index < len(tokens):
882
+ current = tokens[index]
883
+ next_token = tokens[index + 1].lower() if index + 1 < len(tokens) else ""
884
+ # If current could be inferred as a different type, don't pair with date/datetime/timestamp
885
+ # This preserves "total date" → "total" (inferred NUMERIC) + "date" (type)
886
+ if next_token in type_tokens and not (current.lower() in inferrable_names and next_token in {"date", "datetime", "timestamp"}):
887
+ parts.append(f"{current} {tokens[index + 1]}")
888
+ index += 2
889
+ else:
890
+ parts.append(current)
891
+ index += 1
892
+ continue
893
+ if re.search(type_pattern, part, flags=re.IGNORECASE):
894
+ parts.append(part)
895
+ continue
896
+ if len(tokens) > 1 and all(re.match(r"^[A-Za-z_][\wàáâãçèéêíóôõúÀÁÂÃÇÈÉÊÍÓÔÕÚ]*$", token) for token in tokens):
897
+ parts.extend(tokens)
898
+ else:
899
+ parts.append(part)
900
+ return parts
901
+
902
+
903
+ def format_create_table(table_name, columns):
904
+ if not table_name or not columns:
905
+ return ""
906
+ seen = set()
907
+ column_lines = []
908
+ for column_name, column_type in columns:
909
+ if column_name in seen:
910
+ continue
911
+ seen.add(column_name)
912
+ column_lines.append(f" {column_name} {column_type}")
913
+ if not column_lines:
914
+ return ""
915
+ return f"CREATE TABLE {table_name} (\n" + ",\n".join(column_lines) + "\n);"
916
+
917
+
918
+ def create_table_from_message(message):
919
+ message = (message or "").strip()
920
+ patterns = (
921
+ r"\b(?:table|tabela)\s+(?:called\s+|named\s+|chamada?\s+|nomeada?\s+)?([A-Za-z_][\w]*)\s+(?:with|containing|including|com)\s+(.+)$",
922
+ r"\b(?:create|make|build|generate|criar|crie|gerar|gere)\b.*?\b(?:table|tabela)\b\s+([A-Za-z_][\w]*)\s+(?:with|containing|including|com)\s+(.+)$",
923
+ )
924
+ for pattern in patterns:
925
+ match = re.search(pattern, message, flags=re.IGNORECASE)
926
+ if not match:
927
+ continue
928
+ table_name = normalize_identifier(match.group(1))
929
+ columns = [
930
+ parsed
931
+ for parsed in (parse_column_definition(column) for column in split_column_list(match.group(2)))
932
+ if parsed
933
+ ]
934
+ return format_create_table(table_name, columns)
935
+ return ""
936
+
937
+
938
+ def parse_create_table_schema(schema):
939
+ schema = (schema or "").strip()
940
+ match = re.match(
941
+ r"^\s*(?:CREATE\s+TABLE\s+)?([A-Za-z_][\w]*)\s*\((.*?)\)\s*;?\s*$",
942
+ schema,
943
+ flags=re.IGNORECASE | re.DOTALL,
944
+ )
945
+ if not match:
946
+ return "", []
947
+ table_name = normalize_identifier(match.group(1))
948
+ columns = [
949
+ parsed
950
+ for parsed in (parse_column_definition(column) for column in split_column_list(match.group(2)))
951
+ if parsed
952
+ ]
953
+ return table_name, columns
954
+
955
+
956
+ def create_table_from_schema(schema):
957
+ table_name, columns = parse_create_table_schema(schema)
958
+ return format_create_table(table_name, columns)
959
+
960
+
961
+ def extract_create_table_statement(text):
962
+ cleaned = extract_sql_candidate(text)
963
+ match = re.search(
964
+ r"\bCREATE\s+TABLE\s+[A-Za-z_][\w]*\s*\(.*?\)\s*;?",
965
+ cleaned,
966
+ flags=re.IGNORECASE | re.DOTALL,
967
+ )
968
+ return clean_generation(match.group(0)) if match else ""
969
+
970
+
971
+ def last_create_table_from_history(chat_history):
972
+ for item in reversed(list(chat_history or [])):
973
+ if not isinstance(item, dict) or item.get("role") != "assistant":
974
+ continue
975
+ statement = extract_create_table_statement(item.get("content", ""))
976
+ if statement:
977
+ return statement
978
+ return ""
979
+
980
+
981
+ def extract_added_columns(message):
982
+ message = (message or "").strip()
983
+ patterns = (
984
+ r":\s*(.+)$",
985
+ r"\b(?:add|include|with|adicionar|adicione|adicionando|inclua|incluir|acrescente|ter)\b\s+(?:um\s+|uma\s+|a\s+|an\s+)?(?:novo\s+|nova\s+|new\s+)?(?:column|field|element|coluna|campo|elemento|item)?\s*(.+)$",
986
+ )
987
+ for pattern in patterns:
988
+ match = re.search(pattern, message, flags=re.IGNORECASE)
989
+ if not match:
990
+ continue
991
+ columns = [
992
+ parsed
993
+ for parsed in (parse_column_definition(column) for column in split_column_list(match.group(1)))
994
+ if parsed
995
+ ]
996
+ if columns:
997
+ return columns
998
+ return []
999
+
1000
+
1001
+ def extract_removed_columns(message):
1002
+ message = (message or "").strip()
1003
+ patterns = (
1004
+ r"\b(?:remove|delete|drop|remova|remover|deletar|exclua|excluir)\b\s+(?:a\s+|o\s+|the\s+)?(?:column|field|element|coluna|campo|elemento|item)?\s*(.+)$",
1005
+ )
1006
+ for pattern in patterns:
1007
+ match = re.search(pattern, message, flags=re.IGNORECASE)
1008
+ if not match:
1009
+ continue
1010
+ columns = [normalize_identifier(column) for column in split_column_list(match.group(1))]
1011
+ columns = [column for column in columns if column]
1012
+ if columns:
1013
+ return columns
1014
+ return []
1015
+
1016
+
1017
+ def is_rename_intent(message):
1018
+ message = (message or "").strip().lower()
1019
+ return bool(
1020
+ re.search(
1021
+ r"\b(rename|edit|change|renomeie|renomear|altere|mude)\s+\w+\s+(to|para|as|como)\s+\w+",
1022
+ message,
1023
+ flags=re.IGNORECASE,
1024
+ )
1025
+ )
1026
+
1027
+
1028
+ def extract_renamed_columns(message):
1029
+ pattern = (
1030
+ r"\b(?:rename|edit|change|renomeie|renomear|altere|mude)\s+"
1031
+ r"(\w+)\s+(?:to|para|as|como)\s+(\w+)"
1032
+ )
1033
+ matches = re.findall(pattern, message or "", flags=re.IGNORECASE)
1034
+ return [
1035
+ (normalize_identifier(old), normalize_identifier(new))
1036
+ for old, new in matches
1037
+ if normalize_identifier(old) and normalize_identifier(new)
1038
+ ]
1039
+
1040
+
1041
+ def parse_compound_edit(message):
1042
+ """Divide um prompt composto em segmentos e extrai add/remove/rename."""
1043
+ segment_pattern = (
1044
+ r"\s+(?:and|e)\s+"
1045
+ r"(?=\b(?:add|include|remove|delete|drop|rename|edit|change|"
1046
+ r"adicione|adicionar|inclua|acrescente|remova|remover|deletar|"
1047
+ r"exclua|renomeie|renomear|altere|mude)\b)"
1048
+ )
1049
+ segments = re.split(segment_pattern, message or "", flags=re.IGNORECASE)
1050
+
1051
+ added, removed, renamed = [], [], []
1052
+ for seg in segments:
1053
+ seg = seg.strip()
1054
+ if not seg:
1055
+ continue
1056
+ if is_rename_intent(seg):
1057
+ renamed.extend(extract_renamed_columns(seg))
1058
+ elif re.search(
1059
+ r"\b(remove|delete|drop|remova|remover|deletar|exclua|excluir)\b",
1060
+ seg,
1061
+ flags=re.IGNORECASE,
1062
+ ):
1063
+ removed.extend(extract_removed_columns(seg))
1064
+ else:
1065
+ cols = extract_added_columns(seg)
1066
+ if cols:
1067
+ added.extend(cols)
1068
+ return added, removed, renamed
1069
+
1070
+
1071
+ def edit_create_table_from_message(message, chat_history, active_schema):
1072
+ if not is_table_edit_intent(message) and not is_rename_intent(message):
1073
+ return ""
1074
+ base_sql = last_create_table_from_history(chat_history) or create_table_from_schema(active_schema)
1075
+ table_name, existing_columns = parse_create_table_schema(base_sql)
1076
+ if not table_name:
1077
+ return ""
1078
+
1079
+ added_columns, removed_columns_list, renamed_columns = parse_compound_edit(message)
1080
+ removed_set = set(extract_removed_columns(message)) | {r for r in removed_columns_list}
1081
+
1082
+ if not added_columns and not removed_set and not renamed_columns:
1083
+ return ""
1084
+
1085
+ rename_map = dict(renamed_columns)
1086
+ kept_columns = [
1087
+ (rename_map.get(col_name, col_name), col_type)
1088
+ for col_name, col_type in existing_columns
1089
+ if col_name not in removed_set
1090
+ ]
1091
+ return format_create_table(table_name, [*kept_columns, *added_columns])
1092
+
1093
+
1094
  def render_schema_context(schema=""):
1095
  schema = (schema or "").strip()
1096
  if not schema:
 
1108
 
1109
  def query_control_updates(can_generate):
1110
  context_updates = [gr.update(interactive=True) for _ in range(6)]
1111
+ # Keep submit button enabled - model requirement is checked in generate_response
1112
+ return [*context_updates, gr.update(interactive=True), gr.update(interactive=True)]
1113
 
1114
 
1115
  def render_message(message="", kind="error"):
 
1134
  )
1135
 
1136
 
1137
+ def load_selected_model(selected_key=FINE_TUNED_MODEL_KEY):
1138
+ selected_key = FINE_TUNED_MODEL_KEY
1139
  model_def = model_by_key(selected_key)
1140
+ print(
1141
+ f"[LOAD_REQUEST] selected_key={selected_key} model_id={model_def['model_id']}",
1142
+ flush=True,
1143
+ )
1144
  yield (
1145
  None,
1146
  render_status(selected_key, None, state="loading"),
 
1156
  )
1157
  started = time.time()
1158
  try:
1159
+ executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
1160
+ future = executor.submit(_run_model_load, model_def["model_id"])
1161
+ try:
1162
+ result = future.result(timeout=LOAD_TIMEOUT_SECONDS)
1163
+ except concurrent.futures.TimeoutError:
1164
+ # Timeout reached but cannot truly cancel a running thread.
1165
+ # Wait for the operation to complete naturally to avoid race conditions.
1166
+ # The UI stays in loading state until the operation finishes.
1167
+ result = future.result()
1168
+ print(f"[LOAD] Completed after timeout warning ({int(time.time() - started)}s)", flush=True)
1169
+ finally:
1170
+ executor.shutdown(wait=False, cancel_futures=True)
1171
  except Exception as exc:
1172
  error = f"Load failed for {model_def['model_id']}: {type(exc).__name__}: {exc}"
1173
+ print(f"[LOAD_ERROR] {error}", flush=True)
1174
+ traceback.print_exc()
1175
  yield (
1176
  None,
1177
  render_status(selected_key, None),
1178
  render_loading_overlay(visible=False),
1179
  model_metadata(selected_key),
1180
+ gr.update(interactive=True, visible=True),
1181
  *query_control_updates(False),
1182
  "",
1183
  EMPTY_VALIDATOR,
 
1193
  render_status(selected_key, selected_key),
1194
  render_loading_overlay(visible=False),
1195
  model_metadata(selected_key),
1196
+ gr.update(interactive=True, visible=True, value="Load fine-tuned model"),
1197
  *query_control_updates(True),
1198
  "",
1199
  EMPTY_VALIDATOR,
 
1239
  )
1240
 
1241
 
1242
+ def deterministic_response(
1243
+ chat_history,
1244
+ message,
1245
+ active_schema,
1246
+ loaded_key,
1247
+ saved_state,
1248
+ assistant_content,
1249
+ status_message,
1250
+ *,
1251
+ sql_text="",
1252
+ validator=CHAT_VALIDATOR,
1253
+ status_kind="ok",
1254
+ ):
1255
+ new_history = trim_chat_history(
1256
+ [
1257
+ *list(chat_history or []),
1258
+ {"role": "user", "content": message},
1259
+ {"role": "assistant", "content": assistant_content},
1260
+ ]
1261
+ )
1262
+ # If sql_text is a CREATE TABLE, promote it to active_schema for subsequent queries
1263
+ new_schema = active_schema
1264
+ if sql_text and "CREATE TABLE" in sql_text.upper():
1265
+ new_schema = sql_text
1266
+ compare = comparison_updates(saved_state, sql_text, loaded_key)
1267
+ return (
1268
+ new_history,
1269
+ "",
1270
+ new_schema,
1271
+ message,
1272
+ sql_text,
1273
+ validator,
1274
+ gr.update(interactive=False, visible=False),
1275
+ render_message(status_message, kind=status_kind),
1276
+ *compare,
1277
+ )
1278
+
1279
+
1280
  def generate_response(message, chat_history, active_schema, loaded_key, saved_state):
1281
  message = (message or "").strip()
1282
  active_schema = (active_schema or "").strip()
1283
  chat_history = list(chat_history or [])
1284
+ if not message:
1285
+ compare = comparison_updates(saved_state, "", loaded_key)
1286
+ return (
1287
+ chat_history,
1288
+ "",
1289
+ active_schema,
1290
+ "",
1291
+ "",
1292
+ EMPTY_VALIDATOR,
1293
+ gr.update(interactive=False, visible=False),
1294
+ render_message("Type a message before sending."),
1295
+ *compare,
1296
+ )
1297
+
1298
+ # Routing debug log — shows which intent matched
1299
+ _routing = []
1300
+ edited_table = edit_create_table_from_message(message, chat_history, active_schema)
1301
+ if edited_table:
1302
+ _routing.append("edit_create_table")
1303
+ elif is_table_edit_intent(message):
1304
+ _routing.append("is_table_edit_intent")
1305
+ elif is_create_table_intent(message):
1306
+ _routing.append("is_create_table_intent")
1307
+ elif is_sql_intent(message, active_schema):
1308
+ _routing.append("is_sql_intent")
1309
+ else:
1310
+ _routing.append("no_match")
1311
+ print(f"[ROUTING] \"{message[:60]}\" → {_routing}")
1312
+
1313
+ if edited_table:
1314
+ display_response = f"```sql\n{edited_table}\n```"
1315
+ return deterministic_response(
1316
+ chat_history,
1317
+ message,
1318
+ active_schema,
1319
+ loaded_key,
1320
+ saved_state,
1321
+ display_response,
1322
+ "Edited CREATE TABLE without calling the model.",
1323
+ sql_text=edited_table,
1324
+ validator=validate_sql(edited_table),
1325
+ )
1326
+ if is_table_edit_intent(message):
1327
  compare = comparison_updates(saved_state, "", loaded_key)
1328
  return (
1329
  chat_history,
 
1333
  "",
1334
  EMPTY_VALIDATOR,
1335
  gr.update(interactive=False, visible=False),
1336
+ render_message("I need an existing CREATE TABLE in the chat or an active schema before editing columns."),
1337
  *compare,
1338
  )
1339
+
1340
+ if is_create_table_intent(message):
1341
+ sql_text = create_table_from_message(message) or create_table_from_schema(active_schema)
1342
+ if sql_text:
1343
+ display_response = f"```sql\n{sql_text}\n```"
1344
+ return deterministic_response(
1345
+ chat_history,
1346
+ message,
1347
+ active_schema,
1348
+ loaded_key,
1349
+ saved_state,
1350
+ display_response,
1351
+ "Generated CREATE TABLE without calling the model.",
1352
+ sql_text=sql_text,
1353
+ validator=validate_sql(sql_text),
1354
+ )
1355
  compare = comparison_updates(saved_state, "", loaded_key)
1356
  return (
1357
  chat_history,
1358
+ message,
1359
+ active_schema,
1360
+ "",
1361
  "",
1362
+ EMPTY_VALIDATOR,
1363
+ gr.update(interactive=False, visible=False),
1364
+ render_message("CREATE TABLE needs a table name and columns, or an active schema context."),
1365
+ *compare,
1366
+ )
1367
+
1368
+
1369
+ if not is_sql_intent(message, active_schema):
1370
+ fallback = safe_chat_fallback()
1371
+ return deterministic_response(
1372
+ chat_history,
1373
+ message,
1374
+ active_schema,
1375
+ loaded_key,
1376
+ saved_state,
1377
+ fallback,
1378
+ "No SQL intent or active schema detected.",
1379
+ )
1380
+
1381
+ if not loaded_key or _model is None or _tokenizer is None:
1382
+ compare = comparison_updates(saved_state, "", loaded_key)
1383
+ return (
1384
+ chat_history,
1385
+ message,
1386
  active_schema,
1387
  "",
1388
  "",
1389
  EMPTY_VALIDATOR,
1390
  gr.update(interactive=False, visible=False),
1391
+ render_message("Load a model before generating SQL."),
1392
  *compare,
1393
  )
1394
 
 
1409
 
1410
  started = time.time()
1411
  try:
1412
+ import_model_runtime()
1413
  with _model_lock:
1414
+ prompt = build_generation_prompt(active_schema, message, chat_history)
1415
  inputs = _tokenizer(prompt, return_tensors="pt")
1416
  input_length = inputs["input_ids"].shape[-1]
1417
+ gen_kwargs = {
1418
+ "max_new_tokens": 80,
1419
+ "max_time": GENERATION_MAX_TIME_SECONDS,
1420
+ "do_sample": False,
1421
+ "use_cache": False,
1422
+ "repetition_penalty": 1.1,
1423
+ "eos_token_id": getattr(_model.generation_config, "eos_token_id", _tokenizer.eos_token_id),
1424
+ "pad_token_id": _tokenizer.pad_token_id or _tokenizer.eos_token_id,
1425
+ }
1426
+ executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
1427
+ future = executor.submit(_run_generation, _model, inputs, gen_kwargs)
1428
+ try:
1429
+ output_ids = future.result(timeout=GENERATION_TIMEOUT_SECONDS)
1430
+ except concurrent.futures.TimeoutError:
1431
+ # Timeout reached - do NOT call future.result() without timeout as it can block indefinitely.
1432
+ # The thread may continue in background but we won't wait for it.
1433
+ # Return error to user and release the slot.
1434
+ executor.shutdown(wait=False, cancel_futures=False)
1435
+ raise TimeoutError(f"Generation timed out after {GENERATION_TIMEOUT_SECONDS}s")
1436
+ finally:
1437
+ executor.shutdown(wait=False, cancel_futures=True)
1438
  generated_ids = output_ids[0][input_length:]
1439
+ generated_text = _tokenizer.decode(generated_ids, skip_special_tokens=True)
1440
  except Exception as exc:
1441
  compare = comparison_updates(saved_state, "", loaded_key)
1442
  return (
 
1470
  message,
1471
  str(sql_text),
1472
  validator,
1473
+ gr.update(interactive=False, visible=False),
1474
  render_message(f"Generated {response_kind} with {model_def['model_id']} in {elapsed}s.", kind="ok"),
1475
  *compare,
1476
  )
 
1510
  )
1511
 
1512
 
1513
+ def sync_on_load():
1514
+ if _model is not None and _current_model_id is not None:
1515
+ loaded_key = model_key_by_id(_current_model_id)
1516
+ if loaded_key:
1517
+ return (
1518
+ loaded_key,
1519
+ render_status(loaded_key, loaded_key),
1520
+ render_loading_overlay(visible=False),
1521
+ model_metadata(loaded_key),
1522
+ gr.update(interactive=True, visible=True, value="Load fine-tuned model"),
1523
+ *query_control_updates(True),
1524
+ "",
1525
+ EMPTY_VALIDATOR,
1526
+ gr.update(interactive=False, visible=False),
1527
+ render_message(f"Model already loaded: {_current_model_id}", kind="ok"),
1528
+ gr.update(visible=False),
1529
+ )
1530
+ return (
1531
+ None,
1532
+ render_status(DEFAULT_MODEL_KEY, None),
1533
+ render_loading_overlay(visible=False),
1534
+ model_metadata(DEFAULT_MODEL_KEY),
1535
+ gr.update(interactive=True, visible=True),
1536
+ *query_control_updates(False),
1537
+ "",
1538
+ EMPTY_VALIDATOR,
1539
+ gr.update(interactive=False, visible=False),
1540
+ render_message(),
1541
+ gr.update(visible=False),
1542
+ )
1543
+
1544
+
1545
  CSS = """
1546
  @import url('https://fonts.googleapis.com/css2?family=Space+Mono:wght@400;500;700&display=swap');
1547
 
1548
+ /* Prevent Gradio dark theme from overriding text in light-bg components */
1549
+ [class*="badge"],
1550
+ [class*="validator-"],
1551
+ [class*="compare-head"],
1552
+ [class*="model-tag"],
1553
+ [class*="stat-card"] {
1554
+ color: inherit !important;
1555
+ }
1556
+
1557
  :root {
1558
  --bg-base: #0c0c0b;
1559
  --bg-surface: #1a1a18;
 
1644
  .badge-green,
1645
  .validator-ok {
1646
  background: var(--teal-soft);
1647
+ color: var(--teal-text) !important;
1648
  }
1649
 
1650
  .badge-cream,
1651
  .validator-warn {
1652
  background: var(--amber-soft);
1653
+ color: var(--amber-text) !important;
1654
  }
1655
 
1656
  .badge-light,
1657
  .validator-empty {
1658
  background: var(--bg-raised);
1659
+ color: var(--text-secondary) !important;
1660
  border: 0.5px solid var(--border);
1661
  }
1662
 
 
1692
  background: var(--bg-surface);
1693
  border: 0.5px solid var(--border);
1694
  border-radius: 6px;
 
1695
  min-height: 176px;
1696
  padding: 16px;
1697
  transition: border-color 160ms ease, background 160ms ease;
1698
  }
1699
 
 
 
 
 
1700
  .model-card.selected {
1701
  border: 1.5px solid var(--teal);
1702
  }
1703
 
1704
  .model-tag {
1705
  background: var(--amber-soft);
1706
+ color: var(--amber-text) !important;
1707
  margin-bottom: 18px;
1708
  }
1709
 
1710
  .model-card.selected .model-tag {
1711
  background: var(--teal-soft);
1712
+ color: var(--teal-text) !important;
1713
  }
1714
 
1715
  .model-card h3 {
 
1764
  display: flex;
1765
  }
1766
 
1767
+ .evidence-panel {
1768
+ background: var(--bg-surface);
1769
+ border: 0.5px solid var(--border);
1770
+ border-radius: 6px;
1771
+ margin-top: 12px;
1772
+ padding: 16px;
1773
+ }
1774
+
1775
+ .evidence-copy h2 {
1776
+ color: var(--text-primary);
1777
+ font-size: 13px;
1778
+ font-weight: 500;
1779
+ line-height: 1.3;
1780
+ margin: 0 0 6px;
1781
+ }
1782
+
1783
+ .evidence-copy p {
1784
+ color: var(--text-secondary);
1785
+ font-size: 12px;
1786
+ line-height: 1.45;
1787
+ margin: 0;
1788
+ }
1789
+
1790
+ .evidence-grid {
1791
+ display: grid;
1792
+ gap: 8px;
1793
+ grid-template-columns: repeat(3, minmax(0, 1fr));
1794
+ margin-top: 14px;
1795
+ }
1796
+
1797
+ .evidence-card {
1798
+ background: var(--bg-raised);
1799
+ border: 0.5px solid var(--border);
1800
+ border-radius: 6px;
1801
+ padding: 10px;
1802
+ }
1803
+
1804
+ .evidence-card.highlighted {
1805
+ border-color: rgba(29, 158, 117, 0.5);
1806
+ }
1807
+
1808
+ .evidence-card span,
1809
+ .evidence-card small {
1810
+ color: var(--text-secondary);
1811
+ display: block;
1812
+ font-size: 10px;
1813
+ line-height: 1.25;
1814
+ }
1815
+
1816
+ .evidence-card strong {
1817
+ color: var(--text-primary);
1818
+ display: block;
1819
+ font-size: 20px;
1820
+ font-weight: 500;
1821
+ line-height: 1.1;
1822
+ margin: 5px 0;
1823
+ }
1824
+
1825
  #load-button,
1826
  #generate-button,
1827
  #save-button {
 
1846
  width: 100% !important;
1847
  }
1848
 
1849
+ #generate-button button {
1850
+ height: 42px !important;
1851
+ min-height: 42px !important;
1852
+ }
1853
+
1854
  #load-button button:hover,
1855
  #generate-button button:hover {
1856
  background: var(--text-primary) !important;
 
1914
  }
1915
 
1916
  .stat-card strong {
1917
+ color: var(--text-primary) !important;
1918
  display: block;
1919
  font-size: 15px;
1920
  font-weight: 500;
 
1923
  }
1924
 
1925
  .stat-card span {
1926
+ color: var(--text-secondary) !important;
1927
  display: block;
1928
  font-size: 11px;
1929
  font-weight: 400;
 
2008
  }
2009
 
2010
  .composer-row {
2011
+ align-items: flex-end !important;
2012
+ display: flex !important;
2013
  gap: 8px !important;
2014
  }
2015
 
2016
+ .composer-row > div {
2017
+ display: flex !important;
2018
+ flex-direction: column !important;
2019
+ justify-content: flex-end !important;
2020
+ }
2021
+
2022
  #message-input {
2023
  flex: 1 1 auto;
2024
  }
2025
 
2026
  #message-input textarea {
2027
  min-height: 42px !important;
2028
+ max-height: 120px !important;
2029
+ height: 42px !important;
2030
+ resize: none !important;
2031
+ overflow-y: auto !important;
2032
+ }
2033
+
2034
+ #generate-button {
2035
+ align-self: flex-end !important;
2036
+ margin-bottom: 0 !important;
2037
  }
2038
 
2039
  #clear-schema-button button {
 
2139
  }
2140
 
2141
  .validator-detail {
2142
+ color: var(--text-secondary) !important;
2143
  font-size: 11px;
2144
  margin-left: 8px;
2145
  }
 
2189
  .compare-head {
2190
  align-items: center;
2191
  background: var(--amber-soft);
2192
+ color: var(--amber-text) !important;
2193
  display: flex;
2194
  font-size: 11px;
2195
  font-weight: 500;
 
2202
  .compare-card.current .compare-head,
2203
  .current-compare-head .compare-head {
2204
  background: var(--teal-soft);
2205
+ color: var(--teal-text) !important;
2206
  }
2207
 
2208
  .compare-head strong {
 
2277
  @media (max-width: 860px) {
2278
  .top-panel,
2279
  .model-grid,
2280
+ .compare-grid,
2281
+ .evidence-grid {
2282
  grid-template-columns: 1fr;
2283
  }
2284
 
 
2296
  }
2297
  """
2298
 
2299
+ with gr.Blocks(title="Phi-3 Mini SQL Generator") as demo:
 
2300
  loaded_key_state = gr.State(value=None)
2301
  saved_output = gr.State(value=None)
2302
  active_schema = gr.State(value="")
 
2308
 
2309
  gr.HTML(render_step("01", "Model"))
2310
  with gr.Row(elem_classes=["model-grid"]):
 
2311
  fine_tuned_model_card = gr.HTML(render_model_card(FINE_TUNED_MODEL_KEY, DEFAULT_MODEL_KEY))
2312
+ load_button = gr.Button("Load fine-tuned model", variant="primary", elem_id="load-button")
2313
  model_status = gr.HTML(render_status(DEFAULT_MODEL_KEY, None))
2314
  model_info = gr.HTML(model_metadata(DEFAULT_MODEL_KEY))
2315
+ gr.HTML(render_baseline_evidence())
2316
 
2317
  with gr.Column(elem_id="query-section", elem_classes=["query-section"]):
2318
  gr.HTML(render_step("02", "Chat"))
 
2367
  show_label=False,
2368
  )
2369
  save_button = gr.Button(
2370
+ "Save output",
2371
  interactive=False,
2372
  visible=False,
2373
  elem_id="save-button",
 
2384
  current_sql = gr.Code(label="", language="sql", lines=6, show_label=False)
2385
 
2386
  model_state_outputs = [
 
 
2387
  fine_tuned_model_card,
2388
  model_status,
2389
  model_info,
 
2398
  save_button,
2399
  error_output,
2400
  ]
 
 
 
 
 
 
 
 
 
 
2401
 
2402
  load_button.click(
2403
  load_selected_model,
2404
+ inputs=None,
2405
  outputs=[
2406
  loaded_key_state,
2407
  model_status,
 
2472
  error_output,
2473
  ],
2474
  )
2475
+ demo.load(
2476
+ sync_on_load,
2477
+ outputs=[
2478
+ loaded_key_state,
2479
+ model_status,
2480
+ loading_overlay,
2481
+ model_info,
2482
+ load_button,
2483
+ employees_preset,
2484
+ orders_preset,
2485
+ students_preset,
2486
+ products_preset,
2487
+ sales_preset,
2488
+ clear_schema_button,
2489
+ message_input,
2490
+ send_button,
2491
+ sql_output,
2492
+ validator_output,
2493
+ save_button,
2494
+ error_output,
2495
+ comparison_panel,
2496
+ ],
2497
+ )
2498
 
2499
  queue_kwargs = {}
2500
  if "default_concurrency_limit" in inspect.signature(demo.queue).parameters:
 
2503
 
2504
 
2505
  if __name__ == "__main__":
2506
+ demo.launch(css=CSS)
requirements.txt CHANGED
@@ -2,6 +2,6 @@ transformers>=4.44.0
2
  peft>=0.11.0
3
  accelerate>=0.30.0
4
  torch
5
- gradio>=4.0.0
6
  sqlparse
7
  huggingface_hub
 
2
  peft>=0.11.0
3
  accelerate>=0.30.0
4
  torch
5
+ gradio>=6.0.0
6
  sqlparse
7
  huggingface_hub
tests/e2e_flow_test.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ End-to-end flow tests for phi3-mini-sql-generator demo.
3
+ Run with: python tests/e2e_flow_test.py
4
+
5
+ Model must be loaded first. Call app.load_model(app.FINE_TUNED_MODEL_ID)
6
+ before running these tests.
7
+ """
8
+
9
+ import app
10
+ import types
11
+
12
+ # ---------------------------------------------------------------------------
13
+ # Helpers
14
+ # ---------------------------------------------------------------------------
15
+
16
+ def sql_out(result):
17
+ return result[4]
18
+
19
+ def status(result):
20
+ return result[7]
21
+
22
+ def reset_model_state():
23
+ app._model = None
24
+ app._tokenizer = None
25
+ app._current_model_id = None
26
+
27
+
28
+ def check_sql(result, expected_fragments, description):
29
+ """Print and assert SQL output checks."""
30
+ sql = sql_out(result)
31
+ status_msg = status(result)
32
+ ok = True
33
+ for frag in expected_fragments:
34
+ if frag not in sql:
35
+ print(f" FAIL: missing '{frag}' in output")
36
+ ok = False
37
+ if ok:
38
+ print(f" OK: {description}")
39
+ print(f" SQL: {sql[:200]}")
40
+ return ok
41
+
42
+
43
+ # ---------------------------------------------------------------------------
44
+ # Scenario 1: Parser still works (no model call)
45
+ # ---------------------------------------------------------------------------
46
+
47
+ def test_scenario1_parser_keeps_working():
48
+ print("\n=== Scenario 1: Parser — accented columns ===")
49
+ result = app.generate_response(
50
+ "criar tabela animal com nome nome cientifico e especie",
51
+ [], "", None, None
52
+ )
53
+ fragments = ["CREATE TABLE animal", "nome TEXT", "cientifico TEXT", "especie TEXT"]
54
+ return check_sql(result, fragments, "3 columns from Portuguese message")
55
+
56
+
57
+ # ---------------------------------------------------------------------------
58
+ # Scenario 2: SELECT all
59
+ # ---------------------------------------------------------------------------
60
+
61
+ def test_scenario2_select_all():
62
+ print("\n=== Scenario 2: SELECT all rows ===")
63
+ schema = app.PRESETS["employees"]
64
+ result = app.generate_response(
65
+ "liste todos os funcionarios",
66
+ [], schema, app.FINE_TUNED_MODEL_KEY, None
67
+ )
68
+ sql = sql_out(result)
69
+ status_msg = status(result)
70
+ ok = True
71
+ if "SELECT" not in sql.upper():
72
+ print(f" FAIL: no SELECT in output")
73
+ ok = False
74
+ if "FROM" not in sql.upper():
75
+ print(f" FAIL: no FROM in output")
76
+ ok = False
77
+ if ok:
78
+ print(f" OK: generated SELECT")
79
+ print(f" SQL: {sql}")
80
+ return ok
81
+
82
+
83
+ # ---------------------------------------------------------------------------
84
+ # Scenario 3: SELECT with WHERE filter
85
+ # ---------------------------------------------------------------------------
86
+
87
+ def test_scenario3_select_with_filter():
88
+ print("\n=== Scenario 3: SELECT with WHERE ===")
89
+ schema = app.PRESETS["employees"]
90
+ result = app.generate_response(
91
+ "mostre os funcionarios do departamento de vendas",
92
+ [], schema, app.FINE_TUNED_MODEL_KEY, None
93
+ )
94
+ sql = sql_out(result)
95
+ ok = True
96
+ if "SELECT" not in sql.upper():
97
+ print(f" FAIL: no SELECT")
98
+ ok = False
99
+ if "WHERE" not in sql.upper():
100
+ print(f" FAIL: no WHERE")
101
+ ok = False
102
+ if "department" in sql.lower() or "vendas" in sql.lower():
103
+ print(f" OK: WHERE clause present")
104
+ print(f" SQL: {sql}")
105
+ else:
106
+ print(f" FAIL: filter condition missing")
107
+ ok = False
108
+ return ok
109
+
110
+
111
+ # ---------------------------------------------------------------------------
112
+ # Scenario 4: Aggregate (COUNT, AVG, GROUP BY)
113
+ # ---------------------------------------------------------------------------
114
+
115
+ def test_scenario4_aggregates():
116
+ print("\n=== Scenario 4: Aggregate query ===")
117
+ schema = app.PRESETS["employees"]
118
+ result = app.generate_response(
119
+ "qual a media de salarios por departamento",
120
+ [], schema, app.FINE_TUNED_MODEL_KEY, None
121
+ )
122
+ sql = sql_out(result)
123
+ ok = True
124
+ checks = ["SELECT", "AVG", "GROUP BY"]
125
+ for c in checks:
126
+ if c not in sql.upper():
127
+ print(f" FAIL: missing '{c}'")
128
+ ok = False
129
+ if ok:
130
+ print(f" OK: aggregate query generated")
131
+ print(f" SQL: {sql}")
132
+ return ok
133
+
134
+
135
+ # ---------------------------------------------------------------------------
136
+ # Scenario 5: Natural language SQL (Issue 3)
137
+ # ---------------------------------------------------------------------------
138
+
139
+ def test_scenario5_natural_language():
140
+ print("\n=== Scenario 5: Natural language SQL (Issue 3) ===")
141
+ schema = app.PRESETS["products"]
142
+ result = app.generate_response(
143
+ "me diz qual o produto mais caro",
144
+ [], schema, app.FINE_TUNED_MODEL_KEY, None
145
+ )
146
+ sql = sql_out(result)
147
+ status_msg = status(result)
148
+ ok = True
149
+ if not sql.strip():
150
+ print(f" FAIL: no SQL generated — model returned: {status_msg[:100]}")
151
+ ok = False
152
+ elif "SELECT" not in sql.upper():
153
+ print(f" FAIL: output is not SQL: {sql[:100]}")
154
+ ok = False
155
+ else:
156
+ print(f" OK: natural language produced SQL")
157
+ print(f" SQL: {sql}")
158
+ return ok
159
+
160
+
161
+ # ---------------------------------------------------------------------------
162
+ # Scenario 6: Multi-turn flow (create → add → remove → query)
163
+ # ---------------------------------------------------------------------------
164
+
165
+ def test_scenario6_multiturn_flow():
166
+ print("\n=== Scenario 6: Multi-turn schema build + query ===")
167
+ ok = True
168
+
169
+ # Step 1: Create table
170
+ r1 = app.generate_response(
171
+ "crie tabela vendas com id produto quantidade total",
172
+ [], "", None, None
173
+ )
174
+ if not check_sql(r1, ["CREATE TABLE vendas", "id INTEGER", "produto TEXT", "quantidade INTEGER", "total NUMERIC"], "Step 1: CREATE TABLE"):
175
+ ok = False
176
+
177
+ # Step 2: Add column
178
+ r2 = app.generate_response("adicione desconto", r1[0], "", None, None)
179
+ if not check_sql(r2, ["desconto NUMERIC", "CREATE TABLE vendas"], "Step 2: ADD COLUMN"):
180
+ ok = False
181
+
182
+ # Step 3: Remove column
183
+ r3 = app.generate_response("remova quantidade", r2[0], "", None, None)
184
+ sql3 = sql_out(r3)
185
+ # CORRECT: quantidade should NOT be in SQL (it was removed)
186
+ if "quantidade" in sql3:
187
+ print(f" FAIL: 'quantidade' still in table after remove (regression)")
188
+ ok = False
189
+ else:
190
+ print(f" OK: Step 3: REMOVE COLUMN - 'quantidade' removed")
191
+ # Verify remaining columns still exist
192
+ for col in ["id", "produto", "desconto", "total"]:
193
+ if col not in sql3:
194
+ print(f" FAIL: column '{col}' missing after remove")
195
+ ok = False
196
+
197
+ # Step 4: Query (model call)
198
+ final_schema = sql_out(r3)
199
+ r4 = app.generate_response(
200
+ "quanto vendemos no total",
201
+ r3[0], final_schema, app.FINE_TUNED_MODEL_KEY, None
202
+ )
203
+ sql4 = sql_out(r4)
204
+ if "SELECT" not in sql4.upper():
205
+ print(f" FAIL: Step 4 no SELECT generated. Status: {status(r4)[:100]}")
206
+ ok = False
207
+ else:
208
+ print(f" OK: Step 4: model generated SQL from multi-turn context")
209
+ print(f" SQL: {sql4}")
210
+
211
+ return ok
212
+
213
+
214
+ # ---------------------------------------------------------------------------
215
+ # Run all
216
+ # ---------------------------------------------------------------------------
217
+
218
+ def run_all():
219
+ if app._model is None:
220
+ print("ERROR: model not loaded. Run app.load_model(app.FINE_TUNED_MODEL_ID) first.")
221
+ return
222
+
223
+ results = {}
224
+ results["s1_parser"] = test_scenario1_parser_keeps_working()
225
+ results["s2_select_all"] = test_scenario2_select_all()
226
+ results["s3_where"] = test_scenario3_select_with_filter()
227
+ results["s4_aggregates"] = test_scenario4_aggregates()
228
+ results["s5_natlang"] = test_scenario5_natural_language()
229
+ results["s6_multiturn"] = test_scenario6_multiturn_flow()
230
+
231
+ print("\n" + "=" * 50)
232
+ print("SUMMARY")
233
+ print("=" * 50)
234
+ passed = sum(1 for v in results.values() if v)
235
+ total = len(results)
236
+ for name, result in results.items():
237
+ mark = "PASS" if result else "FAIL"
238
+ print(f" {mark} {name}")
239
+ print(f"\n Total: {passed}/{total} passed")
240
+
241
+ return passed == total
242
+
243
+
244
+ if __name__ == "__main__":
245
+ # Check model loaded
246
+ if app._model is None:
247
+ print("Model not loaded. Call app.load_model(app.FINE_TUNED_MODEL_ID) then re-run.")
248
+ print("From python: python -c \"import app; app.load_model(app.FINE_TUNED_MODEL_ID); exec(open('tests/e2e_flow_test.py').read())\"")
249
+ else:
250
+ run_all()
tests/test_chatbot_behavior.py ADDED
@@ -0,0 +1,672 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import types
2
+
3
+ import pytest
4
+
5
+ import app
6
+
7
+
8
+ # ---------------------------------------------------------------------------
9
+ # Helpers
10
+ # ---------------------------------------------------------------------------
11
+
12
+ def reset_model_state():
13
+ app._model = None
14
+ app._tokenizer = None
15
+ app._current_model_id = None
16
+
17
+
18
+ def assistant_text(result):
19
+ return result[0][-1]["content"]
20
+
21
+
22
+ def sql_output(result):
23
+ return result[4]
24
+
25
+
26
+ def status_html(result):
27
+ return result[7]
28
+
29
+
30
+ @pytest.fixture(autouse=True)
31
+ def clean_model_state():
32
+ reset_model_state()
33
+ yield
34
+ reset_model_state()
35
+
36
+
37
+ # ---------------------------------------------------------------------------
38
+ # CREATE TABLE — formas verbais PT/EN
39
+ # ---------------------------------------------------------------------------
40
+
41
+ @pytest.mark.parametrize(
42
+ ("message", "expected"),
43
+ [
44
+ (
45
+ "crie tabela pesquisadores com id nome artigo e curriculo",
46
+ ["CREATE TABLE pesquisadores", "id INTEGER", "nome TEXT", "artigo TEXT", "curriculo TEXT"],
47
+ ),
48
+ (
49
+ "cria tabela animal com nome tamanho peso especie",
50
+ ["CREATE TABLE animal", "nome TEXT", "tamanho TEXT", "peso NUMERIC", "especie TEXT"],
51
+ ),
52
+ (
53
+ "faça tabela clientes com id nome email",
54
+ ["CREATE TABLE clientes", "id INTEGER", "nome TEXT", "email TEXT"],
55
+ ),
56
+ (
57
+ "create table researchers with id, name, articles and cv",
58
+ ["CREATE TABLE researchers", "id INTEGER", "name TEXT", "articles TEXT", "cv TEXT"],
59
+ ),
60
+ (
61
+ "crie tabela alunos com id int nome text nota numeric",
62
+ ["CREATE TABLE alunos", "id INTEGER", "nome TEXT", "nota NUMERIC"],
63
+ ),
64
+ (
65
+ "crie tabela pesquisadores com id nome artigo curriculo",
66
+ ["CREATE TABLE pesquisadores", "id INTEGER", "nome TEXT", "artigo TEXT", "curriculo TEXT"],
67
+ ),
68
+ ],
69
+ )
70
+ def test_create_table_without_model(message, expected, monkeypatch):
71
+ monkeypatch.setattr(app, "_run_generation", lambda *a, **k: pytest.fail("model should not run"))
72
+
73
+ result = app.generate_response(message, [], "", None, None)
74
+
75
+ for fragment in expected:
76
+ assert fragment in sql_output(result), f"missing: {fragment!r}"
77
+ assert "validator-ok" in result[5]
78
+ assert "without calling the model" in status_html(result)
79
+
80
+
81
+ # ---------------------------------------------------------------------------
82
+ # CREATE TABLE — preset ativo como fallback de schema
83
+ # ---------------------------------------------------------------------------
84
+
85
+ def test_create_table_from_active_preset(monkeypatch):
86
+ monkeypatch.setattr(app, "_run_generation", lambda *a, **k: pytest.fail("model should not run"))
87
+
88
+ result = app.generate_response(
89
+ "gere esta tabela", [], app.PRESETS["employees"], None, None
90
+ )
91
+
92
+ assert "CREATE TABLE employees" in sql_output(result)
93
+ assert "without calling the model" in status_html(result)
94
+
95
+
96
+ # ---------------------------------------------------------------------------
97
+ # EDIT — add coluna (formas verbais e padrões)
98
+ # ---------------------------------------------------------------------------
99
+
100
+ @pytest.mark.parametrize(
101
+ ("message", "expected_col"),
102
+ [
103
+ ("adicione cpf", "cpf TEXT"),
104
+ ("add email", "email TEXT"),
105
+ ("inclua telefone", "telefone TEXT"),
106
+ ("acrescente campo bonus numeric", "bonus NUMERIC"),
107
+ ("adicione: matricula", "matricula TEXT"),
108
+ ],
109
+ )
110
+ def test_add_column_variants(message, expected_col, monkeypatch):
111
+ monkeypatch.setattr(app, "_run_generation", lambda *a, **k: pytest.fail("model should not run"))
112
+ base = app.generate_response(
113
+ "crie tabela funcionarios com id nome salario", [], "", None, None
114
+ )
115
+
116
+ result = app.generate_response(message, base[0], "", None, None)
117
+
118
+ assert expected_col in sql_output(result)
119
+ assert "CREATE TABLE funcionarios" in sql_output(result)
120
+ assert "without calling the model" in status_html(result)
121
+
122
+
123
+ # ---------------------------------------------------------------------------
124
+ # EDIT — remove coluna
125
+ # ---------------------------------------------------------------------------
126
+
127
+ @pytest.mark.parametrize(
128
+ ("message", "removed_col"),
129
+ [
130
+ ("remova salario", "salario"),
131
+ ("remove nome", "nome"),
132
+ ("delete salary", "salary"),
133
+ ("drop coluna id", "id"),
134
+ ],
135
+ )
136
+ def test_remove_column_variants(message, removed_col, monkeypatch):
137
+ monkeypatch.setattr(app, "_run_generation", lambda *a, **k: pytest.fail("model should not run"))
138
+ base = app.generate_response(
139
+ "crie tabela funcionarios com id nome salario", [], "", None, None
140
+ )
141
+
142
+ result = app.generate_response(message, base[0], "", None, None)
143
+
144
+ assert "CREATE TABLE funcionarios" in sql_output(result)
145
+ assert removed_col not in sql_output(result)
146
+ assert "validator-ok" in result[5]
147
+
148
+
149
+ # ---------------------------------------------------------------------------
150
+ # EDIT — "altere" e "mude" (regressão fix is_table_edit_intent)
151
+ # ---------------------------------------------------------------------------
152
+
153
+ @pytest.mark.parametrize(
154
+ "edit_message",
155
+ [
156
+ "altere para ter também email",
157
+ "mude adicionando telefone",
158
+ ],
159
+ )
160
+ def test_edit_intent_recognizes_pt_conjugations(edit_message, monkeypatch):
161
+ monkeypatch.setattr(app, "_run_generation", lambda *a, **k: pytest.fail("model should not run"))
162
+ base = app.generate_response(
163
+ "crie tabela x com id nome", [], "", None, None
164
+ )
165
+
166
+ result = app.generate_response(edit_message, base[0], "", None, None)
167
+
168
+ assert "CREATE TABLE x" in sql_output(result)
169
+ assert "without calling the model" in status_html(result)
170
+
171
+
172
+ # ---------------------------------------------------------------------------
173
+ # EDIT — múltiplos add/remove no mesmo turno
174
+ # ---------------------------------------------------------------------------
175
+
176
+ def test_add_multiple_columns(monkeypatch):
177
+ monkeypatch.setattr(app, "_run_generation", lambda *a, **k: pytest.fail("model should not run"))
178
+ base = app.generate_response(
179
+ "crie tabela pesquisadores com id nome artigo e curriculo", [], "", None, None
180
+ )
181
+
182
+ result = app.generate_response("add email and phone", base[0], "", None, None)
183
+
184
+ assert "email TEXT" in sql_output(result)
185
+ assert "phone TEXT" in sql_output(result)
186
+
187
+
188
+ def test_remove_column(monkeypatch):
189
+ monkeypatch.setattr(app, "_run_generation", lambda *a, **k: pytest.fail("model should not run"))
190
+ base = app.generate_response(
191
+ "crie tabela pesquisadores com id nome artigo e curriculo", [], "", None, None
192
+ )
193
+ added = app.generate_response("adicione cpf", base[0], "", None, None)
194
+
195
+ result = app.generate_response("remover curriculo", added[0], "", None, None)
196
+
197
+ assert "curriculo TEXT" not in sql_output(result)
198
+ assert "cpf TEXT" in sql_output(result)
199
+ assert "id INTEGER" in sql_output(result)
200
+ assert "validator-ok" in result[5]
201
+
202
+
203
+ # ---------------------------------------------------------------------------
204
+ # EDIT — histórico com diferentes formatos de content
205
+ # ---------------------------------------------------------------------------
206
+
207
+ @pytest.mark.parametrize(
208
+ "history",
209
+ [
210
+ [{"role": "assistant", "content": "```sql\nCREATE TABLE pesquisadores (\n id INTEGER,\n nome TEXT\n);\n```"}],
211
+ [{"role": "assistant", "content": [{"text": "```sql\nCREATE TABLE pesquisadores (\n id INTEGER,\n nome TEXT\n);\n```"}]}],
212
+ [{"role": "assistant", "content": "CREATE TABLE pesquisadores (\n id INTEGER,\n nome TEXT\n);"}],
213
+ ],
214
+ )
215
+ def test_edit_from_history_content_shapes(history, monkeypatch):
216
+ monkeypatch.setattr(app, "_run_generation", lambda *a, **k: pytest.fail("model should not run"))
217
+
218
+ result = app.generate_response(
219
+ "edita ela para ter um novo elemento: cpf", history, "", None, None
220
+ )
221
+
222
+ assert "CREATE TABLE pesquisadores" in sql_output(result)
223
+ assert "cpf TEXT" in sql_output(result)
224
+ assert "id INTEGER" in sql_output(result)
225
+ assert "validator-ok" in result[5]
226
+
227
+
228
+ # ---------------------------------------------------------------------------
229
+ # EDIT — com active_schema e histórico vazio
230
+ # ---------------------------------------------------------------------------
231
+
232
+ def test_edit_from_active_schema_no_history(monkeypatch):
233
+ monkeypatch.setattr(app, "_run_generation", lambda *a, **k: pytest.fail("model should not run"))
234
+
235
+ result = app.generate_response(
236
+ "adicione bonus", [], app.PRESETS["employees"], None, None
237
+ )
238
+
239
+ assert "CREATE TABLE employees" in sql_output(result)
240
+ assert "bonus" in sql_output(result)
241
+ assert "without calling the model" in status_html(result)
242
+
243
+
244
+ # ---------------------------------------------------------------------------
245
+ # EDIT — last_create_table_from_history retorna o mais recente
246
+ # ---------------------------------------------------------------------------
247
+
248
+ def test_last_create_table_returns_most_recent():
249
+ history = [
250
+ {"role": "assistant", "content": "```sql\nCREATE TABLE old (x TEXT);\n```"},
251
+ {"role": "user", "content": "adicione id"},
252
+ {"role": "assistant", "content": "```sql\nCREATE TABLE new (id INTEGER);\n```"},
253
+ ]
254
+ result = app.last_create_table_from_history(history)
255
+ assert "CREATE TABLE new" in result
256
+ assert "CREATE TABLE old" not in result
257
+
258
+
259
+ # ---------------------------------------------------------------------------
260
+ # FLUXO COMPLETO multi-turn: create → add → add → remove → intenção SQL
261
+ # ---------------------------------------------------------------------------
262
+
263
+ def test_full_schema_build_flow(monkeypatch):
264
+ monkeypatch.setattr(app, "_run_generation", lambda *a, **k: pytest.fail("model should not run"))
265
+
266
+ r1 = app.generate_response(
267
+ "crie tabela produtos com id nome preco", [], "", None, None
268
+ )
269
+ assert "CREATE TABLE produtos" in sql_output(r1)
270
+ assert "preco NUMERIC" in sql_output(r1)
271
+
272
+ r2 = app.generate_response("adicione categoria e estoque", r1[0], "", None, None)
273
+ assert "categoria TEXT" in sql_output(r2)
274
+ assert "estoque INTEGER" in sql_output(r2)
275
+ assert "id INTEGER" in sql_output(r2)
276
+
277
+ r3 = app.generate_response("remova preco", r2[0], "", None, None)
278
+ assert "preco" not in sql_output(r3)
279
+ assert "categoria TEXT" in sql_output(r3)
280
+
281
+ r4 = app.generate_response(
282
+ "qual o produto mais caro?", r3[0], sql_output(r3), None, None
283
+ )
284
+ assert "Load a model" in status_html(r4)
285
+
286
+
287
+ # ---------------------------------------------------------------------------
288
+ # SQL intent routing
289
+ # ---------------------------------------------------------------------------
290
+
291
+ def test_sql_prompt_uses_schema_template():
292
+ prompt = app.build_generation_prompt(
293
+ app.PRESETS["employees"],
294
+ "What is the average salary per department?",
295
+ )
296
+ assert "CREATE TABLE employees" in prompt
297
+ assert "<|user|>" in prompt
298
+ assert "<|assistant|>" in prompt
299
+
300
+
301
+ def test_sql_prompt_fallback_schema_when_empty():
302
+ prompt = app.build_generation_prompt("", "select all rows")
303
+ assert "CREATE TABLE unknown (id INTEGER)" in prompt
304
+
305
+
306
+ def test_sql_intent_detected():
307
+ assert app.is_sql_intent("What is the average salary per department?", app.PRESETS["employees"])
308
+ assert app.is_sql_intent("liste todos os funcionários", app.PRESETS["employees"])
309
+ assert app.is_sql_intent("mostre os alunos com nota maior que 8", app.PRESETS["students"])
310
+
311
+
312
+ def test_greeting_not_sql_intent():
313
+ assert not app.is_sql_intent("oi", app.PRESETS["employees"])
314
+ assert not app.is_sql_intent("hello", "")
315
+
316
+
317
+ # ---------------------------------------------------------------------------
318
+ # Output parsing — clean_generation e format_generation_result
319
+ # ---------------------------------------------------------------------------
320
+
321
+ @pytest.mark.parametrize(("raw", "expected"), [
322
+ ("```sql\nSELECT * FROM x\n```", "SELECT * FROM x"),
323
+ ("SELECT id FROM t<|end|>", "SELECT id FROM t"),
324
+ ("SQL: SELECT name FROM t", "SELECT name FROM t"),
325
+ ("```\nSELECT 1\n```", "SELECT 1"),
326
+ ])
327
+ def test_clean_generation_strips_artifacts(raw, expected):
328
+ assert app.clean_generation(raw) == expected
329
+
330
+
331
+ def test_format_generation_result_sql_path():
332
+ sql, chat, validator = app.format_generation_result("SELECT * FROM employees")
333
+ assert sql == "SELECT * FROM employees"
334
+ assert chat == ""
335
+ assert "validator-ok" in validator
336
+
337
+
338
+ def test_format_generation_result_chat_path():
339
+ sql, chat, validator = app.format_generation_result("I don't know, try again.")
340
+ assert sql == ""
341
+ assert "I don't know" in chat
342
+ assert validator == app.CHAT_VALIDATOR
343
+
344
+
345
+ # ---------------------------------------------------------------------------
346
+ # validate_sql — starters além de SELECT
347
+ # ---------------------------------------------------------------------------
348
+
349
+ @pytest.mark.parametrize("stmt", [
350
+ "SELECT * FROM employees",
351
+ "CREATE TABLE t (id INTEGER)",
352
+ "INSERT INTO t VALUES (1)",
353
+ "WITH cte AS (SELECT 1) SELECT * FROM cte",
354
+ "DROP TABLE t",
355
+ "UPDATE t SET x = 1 WHERE id = 1",
356
+ ])
357
+ def test_validate_sql_valid_starters(stmt):
358
+ assert "validator-ok" in app.validate_sql(stmt)
359
+
360
+
361
+ def test_validate_sql_garbage_returns_warn():
362
+ assert "validator-warn" in app.validate_sql("isto nao e sql %$#")
363
+
364
+
365
+ def test_validate_sql_empty_returns_empty_badge():
366
+ assert app.validate_sql("") == app.EMPTY_VALIDATOR
367
+
368
+
369
+ # ---------------------------------------------------------------------------
370
+ # Normalização de tipos explícitos no parser de colunas
371
+ # ---------------------------------------------------------------------------
372
+
373
+ @pytest.mark.parametrize(("raw", "expected_type"), [
374
+ ("price DECIMAL", "NUMERIC"),
375
+ ("active BOOL", "BOOLEAN"),
376
+ ("qty INT", "INTEGER"),
377
+ ("score REAL", "REAL"),
378
+ # P2 fix: column name matches SQL type keyword (date DATE, int INTEGER)
379
+ # Parser agora pega o último match como tipo, não o primeiro
380
+ ("date DATE", "DATE"),
381
+ ("int INTEGER", "INTEGER"),
382
+ ("name TEXT", "TEXT"),
383
+ ])
384
+ def test_parse_column_explicit_type_normalization(raw, expected_type):
385
+ parsed = app.parse_column_definition(raw)
386
+ assert parsed is not None
387
+ assert parsed[1] == expected_type
388
+ _, col_type = parsed
389
+ assert col_type == expected_type
390
+
391
+
392
+ # ---------------------------------------------------------------------------
393
+ # trim_chat_history
394
+ # ---------------------------------------------------------------------------
395
+
396
+ def test_trim_chat_history_caps_at_max_exchanges():
397
+ history = [
398
+ {"role": "user" if i % 2 == 0 else "assistant", "content": str(i)}
399
+ for i in range(30)
400
+ ]
401
+ trimmed = app.trim_chat_history(history)
402
+ assert len(trimmed) == 20
403
+
404
+
405
+ # ---------------------------------------------------------------------------
406
+ # Errors e estado do modelo
407
+ # ---------------------------------------------------------------------------
408
+
409
+ def test_empty_input_returns_error():
410
+ result = app.generate_response("", [], "", None, None)
411
+ assert result[0] == []
412
+ assert "Type a message" in status_html(result)
413
+
414
+
415
+ def test_malformed_create_table_returns_error():
416
+ result = app.generate_response("crie tabela", [], "", None, None)
417
+ assert sql_output(result) == ""
418
+ assert "CREATE TABLE needs" in status_html(result)
419
+
420
+
421
+ def test_edit_without_existing_table_returns_error():
422
+ result = app.generate_response("adicione cpf", [], "", None, None)
423
+ assert sql_output(result) == ""
424
+ assert "existing CREATE TABLE" in status_html(result)
425
+
426
+
427
+ def test_sql_intent_without_model_returns_load_error():
428
+ result = app.generate_response(
429
+ "What is the average salary?", [], app.PRESETS["employees"], None, None
430
+ )
431
+ assert "Load a model" in status_html(result)
432
+
433
+
434
+ def test_model_id_mismatch_returns_inconsistency_error():
435
+ app._model = types.SimpleNamespace(
436
+ generation_config=types.SimpleNamespace(eos_token_id=0)
437
+ )
438
+ app._tokenizer = object()
439
+ app._current_model_id = app.BASE_MODEL_ID
440
+
441
+ try:
442
+ result = app.generate_response(
443
+ "select all", [], app.PRESETS["employees"], app.FINE_TUNED_MODEL_KEY, None
444
+ )
445
+ assert "inconsistent" in status_html(result)
446
+ finally:
447
+ reset_model_state()
448
+
449
+
450
+ def test_busy_generation_lock_raises():
451
+ assert app._model_activity_lock.acquire(blocking=False)
452
+ try:
453
+ with pytest.raises(RuntimeError, match="Another model operation"):
454
+ app._run_generation(object(), {}, {})
455
+ finally:
456
+ app._model_activity_lock.release()
457
+
458
+
459
+ def test_generation_exception_is_rendered_not_raised(monkeypatch):
460
+ class DummyTokenizer:
461
+ eos_token_id = 0
462
+ pad_token_id = 0
463
+
464
+ def __call__(self, prompt, return_tensors):
465
+ return {"input_ids": types.SimpleNamespace(shape=(1, 1))}
466
+
467
+ monkeypatch.setattr(app, "import_model_runtime", lambda: (object(), None, None, None))
468
+ monkeypatch.setattr(
469
+ app, "_run_generation",
470
+ lambda *a, **k: (_ for _ in ()).throw(RuntimeError("timeout"))
471
+ )
472
+ app._model = types.SimpleNamespace(
473
+ generation_config=types.SimpleNamespace(eos_token_id=0)
474
+ )
475
+ app._tokenizer = DummyTokenizer()
476
+ app._current_model_id = app.FINE_TUNED_MODEL_ID
477
+
478
+ result = app.generate_response(
479
+ "select all rows", [], "", app.FINE_TUNED_MODEL_KEY, None
480
+ )
481
+
482
+ assert sql_output(result) == ""
483
+ assert "Generation failed: RuntimeError: timeout" in status_html(result)
484
+
485
+
486
+ # ---------------------------------------------------------------------------
487
+ # Fallback para mensagens fora de contexto SQL
488
+ # ---------------------------------------------------------------------------
489
+
490
+ def test_off_topic_message_returns_fallback(monkeypatch):
491
+ monkeypatch.setattr(app, "_run_generation", lambda *a, **k: pytest.fail("model should not run"))
492
+
493
+ result = app.generate_response("me conte uma piada", [], "", None, None)
494
+
495
+ assert sql_output(result) == ""
496
+ assert "schema" in assistant_text(result).lower() or "tabela" in assistant_text(result).lower()
497
+
498
+
499
+ def test_greeting_returns_fallback(monkeypatch):
500
+ monkeypatch.setattr(app, "_run_generation", lambda *a, **k: pytest.fail("model should not run"))
501
+
502
+ result = app.generate_response("oi", [], "", None, None)
503
+
504
+ assert sql_output(result) == ""
505
+
506
+
507
+ # ---------------------------------------------------------------------------
508
+ # Stopwords não viram colunas
509
+ # ---------------------------------------------------------------------------
510
+
511
+ def test_stopwords_not_treated_as_columns(monkeypatch):
512
+ monkeypatch.setattr(app, "_run_generation", lambda *a, **k: pytest.fail("model should not run"))
513
+ base = app.generate_response(
514
+ "crie tabela animal com nome especie", [], "", None, None
515
+ )
516
+
517
+ result = app.generate_response("add peso", base[0], "", None, None)
518
+
519
+ schema = sql_output(result)
520
+ assert "peso NUMERIC" in schema
521
+ assert " to TEXT" not in schema
522
+ assert " as TEXT" not in schema
523
+ assert " from TEXT" not in schema
524
+
525
+
526
+ # ---------------------------------------------------------------------------
527
+ # Rename de coluna
528
+ # ---------------------------------------------------------------------------
529
+
530
+ def test_rename_column_basic(monkeypatch):
531
+ monkeypatch.setattr(app, "_run_generation", lambda *a, **k: pytest.fail("model should not run"))
532
+ base = app.generate_response(
533
+ "crie tabela animal com nome cientifico especie", [], "", None, None
534
+ )
535
+
536
+ result = app.generate_response(
537
+ "rename cientifico to nome_cientifico", base[0], "", None, None
538
+ )
539
+
540
+ schema = sql_output(result)
541
+ assert "nome_cientifico TEXT" in schema
542
+ assert "\n cientifico TEXT" not in schema
543
+ assert "nome TEXT" in schema
544
+ assert "especie TEXT" in schema
545
+ assert "validator-ok" in result[5]
546
+
547
+
548
+ def test_rename_column_pt(monkeypatch):
549
+ monkeypatch.setattr(app, "_run_generation", lambda *a, **k: pytest.fail("model should not run"))
550
+ base = app.generate_response(
551
+ "crie tabela produto com id nome preco", [], "", None, None
552
+ )
553
+
554
+ result = app.generate_response(
555
+ "renomeie preco para valor", base[0], "", None, None
556
+ )
557
+
558
+ schema = sql_output(result)
559
+ assert "valor NUMERIC" in schema
560
+ assert "preco" not in schema
561
+
562
+
563
+ # ---------------------------------------------------------------------------
564
+ # Operação composta: add + rename no mesmo prompt
565
+ # ---------------------------------------------------------------------------
566
+
567
+ def test_compound_add_and_rename(monkeypatch):
568
+ monkeypatch.setattr(app, "_run_generation", lambda *a, **k: pytest.fail("model should not run"))
569
+ base = app.generate_response(
570
+ "crie tabela animal com nome cientifico especie", [], "", None, None
571
+ )
572
+
573
+ result = app.generate_response(
574
+ "add peso and rename cientifico to nome_cientifico", base[0], "", None, None
575
+ )
576
+
577
+ schema = sql_output(result)
578
+ assert "peso" in schema
579
+ assert "nome_cientifico TEXT" in schema
580
+ assert "\n cientifico TEXT" not in schema
581
+ assert "edit TEXT" not in schema
582
+ assert " to TEXT" not in schema
583
+ assert "validator-ok" in result[5]
584
+
585
+
586
+ def test_compound_add_and_remove(monkeypatch):
587
+ monkeypatch.setattr(app, "_run_generation", lambda *a, **k: pytest.fail("model should not run"))
588
+ base = app.generate_response(
589
+ "crie tabela funcionarios com id nome salario departamento", [], "", None, None
590
+ )
591
+
592
+ result = app.generate_response(
593
+ "add email and remove salario", base[0], "", None, None
594
+ )
595
+
596
+ schema = sql_output(result)
597
+ assert "email TEXT" in schema
598
+ assert "salario" not in schema
599
+ assert "id INTEGER" in schema
600
+ assert "nome TEXT" in schema
601
+
602
+
603
+ # ---------------------------------------------------------------------------
604
+ # Rename preserva tipo original da coluna
605
+ # ---------------------------------------------------------------------------
606
+
607
+ def test_rename_preserves_column_type(monkeypatch):
608
+ monkeypatch.setattr(app, "_run_generation", lambda *a, **k: pytest.fail("model should not run"))
609
+ base = app.generate_response(
610
+ "crie tabela vendas com id total date", [], "", None, None
611
+ )
612
+
613
+ result = app.generate_response(
614
+ "rename total to valor_total", base[0], "", None, None
615
+ )
616
+
617
+ schema = sql_output(result)
618
+ assert "valor_total NUMERIC" in schema
619
+ assert "\n total NUMERIC" not in schema
620
+
621
+
622
+ # ---------------------------------------------------------------------------
623
+ # Edit terms → off-topic, not SQL intent (Fix 1: off_topic_patterns blocklist)
624
+ # ---------------------------------------------------------------------------
625
+
626
+ @pytest.mark.parametrize(
627
+ ("message", "schema"),
628
+ [
629
+ ("troca tipo por medida", "CREATE TABLE comida (id INTEGER)"),
630
+ ("renomeia nome para titulo", "CREATE TABLE livro (id INTEGER, nome TEXT)"),
631
+ ("muda preco para numeric", "CREATE TABLE produto (id INTEGER, preco TEXT)"),
632
+ ("altera coluna idade para integer", "CREATE TABLE pessoa (id INTEGER, idade TEXT)"),
633
+ ],
634
+ )
635
+ def test_edit_terms_routed_to_off_topic(message, schema, monkeypatch):
636
+ monkeypatch.setattr(app, "_run_generation", lambda *a, **k: pytest.fail("model should not run"))
637
+ # Result must NOT ask to load model — edit terms are off-topic, not SQL intent
638
+ result = app.generate_response(message, [], schema, None, None)
639
+ status = status_html(result)
640
+ assert "Load a model" not in status
641
+ # Should be either edit-without-table error or safe fallback — not model path
642
+
643
+
644
+ # ---------------------------------------------------------------------------
645
+ # build_generation_prompt injects last 3 conversation exchanges (Fix 2)
646
+ # ---------------------------------------------------------------------------
647
+
648
+ def test_build_generation_prompt_injects_history():
649
+ schema = "CREATE TABLE comida (id INTEGER, nome TEXT, sabor TEXT)"
650
+ message = "liste tudo ordenado por nome"
651
+ chat_history = [
652
+ {"role": "user", "content": "crie tabela comida com nome sabor"},
653
+ {"role": "assistant", "content": "```sql\nCREATE TABLE comida (id INTEGER, nome TEXT, sabor TEXT)\n```"},
654
+ {"role": "user", "content": "adiciona coluna peso"},
655
+ {"role": "assistant", "content": "```sql\nALTER TABLE comida ADD COLUMN peso NUMERIC\n```"},
656
+ ]
657
+ prompt = app.build_generation_prompt(schema, message, chat_history)
658
+ assert "Previous conversation:" in prompt
659
+ assert "crie tabela comida" in prompt
660
+ assert "adiciona coluna peso" in prompt
661
+
662
+
663
+ def test_build_generation_prompt_no_history_no_context():
664
+ schema = "CREATE TABLE comida (id INTEGER)"
665
+ message = "liste todos"
666
+ prompt = app.build_generation_prompt(schema, message, None)
667
+ # Should not include conversation context header
668
+ assert "Previous conversation:" not in prompt
669
+ # But should still include schema and question
670
+ assert "comida" in prompt
671
+ assert "liste todos" in prompt or "liste" in prompt
672
+