Spaces:
Running
Running
File size: 3,090 Bytes
dc59b01 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 | import torch
import sqlite3
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# --------------------------------------------------
# PATH
# --------------------------------------------------
MODEL_PATH = "outputs/model"
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
print("Loading fine-tuned model...")
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_PATH)
model.eval()
# --------------------------------------------------
# CONNECT DATABASE
# --------------------------------------------------
print("Connecting to database...")
# conn = sqlite3.connect("../data/database/department_management/department_management.sqlite")
conn = sqlite3.connect("data/database/department_management/department_management.sqlite")
cursor = conn.cursor()
print("Database connected ✔")
# --------------------------------------------------
# BUILD PROMPT
# --------------------------------------------------
def build_prompt(question):
schema = """
Table department columns = Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees.
Table head columns = head_ID, name, born_state, age.
Table management columns = department_ID, head_ID, temporary_acting.
"""
return f"translate English to SQL: {schema} question: {question}"
# --------------------------------------------------
# GENERATE SQL
# --------------------------------------------------
def generate_sql(question):
prompt = build_prompt(question)
encoding = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
padding=True,
max_length=256
)
with torch.no_grad():
outputs = model.generate(
input_ids=encoding["input_ids"],
attention_mask=encoding["attention_mask"],
max_length=256,
num_beams=5,
early_stopping=True
)
sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
return sql.strip()
# --------------------------------------------------
# EVALUATE SQL (REWARD FUNCTION)
# --------------------------------------------------
def evaluate_sql(sql):
try:
cursor.execute(sql)
rows = cursor.fetchall()
# executed but no useful result
if len(rows) == 0:
return -0.2, rows
# good query
else:
return 1.0, rows
except Exception as e:
# invalid SQL
return -1.0, str(e)
# --------------------------------------------------
# INTERACTIVE LOOP
# --------------------------------------------------
while True:
q = input("\nAsk question (type exit to quit): ")
if q.lower() == "exit":
break
sql = generate_sql(q)
print("\nPredicted SQL:")
print(sql)
# ---------------- RUN SQL + REWARD ----------------
reward, output = evaluate_sql(sql)
print("\nReward:", reward)
if reward == -1.0:
print("SQL Error:", output)
elif reward == -0.2:
print("No results found")
else:
print("\nAnswer:")
for r in output:
print(r)
|