SQL_chatbot_API / app.py
saadkhi's picture
Update app.py
17bc164 verified
raw
history blame
2.04 kB
import warnings
warnings.filterwarnings("ignore")
import torch
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
torch.set_num_threads(1)
app = FastAPI()
BASE_MODEL = "distilgpt2"
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL)
model.eval()
print("Model ready")
# ─────────────────────────
# Request schema
# ─────────────────────────
class Query(BaseModel):
question: str
# ─────────────────────────
# SQL FILTER
# ─────────────────────────
SQL_KEYWORDS = [
"sql", "database", "table", "select", "insert",
"update", "delete", "join", "group by",
"postgres", "mysql", "sqlite", "query"
]
def is_sql_related(text):
text = text.lower()
return any(k in text for k in SQL_KEYWORDS)
# ─────────────────────────
# Endpoint
# ─────────────────────────
@app.post("/generate-sql")
def generate_sql(data: Query):
user_input = data.question
if not user_input.strip():
return {"error": "Empty input"}
if not is_sql_related(user_input):
return {"error": "Only SQL-related queries allowed"}
prompt = f"""
You are an expert SQL generator.
Only output SQL query.
User: {user_input}
SQL:
"""
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=80,
temperature=0.2,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
text = tokenizer.decode(output[0], skip_special_tokens=True)
result = text.split("SQL:")[-1].strip().split("\n")[0]
return {"sql": result}