Ru2SQL / src /evaluation /evaluate.py
Tyycha's picture
initial commit
8871df9
"""Скрипт прогона модели на test-сплите PAUQ.
Использование:
python -m src.evaluation.evaluate --split dev --limit 50
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from tqdm import tqdm
from src.config import settings
from src.data.loader import load_pauq_split
from src.data.schema import SchemaRetriever
from src.evaluation.metrics import compute_metrics
from src.models.inference import InferenceEngine
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--split", default="dev", choices=["train", "dev", "test"])
parser.add_argument("--limit", type=int, default=None, help="Ограничить число примеров")
parser.add_argument("--output", type=Path, default=Path("results/predictions.jsonl"))
args = parser.parse_args()
split_path = settings.pauq_data_dir / f"{args.split}.json"
examples = load_pauq_split(split_path)
if args.limit:
examples = examples[: args.limit]
schema_ret = SchemaRetriever(settings.databases_dir)
engine = InferenceEngine()
engine.load()
predictions: list[str] = []
golds: list[str] = []
db_ids: list[str] = []
rows = []
for ex in tqdm(examples, desc="Inference"):
try:
schema = schema_ret.render_schema(ex.db_id)
except FileNotFoundError:
continue
result = engine.generate(schema, ex.question)
predictions.append(result.sql)
golds.append(ex.query)
db_ids.append(ex.db_id)
rows.append(
{
"db_id": ex.db_id,
"question": ex.question,
"gold": ex.query,
"pred": result.sql,
"raw": result.raw_output,
}
)
args.output.parent.mkdir(parents=True, exist_ok=True)
with args.output.open("w", encoding="utf-8") as f:
for r in rows:
f.write(json.dumps(r, ensure_ascii=False) + "\n")
metrics = compute_metrics(predictions, golds, db_ids, settings.databases_dir)
print(json.dumps(metrics, indent=2, ensure_ascii=False))
if __name__ == "__main__":
main()