SQL_chatbot_API / app.py
saadkhi's picture
Update app.py
bb16527 verified
raw
history blame
1.96 kB
import warnings
warnings.filterwarnings("ignore")
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
torch.set_num_threads(1)
# ─────────────────────
# MODEL
# ─────────────────────
BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
device_map="cpu",
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
model.eval()
print("Model ready")
# ─────────────────────
# GENERATION
# ─────────────────────
def generate_sql(question):
if not question.strip():
return "Enter SQL question."
prompt = f"""
You are a SQL expert.
Convert the user request into SQL query only.
User: {question}
SQL:
"""
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=120,
temperature=0.2,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
)
text = tokenizer.decode(output[0], skip_special_tokens=True)
return text.split("SQL:")[-1].strip()
# ─────────────────────
# UI
# ─────────────────────
demo = gr.Interface(
fn=generate_sql,
inputs=gr.Textbox(lines=3, label="SQL Question"),
outputs=gr.Textbox(lines=8, label="Generated SQL"),
title="SQL Generator (Portfolio Demo)",
description="Fast CPU model for portfolio demo.",
examples=[
["Find duplicate emails in users table"],
["Top 5 highest paid employees"],
["Orders per customer last month"],
],
)
demo.launch(server_name="0.0.0.0")