File size: 3,634 Bytes
47affa0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import json
import sys
from pathlib import Path

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

import app  # noqa: E402


def _assistant_text(result):
    history = result[0] or []
    return history[-1]["content"] if history else ""


def _scenario(name, message, history, active_schema, state):
    result = app.generate_response(
        message,
        history,
        active_schema,
        app.FINE_TUNED_MODEL_KEY,
        None,
        state,
    )
    return {
        "name": name,
        "message": message,
        "assistant": _assistant_text(result),
        "sql": result[4],
        "status": result[7],
        "active_schema": result[2],
        "state": result[8],
        "history": result[0],
    }


def _contains_any(text, needles):
    text = (text or "").lower()
    return any(needle.lower() in text for needle in needles)


def _grade(records):
    checks = []
    by_name = {record["name"]: record for record in records}

    checks.append({
        "name": "smalltalk_is_conversational",
        "pass": bool(by_name["greeting"]["assistant"]) and not by_name["greeting"]["sql"],
        "detail": "Greeting should produce chat text and no SQL.",
    })
    checks.append({
        "name": "schema_suggestion_sets_pending",
        "pass": bool((by_name["schema_request"]["state"] or {}).get("pending_schema_suggestion")),
        "detail": "Domain table request should create a pending schema proposal.",
    })
    checks.append({
        "name": "confirmation_generates_create_table",
        "pass": "CREATE TABLE" in (by_name["confirm_generate"]["sql"] or "").upper(),
        "detail": "Confirmation should generate CREATE TABLE SQL.",
    })
    checks.append({
        "name": "edit_updates_schema",
        "pass": _contains_any(by_name["edit_schema"]["sql"], ["numero_animais", "num_animais"]),
        "detail": "Edit should replace capacidade with an animal-count column.",
    })
    checks.append({
        "name": "query_generates_select",
        "pass": "SELECT" in (by_name["query_schema"]["sql"] or "").upper(),
        "detail": "Natural query should generate SELECT SQL.",
    })
    checks.append({
        "name": "smalltalk_with_schema_stays_chat",
        "pass": bool(by_name["smalltalk_with_schema"]["assistant"]) and not by_name["smalltalk_with_schema"]["sql"],
        "detail": "Smalltalk with active schema should not become SQL.",
    })
    return checks


def main():
    app.load_model(app.FINE_TUNED_MODEL_ID)

    history = []
    active_schema = ""
    state = app.chat_core.default_state()
    records = []

    for name, message in [
        ("greeting", "oi"),
        ("schema_request", "preciso de uma tabela sobre zoologico"),
        ("confirm_generate", "gera"),
        ("edit_schema", "troca capacidade por numero_animais"),
        ("query_schema", "liste zoologicos de Sao Paulo"),
        ("smalltalk_with_schema", "como voce esta hoje?"),
    ]:
        record = _scenario(name, message, history, active_schema, state)
        records.append({key: value for key, value in record.items() if key != "history"})
        history = record["history"]
        active_schema = record["active_schema"]
        state = record["state"]

    checks = _grade(records)
    report = {
        "model": app.FINE_TUNED_MODEL_ID,
        "passed": all(check["pass"] for check in checks),
        "checks": checks,
        "records": records,
    }
    print(json.dumps(report, ensure_ascii=False, indent=2))
    return 0 if report["passed"] else 1


if __name__ == "__main__":
    raise SystemExit(main())