Prateek-Kacham's picture
Create app.py
04ce8b2 verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
# Using a lightweight Text-to-SQL model that runs on CPU
MODEL_NAME = "cssupport/t5-small-awesome-text-to-sql"
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
model.eval()
print("Model ready.")
def generate_sql(schema: str, question: str) -> str:
if not schema.strip() or not question.strip():
return "⚠️ Please provide both a schema and a question."
input_text = f"tables:\n{schema.strip()}\nquery for: {question.strip()}"
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=256,
num_beams=4,
early_stopping=True,
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
EXAMPLES = [
[
"CREATE TABLE employees (id INT, name VARCHAR(100), department VARCHAR(50), salary DECIMAL(10,2));",
"Find all employees in the engineering department with salary above 80000."
],
[
"CREATE TABLE orders (order_id INT, customer_id INT, product VARCHAR(100), amount DECIMAL(10,2), order_date DATE);",
"What is the total revenue per product?"
],
[
"CREATE TABLE students (id INT, name VARCHAR(100), gpa FLOAT, major VARCHAR(50));",
"List all students with GPA above 3.5 ordered by GPA descending."
],
]
with gr.Blocks(title="Text-to-SQL Demo", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 🛢️ Text-to-SQL Demo
**Fine-tuned Mistral 7B (QLoRA) — 200% improvement in exact match over base model.**
This demo uses a lightweight SQL model for live inference.
The full Mistral 7B fine-tuned adapter is available on the [model card](https://huggingface.co/Prateek-Kacham/mistral-7b-text2sql-qlora).
Enter a database schema and a natural language question to generate SQL.
"""
)
with gr.Row():
with gr.Column():
schema_box = gr.Textbox(
label="Database Schema (DDL)",
placeholder="CREATE TABLE employees (\n id INT,\n name VARCHAR(100),\n salary DECIMAL\n);",
lines=8,
)
question_box = gr.Textbox(
label="Natural Language Question",
placeholder="Find all employees earning more than $80,000.",
lines=3,
)
run_btn = gr.Button("Generate SQL ⚡", variant="primary")
with gr.Column():
output_box = gr.Textbox(label="Generated SQL", lines=10, interactive=False)
run_btn.click(fn=generate_sql, inputs=[schema_box, question_box], outputs=output_box)
question_box.submit(fn=generate_sql, inputs=[schema_box, question_box], outputs=output_box)
gr.Examples(examples=EXAMPLES, inputs=[schema_box, question_box])
gr.Markdown(
"""
---
**Full Model:** Mistral-7B-Instruct-v0.3 + QLoRA adapter (r=16, α=32) |
**Trained on:** Gretel Synthetic Text-to-SQL (4,750 samples, ~61 min on A100) |
[GitHub](https://github.com/PrateekKacham/mistral-7b-text2sql-finetuning) |
[Model Card](https://huggingface.co/Prateek-Kacham/mistral-7b-text2sql-qlora)
"""
)
demo.launch()