| """ |
| Π‘ΠΊΡΠΈΠΏΡ ΠΏΡΠΎΠ³ΠΎΠ½Π° ΠΌΠ΅ΡΡΠΈΠΊ Ru2SQL Π½Π° Π²Π°Π»ΠΈΠ΄Π°ΡΠΈΠΎΠ½Π½ΠΎΠΉ Π²ΡΠ±ΠΎΡΠΊΠ΅ PAUQ. |
| |
| Π‘ΡΠΈΡΠ°Π΅Ρ Exact Match (EM) ΠΈ Execution Accuracy (EX) Π΄Π»Ρ Π΄ΠΎΠΎΠ±ΡΡΠ΅Π½Π½ΠΎΠΉ ΠΌΠΎΠ΄Π΅Π»ΠΈ |
| Qwen2.5-Coder-3B + QLoRA. ΠΡΠΏΠΎΠ»ΡΠ·ΡΠ΅ΡΡΡ Π½Π° Kaggle (T4 GPU), ΡΠ΅Π·ΡΠ»ΡΡΠ°Ρ |
| ΡΠΎΡ
ΡΠ°Π½ΡΠ΅ΡΡΡ Π² eval_results.csv ΠΈ eval_summary.json. |
| |
| ΠΡΡΠΎΡΠ½ΠΈΠΊ ΡΠΈΡΡ Π΄Π»Ρ ΡΠ°Π·Π΄Π΅Π»Π° 4.3 ΠΏΠΎΡΡΠ½ΠΈΡΠ΅Π»ΡΠ½ΠΎΠΉ Π·Π°ΠΏΠΈΡΠΊΠΈ. |
| |
| ΠΡΠΏΠΎΠ»ΡΠ·ΠΎΠ²Π°Π½ΠΈΠ΅ Π½Π° Kaggle: |
| 1. ΠΠ°Π³ΡΡΠ·ΠΈΡΡ ΠΏΡΠΎΠ΅ΠΊΡ ΡΠ΅Π»ΠΈΠΊΠΎΠΌ (ΡΠ΅ΡΠ΅Π· Kaggle Dataset ΠΈΠ»ΠΈ git clone). |
| 2. Π£ΡΡΠ°Π½ΠΎΠ²ΠΈΡΡ ADAPTER_ID Π½Π° ΡΠ²ΠΎΠΉ HF-ΡΠ΅ΠΏΠΎ. |
| 3. python evaluate_pauq.py |
| |
| ΠΠΎΡΡΠΎΠ±ΡΠ°Π±ΠΎΡΠΊΠ° ΠΈ Π½ΠΎΡΠΌΠ°Π»ΠΈΠ·Π°ΡΠΈΡ SQL ΠΈΠΌΠΏΠΎΡΡΠΈΡΡΡΡΡΡ ΠΈΠ· src/models/postprocess.py, |
| ΡΡΠΎΠ±Ρ ΠΌΠ΅ΡΡΠΈΠΊΠΈ Π»ΠΎΠΊΠ°Π»ΡΠ½ΠΎ ΠΈ Π½Π° Kaggle ΠΎΡΡΠ°Π²Π°Π»ΠΈΡΡ ΡΠΎΠΏΠΎΡΡΠ°Π²ΠΈΠΌΡΠΌΠΈ. |
| """ |
|
|
| import sys |
| from pathlib import Path |
|
|
| |
| _PROJECT_ROOT = Path(__file__).resolve().parent |
| if str(_PROJECT_ROOT) not in sys.path: |
| sys.path.insert(0, str(_PROJECT_ROOT)) |
|
|
| import csv |
| import json |
| import sqlite3 |
|
|
| from tqdm import tqdm |
|
|
| import torch |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
| from peft import PeftModel |
| from datasets import load_dataset |
|
|
| from src.models.postprocess import ( |
| normalize_sql, |
| strip_model_artifacts as strip_artifacts, |
| ) |
|
|
| |
|
|
| BASE_MODEL_ID = "Qwen/Qwen2.5-Coder-3B-Instruct" |
| ADAPTER_ID = "Tyycha/qwen-coder-pauq-lora" |
| PAUQ_SPLIT = "validation" |
| MAX_NEW_TOKENS = 256 |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| |
| |
| PAUQ_DB_DIR = Path("./pauq_databases") |
|
|
| |
|
|
| RESULTS_FILE = "eval_results.csv" |
| SUMMARY_FILE = "eval_summary.json" |
|
|
| |
|
|
| print("Loading model...") |
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| bnb_4bit_use_double_quant=True, |
| ) |
|
|
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True) |
| base_model = AutoModelForCausalLM.from_pretrained( |
| BASE_MODEL_ID, |
| quantization_config=bnb_config, |
| device_map="auto", |
| trust_remote_code=True |
| ) |
| model = PeftModel.from_pretrained(base_model, ADAPTER_ID) |
| model.eval() |
| print(f"Model loaded on {DEVICE}") |
|
|
| |
|
|
| def get_schema(db_id: str) -> str: |
| """Extract CREATE TABLE statements from SQLite database.""" |
| db_path = PAUQ_DB_DIR / db_id / f"{db_id}.sqlite" |
| if not db_path.exists(): |
| return f"-- Database {db_id} not found" |
| |
| conn = sqlite3.connect(str(db_path)) |
| cursor = conn.cursor() |
| |
| cursor.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name") |
| tables = [row[0] for row in cursor.fetchall()] |
| |
| schema_parts = [] |
| for table in tables: |
| cursor.execute(f"SELECT sql FROM sqlite_master WHERE type='table' AND name=?", (table,)) |
| create_sql = cursor.fetchone() |
| if create_sql and create_sql[0]: |
| schema_parts.append(create_sql[0]) |
| |
| try: |
| cursor.execute(f"SELECT * FROM \"{table}\" LIMIT 3") |
| rows = cursor.fetchall() |
| cursor.execute(f"PRAGMA table_info(\"{table}\")") |
| cols = [col[1] for col in cursor.fetchall()] |
| if rows: |
| schema_parts.append(f"-- Sample data for {table}:") |
| schema_parts.append("-- " + " | ".join(cols)) |
| for row in rows[:3]: |
| schema_parts.append("-- " + " | ".join(str(v) for v in row)) |
| except Exception: |
| pass |
| |
| conn.close() |
| return "\n\n".join(schema_parts) |
|
|
| |
|
|
| SYSTEM_PROMPT = ( |
| "Π’Ρ β Π°ΡΡΠΈΡΡΠ΅Π½Ρ, ΠΊΠΎΡΠΎΡΡΠΉ ΠΏΡΠ΅ΠΎΠ±ΡΠ°Π·ΡΠ΅Ρ Π²ΠΎΠΏΡΠΎΡΡ Π½Π° ΡΡΡΡΠΊΠΎΠΌ ΡΠ·ΡΠΊΠ΅ Π² SQL-Π·Π°ΠΏΡΠΎΡΡ. " |
| "ΠΡΠ²Π΅ΡΠ°ΠΉ Π’ΠΠΠ¬ΠΠ SQL-Π·Π°ΠΏΡΠΎΡΠΎΠΌ Π±Π΅Π· ΠΎΠ±ΡΡΡΠ½Π΅Π½ΠΈΠΉ, ΠΊΠΎΠΌΠΌΠ΅Π½ΡΠ°ΡΠΈΠ΅Π² ΠΈ markdown-ΡΠ°Π·ΠΌΠ΅ΡΠΊΠΈ." |
| ) |
|
|
| def build_prompt(question: str, schema: str) -> str: |
| user_msg = f"Schema:\n{schema}\n\nQuestion: {question}\n\nSQL:" |
| messages = [ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": user_msg}, |
| ] |
| return tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True |
| ) |
|
|
| |
| |
| |
| |
|
|
| |
|
|
| def execute_sql(sql: str, db_id: str): |
| """Execute SQL and return result set as frozenset of tuples.""" |
| db_path = PAUQ_DB_DIR / db_id / f"{db_id}.sqlite" |
| if not db_path.exists(): |
| return None |
| try: |
| uri = f"file:{db_path}?mode=ro" |
| conn = sqlite3.connect(uri, uri=True) |
| cursor = conn.cursor() |
| cursor.execute(sql) |
| rows = cursor.fetchall() |
| conn.close() |
| return frozenset(tuple(str(v) for v in row) for row in rows) |
| except Exception: |
| return None |
|
|
| |
|
|
| def generate_sql(question: str, schema: str) -> str: |
| prompt = build_prompt(question, schema) |
| inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) |
| |
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=MAX_NEW_TOKENS, |
| do_sample=False, |
| temperature=1.0, |
| pad_token_id=tokenizer.eos_token_id, |
| ) |
| |
| generated = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True) |
| return strip_artifacts(generated) |
|
|
| |
|
|
| def main(): |
| print(f"Loading PAUQ {PAUQ_SPLIT} split...") |
| dataset = load_dataset("ai-forever/PAUQ", split=PAUQ_SPLIT) |
| print(f"Loaded {len(dataset)} examples") |
| |
| results = [] |
| em_correct = 0 |
| ex_correct = 0 |
| ex_total = 0 |
| |
| with open(RESULTS_FILE, "w", newline="", encoding="utf-8") as f: |
| writer = csv.DictWriter(f, fieldnames=[ |
| "idx", "db_id", "question", "gold_sql", "pred_sql", |
| "em", "ex", "error" |
| ]) |
| writer.writeheader() |
| |
| for idx, example in enumerate(tqdm(dataset, desc="Evaluating")): |
| question = example.get("question_ru") or example.get("question", "") |
| gold_sql = example.get("query", "") |
| db_id = example.get("db_id", "") |
| |
| try: |
| schema = get_schema(db_id) |
| pred_sql = generate_sql(question, schema) |
| |
| |
| norm_pred = normalize_sql(pred_sql) |
| norm_gold = normalize_sql(gold_sql) |
| em = int(norm_pred == norm_gold) |
| |
| |
| pred_result = execute_sql(pred_sql, db_id) |
| gold_result = execute_sql(gold_sql, db_id) |
| |
| if gold_result is not None: |
| ex = int(pred_result == gold_result) |
| ex_correct += ex |
| ex_total += 1 |
| else: |
| ex = None |
| |
| em_correct += em |
| error = "" |
| |
| except Exception as e: |
| pred_sql = "" |
| em = 0 |
| ex = None |
| error = str(e)[:200] |
| |
| row = { |
| "idx": idx, "db_id": db_id, "question": question, |
| "gold_sql": gold_sql, "pred_sql": pred_sql, |
| "em": em, "ex": ex, "error": error |
| } |
| writer.writerow(row) |
| results.append(row) |
| |
| |
| if (idx + 1) % 100 == 0: |
| cur_em = em_correct / (idx + 1) |
| cur_ex = ex_correct / max(ex_total, 1) |
| print(f"[{idx+1}/{len(dataset)}] EM={cur_em:.3f} EX={cur_ex:.3f}") |
| |
| |
| n = len(dataset) |
| final_em = em_correct / n |
| final_ex = ex_correct / max(ex_total, 1) |
| |
| summary = { |
| "model": ADAPTER_ID, |
| "split": PAUQ_SPLIT, |
| "n_examples": n, |
| "exact_match": round(final_em, 4), |
| "execution_accuracy": round(final_ex, 4), |
| "em_correct": em_correct, |
| "ex_correct": ex_correct, |
| "ex_total": ex_total |
| } |
| |
| with open(SUMMARY_FILE, "w", encoding="utf-8") as f: |
| json.dump(summary, f, ensure_ascii=False, indent=2) |
| |
| print("\n" + "="*50) |
| print(f"RESULTS on PAUQ {PAUQ_SPLIT} ({n} examples)") |
| print(f" Exact Match (EM): {final_em:.1%} ({em_correct}/{n})") |
| print(f" Execution Accuracy (EX): {final_ex:.1%} ({ex_correct}/{ex_total})") |
| print(f"\nDetailed results saved to: {RESULTS_FILE}") |
| print(f"Summary saved to: {SUMMARY_FILE}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|