| import gradio as gr |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| import torch |
|
|
| |
| MODEL_NAME = "cssupport/t5-small-awesome-text-to-sql" |
|
|
| print("Loading model...") |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) |
| model.eval() |
| print("Model ready.") |
|
|
| def generate_sql(schema: str, question: str) -> str: |
| if not schema.strip() or not question.strip(): |
| return "⚠️ Please provide both a schema and a question." |
|
|
| input_text = f"tables:\n{schema.strip()}\nquery for: {question.strip()}" |
| inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) |
|
|
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_length=256, |
| num_beams=4, |
| early_stopping=True, |
| ) |
|
|
| return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
| EXAMPLES = [ |
| [ |
| "CREATE TABLE employees (id INT, name VARCHAR(100), department VARCHAR(50), salary DECIMAL(10,2));", |
| "Find all employees in the engineering department with salary above 80000." |
| ], |
| [ |
| "CREATE TABLE orders (order_id INT, customer_id INT, product VARCHAR(100), amount DECIMAL(10,2), order_date DATE);", |
| "What is the total revenue per product?" |
| ], |
| [ |
| "CREATE TABLE students (id INT, name VARCHAR(100), gpa FLOAT, major VARCHAR(50));", |
| "List all students with GPA above 3.5 ordered by GPA descending." |
| ], |
| ] |
|
|
| with gr.Blocks(title="Text-to-SQL Demo", theme=gr.themes.Soft()) as demo: |
| gr.Markdown( |
| """ |
| # 🛢️ Text-to-SQL Demo |
| **Fine-tuned Mistral 7B (QLoRA) — 200% improvement in exact match over base model.** |
| |
| This demo uses a lightweight SQL model for live inference. |
| The full Mistral 7B fine-tuned adapter is available on the [model card](https://huggingface.co/Prateek-Kacham/mistral-7b-text2sql-qlora). |
| |
| Enter a database schema and a natural language question to generate SQL. |
| """ |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(): |
| schema_box = gr.Textbox( |
| label="Database Schema (DDL)", |
| placeholder="CREATE TABLE employees (\n id INT,\n name VARCHAR(100),\n salary DECIMAL\n);", |
| lines=8, |
| ) |
| question_box = gr.Textbox( |
| label="Natural Language Question", |
| placeholder="Find all employees earning more than $80,000.", |
| lines=3, |
| ) |
| run_btn = gr.Button("Generate SQL ⚡", variant="primary") |
|
|
| with gr.Column(): |
| output_box = gr.Textbox(label="Generated SQL", lines=10, interactive=False) |
|
|
| run_btn.click(fn=generate_sql, inputs=[schema_box, question_box], outputs=output_box) |
| question_box.submit(fn=generate_sql, inputs=[schema_box, question_box], outputs=output_box) |
|
|
| gr.Examples(examples=EXAMPLES, inputs=[schema_box, question_box]) |
|
|
| gr.Markdown( |
| """ |
| --- |
| **Full Model:** Mistral-7B-Instruct-v0.3 + QLoRA adapter (r=16, α=32) | |
| **Trained on:** Gretel Synthetic Text-to-SQL (4,750 samples, ~61 min on A100) | |
| [GitHub](https://github.com/PrateekKacham/mistral-7b-text2sql-finetuning) | |
| [Model Card](https://huggingface.co/Prateek-Kacham/mistral-7b-text2sql-qlora) |
| """ |
| ) |
|
|
| demo.launch() |