SQL_chatbot_API / app.py
saadkhi's picture
Update app.py
5d261f7 verified
raw
history blame
2.6 kB
import warnings
warnings.filterwarnings("ignore")
import torch
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import threading
torch.set_num_threads(1)
app = FastAPI()
BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL)
model.eval()
print("Model ready")
# ─────────────────────────
# SQL FILTER
# ─────────────────────────
SQL_KEYWORDS = [
"sql", "database", "table", "select", "insert",
"update", "delete", "join", "group by",
"postgres", "mysql", "sqlite", "query"
]
def is_sql_related(text):
return any(k in text.lower() for k in SQL_KEYWORDS)
SYSTEM_PROMPT = """
You are an expert SQL generator.
Only output SQL query.
"""
def generate_sql(user_input: str):
if not user_input.strip():
return "Enter SQL question."
if not is_sql_related(user_input):
return "Only SQL/database questions allowed."
prompt = f"{SYSTEM_PROMPT}\nUser: {user_input}\nSQL:"
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=120,
temperature=0.1,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
)
text = tokenizer.decode(output[0], skip_special_tokens=True)
result = text.split("SQL:")[-1].strip().split("\n")[0]
return result
# ─────────────────────────
# FastAPI Routes
# ─────────────────────────
class Query(BaseModel):
text: str
@app.get("/")
def root():
return {"status": "API running"}
@app.post("/generate")
def generate(query: Query):
return {"result": generate_sql(query.text)}
# ─────────────────────────
# Gradio UI (for testing)
# ─────────────────────────
def launch_gradio():
demo = gr.Interface(
fn=generate_sql,
inputs=gr.Textbox(lines=3, label="SQL Question"),
outputs=gr.Textbox(lines=6, label="Generated SQL"),
title="SQL Generator Test UI"
)
demo.launch(server_name="0.0.0.0", server_port=7861)
# Run Gradio in parallel thread
threading.Thread(target=launch_gradio).start()