File size: 2,754 Bytes
8871df9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""Метрики Text-to-SQL: Exact Match и Execution Accuracy."""

from __future__ import annotations

import sqlite3
from pathlib import Path

from src.models.postprocess import normalize_sql


def exact_match(predicted: str, gold: str, dialect: str = "sqlite") -> bool:
    """Сравнение нормализованных SQL посимвольно. Грубая, но честная метрика."""
    return normalize_sql(predicted, dialect) == normalize_sql(gold, dialect)


def execution_accuracy(
    predicted_sql: str,
    gold_sql: str,
    db_path: Path | str,
    timeout_seconds: float = 5.0,
) -> bool:
    """Прогон обоих SQL на SQLite. True если результаты совпадают как множества."""
    db_path = Path(db_path)
    conn = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True, timeout=timeout_seconds)
    try:
        conn.text_factory = lambda b: b.decode("utf-8", errors="replace")
        try:
            pred_rows = _run(conn, predicted_sql)
        except sqlite3.Error:
            return False
        try:
            gold_rows = _run(conn, gold_sql)
        except sqlite3.Error:
            return False
        return _rows_equal(pred_rows, gold_rows)
    finally:
        conn.close()


def _run(conn: sqlite3.Connection, sql: str) -> list[tuple]:
    cur = conn.cursor()
    cur.execute(sql)
    return cur.fetchall()


def _rows_equal(a: list[tuple], b: list[tuple]) -> bool:
    """Сравнение как мультимножеств — порядок не важен (если в SQL нет ORDER BY)."""
    if len(a) != len(b):
        return False
    return sorted(map(_row_key, a)) == sorted(map(_row_key, b))


def _row_key(row: tuple) -> tuple:
    return tuple(str(x) for x in row)


def compute_metrics(
    predictions: list[str],
    golds: list[str],
    db_ids: list[str],
    databases_dir: Path | str,
) -> dict:
    """Прогон по всему датасету. Возвращает dict с EM, EX, и счётчиками."""
    databases_dir = Path(databases_dir)
    n = len(predictions)
    assert n == len(golds) == len(db_ids), "Mismatched lengths"

    em_count = 0
    ex_count = 0
    parse_fail = 0

    for pred, gold, db_id in zip(predictions, golds, db_ids):
        if exact_match(pred, gold):
            em_count += 1

        db_path = databases_dir / db_id / f"{db_id}.sqlite"
        if not db_path.exists():
            parse_fail += 1
            continue

        if execution_accuracy(pred, gold, db_path):
            ex_count += 1

    return {
        "n": n,
        "exact_match": em_count / n if n else 0.0,
        "execution_accuracy": ex_count / n if n else 0.0,
        "parse_fail": parse_fail,
    }