File size: 5,527 Bytes
47affa0 ad5be9b 47affa0 ad5be9b 47affa0 ad5be9b 47affa0 ad5be9b 47affa0 ad5be9b 47affa0 ad5be9b 47affa0 ad5be9b 47affa0 ad5be9b 47affa0 92923b1 ad5be9b 47affa0 ad5be9b 737eaac 47affa0 ad5be9b 47affa0 ad5be9b d88f966 ad5be9b d88f966 ad5be9b d88f966 ad5be9b d88f966 ad5be9b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | import types
import app
from chat_state import ConversationState
from intent import CREATE_TABLE, EDIT_TABLE, SMALLTALK, SQL_QUERY, UNKNOWN, classify_intent
def test_conversation_state_roundtrip_dict():
state = ConversationState(
active_schema="CREATE TABLE zoologico (id INTEGER)",
last_intent=SQL_QUERY,
debug={"intent": SQL_QUERY, "confidence": 0.86, "reason": "sql_query_terms"},
)
restored = ConversationState.from_value(state.to_dict())
assert restored.active_schema == "CREATE TABLE zoologico (id INTEGER)"
assert restored.last_intent == SQL_QUERY
assert restored.debug["reason"] == "sql_query_terms"
def test_intent_smalltalk_with_active_schema_is_not_sql():
state = ConversationState(active_schema="CREATE TABLE employees (id INTEGER)")
result = classify_intent("como voce esta hoje?", state)
assert result.intent == SMALLTALK
def test_schema_request_without_columns_is_unknown_not_model_schema_task():
result = classify_intent("preciso de uma tabela sobre zoologico", ConversationState())
assert result.intent == UNKNOWN
def test_intent_create_edit_and_sql_query():
empty_state = ConversationState()
schema_state = ConversationState(active_schema="CREATE TABLE zoologico (id INTEGER, cidade TEXT)")
create = classify_intent("crie tabela zoologico com id nome cidade", empty_state)
edit = classify_intent("troca cidade por municipio", schema_state)
query = classify_intent("liste zoologicos por municipio", schema_state)
assert create.intent == CREATE_TABLE
assert edit.intent == EDIT_TABLE
assert query.intent == SQL_QUERY
def test_destructive_data_mutation_is_unknown_with_active_schema():
state = ConversationState(active_schema="CREATE TABLE animals (id INTEGER, name TEXT, species TEXT)")
assert classify_intent("delete all animals", state).intent == UNKNOWN
assert classify_intent("DELETE FROM animals", state).intent == UNKNOWN
assert classify_intent("UPDATE animals SET name = 'x'", state).intent == UNKNOWN
assert classify_intent("INSERT INTO animals VALUES (1)", state).intent == UNKNOWN
assert classify_intent("insert animal", state).intent == UNKNOWN
assert classify_intent("insert row into animals", state).intent == UNKNOWN
assert classify_intent("add row to animals", state).intent == UNKNOWN
assert classify_intent("add animal record", state).intent == UNKNOWN
assert classify_intent("drop animals table", state).intent == UNKNOWN
assert classify_intent("drop the animals table", state).intent == UNKNOWN
assert classify_intent("delete animals", state).intent == UNKNOWN
assert classify_intent("drop animals", state).intent == UNKNOWN
singular_state = ConversationState(active_schema="CREATE TABLE animal (id INTEGER, species TEXT)")
assert classify_intent("delete animals", singular_state).intent == UNKNOWN
assert classify_intent("drop animals", singular_state).intent == UNKNOWN
def test_data_mutation_uses_schema_from_history_when_active_schema_empty():
history = [
{
"role": "assistant",
"content": "```sql\nCREATE TABLE animals (id INTEGER, species TEXT);\n```",
}
]
assert classify_intent("delete animals", ConversationState(), history).intent == UNKNOWN
assert classify_intent("drop animals", ConversationState(), history).intent == UNKNOWN
def test_delete_existing_column_stays_schema_edit():
state = ConversationState(active_schema="CREATE TABLE animals (id INTEGER, species TEXT)")
assert classify_intent("delete species", state).intent == EDIT_TABLE
assert classify_intent("drop coluna id", state).intent == EDIT_TABLE
assert classify_intent("add habitat", state).intent == EDIT_TABLE
def test_create_tables_named_rows_or_records_are_not_data_mutation():
assert classify_intent("create table rows with id name", ConversationState()).intent == CREATE_TABLE
assert classify_intent("create table records with id name", ConversationState()).intent == CREATE_TABLE
def test_zoologico_transcript_with_mocked_sql_model(monkeypatch):
app._model = types.SimpleNamespace(generation_config=types.SimpleNamespace(eos_token_id=0))
app._tokenizer = types.SimpleNamespace(eos_token_id=0, pad_token_id=0)
app._current_model_id = app.FINE_TUNED_MODEL_ID
def fake_generate(prompt, generation_kind):
assert generation_kind == app.model_core.SQL_GENERATION
assert "CREATE TABLE zoologico" in prompt
return "SELECT * FROM zoologico WHERE city = 'Sao Paulo';", 1
monkeypatch.setattr(app, "_generate_model_text", fake_generate)
r1 = app.generate_response("oi", [], "", app.FINE_TUNED_MODEL_KEY)
assert r1[4] == ""
assert app.FALLBACK_RESPONSE in r1[0][-1]["content"]
r2 = app.generate_response(
"create table zoologico with id name city capacity",
r1[0],
r1[2],
app.FINE_TUNED_MODEL_KEY,
r1[7],
)
assert "CREATE TABLE zoologico" in r2[4]
assert "CREATE TABLE zoologico" in r2[2]
r3 = app.generate_response(
"change capacity to animal_count",
r2[0],
r2[2],
app.FINE_TUNED_MODEL_KEY,
r2[7],
)
assert "animal_count TEXT" in r3[4]
assert "capacidade" not in r3[4]
r4 = app.generate_response(
"list zoologicos from Sao Paulo",
r3[0],
r3[2],
app.FINE_TUNED_MODEL_KEY,
r3[7],
)
assert "SELECT * FROM zoologico" in r4[4]
|