File size: 5,527 Bytes
47affa0
 
 
ad5be9b
 
47affa0
 
 
ad5be9b
 
 
 
47affa0
 
 
 
ad5be9b
 
 
47affa0
 
 
 
 
 
 
 
 
 
ad5be9b
 
47affa0
ad5be9b
47affa0
 
ad5be9b
 
 
47affa0
ad5be9b
 
 
47affa0
ad5be9b
47affa0
 
 
 
92923b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad5be9b
47affa0
 
 
 
 
ad5be9b
 
737eaac
47affa0
 
 
ad5be9b
47affa0
ad5be9b
 
 
d88f966
ad5be9b
 
 
 
 
 
 
 
 
d88f966
ad5be9b
 
 
 
 
d88f966
ad5be9b
 
 
d88f966
ad5be9b
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import types

import app
from chat_state import ConversationState
from intent import CREATE_TABLE, EDIT_TABLE, SMALLTALK, SQL_QUERY, UNKNOWN, classify_intent


def test_conversation_state_roundtrip_dict():
    state = ConversationState(
        active_schema="CREATE TABLE zoologico (id INTEGER)",
        last_intent=SQL_QUERY,
        debug={"intent": SQL_QUERY, "confidence": 0.86, "reason": "sql_query_terms"},
    )

    restored = ConversationState.from_value(state.to_dict())

    assert restored.active_schema == "CREATE TABLE zoologico (id INTEGER)"
    assert restored.last_intent == SQL_QUERY
    assert restored.debug["reason"] == "sql_query_terms"


def test_intent_smalltalk_with_active_schema_is_not_sql():
    state = ConversationState(active_schema="CREATE TABLE employees (id INTEGER)")

    result = classify_intent("como voce esta hoje?", state)

    assert result.intent == SMALLTALK


def test_schema_request_without_columns_is_unknown_not_model_schema_task():
    result = classify_intent("preciso de uma tabela sobre zoologico", ConversationState())

    assert result.intent == UNKNOWN


def test_intent_create_edit_and_sql_query():
    empty_state = ConversationState()
    schema_state = ConversationState(active_schema="CREATE TABLE zoologico (id INTEGER, cidade TEXT)")

    create = classify_intent("crie tabela zoologico com id nome cidade", empty_state)
    edit = classify_intent("troca cidade por municipio", schema_state)
    query = classify_intent("liste zoologicos por municipio", schema_state)

    assert create.intent == CREATE_TABLE
    assert edit.intent == EDIT_TABLE
    assert query.intent == SQL_QUERY


def test_destructive_data_mutation_is_unknown_with_active_schema():
    state = ConversationState(active_schema="CREATE TABLE animals (id INTEGER, name TEXT, species TEXT)")

    assert classify_intent("delete all animals", state).intent == UNKNOWN
    assert classify_intent("DELETE FROM animals", state).intent == UNKNOWN
    assert classify_intent("UPDATE animals SET name = 'x'", state).intent == UNKNOWN
    assert classify_intent("INSERT INTO animals VALUES (1)", state).intent == UNKNOWN
    assert classify_intent("insert animal", state).intent == UNKNOWN
    assert classify_intent("insert row into animals", state).intent == UNKNOWN
    assert classify_intent("add row to animals", state).intent == UNKNOWN
    assert classify_intent("add animal record", state).intent == UNKNOWN
    assert classify_intent("drop animals table", state).intent == UNKNOWN
    assert classify_intent("drop the animals table", state).intent == UNKNOWN
    assert classify_intent("delete animals", state).intent == UNKNOWN
    assert classify_intent("drop animals", state).intent == UNKNOWN

    singular_state = ConversationState(active_schema="CREATE TABLE animal (id INTEGER, species TEXT)")
    assert classify_intent("delete animals", singular_state).intent == UNKNOWN
    assert classify_intent("drop animals", singular_state).intent == UNKNOWN


def test_data_mutation_uses_schema_from_history_when_active_schema_empty():
    history = [
        {
            "role": "assistant",
            "content": "```sql\nCREATE TABLE animals (id INTEGER, species TEXT);\n```",
        }
    ]

    assert classify_intent("delete animals", ConversationState(), history).intent == UNKNOWN
    assert classify_intent("drop animals", ConversationState(), history).intent == UNKNOWN


def test_delete_existing_column_stays_schema_edit():
    state = ConversationState(active_schema="CREATE TABLE animals (id INTEGER, species TEXT)")

    assert classify_intent("delete species", state).intent == EDIT_TABLE
    assert classify_intent("drop coluna id", state).intent == EDIT_TABLE
    assert classify_intent("add habitat", state).intent == EDIT_TABLE


def test_create_tables_named_rows_or_records_are_not_data_mutation():
    assert classify_intent("create table rows with id name", ConversationState()).intent == CREATE_TABLE
    assert classify_intent("create table records with id name", ConversationState()).intent == CREATE_TABLE


def test_zoologico_transcript_with_mocked_sql_model(monkeypatch):
    app._model = types.SimpleNamespace(generation_config=types.SimpleNamespace(eos_token_id=0))
    app._tokenizer = types.SimpleNamespace(eos_token_id=0, pad_token_id=0)
    app._current_model_id = app.FINE_TUNED_MODEL_ID

    def fake_generate(prompt, generation_kind):
        assert generation_kind == app.model_core.SQL_GENERATION
        assert "CREATE TABLE zoologico" in prompt
        return "SELECT * FROM zoologico WHERE city = 'Sao Paulo';", 1

    monkeypatch.setattr(app, "_generate_model_text", fake_generate)

    r1 = app.generate_response("oi", [], "", app.FINE_TUNED_MODEL_KEY)
    assert r1[4] == ""
    assert app.FALLBACK_RESPONSE in r1[0][-1]["content"]

    r2 = app.generate_response(
        "create table zoologico with id name city capacity",
        r1[0],
        r1[2],
        app.FINE_TUNED_MODEL_KEY,
        r1[7],
    )
    assert "CREATE TABLE zoologico" in r2[4]
    assert "CREATE TABLE zoologico" in r2[2]

    r3 = app.generate_response(
        "change capacity to animal_count",
        r2[0],
        r2[2],
        app.FINE_TUNED_MODEL_KEY,
        r2[7],
    )
    assert "animal_count TEXT" in r3[4]
    assert "capacidade" not in r3[4]

    r4 = app.generate_response(
        "list zoologicos from Sao Paulo",
        r3[0],
        r3[2],
        app.FINE_TUNED_MODEL_KEY,
        r3[7],
    )
    assert "SELECT * FROM zoologico" in r4[4]