File size: 2,209 Bytes
8871df9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 | """Скрипт прогона модели на test-сплите PAUQ.
Использование:
python -m src.evaluation.evaluate --split dev --limit 50
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from tqdm import tqdm
from src.config import settings
from src.data.loader import load_pauq_split
from src.data.schema import SchemaRetriever
from src.evaluation.metrics import compute_metrics
from src.models.inference import InferenceEngine
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--split", default="dev", choices=["train", "dev", "test"])
parser.add_argument("--limit", type=int, default=None, help="Ограничить число примеров")
parser.add_argument("--output", type=Path, default=Path("results/predictions.jsonl"))
args = parser.parse_args()
split_path = settings.pauq_data_dir / f"{args.split}.json"
examples = load_pauq_split(split_path)
if args.limit:
examples = examples[: args.limit]
schema_ret = SchemaRetriever(settings.databases_dir)
engine = InferenceEngine()
engine.load()
predictions: list[str] = []
golds: list[str] = []
db_ids: list[str] = []
rows = []
for ex in tqdm(examples, desc="Inference"):
try:
schema = schema_ret.render_schema(ex.db_id)
except FileNotFoundError:
continue
result = engine.generate(schema, ex.question)
predictions.append(result.sql)
golds.append(ex.query)
db_ids.append(ex.db_id)
rows.append(
{
"db_id": ex.db_id,
"question": ex.question,
"gold": ex.query,
"pred": result.sql,
"raw": result.raw_output,
}
)
args.output.parent.mkdir(parents=True, exist_ok=True)
with args.output.open("w", encoding="utf-8") as f:
for r in rows:
f.write(json.dumps(r, ensure_ascii=False) + "\n")
metrics = compute_metrics(predictions, golds, db_ids, settings.databases_dir)
print(json.dumps(metrics, indent=2, ensure_ascii=False))
if __name__ == "__main__":
main()
|