File size: 1,668 Bytes
8871df9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 | """Тесты на метрики EM и EX."""
import sqlite3
from pathlib import Path
import pytest
from src.evaluation.metrics import exact_match, execution_accuracy
def test_exact_match_simple():
assert exact_match("SELECT * FROM t", "select * from t")
def test_exact_match_whitespace():
assert exact_match("SELECT * FROM t", "SELECT * FROM t")
def test_exact_match_negative():
assert not exact_match("SELECT a FROM t", "SELECT b FROM t")
@pytest.fixture
def tmp_sqlite(tmp_path: Path) -> Path:
db = tmp_path / "tiny.sqlite"
conn = sqlite3.connect(db)
conn.execute("CREATE TABLE users (id INT, name TEXT)")
conn.executemany("INSERT INTO users VALUES (?, ?)", [(1, "a"), (2, "b")])
conn.commit()
conn.close()
return db
def test_execution_accuracy_match(tmp_sqlite: Path):
pred = "SELECT id FROM users ORDER BY id"
gold = "SELECT id FROM users ORDER BY id"
assert execution_accuracy(pred, gold, tmp_sqlite)
def test_execution_accuracy_set_equal(tmp_sqlite: Path):
pred = "SELECT id FROM users ORDER BY id DESC"
gold = "SELECT id FROM users ORDER BY id ASC"
# Без ORDER BY проверки — как множества они равны
assert execution_accuracy(pred, gold, tmp_sqlite)
def test_execution_accuracy_mismatch(tmp_sqlite: Path):
pred = "SELECT id FROM users WHERE id = 1"
gold = "SELECT id FROM users WHERE id = 2"
assert not execution_accuracy(pred, gold, tmp_sqlite)
def test_execution_accuracy_invalid_pred(tmp_sqlite: Path):
pred = "SELEC bad sql"
gold = "SELECT id FROM users"
assert not execution_accuracy(pred, gold, tmp_sqlite)
|