refactor: model scope for sql query only
Browse files- .gitignore +8 -1
- README.md +7 -6
- app.py +30 -289
- chat_state.py +1 -73
- intent.py +9 -51
- model_io.py +0 -62
- scripts/model_probe.py +11 -19
- sql_tools.py +34 -21
- tests/e2e_flow_test.py +2 -2
- tests/test_chatbot_behavior.py +4 -3
- tests/test_chatbot_core.py +53 -66
.gitignore
CHANGED
|
@@ -37,6 +37,10 @@ Thumbs.db
|
|
| 37 |
logs/
|
| 38 |
*.log
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
# Large local model artifacts should stay on the Hub, not in this Space repo
|
| 41 |
*.safetensors
|
| 42 |
*.bin
|
|
@@ -48,7 +52,10 @@ logs/
|
|
| 48 |
# Local agent/workspace notes
|
| 49 |
/AGENTS.md
|
| 50 |
/CLAUDE.md
|
| 51 |
-
/
|
|
|
|
|
|
|
|
|
|
| 52 |
.claude
|
| 53 |
.gstack/
|
| 54 |
docs
|
|
|
|
| 37 |
logs/
|
| 38 |
*.log
|
| 39 |
|
| 40 |
+
# Graphify local artifacts
|
| 41 |
+
graphify-out/
|
| 42 |
+
.graphifyignore
|
| 43 |
+
|
| 44 |
# Large local model artifacts should stay on the Hub, not in this Space repo
|
| 45 |
*.safetensors
|
| 46 |
*.bin
|
|
|
|
| 52 |
# Local agent/workspace notes
|
| 53 |
/AGENTS.md
|
| 54 |
/CLAUDE.md
|
| 55 |
+
/HANDOFF.md
|
| 56 |
+
/FLOWS.md
|
| 57 |
+
/ERRORS.md
|
| 58 |
+
/ROADMAP.md
|
| 59 |
.claude
|
| 60 |
.gstack/
|
| 61 |
docs
|
README.md
CHANGED
|
@@ -17,16 +17,16 @@ Generates SQL queries from a table schema and a natural-language question using
|
|
| 17 |
|
| 18 |
## What the App Does
|
| 19 |
|
| 20 |
-
Transforms
|
| 21 |
|
| 22 |
## How to Use
|
| 23 |
|
| 24 |
1. Click **Load fine-tuned model**.
|
| 25 |
- Loading is lazy: the model is only downloaded and loaded when you request it.
|
| 26 |
- On CPU, the first load can take a few minutes.
|
| 27 |
-
2.
|
| 28 |
- You can use the presets: `employees`, `orders`, `students`, `products`, `sales`.
|
| 29 |
-
- You can also
|
| 30 |
3. Enter the question in the chat input.
|
| 31 |
4. Click **Send**.
|
| 32 |
5. Review the result in `gr.Code(language="sql")`.
|
|
@@ -50,8 +50,9 @@ Reported gain: **+71.5 percentage points** over the base model.
|
|
| 50 |
|
| 51 |
## Current Features
|
| 52 |
|
| 53 |
-
- Gradio UI with a step-by-step flow: load the fine-tuned model,
|
| 54 |
-
- Intent routing
|
|
|
|
| 55 |
- Lazy loading to reduce startup cost.
|
| 56 |
- Preserved Phi-3 patches for local/Spaces compatibility.
|
| 57 |
- Schema presets without blocking manual input.
|
|
@@ -66,7 +67,7 @@ The normal pytest suite does not load the 3.8B model. To manually verify the rea
|
|
| 66 |
python scripts/model_probe.py
|
| 67 |
```
|
| 68 |
|
| 69 |
-
The probe prints JSON with pass/fail checks for
|
| 70 |
|
| 71 |
## Run Locally
|
| 72 |
|
|
|
|
| 17 |
|
| 18 |
## What the App Does
|
| 19 |
|
| 20 |
+
Transforms explicit table schemas and schema-edit requests deterministically, then uses the fine-tuned Phi-3 Mini model only for SQL query generation. The base model is shown as offline evaluation evidence instead of a second live CPU-loaded model.
|
| 21 |
|
| 22 |
## How to Use
|
| 23 |
|
| 24 |
1. Click **Load fine-tuned model**.
|
| 25 |
- Loading is lazy: the model is only downloaded and loaded when you request it.
|
| 26 |
- On CPU, the first load can take a few minutes.
|
| 27 |
+
2. Select, create, or edit the **SQL table schema**.
|
| 28 |
- You can use the presets: `employees`, `orders`, `students`, `products`, `sales`.
|
| 29 |
+
- You can also ask for explicit schema operations, such as `create table products with id name price` or `add stock`.
|
| 30 |
3. Enter the question in the chat input.
|
| 31 |
4. Click **Send**.
|
| 32 |
5. Review the result in `gr.Code(language="sql")`.
|
|
|
|
| 50 |
|
| 51 |
## Current Features
|
| 52 |
|
| 53 |
+
- Gradio UI with a step-by-step flow: load the fine-tuned model, define schema context, and inspect SQL artifacts.
|
| 54 |
+
- Intent routing with 5 supported routes: `CREATE_TABLE`, `EDIT_TABLE`, `SQL_QUERY`, `SMALLTALK`, `UNKNOWN`.
|
| 55 |
+
- Model calls only for `SQL_QUERY`; smalltalk and unknown messages use a static fallback.
|
| 56 |
- Lazy loading to reduce startup cost.
|
| 57 |
- Preserved Phi-3 patches for local/Spaces compatibility.
|
| 58 |
- Schema presets without blocking manual input.
|
|
|
|
| 67 |
python scripts/model_probe.py
|
| 68 |
```
|
| 69 |
|
| 70 |
+
The probe prints JSON with pass/fail checks for static fallback, deterministic CREATE TABLE, deterministic schema edit, SQL query generation, and smalltalk while a schema is active.
|
| 71 |
|
| 72 |
## Run Locally
|
| 73 |
|
app.py
CHANGED
|
@@ -18,29 +18,18 @@ import model_io as model_core
|
|
| 18 |
import sql_tools as sql_core
|
| 19 |
|
| 20 |
|
| 21 |
-
BASE_MODEL_ID = "microsoft/Phi-3-mini-4k-instruct"
|
| 22 |
FINE_TUNED_MODEL_ID = "Shizu0n/phi3-mini-sql-generator-merged"
|
| 23 |
|
| 24 |
-
BASE_MODEL_KEY = "base"
|
| 25 |
FINE_TUNED_MODEL_KEY = "fine_tuned"
|
| 26 |
DEFAULT_MODEL_KEY = FINE_TUNED_MODEL_KEY
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
MODEL_CATALOG = {
|
| 29 |
-
BASE_MODEL_KEY: {
|
| 30 |
-
"label": "Base Phi-3 Mini",
|
| 31 |
-
"short_label": "Base",
|
| 32 |
-
"tag": "Base",
|
| 33 |
-
"title": "Phi-3 Mini base",
|
| 34 |
-
"model_id": BASE_MODEL_ID,
|
| 35 |
-
"exact_match": "2.0%",
|
| 36 |
-
"trust_remote_code": False,
|
| 37 |
-
"ready_text": "Base model ready",
|
| 38 |
-
"metadata": (
|
| 39 |
-
"Model: microsoft/Phi-3-mini-4k-instruct\n"
|
| 40 |
-
"Role: unfine-tuned baseline\n"
|
| 41 |
-
"Metric: 2.0% exact match on the comparison setup"
|
| 42 |
-
),
|
| 43 |
-
},
|
| 44 |
FINE_TUNED_MODEL_KEY: {
|
| 45 |
"label": "Fine-tuned QLoRA model",
|
| 46 |
"short_label": "Fine-tuned",
|
|
@@ -400,12 +389,7 @@ def normalize_text(value):
|
|
| 400 |
|
| 401 |
|
| 402 |
def safe_chat_fallback(_message=""):
|
| 403 |
-
return
|
| 404 |
-
"Selecione um schema e faça uma pergunta SQL, "
|
| 405 |
-
"ou peça para criar ou editar uma tabela. "
|
| 406 |
-
"Exemplo: 'crie tabela produtos com id nome preco' "
|
| 407 |
-
"ou 'qual o produto mais caro?'."
|
| 408 |
-
)
|
| 409 |
|
| 410 |
|
| 411 |
def clean_generation(text):
|
|
@@ -459,11 +443,11 @@ def render_header():
|
|
| 459 |
<section class="top-panel">
|
| 460 |
<div>
|
| 461 |
<h1>Phi-3 Mini SQL Chatbot</h1>
|
| 462 |
-
<p>
|
| 463 |
</div>
|
| 464 |
<div class="top-badges">
|
| 465 |
-
<span class="badge badge-green">
|
| 466 |
-
<span class="badge badge-cream">
|
| 467 |
<span class="badge badge-light">CPU lazy load</span>
|
| 468 |
</div>
|
| 469 |
</section>
|
|
@@ -534,10 +518,10 @@ def render_loading_overlay(model_key=None, visible=False):
|
|
| 534 |
def model_metadata(model_key=None):
|
| 535 |
return """
|
| 536 |
<section class="stats-row">
|
| 537 |
-
<div class="stat-card"><strong>
|
| 538 |
-
<div class="stat-card"><strong>
|
| 539 |
-
<div class="stat-card"><strong>
|
| 540 |
-
<div class="stat-card"><strong>
|
| 541 |
</section>
|
| 542 |
"""
|
| 543 |
|
|
@@ -935,21 +919,6 @@ def render_message(message="", kind="error"):
|
|
| 935 |
return f'<div class="message-box {class_name}">{html.escape(str(message))}</div>'
|
| 936 |
|
| 937 |
|
| 938 |
-
def select_model(model_key, loaded_key):
|
| 939 |
-
selected_key = model_key if model_key in MODEL_CATALOG else DEFAULT_MODEL_KEY
|
| 940 |
-
query_is_active = loaded_key == selected_key
|
| 941 |
-
return (
|
| 942 |
-
selected_key,
|
| 943 |
-
render_model_card(BASE_MODEL_KEY, selected_key),
|
| 944 |
-
render_model_card(FINE_TUNED_MODEL_KEY, selected_key),
|
| 945 |
-
render_status(selected_key, loaded_key),
|
| 946 |
-
model_metadata(selected_key),
|
| 947 |
-
*query_control_updates(query_is_active),
|
| 948 |
-
gr.update(interactive=False),
|
| 949 |
-
render_message(),
|
| 950 |
-
)
|
| 951 |
-
|
| 952 |
-
|
| 953 |
def load_selected_model(selected_key=FINE_TUNED_MODEL_KEY):
|
| 954 |
selected_key = FINE_TUNED_MODEL_KEY
|
| 955 |
model_def = model_by_key(selected_key)
|
|
@@ -966,7 +935,6 @@ def load_selected_model(selected_key=FINE_TUNED_MODEL_KEY):
|
|
| 966 |
*query_control_updates(False),
|
| 967 |
"",
|
| 968 |
EMPTY_VALIDATOR,
|
| 969 |
-
gr.update(value=None),
|
| 970 |
render_message(),
|
| 971 |
)
|
| 972 |
started = time.time()
|
|
@@ -996,7 +964,6 @@ def load_selected_model(selected_key=FINE_TUNED_MODEL_KEY):
|
|
| 996 |
*query_control_updates(False),
|
| 997 |
"",
|
| 998 |
EMPTY_VALIDATOR,
|
| 999 |
-
gr.update(value=None),
|
| 1000 |
render_message(error),
|
| 1001 |
)
|
| 1002 |
return
|
|
@@ -1011,7 +978,6 @@ def load_selected_model(selected_key=FINE_TUNED_MODEL_KEY):
|
|
| 1011 |
*query_control_updates(True),
|
| 1012 |
"",
|
| 1013 |
EMPTY_VALIDATOR,
|
| 1014 |
-
gr.update(value=None),
|
| 1015 |
render_message(f"Loaded {model_def['model_id']} in {elapsed}s.", kind="ok"),
|
| 1016 |
)
|
| 1017 |
|
|
@@ -1030,62 +996,6 @@ def trim_chat_history(chat_history, max_exchanges=10):
|
|
| 1030 |
return history[-max_exchanges * 2 :]
|
| 1031 |
|
| 1032 |
|
| 1033 |
-
def comparison_updates(saved_state, current_sql, loaded_key):
|
| 1034 |
-
if not saved_state or not (current_sql or "").strip():
|
| 1035 |
-
return gr.update(visible=False), "", "", "", ""
|
| 1036 |
-
|
| 1037 |
-
loaded_def = model_by_key(loaded_key) if loaded_key else model_by_key(DEFAULT_MODEL_KEY)
|
| 1038 |
-
return (
|
| 1039 |
-
gr.update(visible=True),
|
| 1040 |
-
render_compare_label("Saved", saved_state.get("model_label", "Unknown"), saved_state.get("match", "")),
|
| 1041 |
-
saved_state.get("sql", ""),
|
| 1042 |
-
render_compare_label("Current", loaded_def["short_label"], loaded_def["exact_match"]),
|
| 1043 |
-
current_sql or "",
|
| 1044 |
-
)
|
| 1045 |
-
|
| 1046 |
-
|
| 1047 |
-
def render_compare_label(prefix, model_label, metric):
|
| 1048 |
-
metric_html = f"<strong>{html.escape(metric)} match</strong>" if metric else ""
|
| 1049 |
-
return (
|
| 1050 |
-
f'<div class="compare-head"><span>{html.escape(prefix)} - '
|
| 1051 |
-
f"{html.escape(model_label)}</span>{metric_html}</div>"
|
| 1052 |
-
)
|
| 1053 |
-
|
| 1054 |
-
|
| 1055 |
-
def save_for_comparison(sql_text, loaded_key, active_schema, last_message):
|
| 1056 |
-
sql_text = (sql_text or "").strip()
|
| 1057 |
-
if not sql_text or not loaded_key:
|
| 1058 |
-
return (
|
| 1059 |
-
None,
|
| 1060 |
-
gr.update(visible=False),
|
| 1061 |
-
"",
|
| 1062 |
-
"",
|
| 1063 |
-
"",
|
| 1064 |
-
"",
|
| 1065 |
-
gr.update(interactive=False, visible=False),
|
| 1066 |
-
render_message("Generate SQL before saving a comparison."),
|
| 1067 |
-
)
|
| 1068 |
-
|
| 1069 |
-
model_def = model_by_key(loaded_key)
|
| 1070 |
-
saved = {
|
| 1071 |
-
"sql": sql_text,
|
| 1072 |
-
"model_label": model_def["short_label"],
|
| 1073 |
-
"match": model_def["exact_match"],
|
| 1074 |
-
"schema_context": active_schema or "",
|
| 1075 |
-
"user_message": last_message or "",
|
| 1076 |
-
}
|
| 1077 |
-
return (
|
| 1078 |
-
saved,
|
| 1079 |
-
gr.update(visible=True),
|
| 1080 |
-
render_compare_label("Saved", model_def["short_label"], model_def["exact_match"]),
|
| 1081 |
-
sql_text,
|
| 1082 |
-
render_compare_label("Current", model_def["short_label"], model_def["exact_match"]),
|
| 1083 |
-
sql_text,
|
| 1084 |
-
gr.update(interactive=True),
|
| 1085 |
-
render_message("Saved output for comparison.", kind="ok"),
|
| 1086 |
-
)
|
| 1087 |
-
|
| 1088 |
-
|
| 1089 |
def _append_chat_turn(chat_history, message, assistant_content):
|
| 1090 |
return trim_chat_history(
|
| 1091 |
[
|
|
@@ -1109,7 +1019,7 @@ def _response_tuple(
|
|
| 1109 |
):
|
| 1110 |
state = chat_core.ConversationState.from_value(state)
|
| 1111 |
if sql_text and "CREATE TABLE" in sql_text.upper():
|
| 1112 |
-
state = state.with_active_schema(sql_text)
|
| 1113 |
new_history = _append_chat_turn(chat_history, message, assistant_content)
|
| 1114 |
return (
|
| 1115 |
new_history,
|
|
@@ -1118,7 +1028,6 @@ def _response_tuple(
|
|
| 1118 |
message,
|
| 1119 |
sql_text,
|
| 1120 |
validator,
|
| 1121 |
-
gr.update(value=None),
|
| 1122 |
render_message(status_message, kind=status_kind),
|
| 1123 |
state.to_dict(),
|
| 1124 |
)
|
|
@@ -1129,7 +1038,6 @@ def deterministic_response(
|
|
| 1129 |
message,
|
| 1130 |
active_schema,
|
| 1131 |
loaded_key,
|
| 1132 |
-
saved_state,
|
| 1133 |
assistant_content,
|
| 1134 |
status_message,
|
| 1135 |
*,
|
|
@@ -1153,7 +1061,7 @@ def deterministic_response(
|
|
| 1153 |
|
| 1154 |
def _model_ready(loaded_key):
|
| 1155 |
if not loaded_key or _model is None or _tokenizer is None:
|
| 1156 |
-
return False, "Load the fine-tuned model before
|
| 1157 |
model_def = model_by_key(loaded_key)
|
| 1158 |
if _current_model_id != model_def["model_id"]:
|
| 1159 |
return False, "Loaded model state is inconsistent. Reload the selected model."
|
|
@@ -1195,12 +1103,6 @@ def _generate_model_text(prompt, generation_kind=model_core.SQL_GENERATION):
|
|
| 1195 |
return generated_text, int(time.time() - started)
|
| 1196 |
|
| 1197 |
|
| 1198 |
-
def _schema_suggestion_message(suggestion):
|
| 1199 |
-
columns = ", ".join(f"{name} {column_type}" for name, column_type in suggestion.columns)
|
| 1200 |
-
rationale = f"\n\n{suggestion.rationale}" if suggestion.rationale else ""
|
| 1201 |
-
return f"Posso montar a tabela `{suggestion.table_name}` com: {columns}.{rationale}\n\nSe quiser, diga `gera`."
|
| 1202 |
-
|
| 1203 |
-
|
| 1204 |
def _empty_generation_response(chat_history, message, state, status_message, *, status_kind="error"):
|
| 1205 |
return (
|
| 1206 |
chat_history,
|
|
@@ -1209,13 +1111,12 @@ def _empty_generation_response(chat_history, message, state, status_message, *,
|
|
| 1209 |
"",
|
| 1210 |
"",
|
| 1211 |
EMPTY_VALIDATOR,
|
| 1212 |
-
gr.update(value=None),
|
| 1213 |
render_message(status_message, kind=status_kind),
|
| 1214 |
state.to_dict(),
|
| 1215 |
)
|
| 1216 |
|
| 1217 |
|
| 1218 |
-
def generate_response(message, chat_history, active_schema, loaded_key,
|
| 1219 |
message = (message or "").strip()
|
| 1220 |
chat_history = list(chat_history or [])
|
| 1221 |
state = chat_core.ConversationState.from_value(conversation_state, active_schema=(active_schema or ""))
|
|
@@ -1227,7 +1128,6 @@ def generate_response(message, chat_history, active_schema, loaded_key, saved_st
|
|
| 1227 |
"",
|
| 1228 |
"",
|
| 1229 |
EMPTY_VALIDATOR,
|
| 1230 |
-
gr.update(value=None),
|
| 1231 |
render_message("Type a message before sending."),
|
| 1232 |
state.to_dict(),
|
| 1233 |
)
|
|
@@ -1241,20 +1141,6 @@ def generate_response(message, chat_history, active_schema, loaded_key, saved_st
|
|
| 1241 |
)
|
| 1242 |
|
| 1243 |
if intent_result.intent == intent_core.EDIT_TABLE:
|
| 1244 |
-
if state.pending_schema_suggestion and not state.active_schema:
|
| 1245 |
-
pending_sql = sql_core.create_table_from_suggestion(state.pending_schema_suggestion)
|
| 1246 |
-
edited_table = sql_core.edit_create_table_from_message(message, chat_history, pending_sql)
|
| 1247 |
-
table_name, columns = sql_core.parse_create_table_schema(edited_table)
|
| 1248 |
-
if edited_table and table_name and columns:
|
| 1249 |
-
suggestion = chat_core.SchemaSuggestion(table_name=table_name, columns=tuple(columns))
|
| 1250 |
-
state = state.with_pending_schema(suggestion)
|
| 1251 |
-
return _response_tuple(
|
| 1252 |
-
chat_history,
|
| 1253 |
-
message,
|
| 1254 |
-
state,
|
| 1255 |
-
_schema_suggestion_message(suggestion),
|
| 1256 |
-
"Updated pending table proposal without calling the model.",
|
| 1257 |
-
)
|
| 1258 |
edited_table = sql_core.edit_create_table_from_message(message, chat_history, state.active_schema)
|
| 1259 |
if edited_table:
|
| 1260 |
display_response = f"```sql\n{edited_table}\n```"
|
|
@@ -1274,32 +1160,8 @@ def generate_response(message, chat_history, active_schema, loaded_key, saved_st
|
|
| 1274 |
"I need an existing CREATE TABLE in the chat or an active schema before editing columns.",
|
| 1275 |
)
|
| 1276 |
|
| 1277 |
-
if intent_result.intent == intent_core.CREATE_TABLE_CONFIRM:
|
| 1278 |
-
sql_text = sql_core.create_table_from_suggestion(state.pending_schema_suggestion)
|
| 1279 |
-
if sql_text:
|
| 1280 |
-
display_response = f"```sql\n{sql_text}\n```"
|
| 1281 |
-
return _response_tuple(
|
| 1282 |
-
chat_history,
|
| 1283 |
-
message,
|
| 1284 |
-
state.clear_pending_schema(),
|
| 1285 |
-
display_response,
|
| 1286 |
-
"Generated CREATE TABLE from the pending proposal.",
|
| 1287 |
-
sql_text=sql_text,
|
| 1288 |
-
validator=sql_core.validate_sql(sql_text),
|
| 1289 |
-
)
|
| 1290 |
-
|
| 1291 |
if intent_result.intent == intent_core.CREATE_TABLE:
|
| 1292 |
sql_text = sql_core.create_table_from_message(message) or sql_core.create_table_from_schema(state.active_schema)
|
| 1293 |
-
if not sql_text:
|
| 1294 |
-
ready, _error = _model_ready(loaded_key)
|
| 1295 |
-
if ready:
|
| 1296 |
-
try:
|
| 1297 |
-
prompt = model_core.build_schema_suggestion_prompt(message, state, chat_history)
|
| 1298 |
-
generated_text, _elapsed = _generate_model_text(prompt, model_core.SCHEMA_GENERATION)
|
| 1299 |
-
suggestion = model_core.parse_schema_suggestion(generated_text)
|
| 1300 |
-
sql_text = sql_core.create_table_from_suggestion(suggestion)
|
| 1301 |
-
except Exception:
|
| 1302 |
-
sql_text = ""
|
| 1303 |
if sql_text:
|
| 1304 |
display_response = f"```sql\n{sql_text}\n```"
|
| 1305 |
return _response_tuple(
|
|
@@ -1315,72 +1177,17 @@ def generate_response(message, chat_history, active_schema, loaded_key, saved_st
|
|
| 1315 |
chat_history,
|
| 1316 |
message,
|
| 1317 |
state,
|
| 1318 |
-
"CREATE TABLE needs a table name and columns
|
| 1319 |
)
|
| 1320 |
|
| 1321 |
-
if intent_result.intent
|
| 1322 |
-
|
| 1323 |
-
|
| 1324 |
-
|
| 1325 |
-
|
| 1326 |
-
|
| 1327 |
-
|
| 1328 |
-
|
| 1329 |
-
if not suggestion:
|
| 1330 |
-
repair_prompt = (
|
| 1331 |
-
"Return valid JSON only for this SQL table proposal. "
|
| 1332 |
-
"Use table_name, columns, and rationale.\n\n"
|
| 1333 |
-
f"Previous output:\n{generated_text}"
|
| 1334 |
-
)
|
| 1335 |
-
repaired_text, elapsed = _generate_model_text(repair_prompt, model_core.SCHEMA_GENERATION)
|
| 1336 |
-
suggestion = model_core.parse_schema_suggestion(repaired_text)
|
| 1337 |
-
if not suggestion:
|
| 1338 |
-
return _response_tuple(
|
| 1339 |
-
chat_history,
|
| 1340 |
-
message,
|
| 1341 |
-
state,
|
| 1342 |
-
"Nao consegui estruturar essa proposta de tabela. Diga o nome da tabela e algumas colunas.",
|
| 1343 |
-
"Schema proposal was not valid JSON.",
|
| 1344 |
-
status_kind="error",
|
| 1345 |
-
)
|
| 1346 |
-
state = state.with_pending_schema(suggestion)
|
| 1347 |
-
return _response_tuple(
|
| 1348 |
-
chat_history,
|
| 1349 |
-
message,
|
| 1350 |
-
state,
|
| 1351 |
-
_schema_suggestion_message(suggestion),
|
| 1352 |
-
f"Generated schema proposal in {elapsed}s.",
|
| 1353 |
-
)
|
| 1354 |
-
except Exception as exc:
|
| 1355 |
-
return _empty_generation_response(
|
| 1356 |
-
chat_history,
|
| 1357 |
-
message,
|
| 1358 |
-
state,
|
| 1359 |
-
f"Generation failed: {type(exc).__name__}: {exc}",
|
| 1360 |
-
)
|
| 1361 |
-
|
| 1362 |
-
if intent_result.intent in {intent_core.SMALLTALK, intent_core.CLARIFICATION, intent_core.UNKNOWN}:
|
| 1363 |
-
ready, error = _model_ready(loaded_key)
|
| 1364 |
-
if not ready:
|
| 1365 |
-
return _response_tuple(chat_history, message, state, error, error, status_kind="error")
|
| 1366 |
-
try:
|
| 1367 |
-
prompt = model_core.build_chat_prompt(message, state, chat_history)
|
| 1368 |
-
generated_text, elapsed = _generate_model_text(prompt, model_core.CHAT_GENERATION)
|
| 1369 |
-
chat_text = model_core.clean_generation(generated_text)
|
| 1370 |
-
return _response_tuple(
|
| 1371 |
-
chat_history,
|
| 1372 |
-
message,
|
| 1373 |
-
state,
|
| 1374 |
-
chat_text,
|
| 1375 |
-
f"Generated chat response in {elapsed}s.",
|
| 1376 |
-
)
|
| 1377 |
-
except Exception as exc:
|
| 1378 |
-
return _empty_generation_response(
|
| 1379 |
-
chat_history,
|
| 1380 |
-
message,
|
| 1381 |
-
state,
|
| 1382 |
-
f"Generation failed: {type(exc).__name__}: {exc}",
|
| 1383 |
-
)
|
| 1384 |
|
| 1385 |
ready, error = _model_ready(loaded_key)
|
| 1386 |
if not ready:
|
|
@@ -1458,7 +1265,6 @@ def sync_on_load():
|
|
| 1458 |
*query_control_updates(True),
|
| 1459 |
"",
|
| 1460 |
EMPTY_VALIDATOR,
|
| 1461 |
-
gr.update(value=None),
|
| 1462 |
render_message(f"Model already loaded: {_current_model_id}", kind="ok"),
|
| 1463 |
)
|
| 1464 |
return (
|
|
@@ -1470,7 +1276,6 @@ def sync_on_load():
|
|
| 1470 |
*query_control_updates(False),
|
| 1471 |
"",
|
| 1472 |
EMPTY_VALIDATOR,
|
| 1473 |
-
gr.update(value=None),
|
| 1474 |
render_message(),
|
| 1475 |
)
|
| 1476 |
|
|
@@ -1481,7 +1286,6 @@ CSS = """
|
|
| 1481 |
/* Prevent Gradio dark theme from overriding text in light-bg components */
|
| 1482 |
[class*="badge"],
|
| 1483 |
[class*="validator-"],
|
| 1484 |
-
[class*="compare-head"],
|
| 1485 |
[class*="model-tag"],
|
| 1486 |
[class*="stat-card"] {
|
| 1487 |
color: inherit !important;
|
|
@@ -1608,7 +1412,6 @@ CSS = """
|
|
| 1608 |
}
|
| 1609 |
|
| 1610 |
.model-grid,
|
| 1611 |
-
.compare-grid,
|
| 1612 |
.stats-row {
|
| 1613 |
display: grid;
|
| 1614 |
gap: 12px;
|
|
@@ -1616,7 +1419,6 @@ CSS = """
|
|
| 1616 |
}
|
| 1617 |
|
| 1618 |
.model-grid > div,
|
| 1619 |
-
.compare-grid > div,
|
| 1620 |
.stats-row > div {
|
| 1621 |
min-width: 0;
|
| 1622 |
}
|
|
@@ -1756,8 +1558,7 @@ CSS = """
|
|
| 1756 |
}
|
| 1757 |
|
| 1758 |
#load-button,
|
| 1759 |
-
#generate-button
|
| 1760 |
-
#save-button {
|
| 1761 |
width: 100% !important;
|
| 1762 |
}
|
| 1763 |
|
|
@@ -1790,19 +1591,6 @@ CSS = """
|
|
| 1790 |
color: var(--bg-base) !important;
|
| 1791 |
}
|
| 1792 |
|
| 1793 |
-
#save-button button {
|
| 1794 |
-
background: transparent !important;
|
| 1795 |
-
border: 0.5px solid var(--border-hi) !important;
|
| 1796 |
-
color: var(--text-primary) !important;
|
| 1797 |
-
min-height: 38px !important;
|
| 1798 |
-
width: 100% !important;
|
| 1799 |
-
}
|
| 1800 |
-
|
| 1801 |
-
#save-button button:hover {
|
| 1802 |
-
border-color: var(--text-primary) !important;
|
| 1803 |
-
}
|
| 1804 |
-
|
| 1805 |
-
#save-button button:disabled,
|
| 1806 |
#generate-button button:disabled {
|
| 1807 |
opacity: 0.4 !important;
|
| 1808 |
}
|
|
@@ -2079,10 +1867,7 @@ textarea {
|
|
| 2079 |
|
| 2080 |
.output-shell .cm-editor,
|
| 2081 |
.output-shell pre,
|
| 2082 |
-
.output-shell code
|
| 2083 |
-
.compare-card .cm-editor,
|
| 2084 |
-
.compare-card pre,
|
| 2085 |
-
.compare-card code {
|
| 2086 |
border: 0 !important;
|
| 2087 |
font-size: 12px !important;
|
| 2088 |
font-weight: 400 !important;
|
|
@@ -2104,45 +1889,6 @@ textarea {
|
|
| 2104 |
color: var(--teal);
|
| 2105 |
}
|
| 2106 |
|
| 2107 |
-
.comparison-panel {
|
| 2108 |
-
margin-top: 28px;
|
| 2109 |
-
}
|
| 2110 |
-
|
| 2111 |
-
.compare-card {
|
| 2112 |
-
background: var(--bg-surface);
|
| 2113 |
-
border: 0.5px solid var(--border);
|
| 2114 |
-
border-radius: 6px;
|
| 2115 |
-
overflow: hidden;
|
| 2116 |
-
}
|
| 2117 |
-
|
| 2118 |
-
.compare-card.current {
|
| 2119 |
-
border-color: rgba(29, 158, 117, 0.45);
|
| 2120 |
-
}
|
| 2121 |
-
|
| 2122 |
-
.compare-head {
|
| 2123 |
-
align-items: center;
|
| 2124 |
-
background: var(--amber-soft);
|
| 2125 |
-
color: var(--amber-text) !important;
|
| 2126 |
-
display: flex;
|
| 2127 |
-
font-size: 11px;
|
| 2128 |
-
font-weight: 500;
|
| 2129 |
-
gap: 16px;
|
| 2130 |
-
justify-content: space-between;
|
| 2131 |
-
min-height: 34px;
|
| 2132 |
-
padding: 0 12px;
|
| 2133 |
-
}
|
| 2134 |
-
|
| 2135 |
-
.compare-card.current .compare-head,
|
| 2136 |
-
.current-compare-head .compare-head {
|
| 2137 |
-
background: var(--teal-soft);
|
| 2138 |
-
color: var(--teal-text) !important;
|
| 2139 |
-
}
|
| 2140 |
-
|
| 2141 |
-
.compare-head strong {
|
| 2142 |
-
color: inherit;
|
| 2143 |
-
font-weight: 500;
|
| 2144 |
-
}
|
| 2145 |
-
|
| 2146 |
.loading-overlay {
|
| 2147 |
align-items: center;
|
| 2148 |
background: rgba(0, 0, 0, 0.6);
|
|
@@ -2210,7 +1956,6 @@ textarea {
|
|
| 2210 |
@media (max-width: 860px) {
|
| 2211 |
.top-panel,
|
| 2212 |
.model-grid,
|
| 2213 |
-
.compare-grid,
|
| 2214 |
.evidence-grid {
|
| 2215 |
grid-template-columns: 1fr;
|
| 2216 |
}
|
|
@@ -2233,7 +1978,6 @@ with gr.Blocks(title="Phi-3 Mini SQL Chatbot") as demo:
|
|
| 2233 |
loaded_key_state = gr.State(value=None)
|
| 2234 |
active_schema = gr.State(value="")
|
| 2235 |
conversation_state = gr.State(value=chat_core.default_state())
|
| 2236 |
-
generation_meta_state = gr.State(value=None)
|
| 2237 |
last_user_message = gr.State(value="")
|
| 2238 |
|
| 2239 |
with gr.Column(elem_classes=["app-shell"]):
|
|
@@ -2335,7 +2079,6 @@ with gr.Blocks(title="Phi-3 Mini SQL Chatbot") as demo:
|
|
| 2335 |
send_button,
|
| 2336 |
sql_output,
|
| 2337 |
validator_output,
|
| 2338 |
-
generation_meta_state,
|
| 2339 |
error_output,
|
| 2340 |
],
|
| 2341 |
js=LOAD_SCROLL_JS,
|
|
@@ -2356,18 +2099,17 @@ with gr.Blocks(title="Phi-3 Mini SQL Chatbot") as demo:
|
|
| 2356 |
last_user_message,
|
| 2357 |
sql_output,
|
| 2358 |
validator_output,
|
| 2359 |
-
generation_meta_state,
|
| 2360 |
error_output,
|
| 2361 |
conversation_state,
|
| 2362 |
]
|
| 2363 |
send_button.click(
|
| 2364 |
generate_response,
|
| 2365 |
-
inputs=[message_input, chatbot, active_schema, loaded_key_state,
|
| 2366 |
outputs=chat_generation_outputs,
|
| 2367 |
)
|
| 2368 |
message_input.submit(
|
| 2369 |
generate_response,
|
| 2370 |
-
inputs=[message_input, chatbot, active_schema, loaded_key_state,
|
| 2371 |
outputs=chat_generation_outputs,
|
| 2372 |
)
|
| 2373 |
demo.load(
|
|
@@ -2388,7 +2130,6 @@ with gr.Blocks(title="Phi-3 Mini SQL Chatbot") as demo:
|
|
| 2388 |
send_button,
|
| 2389 |
sql_output,
|
| 2390 |
validator_output,
|
| 2391 |
-
generation_meta_state,
|
| 2392 |
error_output,
|
| 2393 |
],
|
| 2394 |
)
|
|
|
|
| 18 |
import sql_tools as sql_core
|
| 19 |
|
| 20 |
|
|
|
|
| 21 |
FINE_TUNED_MODEL_ID = "Shizu0n/phi3-mini-sql-generator-merged"
|
| 22 |
|
|
|
|
| 23 |
FINE_TUNED_MODEL_KEY = "fine_tuned"
|
| 24 |
DEFAULT_MODEL_KEY = FINE_TUNED_MODEL_KEY
|
| 25 |
+
FALLBACK_RESPONSE = (
|
| 26 |
+
"Select a schema and ask a SQL question, "
|
| 27 |
+
"or ask to create or edit a table. "
|
| 28 |
+
"Example: 'what is the most expensive product?' or "
|
| 29 |
+
"'create table products with id name price'."
|
| 30 |
+
)
|
| 31 |
|
| 32 |
MODEL_CATALOG = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
FINE_TUNED_MODEL_KEY: {
|
| 34 |
"label": "Fine-tuned QLoRA model",
|
| 35 |
"short_label": "Fine-tuned",
|
|
|
|
| 389 |
|
| 390 |
|
| 391 |
def safe_chat_fallback(_message=""):
|
| 392 |
+
return FALLBACK_RESPONSE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
|
| 394 |
|
| 395 |
def clean_generation(text):
|
|
|
|
| 443 |
<section class="top-panel">
|
| 444 |
<div>
|
| 445 |
<h1>Phi-3 Mini SQL Chatbot</h1>
|
| 446 |
+
<p>SQL generation demo powered by a fine-tuned Phi-3 Mini model</p>
|
| 447 |
</div>
|
| 448 |
<div class="top-badges">
|
| 449 |
+
<span class="badge badge-green">SQL_QUERY only</span>
|
| 450 |
+
<span class="badge badge-cream">Deterministic schema edits</span>
|
| 451 |
<span class="badge badge-light">CPU lazy load</span>
|
| 452 |
</div>
|
| 453 |
</section>
|
|
|
|
| 518 |
def model_metadata(model_key=None):
|
| 519 |
return """
|
| 520 |
<section class="stats-row">
|
| 521 |
+
<div class="stat-card"><strong>SQL</strong><span>model-generated SELECT queries</span></div>
|
| 522 |
+
<div class="stat-card"><strong>Create</strong><span>deterministic CREATE TABLE parser</span></div>
|
| 523 |
+
<div class="stat-card"><strong>Edit</strong><span>deterministic schema updates</span></div>
|
| 524 |
+
<div class="stat-card"><strong>Fallback</strong><span>static non-SQL response</span></div>
|
| 525 |
</section>
|
| 526 |
"""
|
| 527 |
|
|
|
|
| 919 |
return f'<div class="message-box {class_name}">{html.escape(str(message))}</div>'
|
| 920 |
|
| 921 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 922 |
def load_selected_model(selected_key=FINE_TUNED_MODEL_KEY):
|
| 923 |
selected_key = FINE_TUNED_MODEL_KEY
|
| 924 |
model_def = model_by_key(selected_key)
|
|
|
|
| 935 |
*query_control_updates(False),
|
| 936 |
"",
|
| 937 |
EMPTY_VALIDATOR,
|
|
|
|
| 938 |
render_message(),
|
| 939 |
)
|
| 940 |
started = time.time()
|
|
|
|
| 964 |
*query_control_updates(False),
|
| 965 |
"",
|
| 966 |
EMPTY_VALIDATOR,
|
|
|
|
| 967 |
render_message(error),
|
| 968 |
)
|
| 969 |
return
|
|
|
|
| 978 |
*query_control_updates(True),
|
| 979 |
"",
|
| 980 |
EMPTY_VALIDATOR,
|
|
|
|
| 981 |
render_message(f"Loaded {model_def['model_id']} in {elapsed}s.", kind="ok"),
|
| 982 |
)
|
| 983 |
|
|
|
|
| 996 |
return history[-max_exchanges * 2 :]
|
| 997 |
|
| 998 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 999 |
def _append_chat_turn(chat_history, message, assistant_content):
|
| 1000 |
return trim_chat_history(
|
| 1001 |
[
|
|
|
|
| 1019 |
):
|
| 1020 |
state = chat_core.ConversationState.from_value(state)
|
| 1021 |
if sql_text and "CREATE TABLE" in sql_text.upper():
|
| 1022 |
+
state = state.with_active_schema(sql_text)
|
| 1023 |
new_history = _append_chat_turn(chat_history, message, assistant_content)
|
| 1024 |
return (
|
| 1025 |
new_history,
|
|
|
|
| 1028 |
message,
|
| 1029 |
sql_text,
|
| 1030 |
validator,
|
|
|
|
| 1031 |
render_message(status_message, kind=status_kind),
|
| 1032 |
state.to_dict(),
|
| 1033 |
)
|
|
|
|
| 1038 |
message,
|
| 1039 |
active_schema,
|
| 1040 |
loaded_key,
|
|
|
|
| 1041 |
assistant_content,
|
| 1042 |
status_message,
|
| 1043 |
*,
|
|
|
|
| 1061 |
|
| 1062 |
def _model_ready(loaded_key):
|
| 1063 |
if not loaded_key or _model is None or _tokenizer is None:
|
| 1064 |
+
return False, "Load the fine-tuned model before generating SQL."
|
| 1065 |
model_def = model_by_key(loaded_key)
|
| 1066 |
if _current_model_id != model_def["model_id"]:
|
| 1067 |
return False, "Loaded model state is inconsistent. Reload the selected model."
|
|
|
|
| 1103 |
return generated_text, int(time.time() - started)
|
| 1104 |
|
| 1105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1106 |
def _empty_generation_response(chat_history, message, state, status_message, *, status_kind="error"):
|
| 1107 |
return (
|
| 1108 |
chat_history,
|
|
|
|
| 1111 |
"",
|
| 1112 |
"",
|
| 1113 |
EMPTY_VALIDATOR,
|
|
|
|
| 1114 |
render_message(status_message, kind=status_kind),
|
| 1115 |
state.to_dict(),
|
| 1116 |
)
|
| 1117 |
|
| 1118 |
|
| 1119 |
+
def generate_response(message, chat_history, active_schema, loaded_key, conversation_state=None):
|
| 1120 |
message = (message or "").strip()
|
| 1121 |
chat_history = list(chat_history or [])
|
| 1122 |
state = chat_core.ConversationState.from_value(conversation_state, active_schema=(active_schema or ""))
|
|
|
|
| 1128 |
"",
|
| 1129 |
"",
|
| 1130 |
EMPTY_VALIDATOR,
|
|
|
|
| 1131 |
render_message("Type a message before sending."),
|
| 1132 |
state.to_dict(),
|
| 1133 |
)
|
|
|
|
| 1141 |
)
|
| 1142 |
|
| 1143 |
if intent_result.intent == intent_core.EDIT_TABLE:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1144 |
edited_table = sql_core.edit_create_table_from_message(message, chat_history, state.active_schema)
|
| 1145 |
if edited_table:
|
| 1146 |
display_response = f"```sql\n{edited_table}\n```"
|
|
|
|
| 1160 |
"I need an existing CREATE TABLE in the chat or an active schema before editing columns.",
|
| 1161 |
)
|
| 1162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1163 |
if intent_result.intent == intent_core.CREATE_TABLE:
|
| 1164 |
sql_text = sql_core.create_table_from_message(message) or sql_core.create_table_from_schema(state.active_schema)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1165 |
if sql_text:
|
| 1166 |
display_response = f"```sql\n{sql_text}\n```"
|
| 1167 |
return _response_tuple(
|
|
|
|
| 1177 |
chat_history,
|
| 1178 |
message,
|
| 1179 |
state,
|
| 1180 |
+
"CREATE TABLE needs a table name and columns.",
|
| 1181 |
)
|
| 1182 |
|
| 1183 |
+
if intent_result.intent in {intent_core.SMALLTALK, intent_core.UNKNOWN}:
|
| 1184 |
+
return _response_tuple(
|
| 1185 |
+
chat_history,
|
| 1186 |
+
message,
|
| 1187 |
+
state,
|
| 1188 |
+
FALLBACK_RESPONSE,
|
| 1189 |
+
"Static fallback - no model call.",
|
| 1190 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1191 |
|
| 1192 |
ready, error = _model_ready(loaded_key)
|
| 1193 |
if not ready:
|
|
|
|
| 1265 |
*query_control_updates(True),
|
| 1266 |
"",
|
| 1267 |
EMPTY_VALIDATOR,
|
|
|
|
| 1268 |
render_message(f"Model already loaded: {_current_model_id}", kind="ok"),
|
| 1269 |
)
|
| 1270 |
return (
|
|
|
|
| 1276 |
*query_control_updates(False),
|
| 1277 |
"",
|
| 1278 |
EMPTY_VALIDATOR,
|
|
|
|
| 1279 |
render_message(),
|
| 1280 |
)
|
| 1281 |
|
|
|
|
| 1286 |
/* Prevent Gradio dark theme from overriding text in light-bg components */
|
| 1287 |
[class*="badge"],
|
| 1288 |
[class*="validator-"],
|
|
|
|
| 1289 |
[class*="model-tag"],
|
| 1290 |
[class*="stat-card"] {
|
| 1291 |
color: inherit !important;
|
|
|
|
| 1412 |
}
|
| 1413 |
|
| 1414 |
.model-grid,
|
|
|
|
| 1415 |
.stats-row {
|
| 1416 |
display: grid;
|
| 1417 |
gap: 12px;
|
|
|
|
| 1419 |
}
|
| 1420 |
|
| 1421 |
.model-grid > div,
|
|
|
|
| 1422 |
.stats-row > div {
|
| 1423 |
min-width: 0;
|
| 1424 |
}
|
|
|
|
| 1558 |
}
|
| 1559 |
|
| 1560 |
#load-button,
|
| 1561 |
+
#generate-button {
|
|
|
|
| 1562 |
width: 100% !important;
|
| 1563 |
}
|
| 1564 |
|
|
|
|
| 1591 |
color: var(--bg-base) !important;
|
| 1592 |
}
|
| 1593 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1594 |
#generate-button button:disabled {
|
| 1595 |
opacity: 0.4 !important;
|
| 1596 |
}
|
|
|
|
| 1867 |
|
| 1868 |
.output-shell .cm-editor,
|
| 1869 |
.output-shell pre,
|
| 1870 |
+
.output-shell code {
|
|
|
|
|
|
|
|
|
|
| 1871 |
border: 0 !important;
|
| 1872 |
font-size: 12px !important;
|
| 1873 |
font-weight: 400 !important;
|
|
|
|
| 1889 |
color: var(--teal);
|
| 1890 |
}
|
| 1891 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1892 |
.loading-overlay {
|
| 1893 |
align-items: center;
|
| 1894 |
background: rgba(0, 0, 0, 0.6);
|
|
|
|
| 1956 |
@media (max-width: 860px) {
|
| 1957 |
.top-panel,
|
| 1958 |
.model-grid,
|
|
|
|
| 1959 |
.evidence-grid {
|
| 1960 |
grid-template-columns: 1fr;
|
| 1961 |
}
|
|
|
|
| 1978 |
loaded_key_state = gr.State(value=None)
|
| 1979 |
active_schema = gr.State(value="")
|
| 1980 |
conversation_state = gr.State(value=chat_core.default_state())
|
|
|
|
| 1981 |
last_user_message = gr.State(value="")
|
| 1982 |
|
| 1983 |
with gr.Column(elem_classes=["app-shell"]):
|
|
|
|
| 2079 |
send_button,
|
| 2080 |
sql_output,
|
| 2081 |
validator_output,
|
|
|
|
| 2082 |
error_output,
|
| 2083 |
],
|
| 2084 |
js=LOAD_SCROLL_JS,
|
|
|
|
| 2099 |
last_user_message,
|
| 2100 |
sql_output,
|
| 2101 |
validator_output,
|
|
|
|
| 2102 |
error_output,
|
| 2103 |
conversation_state,
|
| 2104 |
]
|
| 2105 |
send_button.click(
|
| 2106 |
generate_response,
|
| 2107 |
+
inputs=[message_input, chatbot, active_schema, loaded_key_state, conversation_state],
|
| 2108 |
outputs=chat_generation_outputs,
|
| 2109 |
)
|
| 2110 |
message_input.submit(
|
| 2111 |
generate_response,
|
| 2112 |
+
inputs=[message_input, chatbot, active_schema, loaded_key_state, conversation_state],
|
| 2113 |
outputs=chat_generation_outputs,
|
| 2114 |
)
|
| 2115 |
demo.load(
|
|
|
|
| 2130 |
send_button,
|
| 2131 |
sql_output,
|
| 2132 |
validator_output,
|
|
|
|
| 2133 |
error_output,
|
| 2134 |
],
|
| 2135 |
)
|
chat_state.py
CHANGED
|
@@ -1,51 +1,10 @@
|
|
| 1 |
from dataclasses import dataclass, field
|
| 2 |
|
| 3 |
|
| 4 |
-
@dataclass(frozen=True)
|
| 5 |
-
class SchemaSuggestion:
|
| 6 |
-
table_name: str = ""
|
| 7 |
-
columns: tuple[tuple[str, str], ...] = ()
|
| 8 |
-
rationale: str = ""
|
| 9 |
-
|
| 10 |
-
@classmethod
|
| 11 |
-
def from_value(cls, value):
|
| 12 |
-
if isinstance(value, cls):
|
| 13 |
-
return value
|
| 14 |
-
if not isinstance(value, dict):
|
| 15 |
-
return None
|
| 16 |
-
raw_columns = value.get("columns") or ()
|
| 17 |
-
columns = []
|
| 18 |
-
for column in raw_columns:
|
| 19 |
-
if isinstance(column, dict):
|
| 20 |
-
name = str(column.get("name") or "").strip()
|
| 21 |
-
column_type = str(column.get("type") or "TEXT").strip().upper()
|
| 22 |
-
elif isinstance(column, (list, tuple)) and len(column) >= 2:
|
| 23 |
-
name = str(column[0] or "").strip()
|
| 24 |
-
column_type = str(column[1] or "TEXT").strip().upper()
|
| 25 |
-
else:
|
| 26 |
-
continue
|
| 27 |
-
if name:
|
| 28 |
-
columns.append((name, column_type or "TEXT"))
|
| 29 |
-
table_name = str(value.get("table_name") or "").strip()
|
| 30 |
-
rationale = str(value.get("rationale") or "").strip()
|
| 31 |
-
if not table_name or not columns:
|
| 32 |
-
return None
|
| 33 |
-
return cls(table_name=table_name, columns=tuple(columns), rationale=rationale)
|
| 34 |
-
|
| 35 |
-
def to_dict(self):
|
| 36 |
-
return {
|
| 37 |
-
"table_name": self.table_name,
|
| 38 |
-
"columns": [{"name": name, "type": column_type} for name, column_type in self.columns],
|
| 39 |
-
"rationale": self.rationale,
|
| 40 |
-
}
|
| 41 |
-
|
| 42 |
-
|
| 43 |
@dataclass(frozen=True)
|
| 44 |
class ConversationState:
|
| 45 |
active_schema: str = ""
|
| 46 |
-
pending_schema_suggestion: SchemaSuggestion | None = None
|
| 47 |
last_intent: str | None = None
|
| 48 |
-
last_table_topic: str | None = None
|
| 49 |
debug: dict = field(default_factory=dict)
|
| 50 |
|
| 51 |
@classmethod
|
|
@@ -56,52 +15,24 @@ class ConversationState:
|
|
| 56 |
return value
|
| 57 |
if not isinstance(value, dict):
|
| 58 |
return cls(active_schema=(active_schema or "").strip())
|
| 59 |
-
pending = SchemaSuggestion.from_value(value.get("pending_schema_suggestion"))
|
| 60 |
state_active_schema = (value.get("active_schema") or active_schema or "").strip()
|
| 61 |
return cls(
|
| 62 |
active_schema=state_active_schema,
|
| 63 |
-
pending_schema_suggestion=pending,
|
| 64 |
last_intent=value.get("last_intent"),
|
| 65 |
-
last_table_topic=value.get("last_table_topic"),
|
| 66 |
debug=dict(value.get("debug") or {}),
|
| 67 |
)
|
| 68 |
|
| 69 |
def to_dict(self):
|
| 70 |
return {
|
| 71 |
"active_schema": self.active_schema,
|
| 72 |
-
"pending_schema_suggestion": (
|
| 73 |
-
self.pending_schema_suggestion.to_dict() if self.pending_schema_suggestion else None
|
| 74 |
-
),
|
| 75 |
"last_intent": self.last_intent,
|
| 76 |
-
"last_table_topic": self.last_table_topic,
|
| 77 |
"debug": dict(self.debug or {}),
|
| 78 |
}
|
| 79 |
|
| 80 |
def with_active_schema(self, schema):
|
| 81 |
return ConversationState(
|
| 82 |
active_schema=(schema or "").strip(),
|
| 83 |
-
pending_schema_suggestion=self.pending_schema_suggestion,
|
| 84 |
last_intent=self.last_intent,
|
| 85 |
-
last_table_topic=self.last_table_topic,
|
| 86 |
-
debug=dict(self.debug or {}),
|
| 87 |
-
)
|
| 88 |
-
|
| 89 |
-
def with_pending_schema(self, suggestion):
|
| 90 |
-
suggestion = SchemaSuggestion.from_value(suggestion)
|
| 91 |
-
return ConversationState(
|
| 92 |
-
active_schema=self.active_schema,
|
| 93 |
-
pending_schema_suggestion=suggestion,
|
| 94 |
-
last_intent=self.last_intent,
|
| 95 |
-
last_table_topic=(suggestion.table_name if suggestion else self.last_table_topic),
|
| 96 |
-
debug=dict(self.debug or {}),
|
| 97 |
-
)
|
| 98 |
-
|
| 99 |
-
def clear_pending_schema(self):
|
| 100 |
-
return ConversationState(
|
| 101 |
-
active_schema=self.active_schema,
|
| 102 |
-
pending_schema_suggestion=None,
|
| 103 |
-
last_intent=self.last_intent,
|
| 104 |
-
last_table_topic=self.last_table_topic,
|
| 105 |
debug=dict(self.debug or {}),
|
| 106 |
)
|
| 107 |
|
|
@@ -112,13 +43,10 @@ class ConversationState:
|
|
| 112 |
debug["reason"] = getattr(intent_result, "reason", None)
|
| 113 |
return ConversationState(
|
| 114 |
active_schema=self.active_schema,
|
| 115 |
-
pending_schema_suggestion=self.pending_schema_suggestion,
|
| 116 |
last_intent=getattr(intent_result, "intent", None),
|
| 117 |
-
last_table_topic=self.last_table_topic,
|
| 118 |
debug=debug,
|
| 119 |
)
|
| 120 |
|
| 121 |
|
| 122 |
def default_state(active_schema=""):
|
| 123 |
-
return ConversationState(active_schema=(active_schema or "").strip()).to_dict()
|
| 124 |
-
|
|
|
|
| 1 |
from dataclasses import dataclass, field
|
| 2 |
|
| 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
@dataclass(frozen=True)
|
| 5 |
class ConversationState:
|
| 6 |
active_schema: str = ""
|
|
|
|
| 7 |
last_intent: str | None = None
|
|
|
|
| 8 |
debug: dict = field(default_factory=dict)
|
| 9 |
|
| 10 |
@classmethod
|
|
|
|
| 15 |
return value
|
| 16 |
if not isinstance(value, dict):
|
| 17 |
return cls(active_schema=(active_schema or "").strip())
|
|
|
|
| 18 |
state_active_schema = (value.get("active_schema") or active_schema or "").strip()
|
| 19 |
return cls(
|
| 20 |
active_schema=state_active_schema,
|
|
|
|
| 21 |
last_intent=value.get("last_intent"),
|
|
|
|
| 22 |
debug=dict(value.get("debug") or {}),
|
| 23 |
)
|
| 24 |
|
| 25 |
def to_dict(self):
|
| 26 |
return {
|
| 27 |
"active_schema": self.active_schema,
|
|
|
|
|
|
|
|
|
|
| 28 |
"last_intent": self.last_intent,
|
|
|
|
| 29 |
"debug": dict(self.debug or {}),
|
| 30 |
}
|
| 31 |
|
| 32 |
def with_active_schema(self, schema):
|
| 33 |
return ConversationState(
|
| 34 |
active_schema=(schema or "").strip(),
|
|
|
|
| 35 |
last_intent=self.last_intent,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
debug=dict(self.debug or {}),
|
| 37 |
)
|
| 38 |
|
|
|
|
| 43 |
debug["reason"] = getattr(intent_result, "reason", None)
|
| 44 |
return ConversationState(
|
| 45 |
active_schema=self.active_schema,
|
|
|
|
| 46 |
last_intent=getattr(intent_result, "intent", None),
|
|
|
|
| 47 |
debug=debug,
|
| 48 |
)
|
| 49 |
|
| 50 |
|
| 51 |
def default_state(active_schema=""):
|
| 52 |
+
return ConversationState(active_schema=(active_schema or "").strip()).to_dict()
|
|
|
intent.py
CHANGED
|
@@ -5,12 +5,9 @@ import sql_tools
|
|
| 5 |
|
| 6 |
|
| 7 |
SMALLTALK = "smalltalk"
|
| 8 |
-
SCHEMA_SUGGESTION = "schema_suggestion"
|
| 9 |
CREATE_TABLE = "create_table"
|
| 10 |
-
CREATE_TABLE_CONFIRM = "create_table_confirm"
|
| 11 |
EDIT_TABLE = "edit_table"
|
| 12 |
SQL_QUERY = "sql_query"
|
| 13 |
-
CLARIFICATION = "clarification"
|
| 14 |
UNKNOWN = "unknown"
|
| 15 |
|
| 16 |
|
|
@@ -21,29 +18,17 @@ class IntentResult:
|
|
| 21 |
reason: str
|
| 22 |
|
| 23 |
|
| 24 |
-
def _has_pending_schema(state):
|
| 25 |
-
return bool(getattr(state, "pending_schema_suggestion", None))
|
| 26 |
-
|
| 27 |
-
|
| 28 |
def _has_active_schema(state):
|
| 29 |
return bool((getattr(state, "active_schema", "") or "").strip())
|
| 30 |
|
| 31 |
|
| 32 |
-
def _is_confirmation(message):
|
| 33 |
-
normalized = sql_tools.normalize_text(message)
|
| 34 |
-
confirmations = {
|
| 35 |
-
"sim", "yes", "ok", "claro", "pode", "pode gerar", "gera", "gerar",
|
| 36 |
-
"gere", "faz", "faca", "cria", "crie", "manda", "confirmo", "isso",
|
| 37 |
-
"isso mesmo", "perfeito",
|
| 38 |
-
}
|
| 39 |
-
return normalized in confirmations or normalized.startswith(("gera ", "pode gerar", "faz "))
|
| 40 |
-
|
| 41 |
-
|
| 42 |
def _is_smalltalk(message):
|
| 43 |
normalized = sql_tools.normalize_text(message)
|
| 44 |
exact = {
|
| 45 |
"oi", "ola", "hi", "hello", "hey", "bom dia", "boa tarde", "boa noite",
|
| 46 |
"obrigado", "obrigada", "valeu", "thanks", "thank you",
|
|
|
|
|
|
|
| 47 |
"como voce esta", "como voce esta hoje", "qual seu nome",
|
| 48 |
"me conte uma piada", "conte uma piada", "vamos conversar",
|
| 49 |
"o que voce faz", "como voce funciona", "como funciona",
|
|
@@ -56,60 +41,33 @@ def _is_smalltalk(message):
|
|
| 56 |
"conte uma piada",
|
| 57 |
"vamos conversar",
|
| 58 |
"obrigado",
|
|
|
|
|
|
|
| 59 |
)
|
| 60 |
return any(fragment in normalized for fragment in smalltalk_fragments)
|
| 61 |
|
| 62 |
|
| 63 |
-
def _is_schema_suggestion(message):
|
| 64 |
-
normalized = sql_tools.normalize_text(message)
|
| 65 |
-
patterns = (
|
| 66 |
-
"preciso de uma tabela",
|
| 67 |
-
"preciso de um schema",
|
| 68 |
-
"quero uma tabela",
|
| 69 |
-
"quero um schema",
|
| 70 |
-
"sugira uma tabela",
|
| 71 |
-
"sugerir uma tabela",
|
| 72 |
-
"tabela sobre",
|
| 73 |
-
"tabela de",
|
| 74 |
-
"schema sobre",
|
| 75 |
-
"schema de",
|
| 76 |
-
"modelo de tabela",
|
| 77 |
-
"modelar",
|
| 78 |
-
)
|
| 79 |
-
if any(pattern in normalized for pattern in patterns) and not sql_tools.is_create_table_intent(message):
|
| 80 |
-
return True
|
| 81 |
-
if "tabela" in normalized and any(term in normalized for term in ("sobre", "para", "de")):
|
| 82 |
-
return not any(term in normalized for term in ("crie", "criar", "create", "generate", "gerar", "gere"))
|
| 83 |
-
return False
|
| 84 |
-
|
| 85 |
-
|
| 86 |
def classify_intent(message, state=None, chat_history=None):
|
| 87 |
state = ConversationState.from_value(state)
|
| 88 |
normalized = sql_tools.normalize_text(message)
|
| 89 |
if not normalized:
|
| 90 |
return IntentResult(UNKNOWN, 0.0, "empty_message")
|
| 91 |
|
| 92 |
-
if _has_pending_schema(state) and _is_confirmation(message):
|
| 93 |
-
return IntentResult(CREATE_TABLE_CONFIRM, 0.95, "confirmation_with_pending_schema")
|
| 94 |
-
|
| 95 |
if _is_smalltalk(message):
|
| 96 |
return IntentResult(SMALLTALK, 0.95, "smalltalk_phrase")
|
| 97 |
|
| 98 |
edited_table = sql_tools.edit_create_table_from_message(message, chat_history, state.active_schema)
|
| 99 |
if edited_table or sql_tools.is_table_edit_intent(message):
|
| 100 |
-
return IntentResult(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
if sql_tools.is_create_table_intent(message):
|
| 103 |
return IntentResult(CREATE_TABLE, 0.9, "explicit_create_table")
|
| 104 |
|
| 105 |
-
if _is_schema_suggestion(message):
|
| 106 |
-
return IntentResult(SCHEMA_SUGGESTION, 0.86, "schema_suggestion_phrase")
|
| 107 |
-
|
| 108 |
if sql_tools.is_sql_intent(message, state.active_schema):
|
| 109 |
return IntentResult(SQL_QUERY, 0.86, "sql_query_terms")
|
| 110 |
|
| 111 |
-
if _has_pending_schema(state):
|
| 112 |
-
return IntentResult(CLARIFICATION, 0.55, "pending_schema_context")
|
| 113 |
-
|
| 114 |
return IntentResult(UNKNOWN, 0.25, "no_intent_match")
|
| 115 |
-
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
SMALLTALK = "smalltalk"
|
|
|
|
| 8 |
CREATE_TABLE = "create_table"
|
|
|
|
| 9 |
EDIT_TABLE = "edit_table"
|
| 10 |
SQL_QUERY = "sql_query"
|
|
|
|
| 11 |
UNKNOWN = "unknown"
|
| 12 |
|
| 13 |
|
|
|
|
| 18 |
reason: str
|
| 19 |
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
def _has_active_schema(state):
|
| 22 |
return bool((getattr(state, "active_schema", "") or "").strip())
|
| 23 |
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
def _is_smalltalk(message):
|
| 26 |
normalized = sql_tools.normalize_text(message)
|
| 27 |
exact = {
|
| 28 |
"oi", "ola", "hi", "hello", "hey", "bom dia", "boa tarde", "boa noite",
|
| 29 |
"obrigado", "obrigada", "valeu", "thanks", "thank you",
|
| 30 |
+
"tudo bem", "tudo bom", "tudo", "tchau", "xau", "ate mais", "ate logo",
|
| 31 |
+
"de nada", "por nada", "imagina",
|
| 32 |
"como voce esta", "como voce esta hoje", "qual seu nome",
|
| 33 |
"me conte uma piada", "conte uma piada", "vamos conversar",
|
| 34 |
"o que voce faz", "como voce funciona", "como funciona",
|
|
|
|
| 41 |
"conte uma piada",
|
| 42 |
"vamos conversar",
|
| 43 |
"obrigado",
|
| 44 |
+
"tudo bem",
|
| 45 |
+
"tudo bom",
|
| 46 |
)
|
| 47 |
return any(fragment in normalized for fragment in smalltalk_fragments)
|
| 48 |
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
def classify_intent(message, state=None, chat_history=None):
|
| 51 |
state = ConversationState.from_value(state)
|
| 52 |
normalized = sql_tools.normalize_text(message)
|
| 53 |
if not normalized:
|
| 54 |
return IntentResult(UNKNOWN, 0.0, "empty_message")
|
| 55 |
|
|
|
|
|
|
|
|
|
|
| 56 |
if _is_smalltalk(message):
|
| 57 |
return IntentResult(SMALLTALK, 0.95, "smalltalk_phrase")
|
| 58 |
|
| 59 |
edited_table = sql_tools.edit_create_table_from_message(message, chat_history, state.active_schema)
|
| 60 |
if edited_table or sql_tools.is_table_edit_intent(message):
|
| 61 |
+
return IntentResult(
|
| 62 |
+
EDIT_TABLE,
|
| 63 |
+
0.9 if (edited_table or _has_active_schema(state)) else 0.7,
|
| 64 |
+
"table_edit_terms",
|
| 65 |
+
)
|
| 66 |
|
| 67 |
if sql_tools.is_create_table_intent(message):
|
| 68 |
return IntentResult(CREATE_TABLE, 0.9, "explicit_create_table")
|
| 69 |
|
|
|
|
|
|
|
|
|
|
| 70 |
if sql_tools.is_sql_intent(message, state.active_schema):
|
| 71 |
return IntentResult(SQL_QUERY, 0.86, "sql_query_terms")
|
| 72 |
|
|
|
|
|
|
|
|
|
|
| 73 |
return IntentResult(UNKNOWN, 0.25, "no_intent_match")
|
|
|
model_io.py
CHANGED
|
@@ -1,41 +1,12 @@
|
|
| 1 |
-
import json
|
| 2 |
-
import re
|
| 3 |
-
|
| 4 |
-
from chat_state import ConversationState, SchemaSuggestion
|
| 5 |
import sql_tools
|
| 6 |
|
| 7 |
|
| 8 |
-
CHAT_GENERATION = "chat"
|
| 9 |
-
SCHEMA_GENERATION = "schema"
|
| 10 |
SQL_GENERATION = "sql"
|
| 11 |
|
| 12 |
GENERATION_BUDGETS = {
|
| 13 |
-
CHAT_GENERATION: 120,
|
| 14 |
-
SCHEMA_GENERATION: 180,
|
| 15 |
SQL_GENERATION: 96,
|
| 16 |
}
|
| 17 |
|
| 18 |
-
CHAT_PROMPT_TEMPLATE = (
|
| 19 |
-
"<|user|>\n"
|
| 20 |
-
"You are a conversational SQL assistant. Reply naturally in Brazilian Portuguese unless the user writes in English.\n"
|
| 21 |
-
"You can chat normally, discuss table ideas, and help generate SQL, but do not generate SQL unless the user asks for it.\n"
|
| 22 |
-
"Current state:\n{state_summary}\n\n"
|
| 23 |
-
"{history_context}"
|
| 24 |
-
"User message: {message}<|end|>\n"
|
| 25 |
-
"<|assistant|>"
|
| 26 |
-
)
|
| 27 |
-
|
| 28 |
-
SCHEMA_SUGGESTION_PROMPT_TEMPLATE = (
|
| 29 |
-
"<|user|>\n"
|
| 30 |
-
"Create a practical SQL table proposal for the user's domain request.\n"
|
| 31 |
-
"Return JSON only with this shape: "
|
| 32 |
-
'{{"table_name":"name","columns":[{{"name":"id","type":"INTEGER"}}],"rationale":"short reason"}}.\n'
|
| 33 |
-
"Use simple SQL types: INTEGER, TEXT, NUMERIC, DATE, BOOLEAN.\n"
|
| 34 |
-
"{history_context}"
|
| 35 |
-
"Request: {message}<|end|>\n"
|
| 36 |
-
"<|assistant|>"
|
| 37 |
-
)
|
| 38 |
-
|
| 39 |
SQL_PROMPT_TEMPLATE = (
|
| 40 |
"<|user|>\n"
|
| 41 |
"Given the following SQL table, write one SQL query. Output SQL only.\n\n"
|
|
@@ -63,28 +34,6 @@ def _history_context(chat_history, max_exchanges=3):
|
|
| 63 |
return "Previous conversation:\n" + "\n".join(lines) + "\n\n"
|
| 64 |
|
| 65 |
|
| 66 |
-
def _state_summary(state):
|
| 67 |
-
state = ConversationState.from_value(state)
|
| 68 |
-
pending = state.pending_schema_suggestion.table_name if state.pending_schema_suggestion else "none"
|
| 69 |
-
active = "present" if state.active_schema else "none"
|
| 70 |
-
return f"- active_schema: {active}\n- pending_schema_suggestion: {pending}"
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
def build_chat_prompt(message, state=None, chat_history=None):
|
| 74 |
-
return CHAT_PROMPT_TEMPLATE.format(
|
| 75 |
-
message=(message or "").strip(),
|
| 76 |
-
state_summary=_state_summary(state),
|
| 77 |
-
history_context=_history_context(chat_history),
|
| 78 |
-
)
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
def build_schema_suggestion_prompt(message, state=None, chat_history=None):
|
| 82 |
-
return SCHEMA_SUGGESTION_PROMPT_TEMPLATE.format(
|
| 83 |
-
message=(message or "").strip(),
|
| 84 |
-
history_context=_history_context(chat_history),
|
| 85 |
-
)
|
| 86 |
-
|
| 87 |
-
|
| 88 |
def build_sql_prompt(schema, message, chat_history=None):
|
| 89 |
table_schema = (schema or "").strip() or "CREATE TABLE unknown (id INTEGER)"
|
| 90 |
return SQL_PROMPT_TEMPLATE.format(
|
|
@@ -119,14 +68,3 @@ def format_generation_result(text):
|
|
| 119 |
if is_sql_like(cleaned):
|
| 120 |
return str(cleaned), "", sql_tools.validate_sql(cleaned)
|
| 121 |
return "", str(cleaned), '<span class="validator-badge validator-empty">Chat response</span>'
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
def parse_schema_suggestion(text):
|
| 125 |
-
cleaned = clean_generation(text)
|
| 126 |
-
match = re.search(r"\{.*\}", cleaned, flags=re.DOTALL)
|
| 127 |
-
raw_json = match.group(0) if match else cleaned
|
| 128 |
-
try:
|
| 129 |
-
payload = json.loads(raw_json)
|
| 130 |
-
except json.JSONDecodeError:
|
| 131 |
-
return None
|
| 132 |
-
return SchemaSuggestion.from_value(payload)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import sql_tools
|
| 2 |
|
| 3 |
|
|
|
|
|
|
|
| 4 |
SQL_GENERATION = "sql"
|
| 5 |
|
| 6 |
GENERATION_BUDGETS = {
|
|
|
|
|
|
|
| 7 |
SQL_GENERATION: 96,
|
| 8 |
}
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
SQL_PROMPT_TEMPLATE = (
|
| 11 |
"<|user|>\n"
|
| 12 |
"Given the following SQL table, write one SQL query. Output SQL only.\n\n"
|
|
|
|
| 34 |
return "Previous conversation:\n" + "\n".join(lines) + "\n\n"
|
| 35 |
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
def build_sql_prompt(schema, message, chat_history=None):
|
| 38 |
table_schema = (schema or "").strip() or "CREATE TABLE unknown (id INTEGER)"
|
| 39 |
return SQL_PROMPT_TEMPLATE.format(
|
|
|
|
| 68 |
if is_sql_like(cleaned):
|
| 69 |
return str(cleaned), "", sql_tools.validate_sql(cleaned)
|
| 70 |
return "", str(cleaned), '<span class="validator-badge validator-empty">Chat response</span>'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/model_probe.py
CHANGED
|
@@ -20,7 +20,6 @@ def _scenario(name, message, history, active_schema, state):
|
|
| 20 |
history,
|
| 21 |
active_schema,
|
| 22 |
app.FINE_TUNED_MODEL_KEY,
|
| 23 |
-
None,
|
| 24 |
state,
|
| 25 |
)
|
| 26 |
return {
|
|
@@ -28,9 +27,9 @@ def _scenario(name, message, history, active_schema, state):
|
|
| 28 |
"message": message,
|
| 29 |
"assistant": _assistant_text(result),
|
| 30 |
"sql": result[4],
|
| 31 |
-
"status": result[
|
| 32 |
"active_schema": result[2],
|
| 33 |
-
"state": result[
|
| 34 |
"history": result[0],
|
| 35 |
}
|
| 36 |
|
|
@@ -45,19 +44,14 @@ def _grade(records):
|
|
| 45 |
by_name = {record["name"]: record for record in records}
|
| 46 |
|
| 47 |
checks.append({
|
| 48 |
-
"name": "
|
| 49 |
-
"pass":
|
| 50 |
-
"detail": "Greeting should
|
| 51 |
})
|
| 52 |
checks.append({
|
| 53 |
-
"name": "
|
| 54 |
-
"pass":
|
| 55 |
-
"detail": "
|
| 56 |
-
})
|
| 57 |
-
checks.append({
|
| 58 |
-
"name": "confirmation_generates_create_table",
|
| 59 |
-
"pass": "CREATE TABLE" in (by_name["confirm_generate"]["sql"] or "").upper(),
|
| 60 |
-
"detail": "Confirmation should generate CREATE TABLE SQL.",
|
| 61 |
})
|
| 62 |
checks.append({
|
| 63 |
"name": "edit_updates_schema",
|
|
@@ -71,8 +65,8 @@ def _grade(records):
|
|
| 71 |
})
|
| 72 |
checks.append({
|
| 73 |
"name": "smalltalk_with_schema_stays_chat",
|
| 74 |
-
"pass":
|
| 75 |
-
"detail": "Smalltalk with active schema should
|
| 76 |
})
|
| 77 |
return checks
|
| 78 |
|
|
@@ -87,8 +81,7 @@ def main():
|
|
| 87 |
|
| 88 |
for name, message in [
|
| 89 |
("greeting", "oi"),
|
| 90 |
-
("
|
| 91 |
-
("confirm_generate", "gera"),
|
| 92 |
("edit_schema", "troca capacidade por numero_animais"),
|
| 93 |
("query_schema", "liste zoologicos de Sao Paulo"),
|
| 94 |
("smalltalk_with_schema", "como voce esta hoje?"),
|
|
@@ -112,4 +105,3 @@ def main():
|
|
| 112 |
|
| 113 |
if __name__ == "__main__":
|
| 114 |
raise SystemExit(main())
|
| 115 |
-
|
|
|
|
| 20 |
history,
|
| 21 |
active_schema,
|
| 22 |
app.FINE_TUNED_MODEL_KEY,
|
|
|
|
| 23 |
state,
|
| 24 |
)
|
| 25 |
return {
|
|
|
|
| 27 |
"message": message,
|
| 28 |
"assistant": _assistant_text(result),
|
| 29 |
"sql": result[4],
|
| 30 |
+
"status": result[6],
|
| 31 |
"active_schema": result[2],
|
| 32 |
+
"state": result[7],
|
| 33 |
"history": result[0],
|
| 34 |
}
|
| 35 |
|
|
|
|
| 44 |
by_name = {record["name"]: record for record in records}
|
| 45 |
|
| 46 |
checks.append({
|
| 47 |
+
"name": "smalltalk_is_static_fallback",
|
| 48 |
+
"pass": app.FALLBACK_RESPONSE in by_name["greeting"]["assistant"] and not by_name["greeting"]["sql"],
|
| 49 |
+
"detail": "Greeting should use the static fallback and no SQL.",
|
| 50 |
})
|
| 51 |
checks.append({
|
| 52 |
+
"name": "create_table_is_deterministic",
|
| 53 |
+
"pass": "CREATE TABLE ZOOLOGICO" in (by_name["create_schema"]["sql"] or "").upper(),
|
| 54 |
+
"detail": "Explicit CREATE TABLE request should not need model generation.",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
})
|
| 56 |
checks.append({
|
| 57 |
"name": "edit_updates_schema",
|
|
|
|
| 65 |
})
|
| 66 |
checks.append({
|
| 67 |
"name": "smalltalk_with_schema_stays_chat",
|
| 68 |
+
"pass": app.FALLBACK_RESPONSE in by_name["smalltalk_with_schema"]["assistant"] and not by_name["smalltalk_with_schema"]["sql"],
|
| 69 |
+
"detail": "Smalltalk with active schema should still use static fallback.",
|
| 70 |
})
|
| 71 |
return checks
|
| 72 |
|
|
|
|
| 81 |
|
| 82 |
for name, message in [
|
| 83 |
("greeting", "oi"),
|
| 84 |
+
("create_schema", "crie tabela zoologico com id nome cidade capacidade"),
|
|
|
|
| 85 |
("edit_schema", "troca capacidade por numero_animais"),
|
| 86 |
("query_schema", "liste zoologicos de Sao Paulo"),
|
| 87 |
("smalltalk_with_schema", "como voce esta hoje?"),
|
|
|
|
| 105 |
|
| 106 |
if __name__ == "__main__":
|
| 107 |
raise SystemExit(main())
|
|
|
sql_tools.py
CHANGED
|
@@ -118,32 +118,37 @@ def validate_sql(sql_text):
|
|
| 118 |
return '<span class="validator-badge validator-ok">Valid SQL</span>'
|
| 119 |
|
| 120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
def is_create_table_intent(message):
|
| 122 |
message = (message or "").strip().lower()
|
| 123 |
return bool(
|
| 124 |
-
re.search(
|
| 125 |
and re.search(r"\b(table|schema|tabela)\b", message)
|
| 126 |
)
|
| 127 |
|
| 128 |
|
| 129 |
def is_rename_intent(message):
|
| 130 |
-
|
| 131 |
-
return bool(
|
| 132 |
-
re.search(
|
| 133 |
-
r"\b(rename|edit|change|renomeie|renomear|renomeia|altere|mude|muda|troca|trocar)\s+\w+\s+(to|para|as|como|por)\s+\w+",
|
| 134 |
-
message,
|
| 135 |
-
flags=re.IGNORECASE,
|
| 136 |
-
)
|
| 137 |
-
)
|
| 138 |
|
| 139 |
|
| 140 |
def is_table_edit_intent(message):
|
| 141 |
message = (message or "").strip().lower()
|
| 142 |
-
edit_terms = r"\b(edit|update|modify|alter|add|include|remove|delete|drop|edita|editar|altera|altere|alterar|mude|mudar|adicione|adicionar|inclua|incluir|acrescente|remova|remover|deletar|exclua|excluir|novo|nova|troca|trocar|coloque|colocar)\b"
|
| 143 |
-
direct_add_terms = r"\b(add|include|adicione|adicionar|adicionando|inclua|incluir|acrescente|coloque|colocar)\b"
|
| 144 |
-
direct_remove_terms = r"\b(remove|delete|drop|remova|remover|deletar|exclua|excluir)\b"
|
| 145 |
target_terms = r"\b(column|field|element|coluna|campo|elemento|item)\b"
|
| 146 |
-
sql_aggregation_terms = {"up", "sum", "total", "count", "average", "avg", "max", "min", "by"}
|
| 147 |
add_match = re.search(direct_add_terms, message)
|
| 148 |
if add_match:
|
| 149 |
after_add = message[add_match.start() + len(add_match.group()) :].strip()
|
|
@@ -278,8 +283,8 @@ def format_create_table(table_name, columns):
|
|
| 278 |
def create_table_from_message(message):
|
| 279 |
message = (message or "").strip()
|
| 280 |
patterns = (
|
| 281 |
-
r"\b(?:table|tabela)\s+(?:called\s+|named\s+|chamada?\s+|nomeada?\s+)?([A-Za-z_][\w]*)\s+(?:with|containing|including|com)\s+(.+)$",
|
| 282 |
-
|
| 283 |
)
|
| 284 |
for pattern in patterns:
|
| 285 |
match = re.search(pattern, message, flags=re.IGNORECASE)
|
|
@@ -361,7 +366,7 @@ def extract_added_columns(message):
|
|
| 361 |
def extract_removed_columns(message):
|
| 362 |
message = (message or "").strip()
|
| 363 |
patterns = (
|
| 364 |
-
r"\b(?:remove|delete|drop|remova|remover|deletar|exclua|excluir)\b\s+(?:a\s+|o\s+|the\s+)?(?:column|field|element|coluna|campo|elemento|item)?\s*(.+)$",
|
| 365 |
)
|
| 366 |
for pattern in patterns:
|
| 367 |
match = re.search(pattern, message, flags=re.IGNORECASE)
|
|
@@ -376,15 +381,20 @@ def extract_removed_columns(message):
|
|
| 376 |
|
| 377 |
def extract_renamed_columns(message):
|
| 378 |
pattern = (
|
| 379 |
-
r"\b(?:rename|edit|change|renomeie|renomear|renomeia|altere|mude)\s+"
|
| 380 |
r"(\w+)\s+(?:to|para|as|como|por)\s+(\w+)"
|
| 381 |
)
|
| 382 |
matches = re.findall(pattern, message or "", flags=re.IGNORECASE)
|
| 383 |
troca_matches = re.findall(r"\btroca\b\s+(\w+)\s+\bpor\b\s+(\w+)", message or "", flags=re.IGNORECASE)
|
|
|
|
|
|
|
| 384 |
return [
|
| 385 |
(normalize_identifier(old), normalize_identifier(new))
|
| 386 |
for old, new in [*matches, *troca_matches]
|
| 387 |
-
if normalize_identifier(old)
|
|
|
|
|
|
|
|
|
|
| 388 |
]
|
| 389 |
|
| 390 |
|
|
@@ -392,8 +402,12 @@ def parse_compound_edit(message):
|
|
| 392 |
segment_pattern = (
|
| 393 |
r"\s+(?:and|e)\s+"
|
| 394 |
r"(?=\b(?:add|include|remove|delete|drop|rename|edit|change|"
|
| 395 |
-
r"adicione|adicionar|inclua|acrescente|
|
| 396 |
-
r"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
)
|
| 398 |
segments = re.split(segment_pattern, message or "", flags=re.IGNORECASE)
|
| 399 |
added, removed, renamed = [], [], []
|
|
@@ -451,4 +465,3 @@ def create_table_from_suggestion(suggestion):
|
|
| 451 |
if identifier:
|
| 452 |
parsed.append((identifier, (column_type or "TEXT").upper()))
|
| 453 |
return format_create_table(normalize_identifier(table_name), parsed)
|
| 454 |
-
|
|
|
|
| 118 |
return '<span class="validator-badge validator-ok">Valid SQL</span>'
|
| 119 |
|
| 120 |
|
| 121 |
+
# Verb stems for create-table intent — keep in sync with create_table_from_message below.
|
| 122 |
+
_CREATE_VERBS = (
|
| 123 |
+
r"create|make|build|generate"
|
| 124 |
+
r"|criar|crie|cria|criando"
|
| 125 |
+
r"|gerar|gere|gera|gerando"
|
| 126 |
+
r"|faz|faca|fa\u00e7a|fazendo"
|
| 127 |
+
r"|monta|montar|monte"
|
| 128 |
+
r"|construa|construir|constroi"
|
| 129 |
+
r"|elabore|elaborar|elabora"
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
def is_create_table_intent(message):
|
| 134 |
message = (message or "").strip().lower()
|
| 135 |
return bool(
|
| 136 |
+
re.search(rf"\b({_CREATE_VERBS})\b", message)
|
| 137 |
and re.search(r"\b(table|schema|tabela)\b", message)
|
| 138 |
)
|
| 139 |
|
| 140 |
|
| 141 |
def is_rename_intent(message):
|
| 142 |
+
return bool(extract_renamed_columns(message))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
|
| 145 |
def is_table_edit_intent(message):
|
| 146 |
message = (message or "").strip().lower()
|
| 147 |
+
edit_terms = r"\b(edit|update|modify|alter|add|include|remove|delete|drop|edita|editar|altera|altere|alterar|mude|mudar|adicione|adicionar|inclua|incluir|acrescente|remova|remover|deletar|exclua|excluir|exclui|novo|nova|troca|trocar|coloque|colocar|coloca|insira|insere|bota|tira|retire|retira|apaga|apague)\b"
|
| 148 |
+
direct_add_terms = r"\b(add|include|adicione|adicionar|adicionando|inclua|incluir|acrescente|coloque|colocar|coloca|acrescenta|insere|inserir|insira|bota|botar|bote)\b"
|
| 149 |
+
direct_remove_terms = r"\b(remove|delete|drop|remova|remover|deletar|exclua|excluir|exclui|tira|tirar|tire|retira|retirar|retire|apaga|apagar|apague)\b"
|
| 150 |
target_terms = r"\b(column|field|element|coluna|campo|elemento|item)\b"
|
| 151 |
+
sql_aggregation_terms = {"up", "sum", "total", "count", "average", "avg", "max", "min", "by", "soma", "media", "contagem", "maximo", "minimo"}
|
| 152 |
add_match = re.search(direct_add_terms, message)
|
| 153 |
if add_match:
|
| 154 |
after_add = message[add_match.start() + len(add_match.group()) :].strip()
|
|
|
|
| 283 |
def create_table_from_message(message):
|
| 284 |
message = (message or "").strip()
|
| 285 |
patterns = (
|
| 286 |
+
r"\b(?:table|tabela)\s+(?:called\s+|named\s+|chamada?\s+|nomeada?\s+)?([A-Za-z_][\w]*)\s+(?:with|containing|including|com|tendo|contendo)\s+(.+)$",
|
| 287 |
+
rf"\b(?:{_CREATE_VERBS})\b.*?\b(?:table|tabela)\b\s+([A-Za-z_][\w]*)\s+(?:with|containing|including|com|tendo|contendo)\s+(.+)$",
|
| 288 |
)
|
| 289 |
for pattern in patterns:
|
| 290 |
match = re.search(pattern, message, flags=re.IGNORECASE)
|
|
|
|
| 366 |
def extract_removed_columns(message):
|
| 367 |
message = (message or "").strip()
|
| 368 |
patterns = (
|
| 369 |
+
r"\b(?:remove|delete|drop|remova|remover|deletar|exclua|excluir|exclui|tira|tirar|tire|retira|retirar|retire|apaga|apagar|apague)\b\s+(?:a\s+|o\s+|the\s+)?(?:column|field|element|coluna|campo|elemento|item)?\s*(.+)$",
|
| 370 |
)
|
| 371 |
for pattern in patterns:
|
| 372 |
match = re.search(pattern, message, flags=re.IGNORECASE)
|
|
|
|
| 381 |
|
| 382 |
def extract_renamed_columns(message):
|
| 383 |
pattern = (
|
| 384 |
+
r"\b(?:rename|edit|change|renomeie|renomear|renomeia|altere|alterar|altera|mude|mudar|muda|edita|editar)\s+"
|
| 385 |
r"(\w+)\s+(?:to|para|as|como|por)\s+(\w+)"
|
| 386 |
)
|
| 387 |
matches = re.findall(pattern, message or "", flags=re.IGNORECASE)
|
| 388 |
troca_matches = re.findall(r"\btroca\b\s+(\w+)\s+\bpor\b\s+(\w+)", message or "", flags=re.IGNORECASE)
|
| 389 |
+
invalid_old_names = {"ela", "ele", "isso", "isto", "essa", "esse", "this", "it"}
|
| 390 |
+
invalid_new_names = {"ter", "have", "having", "tambem", "tambem"}
|
| 391 |
return [
|
| 392 |
(normalize_identifier(old), normalize_identifier(new))
|
| 393 |
for old, new in [*matches, *troca_matches]
|
| 394 |
+
if normalize_identifier(old)
|
| 395 |
+
and normalize_identifier(new)
|
| 396 |
+
and normalize_identifier(old) not in invalid_old_names
|
| 397 |
+
and normalize_identifier(new) not in invalid_new_names
|
| 398 |
]
|
| 399 |
|
| 400 |
|
|
|
|
| 402 |
segment_pattern = (
|
| 403 |
r"\s+(?:and|e)\s+"
|
| 404 |
r"(?=\b(?:add|include|remove|delete|drop|rename|edit|change|"
|
| 405 |
+
r"adicione|adicionar|adicionando|inclua|incluir|acrescente|"
|
| 406 |
+
r"coloca|coloque|bota|insira|insere|"
|
| 407 |
+
r"remova|remover|deletar|exclua|excluir|exclui|"
|
| 408 |
+
r"tira|tirar|tire|retira|retire|apaga|apague|"
|
| 409 |
+
r"renomeie|renomear|renomeia|altere|alterar|altera|"
|
| 410 |
+
r"mude|mudar|muda|edita|editar|troca|trocar)\b)"
|
| 411 |
)
|
| 412 |
segments = re.split(segment_pattern, message or "", flags=re.IGNORECASE)
|
| 413 |
added, removed, renamed = [], [], []
|
|
|
|
| 465 |
if identifier:
|
| 466 |
parsed.append((identifier, (column_type or "TEXT").upper()))
|
| 467 |
return format_create_table(normalize_identifier(table_name), parsed)
|
|
|
tests/e2e_flow_test.py
CHANGED
|
@@ -17,7 +17,7 @@ def sql_out(result):
|
|
| 17 |
return result[4]
|
| 18 |
|
| 19 |
def status(result):
|
| 20 |
-
return result[
|
| 21 |
|
| 22 |
def reset_model_state():
|
| 23 |
app._model = None
|
|
@@ -247,4 +247,4 @@ if __name__ == "__main__":
|
|
| 247 |
print("Model not loaded. Call app.load_model(app.FINE_TUNED_MODEL_ID) then re-run.")
|
| 248 |
print("From python: python -c \"import app; app.load_model(app.FINE_TUNED_MODEL_ID); exec(open('tests/e2e_flow_test.py').read())\"")
|
| 249 |
else:
|
| 250 |
-
run_all()
|
|
|
|
| 17 |
return result[4]
|
| 18 |
|
| 19 |
def status(result):
|
| 20 |
+
return result[6]
|
| 21 |
|
| 22 |
def reset_model_state():
|
| 23 |
app._model = None
|
|
|
|
| 247 |
print("Model not loaded. Call app.load_model(app.FINE_TUNED_MODEL_ID) then re-run.")
|
| 248 |
print("From python: python -c \"import app; app.load_model(app.FINE_TUNED_MODEL_ID); exec(open('tests/e2e_flow_test.py').read())\"")
|
| 249 |
else:
|
| 250 |
+
run_all()
|
tests/test_chatbot_behavior.py
CHANGED
|
@@ -24,7 +24,7 @@ def sql_output(result):
|
|
| 24 |
|
| 25 |
|
| 26 |
def status_html(result):
|
| 27 |
-
return result[
|
| 28 |
|
| 29 |
|
| 30 |
@pytest.fixture(autouse=True)
|
|
@@ -436,7 +436,7 @@ def test_model_id_mismatch_returns_inconsistency_error():
|
|
| 436 |
generation_config=types.SimpleNamespace(eos_token_id=0)
|
| 437 |
)
|
| 438 |
app._tokenizer = object()
|
| 439 |
-
app._current_model_id =
|
| 440 |
|
| 441 |
try:
|
| 442 |
result = app.generate_response(
|
|
@@ -493,7 +493,7 @@ def test_off_topic_message_returns_fallback(monkeypatch):
|
|
| 493 |
result = app.generate_response("me conte uma piada", [], "", None, None)
|
| 494 |
|
| 495 |
assert sql_output(result) == ""
|
| 496 |
-
assert
|
| 497 |
|
| 498 |
|
| 499 |
def test_greeting_returns_fallback(monkeypatch):
|
|
@@ -502,6 +502,7 @@ def test_greeting_returns_fallback(monkeypatch):
|
|
| 502 |
result = app.generate_response("oi", [], "", None, None)
|
| 503 |
|
| 504 |
assert sql_output(result) == ""
|
|
|
|
| 505 |
|
| 506 |
|
| 507 |
# ---------------------------------------------------------------------------
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
def status_html(result):
|
| 27 |
+
return result[6]
|
| 28 |
|
| 29 |
|
| 30 |
@pytest.fixture(autouse=True)
|
|
|
|
| 436 |
generation_config=types.SimpleNamespace(eos_token_id=0)
|
| 437 |
)
|
| 438 |
app._tokenizer = object()
|
| 439 |
+
app._current_model_id = "microsoft/Phi-3-mini-4k-instruct"
|
| 440 |
|
| 441 |
try:
|
| 442 |
result = app.generate_response(
|
|
|
|
| 493 |
result = app.generate_response("me conte uma piada", [], "", None, None)
|
| 494 |
|
| 495 |
assert sql_output(result) == ""
|
| 496 |
+
assert app.FALLBACK_RESPONSE in assistant_text(result)
|
| 497 |
|
| 498 |
|
| 499 |
def test_greeting_returns_fallback(monkeypatch):
|
|
|
|
| 502 |
result = app.generate_response("oi", [], "", None, None)
|
| 503 |
|
| 504 |
assert sql_output(result) == ""
|
| 505 |
+
assert app.FALLBACK_RESPONSE in assistant_text(result)
|
| 506 |
|
| 507 |
|
| 508 |
# ---------------------------------------------------------------------------
|
tests/test_chatbot_core.py
CHANGED
|
@@ -1,32 +1,22 @@
|
|
| 1 |
import types
|
| 2 |
|
| 3 |
import app
|
| 4 |
-
from chat_state import ConversationState
|
| 5 |
-
from intent import
|
| 6 |
-
CREATE_TABLE_CONFIRM,
|
| 7 |
-
EDIT_TABLE,
|
| 8 |
-
SCHEMA_SUGGESTION,
|
| 9 |
-
SMALLTALK,
|
| 10 |
-
SQL_QUERY,
|
| 11 |
-
classify_intent,
|
| 12 |
-
)
|
| 13 |
|
| 14 |
|
| 15 |
def test_conversation_state_roundtrip_dict():
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
)
|
| 21 |
-
state = ConversationState(active_schema="", pending_schema_suggestion=suggestion, last_intent=SCHEMA_SUGGESTION)
|
| 22 |
|
| 23 |
restored = ConversationState.from_value(state.to_dict())
|
| 24 |
|
| 25 |
-
|
| 26 |
-
assert
|
| 27 |
-
assert
|
| 28 |
-
assert pending.columns == (("id", "INTEGER"), ("nome", "TEXT"))
|
| 29 |
-
assert restored.last_intent == SCHEMA_SUGGESTION
|
| 30 |
|
| 31 |
|
| 32 |
def test_intent_smalltalk_with_active_schema_is_not_sql():
|
|
@@ -37,69 +27,66 @@ def test_intent_smalltalk_with_active_schema_is_not_sql():
|
|
| 37 |
assert result.intent == SMALLTALK
|
| 38 |
|
| 39 |
|
| 40 |
-
def
|
| 41 |
-
|
| 42 |
|
| 43 |
-
|
| 44 |
-
pending = state.with_pending_schema(
|
| 45 |
-
SchemaSuggestion(table_name="zoologico", columns=(("id", "INTEGER"), ("nome", "TEXT")))
|
| 46 |
-
)
|
| 47 |
-
confirmation = classify_intent("gera", pending)
|
| 48 |
-
|
| 49 |
-
assert suggestion.intent == SCHEMA_SUGGESTION
|
| 50 |
-
assert confirmation.intent == CREATE_TABLE_CONFIRM
|
| 51 |
|
| 52 |
|
| 53 |
-
def
|
| 54 |
-
|
|
|
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
|
|
|
| 58 |
|
|
|
|
| 59 |
assert edit.intent == EDIT_TABLE
|
| 60 |
assert query.intent == SQL_QUERY
|
| 61 |
|
| 62 |
|
| 63 |
-
def
|
| 64 |
app._model = types.SimpleNamespace(generation_config=types.SimpleNamespace(eos_token_id=0))
|
| 65 |
app._tokenizer = types.SimpleNamespace(eos_token_id=0, pad_token_id=0)
|
| 66 |
app._current_model_id = app.FINE_TUNED_MODEL_ID
|
| 67 |
|
| 68 |
def fake_generate(prompt, generation_kind):
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
if generation_kind == app.model_core.SCHEMA_GENERATION:
|
| 72 |
-
return (
|
| 73 |
-
'{"table_name":"zoologico","columns":['
|
| 74 |
-
'{"name":"id","type":"INTEGER"},'
|
| 75 |
-
'{"name":"nome","type":"TEXT"},'
|
| 76 |
-
'{"name":"cidade","type":"TEXT"},'
|
| 77 |
-
'{"name":"capacidade","type":"INTEGER"}],'
|
| 78 |
-
'"rationale":"Tabela inicial para zoologicos."}',
|
| 79 |
-
1,
|
| 80 |
-
)
|
| 81 |
return "SELECT * FROM zoologico WHERE cidade = 'Sao Paulo';", 1
|
| 82 |
|
| 83 |
monkeypatch.setattr(app, "_generate_model_text", fake_generate)
|
| 84 |
|
| 85 |
-
r1 = app.generate_response("oi", [], "", app.FINE_TUNED_MODEL_KEY
|
| 86 |
-
assert app.EMPTY_CHAT_OUTPUT == ""
|
| 87 |
assert r1[4] == ""
|
| 88 |
-
assert
|
| 89 |
-
|
| 90 |
-
r2 = app.generate_response(
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
assert "CREATE TABLE zoologico" in
|
| 98 |
-
assert
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import types
|
| 2 |
|
| 3 |
import app
|
| 4 |
+
from chat_state import ConversationState
|
| 5 |
+
from intent import CREATE_TABLE, EDIT_TABLE, SMALLTALK, SQL_QUERY, UNKNOWN, classify_intent
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
def test_conversation_state_roundtrip_dict():
|
| 9 |
+
state = ConversationState(
|
| 10 |
+
active_schema="CREATE TABLE zoologico (id INTEGER)",
|
| 11 |
+
last_intent=SQL_QUERY,
|
| 12 |
+
debug={"intent": SQL_QUERY, "confidence": 0.86, "reason": "sql_query_terms"},
|
| 13 |
)
|
|
|
|
| 14 |
|
| 15 |
restored = ConversationState.from_value(state.to_dict())
|
| 16 |
|
| 17 |
+
assert restored.active_schema == "CREATE TABLE zoologico (id INTEGER)"
|
| 18 |
+
assert restored.last_intent == SQL_QUERY
|
| 19 |
+
assert restored.debug["reason"] == "sql_query_terms"
|
|
|
|
|
|
|
| 20 |
|
| 21 |
|
| 22 |
def test_intent_smalltalk_with_active_schema_is_not_sql():
|
|
|
|
| 27 |
assert result.intent == SMALLTALK
|
| 28 |
|
| 29 |
|
| 30 |
+
def test_schema_request_without_columns_is_unknown_not_model_schema_task():
|
| 31 |
+
result = classify_intent("preciso de uma tabela sobre zoologico", ConversationState())
|
| 32 |
|
| 33 |
+
assert result.intent == UNKNOWN
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
|
| 36 |
+
def test_intent_create_edit_and_sql_query():
|
| 37 |
+
empty_state = ConversationState()
|
| 38 |
+
schema_state = ConversationState(active_schema="CREATE TABLE zoologico (id INTEGER, cidade TEXT)")
|
| 39 |
|
| 40 |
+
create = classify_intent("crie tabela zoologico com id nome cidade", empty_state)
|
| 41 |
+
edit = classify_intent("troca cidade por municipio", schema_state)
|
| 42 |
+
query = classify_intent("liste zoologicos por municipio", schema_state)
|
| 43 |
|
| 44 |
+
assert create.intent == CREATE_TABLE
|
| 45 |
assert edit.intent == EDIT_TABLE
|
| 46 |
assert query.intent == SQL_QUERY
|
| 47 |
|
| 48 |
|
| 49 |
+
def test_zoologico_transcript_with_mocked_sql_model(monkeypatch):
|
| 50 |
app._model = types.SimpleNamespace(generation_config=types.SimpleNamespace(eos_token_id=0))
|
| 51 |
app._tokenizer = types.SimpleNamespace(eos_token_id=0, pad_token_id=0)
|
| 52 |
app._current_model_id = app.FINE_TUNED_MODEL_ID
|
| 53 |
|
| 54 |
def fake_generate(prompt, generation_kind):
|
| 55 |
+
assert generation_kind == app.model_core.SQL_GENERATION
|
| 56 |
+
assert "CREATE TABLE zoologico" in prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
return "SELECT * FROM zoologico WHERE cidade = 'Sao Paulo';", 1
|
| 58 |
|
| 59 |
monkeypatch.setattr(app, "_generate_model_text", fake_generate)
|
| 60 |
|
| 61 |
+
r1 = app.generate_response("oi", [], "", app.FINE_TUNED_MODEL_KEY)
|
|
|
|
| 62 |
assert r1[4] == ""
|
| 63 |
+
assert app.FALLBACK_RESPONSE in r1[0][-1]["content"]
|
| 64 |
+
|
| 65 |
+
r2 = app.generate_response(
|
| 66 |
+
"crie tabela zoologico com id nome cidade capacidade",
|
| 67 |
+
r1[0],
|
| 68 |
+
r1[2],
|
| 69 |
+
app.FINE_TUNED_MODEL_KEY,
|
| 70 |
+
r1[7],
|
| 71 |
+
)
|
| 72 |
+
assert "CREATE TABLE zoologico" in r2[4]
|
| 73 |
+
assert "CREATE TABLE zoologico" in r2[2]
|
| 74 |
+
|
| 75 |
+
r3 = app.generate_response(
|
| 76 |
+
"troca capacidade por numero_animais",
|
| 77 |
+
r2[0],
|
| 78 |
+
r2[2],
|
| 79 |
+
app.FINE_TUNED_MODEL_KEY,
|
| 80 |
+
r2[7],
|
| 81 |
+
)
|
| 82 |
+
assert "numero_animais TEXT" in r3[4]
|
| 83 |
+
assert "capacidade" not in r3[4]
|
| 84 |
+
|
| 85 |
+
r4 = app.generate_response(
|
| 86 |
+
"liste zoologicos de Sao Paulo",
|
| 87 |
+
r3[0],
|
| 88 |
+
r3[2],
|
| 89 |
+
app.FINE_TUNED_MODEL_KEY,
|
| 90 |
+
r3[7],
|
| 91 |
+
)
|
| 92 |
+
assert "SELECT * FROM zoologico" in r4[4]
|