| """Скрипт прогона модели на test-сплите PAUQ. |
| |
| Использование: |
| python -m src.evaluation.evaluate --split dev --limit 50 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| from pathlib import Path |
|
|
| from tqdm import tqdm |
|
|
| from src.config import settings |
| from src.data.loader import load_pauq_split |
| from src.data.schema import SchemaRetriever |
| from src.evaluation.metrics import compute_metrics |
| from src.models.inference import InferenceEngine |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--split", default="dev", choices=["train", "dev", "test"]) |
| parser.add_argument("--limit", type=int, default=None, help="Ограничить число примеров") |
| parser.add_argument("--output", type=Path, default=Path("results/predictions.jsonl")) |
| args = parser.parse_args() |
|
|
| split_path = settings.pauq_data_dir / f"{args.split}.json" |
| examples = load_pauq_split(split_path) |
| if args.limit: |
| examples = examples[: args.limit] |
|
|
| schema_ret = SchemaRetriever(settings.databases_dir) |
| engine = InferenceEngine() |
| engine.load() |
|
|
| predictions: list[str] = [] |
| golds: list[str] = [] |
| db_ids: list[str] = [] |
| rows = [] |
|
|
| for ex in tqdm(examples, desc="Inference"): |
| try: |
| schema = schema_ret.render_schema(ex.db_id) |
| except FileNotFoundError: |
| continue |
| result = engine.generate(schema, ex.question) |
| predictions.append(result.sql) |
| golds.append(ex.query) |
| db_ids.append(ex.db_id) |
| rows.append( |
| { |
| "db_id": ex.db_id, |
| "question": ex.question, |
| "gold": ex.query, |
| "pred": result.sql, |
| "raw": result.raw_output, |
| } |
| ) |
|
|
| args.output.parent.mkdir(parents=True, exist_ok=True) |
| with args.output.open("w", encoding="utf-8") as f: |
| for r in rows: |
| f.write(json.dumps(r, ensure_ascii=False) + "\n") |
|
|
| metrics = compute_metrics(predictions, golds, db_ids, settings.databases_dir) |
| print(json.dumps(metrics, indent=2, ensure_ascii=False)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|