Spaces:
Running
Running
File size: 4,349 Bytes
dc59b01 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | import json
import sqlite3
from pathlib import Path
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import PeftModel
PROJECT_ROOT = Path(__file__).resolve().parents[1]
DB_ROOT = PROJECT_ROOT / "data" / "database"
# Added CUDA fallback for consistency
DEVICE = "mps" if torch.backends.mps.is_available() else (
"cuda" if torch.cuda.is_available() else "cpu"
)
# ================= LOAD MODEL =================
def load_model(adapter_path):
base_name = "Salesforce/codet5-base"
# 🐛 FIXED: Convert relative path to absolute path to prevent Hugging Face 404 errors
abs_path = (PROJECT_ROOT / adapter_path).resolve()
if not abs_path.exists():
raise FileNotFoundError(f"Adapter not found at: {abs_path}")
print(f"\nLoading model from: {abs_path}")
# 🐛 FIXED: Added fallback in case tokenizer isn't saved in the adapter folder
try:
tokenizer = AutoTokenizer.from_pretrained(str(abs_path), local_files_only=True)
except Exception:
print("Adapter tokenizer missing — using base tokenizer")
tokenizer = AutoTokenizer.from_pretrained(base_name)
base = AutoModelForSeq2SeqLM.from_pretrained(base_name).to(DEVICE)
model = PeftModel.from_pretrained(base, str(abs_path)).to(DEVICE)
model.eval()
return tokenizer, model
# ================= SCHEMA =================
def load_schema(db_id):
db_path = DB_ROOT / db_id / f"{db_id}.sqlite"
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
tables = cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table';"
).fetchall()
schema = ""
for (table,) in tables:
cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
col_names = [c[1] for c in cols]
schema += f"{table}({', '.join(col_names)})\n"
conn.close()
return schema
# ================= GENERATE =================
def generate_sql(tokenizer, model, question, db_id):
schema = load_schema(db_id)
prompt = f"""
Database Schema:
{schema}
Translate English to SQL:
{question}
SQL:
"""
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=120,
num_beams=4,
do_sample=False
)
sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
if "SQL:" in sql:
sql = sql.split("SQL:")[-1]
return sql.strip()
# ================= EXECUTE =================
def try_execute(sql, db_id):
db_path = DB_ROOT / db_id / f"{db_id}.sqlite"
try:
conn = sqlite3.connect(db_path)
cur = conn.cursor()
cur.execute(sql)
cur.fetchall()
conn.close()
return True
except:
return False
# ================= MAIN =================
def main():
# paths (change if needed)
SFT_MODEL = "checkpoints/sft_adapter_codet5" # Ensure this matches your actual SFT folder name!
RLHF_MODEL = "checkpoints/best_rlhf_model"
tokenizer_sft, model_sft = load_model(SFT_MODEL)
tokenizer_rl, model_rl = load_model(RLHF_MODEL)
human_eval_path = PROJECT_ROOT / "data/human_eval.json"
with open(human_eval_path) as f:
questions = json.load(f)
sft_success = 0
rl_success = 0
print("\nRunning Human Evaluation...\n")
for i, q in enumerate(questions, 1):
db = q["db_id"]
question = q["question"]
sql_sft = generate_sql(tokenizer_sft, model_sft, question, db)
sql_rl = generate_sql(tokenizer_rl, model_rl, question, db)
ok_sft = try_execute(sql_sft, db)
ok_rl = try_execute(sql_rl, db)
if ok_sft:
sft_success += 1
if ok_rl:
rl_success += 1
print(f"\nQ{i}: {question}")
print(f"SFT : {'OK' if ok_sft else 'FAIL'}")
print(f"RLHF: {'OK' if ok_rl else 'FAIL'}")
print("\n=============================")
print("HUMAN EVALUATION RESULT")
print("=============================")
print(f"SFT Success: {sft_success}/{len(questions)} = {sft_success/len(questions)*100:.2f}%")
print(f"RLHF Success: {rl_success}/{len(questions)} = {rl_success/len(questions)*100:.2f}%")
print("=============================\n")
if __name__ == "__main__":
main() |