File size: 11,375 Bytes
cc2ed2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
"""
Скрипт прогона метрик Ru2SQL на валидационной выборке PAUQ.

Считает Exact Match (EM) и Execution Accuracy (EX) для дообученной модели
Qwen2.5-Coder-3B + QLoRA. Используется на Kaggle (T4 GPU), результат
сохраняется в eval_results.csv и eval_summary.json.

Источник цифр для раздела 4.3 пояснительной записки.

Использование на Kaggle:
  1. Загрузить проект целиком (через Kaggle Dataset или git clone).
  2. Установить ADAPTER_ID на свой HF-репо.
  3. python evaluate_pauq.py

Постобработка и нормализация SQL импортируются из src/models/postprocess.py,
чтобы метрики локально и на Kaggle оставались сопоставимыми.
"""

import sys
from pathlib import Path

# Делаем пакет src/ импортируемым, когда скрипт запускается из корня.
_PROJECT_ROOT = Path(__file__).resolve().parent
if str(_PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(_PROJECT_ROOT))

import csv
import json
import sqlite3

from tqdm import tqdm

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
from datasets import load_dataset

from src.models.postprocess import (
    normalize_sql,
    strip_model_artifacts as strip_artifacts,
)

# ─── CONFIG ───────────────────────────────────────────────────────────────────

BASE_MODEL_ID  = "Qwen/Qwen2.5-Coder-3B-Instruct"
ADAPTER_ID     = "Tyycha/qwen-coder-pauq-lora"
PAUQ_SPLIT     = "validation"   # 1034 примера
MAX_NEW_TOKENS = 256
DEVICE         = "cuda" if torch.cuda.is_available() else "cpu"

# Путь к папке с SQLite-базами PAUQ (Spider databases)
# На Kaggle: скачай https://drive.google.com/uc?id=1iRDVHLr4mX2wQKSgA9VxUUFpj-3-Kj5B
# Или используй датасет: kaggle datasets download -d wikisql/spider
PAUQ_DB_DIR    = Path("./pauq_databases")   # папка, где лежат папки баз данных

# ─── OUTPUT ───────────────────────────────────────────────────────────────────

RESULTS_FILE   = "eval_results.csv"
SUMMARY_FILE   = "eval_summary.json"

# ─── LOAD MODEL ───────────────────────────────────────────────────────────────

print("Loading model...")
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)
base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_ID,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True
)
model = PeftModel.from_pretrained(base_model, ADAPTER_ID)
model.eval()
print(f"Model loaded on {DEVICE}")

# ─── SCHEMA RETRIEVER ─────────────────────────────────────────────────────────

def get_schema(db_id: str) -> str:
    """Extract CREATE TABLE statements from SQLite database."""
    db_path = PAUQ_DB_DIR / db_id / f"{db_id}.sqlite"
    if not db_path.exists():
        return f"-- Database {db_id} not found"
    
    conn = sqlite3.connect(str(db_path))
    cursor = conn.cursor()
    
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
    tables = [row[0] for row in cursor.fetchall()]
    
    schema_parts = []
    for table in tables:
        cursor.execute(f"SELECT sql FROM sqlite_master WHERE type='table' AND name=?", (table,))
        create_sql = cursor.fetchone()
        if create_sql and create_sql[0]:
            schema_parts.append(create_sql[0])
            # Add sample rows
            try:
                cursor.execute(f"SELECT * FROM \"{table}\" LIMIT 3")
                rows = cursor.fetchall()
                cursor.execute(f"PRAGMA table_info(\"{table}\")")
                cols = [col[1] for col in cursor.fetchall()]
                if rows:
                    schema_parts.append(f"-- Sample data for {table}:")
                    schema_parts.append("-- " + " | ".join(cols))
                    for row in rows[:3]:
                        schema_parts.append("-- " + " | ".join(str(v) for v in row))
            except Exception:
                pass
    
    conn.close()
    return "\n\n".join(schema_parts)

# ─── PROMPT BUILDER ───────────────────────────────────────────────────────────

SYSTEM_PROMPT = (
    "Ты — ассистент, который преобразует вопросы на русском языке в SQL-запросы. "
    "Отвечай ТОЛЬКО SQL-запросом без объяснений, комментариев и markdown-разметки."
)

def build_prompt(question: str, schema: str) -> str:
    user_msg = f"Schema:\n{schema}\n\nQuestion: {question}\n\nSQL:"
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user",   "content": user_msg},
    ]
    return tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )

# ─── SQL POSTPROCESSING ───────────────────────────────────────────────────────
# strip_artifacts и normalize_sql импортируются из src/models/postprocess.py
# для гарантии того, что метрики на Kaggle и в локальных тестах считаются
# по одной и той же логике.

# ─── EXECUTION ────────────────────────────────────────────────────────────────

def execute_sql(sql: str, db_id: str):
    """Execute SQL and return result set as frozenset of tuples."""
    db_path = PAUQ_DB_DIR / db_id / f"{db_id}.sqlite"
    if not db_path.exists():
        return None
    try:
        uri = f"file:{db_path}?mode=ro"
        conn = sqlite3.connect(uri, uri=True)
        cursor = conn.cursor()
        cursor.execute(sql)
        rows = cursor.fetchall()
        conn.close()
        return frozenset(tuple(str(v) for v in row) for row in rows)
    except Exception:
        return None

# ─── INFERENCE ────────────────────────────────────────────────────────────────

def generate_sql(question: str, schema: str) -> str:
    prompt = build_prompt(question, schema)
    inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=False,
            temperature=1.0,
            pad_token_id=tokenizer.eos_token_id,
        )
    
    generated = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
    return strip_artifacts(generated)

# ─── MAIN EVALUATION LOOP ─────────────────────────────────────────────────────

def main():
    print(f"Loading PAUQ {PAUQ_SPLIT} split...")
    dataset = load_dataset("ai-forever/PAUQ", split=PAUQ_SPLIT)
    print(f"Loaded {len(dataset)} examples")
    
    results = []
    em_correct = 0
    ex_correct = 0
    ex_total   = 0  # only count where execution was possible
    
    with open(RESULTS_FILE, "w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=[
            "idx", "db_id", "question", "gold_sql", "pred_sql",
            "em", "ex", "error"
        ])
        writer.writeheader()
        
        for idx, example in enumerate(tqdm(dataset, desc="Evaluating")):
            question  = example.get("question_ru") or example.get("question", "")
            gold_sql  = example.get("query", "")
            db_id     = example.get("db_id", "")
            
            try:
                schema   = get_schema(db_id)
                pred_sql = generate_sql(question, schema)
                
                # Exact Match
                norm_pred = normalize_sql(pred_sql)
                norm_gold = normalize_sql(gold_sql)
                em = int(norm_pred == norm_gold)
                
                # Execution Accuracy
                pred_result = execute_sql(pred_sql, db_id)
                gold_result = execute_sql(gold_sql, db_id)
                
                if gold_result is not None:
                    ex = int(pred_result == gold_result)
                    ex_correct += ex
                    ex_total   += 1
                else:
                    ex = None
                
                em_correct += em
                error = ""
                
            except Exception as e:
                pred_sql = ""
                em = 0
                ex = None
                error = str(e)[:200]
            
            row = {
                "idx": idx, "db_id": db_id, "question": question,
                "gold_sql": gold_sql, "pred_sql": pred_sql,
                "em": em, "ex": ex, "error": error
            }
            writer.writerow(row)
            results.append(row)
            
            # Progress every 100 examples
            if (idx + 1) % 100 == 0:
                cur_em = em_correct / (idx + 1)
                cur_ex = ex_correct / max(ex_total, 1)
                print(f"[{idx+1}/{len(dataset)}] EM={cur_em:.3f}  EX={cur_ex:.3f}")
    
    # Final summary
    n = len(dataset)
    final_em = em_correct / n
    final_ex = ex_correct / max(ex_total, 1)
    
    summary = {
        "model": ADAPTER_ID,
        "split": PAUQ_SPLIT,
        "n_examples": n,
        "exact_match": round(final_em, 4),
        "execution_accuracy": round(final_ex, 4),
        "em_correct": em_correct,
        "ex_correct": ex_correct,
        "ex_total": ex_total
    }
    
    with open(SUMMARY_FILE, "w", encoding="utf-8") as f:
        json.dump(summary, f, ensure_ascii=False, indent=2)
    
    print("\n" + "="*50)
    print(f"RESULTS on PAUQ {PAUQ_SPLIT} ({n} examples)")
    print(f"  Exact Match (EM):          {final_em:.1%}  ({em_correct}/{n})")
    print(f"  Execution Accuracy (EX):   {final_ex:.1%}  ({ex_correct}/{ex_total})")
    print(f"\nDetailed results saved to: {RESULTS_FILE}")
    print(f"Summary saved to:          {SUMMARY_FILE}")

if __name__ == "__main__":
    main()