Ru2SQL / evaluate_pauq.py
Tyycha's picture
fix bugs
cc2ed2f
"""
Π‘ΠΊΡ€ΠΈΠΏΡ‚ ΠΏΡ€ΠΎΠ³ΠΎΠ½Π° ΠΌΠ΅Ρ‚Ρ€ΠΈΠΊ Ru2SQL Π½Π° Π²Π°Π»ΠΈΠ΄Π°Ρ†ΠΈΠΎΠ½Π½ΠΎΠΉ Π²Ρ‹Π±ΠΎΡ€ΠΊΠ΅ PAUQ.
Π‘Ρ‡ΠΈΡ‚Π°Π΅Ρ‚ Exact Match (EM) ΠΈ Execution Accuracy (EX) для Π΄ΠΎΠΎΠ±ΡƒΡ‡Π΅Π½Π½ΠΎΠΉ ΠΌΠΎΠ΄Π΅Π»ΠΈ
Qwen2.5-Coder-3B + QLoRA. Π˜ΡΠΏΠΎΠ»ΡŒΠ·ΡƒΠ΅Ρ‚ΡΡ Π½Π° Kaggle (T4 GPU), Ρ€Π΅Π·ΡƒΠ»ΡŒΡ‚Π°Ρ‚
сохраняСтся Π² eval_results.csv ΠΈ eval_summary.json.
Π˜ΡΡ‚ΠΎΡ‡Π½ΠΈΠΊ Ρ†ΠΈΡ„Ρ€ для Ρ€Π°Π·Π΄Π΅Π»Π° 4.3 ΠΏΠΎΡΡΠ½ΠΈΡ‚Π΅Π»ΡŒΠ½ΠΎΠΉ записки.
ИспользованиС на Kaggle:
1. Π—Π°Π³Ρ€ΡƒΠ·ΠΈΡ‚ΡŒ ΠΏΡ€ΠΎΠ΅ΠΊΡ‚ Ρ†Π΅Π»ΠΈΠΊΠΎΠΌ (Ρ‡Π΅Ρ€Π΅Π· Kaggle Dataset ΠΈΠ»ΠΈ git clone).
2. Π£ΡΡ‚Π°Π½ΠΎΠ²ΠΈΡ‚ΡŒ ADAPTER_ID Π½Π° свой HF-Ρ€Π΅ΠΏΠΎ.
3. python evaluate_pauq.py
ΠŸΠΎΡΡ‚ΠΎΠ±Ρ€Π°Π±ΠΎΡ‚ΠΊΠ° ΠΈ нормализация SQL ΠΈΠΌΠΏΠΎΡ€Ρ‚ΠΈΡ€ΡƒΡŽΡ‚ΡΡ ΠΈΠ· src/models/postprocess.py,
Ρ‡Ρ‚ΠΎΠ±Ρ‹ ΠΌΠ΅Ρ‚Ρ€ΠΈΠΊΠΈ локально ΠΈ Π½Π° Kaggle ΠΎΡΡ‚Π°Π²Π°Π»ΠΈΡΡŒ сопоставимыми.
"""
import sys
from pathlib import Path
# Π”Π΅Π»Π°Π΅ΠΌ ΠΏΠ°ΠΊΠ΅Ρ‚ src/ ΠΈΠΌΠΏΠΎΡ€Ρ‚ΠΈΡ€ΡƒΠ΅ΠΌΡ‹ΠΌ, ΠΊΠΎΠ³Π΄Π° скрипт запускаСтся ΠΈΠ· корня.
_PROJECT_ROOT = Path(__file__).resolve().parent
if str(_PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(_PROJECT_ROOT))
import csv
import json
import sqlite3
from tqdm import tqdm
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
from datasets import load_dataset
from src.models.postprocess import (
normalize_sql,
strip_model_artifacts as strip_artifacts,
)
# ─── CONFIG ───────────────────────────────────────────────────────────────────
BASE_MODEL_ID = "Qwen/Qwen2.5-Coder-3B-Instruct"
ADAPTER_ID = "Tyycha/qwen-coder-pauq-lora"
PAUQ_SPLIT = "validation" # 1034 ΠΏΡ€ΠΈΠΌΠ΅Ρ€Π°
MAX_NEW_TOKENS = 256
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ΠŸΡƒΡ‚ΡŒ ΠΊ ΠΏΠ°ΠΏΠΊΠ΅ с SQLite-Π±Π°Π·Π°ΠΌΠΈ PAUQ (Spider databases)
# На Kaggle: скачай https://drive.google.com/uc?id=1iRDVHLr4mX2wQKSgA9VxUUFpj-3-Kj5B
# Или ΠΈΡΠΏΠΎΠ»ΡŒΠ·ΡƒΠΉ датасСт: kaggle datasets download -d wikisql/spider
PAUQ_DB_DIR = Path("./pauq_databases") # ΠΏΠ°ΠΏΠΊΠ°, Π³Π΄Π΅ Π»Π΅ΠΆΠ°Ρ‚ ΠΏΠ°ΠΏΠΊΠΈ Π±Π°Π· Π΄Π°Π½Π½Ρ‹Ρ…
# ─── OUTPUT ───────────────────────────────────────────────────────────────────
RESULTS_FILE = "eval_results.csv"
SUMMARY_FILE = "eval_summary.json"
# ─── LOAD MODEL ───────────────────────────────────────────────────────────────
print("Loading model...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True
)
model = PeftModel.from_pretrained(base_model, ADAPTER_ID)
model.eval()
print(f"Model loaded on {DEVICE}")
# ─── SCHEMA RETRIEVER ─────────────────────────────────────────────────────────
def get_schema(db_id: str) -> str:
"""Extract CREATE TABLE statements from SQLite database."""
db_path = PAUQ_DB_DIR / db_id / f"{db_id}.sqlite"
if not db_path.exists():
return f"-- Database {db_id} not found"
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
tables = [row[0] for row in cursor.fetchall()]
schema_parts = []
for table in tables:
cursor.execute(f"SELECT sql FROM sqlite_master WHERE type='table' AND name=?", (table,))
create_sql = cursor.fetchone()
if create_sql and create_sql[0]:
schema_parts.append(create_sql[0])
# Add sample rows
try:
cursor.execute(f"SELECT * FROM \"{table}\" LIMIT 3")
rows = cursor.fetchall()
cursor.execute(f"PRAGMA table_info(\"{table}\")")
cols = [col[1] for col in cursor.fetchall()]
if rows:
schema_parts.append(f"-- Sample data for {table}:")
schema_parts.append("-- " + " | ".join(cols))
for row in rows[:3]:
schema_parts.append("-- " + " | ".join(str(v) for v in row))
except Exception:
pass
conn.close()
return "\n\n".join(schema_parts)
# ─── PROMPT BUILDER ───────────────────────────────────────────────────────────
SYSTEM_PROMPT = (
"Π’Ρ‹ β€” ассистСнт, ΠΊΠΎΡ‚ΠΎΡ€Ρ‹ΠΉ ΠΏΡ€Π΅ΠΎΠ±Ρ€Π°Π·ΡƒΠ΅Ρ‚ вопросы Π½Π° русском языкС Π² SQL-запросы. "
"ΠžΡ‚Π²Π΅Ρ‡Π°ΠΉ Π’ΠžΠ›Π¬ΠšΠž SQL-запросом Π±Π΅Π· объяснСний, ΠΊΠΎΠΌΠΌΠ΅Π½Ρ‚Π°Ρ€ΠΈΠ΅Π² ΠΈ markdown-Ρ€Π°Π·ΠΌΠ΅Ρ‚ΠΊΠΈ."
)
def build_prompt(question: str, schema: str) -> str:
user_msg = f"Schema:\n{schema}\n\nQuestion: {question}\n\nSQL:"
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_msg},
]
return tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# ─── SQL POSTPROCESSING ───────────────────────────────────────────────────────
# strip_artifacts ΠΈ normalize_sql ΠΈΠΌΠΏΠΎΡ€Ρ‚ΠΈΡ€ΡƒΡŽΡ‚ΡΡ ΠΈΠ· src/models/postprocess.py
# для Π³Π°Ρ€Π°Π½Ρ‚ΠΈΠΈ Ρ‚ΠΎΠ³ΠΎ, Ρ‡Ρ‚ΠΎ ΠΌΠ΅Ρ‚Ρ€ΠΈΠΊΠΈ Π½Π° Kaggle ΠΈ Π² Π»ΠΎΠΊΠ°Π»ΡŒΠ½Ρ‹Ρ… тСстах ΡΡ‡ΠΈΡ‚Π°ΡŽΡ‚ΡΡ
# ΠΏΠΎ ΠΎΠ΄Π½ΠΎΠΉ ΠΈ Ρ‚ΠΎΠΉ ΠΆΠ΅ Π»ΠΎΠ³ΠΈΠΊΠ΅.
# ─── EXECUTION ────────────────────────────────────────────────────────────────
def execute_sql(sql: str, db_id: str):
"""Execute SQL and return result set as frozenset of tuples."""
db_path = PAUQ_DB_DIR / db_id / f"{db_id}.sqlite"
if not db_path.exists():
return None
try:
uri = f"file:{db_path}?mode=ro"
conn = sqlite3.connect(uri, uri=True)
cursor = conn.cursor()
cursor.execute(sql)
rows = cursor.fetchall()
conn.close()
return frozenset(tuple(str(v) for v in row) for row in rows)
except Exception:
return None
# ─── INFERENCE ────────────────────────────────────────────────────────────────
def generate_sql(question: str, schema: str) -> str:
prompt = build_prompt(question, schema)
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
temperature=1.0,
pad_token_id=tokenizer.eos_token_id,
)
generated = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
return strip_artifacts(generated)
# ─── MAIN EVALUATION LOOP ─────────────────────────────────────────────────────
def main():
print(f"Loading PAUQ {PAUQ_SPLIT} split...")
dataset = load_dataset("ai-forever/PAUQ", split=PAUQ_SPLIT)
print(f"Loaded {len(dataset)} examples")
results = []
em_correct = 0
ex_correct = 0
ex_total = 0 # only count where execution was possible
with open(RESULTS_FILE, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=[
"idx", "db_id", "question", "gold_sql", "pred_sql",
"em", "ex", "error"
])
writer.writeheader()
for idx, example in enumerate(tqdm(dataset, desc="Evaluating")):
question = example.get("question_ru") or example.get("question", "")
gold_sql = example.get("query", "")
db_id = example.get("db_id", "")
try:
schema = get_schema(db_id)
pred_sql = generate_sql(question, schema)
# Exact Match
norm_pred = normalize_sql(pred_sql)
norm_gold = normalize_sql(gold_sql)
em = int(norm_pred == norm_gold)
# Execution Accuracy
pred_result = execute_sql(pred_sql, db_id)
gold_result = execute_sql(gold_sql, db_id)
if gold_result is not None:
ex = int(pred_result == gold_result)
ex_correct += ex
ex_total += 1
else:
ex = None
em_correct += em
error = ""
except Exception as e:
pred_sql = ""
em = 0
ex = None
error = str(e)[:200]
row = {
"idx": idx, "db_id": db_id, "question": question,
"gold_sql": gold_sql, "pred_sql": pred_sql,
"em": em, "ex": ex, "error": error
}
writer.writerow(row)
results.append(row)
# Progress every 100 examples
if (idx + 1) % 100 == 0:
cur_em = em_correct / (idx + 1)
cur_ex = ex_correct / max(ex_total, 1)
print(f"[{idx+1}/{len(dataset)}] EM={cur_em:.3f} EX={cur_ex:.3f}")
# Final summary
n = len(dataset)
final_em = em_correct / n
final_ex = ex_correct / max(ex_total, 1)
summary = {
"model": ADAPTER_ID,
"split": PAUQ_SPLIT,
"n_examples": n,
"exact_match": round(final_em, 4),
"execution_accuracy": round(final_ex, 4),
"em_correct": em_correct,
"ex_correct": ex_correct,
"ex_total": ex_total
}
with open(SUMMARY_FILE, "w", encoding="utf-8") as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
print("\n" + "="*50)
print(f"RESULTS on PAUQ {PAUQ_SPLIT} ({n} examples)")
print(f" Exact Match (EM): {final_em:.1%} ({em_correct}/{n})")
print(f" Execution Accuracy (EX): {final_ex:.1%} ({ex_correct}/{ex_total})")
print(f"\nDetailed results saved to: {RESULTS_FILE}")
print(f"Summary saved to: {SUMMARY_FILE}")
if __name__ == "__main__":
main()