Shizu0n's picture
refactor: split chat flow from SQL routing
47affa0
import json
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
import app # noqa: E402
def _assistant_text(result):
history = result[0] or []
return history[-1]["content"] if history else ""
def _scenario(name, message, history, active_schema, state):
result = app.generate_response(
message,
history,
active_schema,
app.FINE_TUNED_MODEL_KEY,
None,
state,
)
return {
"name": name,
"message": message,
"assistant": _assistant_text(result),
"sql": result[4],
"status": result[7],
"active_schema": result[2],
"state": result[8],
"history": result[0],
}
def _contains_any(text, needles):
text = (text or "").lower()
return any(needle.lower() in text for needle in needles)
def _grade(records):
checks = []
by_name = {record["name"]: record for record in records}
checks.append({
"name": "smalltalk_is_conversational",
"pass": bool(by_name["greeting"]["assistant"]) and not by_name["greeting"]["sql"],
"detail": "Greeting should produce chat text and no SQL.",
})
checks.append({
"name": "schema_suggestion_sets_pending",
"pass": bool((by_name["schema_request"]["state"] or {}).get("pending_schema_suggestion")),
"detail": "Domain table request should create a pending schema proposal.",
})
checks.append({
"name": "confirmation_generates_create_table",
"pass": "CREATE TABLE" in (by_name["confirm_generate"]["sql"] or "").upper(),
"detail": "Confirmation should generate CREATE TABLE SQL.",
})
checks.append({
"name": "edit_updates_schema",
"pass": _contains_any(by_name["edit_schema"]["sql"], ["numero_animais", "num_animais"]),
"detail": "Edit should replace capacidade with an animal-count column.",
})
checks.append({
"name": "query_generates_select",
"pass": "SELECT" in (by_name["query_schema"]["sql"] or "").upper(),
"detail": "Natural query should generate SELECT SQL.",
})
checks.append({
"name": "smalltalk_with_schema_stays_chat",
"pass": bool(by_name["smalltalk_with_schema"]["assistant"]) and not by_name["smalltalk_with_schema"]["sql"],
"detail": "Smalltalk with active schema should not become SQL.",
})
return checks
def main():
app.load_model(app.FINE_TUNED_MODEL_ID)
history = []
active_schema = ""
state = app.chat_core.default_state()
records = []
for name, message in [
("greeting", "oi"),
("schema_request", "preciso de uma tabela sobre zoologico"),
("confirm_generate", "gera"),
("edit_schema", "troca capacidade por numero_animais"),
("query_schema", "liste zoologicos de Sao Paulo"),
("smalltalk_with_schema", "como voce esta hoje?"),
]:
record = _scenario(name, message, history, active_schema, state)
records.append({key: value for key, value in record.items() if key != "history"})
history = record["history"]
active_schema = record["active_schema"]
state = record["state"]
checks = _grade(records)
report = {
"model": app.FINE_TUNED_MODEL_ID,
"passed": all(check["pass"] for check in checks),
"checks": checks,
"records": records,
}
print(json.dumps(report, ensure_ascii=False, indent=2))
return 0 if report["passed"] else 1
if __name__ == "__main__":
raise SystemExit(main())