SQL_chatbot_API / app.py
saadkhi's picture
Update app.py
cc1250f verified
raw
history blame
2.84 kB
import warnings
warnings.filterwarnings("ignore")
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
# Reduce CPU pressure
torch.set_num_threads(1)
# βœ… Use lightweight model (IMPORTANT)
BASE_MODEL = "distilgpt2"
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL)
model.eval()
print("Model ready")
# ─────────────────────────
# SQL FILTER
# ─────────────────────────
SQL_KEYWORDS = [
"sql", "database", "table", "select", "insert",
"update", "delete", "join", "group by",
"postgres", "mysql", "sqlite", "query"
]
def is_sql_related(text):
text = text.lower()
return any(k in text for k in SQL_KEYWORDS)
# ─────────────────────────
# PROMPT
# ─────────────────────────
SYSTEM_PROMPT = """
You are an expert SQL generator.
Rules:
- Only respond to SQL or database related questions.
- Output ONLY SQL query.
- No explanation.
"""
# ─────────────────────────
# GENERATION
# ─────────────────────────
def generate_sql(user_input):
if not user_input.strip():
return "Enter SQL question."
if not is_sql_related(user_input):
return "Only SQL/database questions are allowed."
prompt = f"{SYSTEM_PROMPT}\nUser: {user_input}\nSQL:"
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=80,
temperature=0.2,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
text = tokenizer.decode(output[0], skip_special_tokens=True)
result = text.split("SQL:")[-1].strip()
result = result.split("\n")[0]
return result
# ─────────────────────────
# UI
# ─────────────────────────
demo = gr.Interface(
fn=generate_sql,
inputs=gr.Textbox(
lines=3,
label="SQL Question",
placeholder="Find duplicate emails in users table"
),
outputs=gr.Textbox(
lines=6,
label="Generated SQL"
),
title="AI SQL Generator (Portfolio Project)",
description="Only SQL/database queries are supported.",
examples=[
["Find duplicate emails in users table"],
["Top 5 highest paid employees"],
["Count orders per customer last month"],
["Write a joke about cats"]
],
)
demo.launch(server_name="0.0.0.0", server_port=7860)