File size: 2,209 Bytes
8871df9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""Скрипт прогона модели на test-сплите PAUQ.

Использование:
    python -m src.evaluation.evaluate --split dev --limit 50
"""

from __future__ import annotations

import argparse
import json
from pathlib import Path

from tqdm import tqdm

from src.config import settings
from src.data.loader import load_pauq_split
from src.data.schema import SchemaRetriever
from src.evaluation.metrics import compute_metrics
from src.models.inference import InferenceEngine


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--split", default="dev", choices=["train", "dev", "test"])
    parser.add_argument("--limit", type=int, default=None, help="Ограничить число примеров")
    parser.add_argument("--output", type=Path, default=Path("results/predictions.jsonl"))
    args = parser.parse_args()

    split_path = settings.pauq_data_dir / f"{args.split}.json"
    examples = load_pauq_split(split_path)
    if args.limit:
        examples = examples[: args.limit]

    schema_ret = SchemaRetriever(settings.databases_dir)
    engine = InferenceEngine()
    engine.load()

    predictions: list[str] = []
    golds: list[str] = []
    db_ids: list[str] = []
    rows = []

    for ex in tqdm(examples, desc="Inference"):
        try:
            schema = schema_ret.render_schema(ex.db_id)
        except FileNotFoundError:
            continue
        result = engine.generate(schema, ex.question)
        predictions.append(result.sql)
        golds.append(ex.query)
        db_ids.append(ex.db_id)
        rows.append(
            {
                "db_id": ex.db_id,
                "question": ex.question,
                "gold": ex.query,
                "pred": result.sql,
                "raw": result.raw_output,
            }
        )

    args.output.parent.mkdir(parents=True, exist_ok=True)
    with args.output.open("w", encoding="utf-8") as f:
        for r in rows:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")

    metrics = compute_metrics(predictions, golds, db_ids, settings.databases_dir)
    print(json.dumps(metrics, indent=2, ensure_ascii=False))


if __name__ == "__main__":
    main()