tinysql-demo / app.py
abir-hr196's picture
new version
febdf85
raw
history blame
10.5 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Model configurations
MODELS = {
"BM1_CS1_Syn (33M)": "withmartian/sql_interp_bm1_cs1_experiment_1.10",
"BM1_CS2_Syn (33M)": "withmartian/sql_interp_bm1_cs2_experiment_2.10",
"BM1_CS3_Syn (33M)": "withmartian/sql_interp_bm1_cs3_experiment_3.10",
"BM1_CS4_Syn (33M)": "withmartian/sql_interp_bm1_cs4_dataset_synonyms_experiment_1.1",
"BM1_CS5_Syn (33M)": "withmartian/sql_interp_bm1_cs5_dataset_synonyms_experiment_1.2",
"BM2_CS1_Syn (0.5B)": "withmartian/sql_interp_bm2_cs1_experiment_4.3",
"BM2_CS2_Syn (0.5B)": "withmartian/sql_interp_bm2_cs2_experiment_5.3",
"BM2_CS3_Syn (0.5B)": "withmartian/sql_interp_bm2_cs3_experiment_6.3",
"BM3_CS1_Syn (1B)": "withmartian/sql_interp_bm3_cs1_experiment_7.3",
"BM3_CS2_Syn (1B)": "withmartian/sql_interp_bm3_cs2_experiment_8.3",
"BM3_CS3_Syn (1B)": "withmartian/sql_interp_bm3_cs3_experiment_9.3",
}
model_cache = {}
def load_model(model_name):
if model_name not in model_cache:
model_id = MODELS[model_name]
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto"
)
model_cache[model_name] = (tokenizer, model)
return model_cache[model_name]
def generate_sql(model_name, instruction, schema, max_length=256, temperature=0.7):
if not model_name or not instruction or not schema:
return "⚠️ Please fill in all fields and select a model"
try:
tokenizer, model = load_model(model_name)
prompt = f"""### Instruction: {instruction}
### Context: {schema}
### Response:"""
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_length=max_length,
temperature=temperature,
do_sample=temperature > 0,
pad_token_id=tokenizer.eos_token_id
)
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
if "### Response:" in generated:
sql = generated.split("### Response:")[-1].strip()
else:
sql = generated.strip()
return sql
except Exception as e:
return f"❌ Error: {str(e)}"
# Example queries
examples = [
[
"BM1_CS1_Syn (33M)",
"Show me the name and salary from employees",
"CREATE TABLE employees (name VARCHAR(100), salary INT, department VARCHAR(100))"
],
[
"BM2_CS2_Syn (0.5B)",
"List worker earnings from highest to lowest",
"CREATE TABLE employees (name VARCHAR(100), salary INT, department VARCHAR(100))"
],
[
"BM3_CS3_Syn (1B)",
"Count how many employees in each department",
"CREATE TABLE employees (name VARCHAR(100), salary INT, department VARCHAR(100))"
],
]
# Custom CSS for beautiful styling
custom_css = """
.gradio-container {
font-family: 'Inter', sans-serif;
}
.header-section {
text-align: center;
padding: 2rem 0;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
border-radius: 12px;
margin-bottom: 2rem;
color: white;
}
.logo-container {
display: flex;
justify-content: center;
align-items: center;
gap: 1rem;
margin-bottom: 1rem;
}
.martian-badge {
background: rgba(255, 255, 255, 0.2);
padding: 0.5rem 1rem;
border-radius: 20px;
font-size: 0.9rem;
backdrop-filter: blur(10px);
}
.info-box {
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
border-radius: 12px;
padding: 1.5rem;
margin: 1rem 0;
border-left: 4px solid #667eea;
}
.citation-box {
background: #f8f9fa;
border: 1px solid #dee2e6;
border-radius: 8px;
padding: 1.5rem;
margin: 2rem 0;
font-family: 'Monaco', 'Courier New', monospace;
font-size: 0.85rem;
}
.citation-header {
font-weight: bold;
color: #495057;
margin-bottom: 0.5rem;
display: flex;
align-items: center;
gap: 0.5rem;
}
.resource-links {
display: flex;
gap: 1rem;
justify-content: center;
margin: 1.5rem 0;
flex-wrap: wrap;
}
.resource-link {
background: white;
padding: 0.75rem 1.5rem;
border-radius: 8px;
text-decoration: none;
color: #667eea;
border: 2px solid #667eea;
font-weight: 500;
transition: all 0.3s ease;
}
.resource-link:hover {
background: #667eea;
color: white;
}
footer {
text-align: center;
padding: 2rem 0;
color: #6c757d;
border-top: 1px solid #dee2e6;
margin-top: 3rem;
}
"""
# Create Gradio interface
with gr.Blocks(css=custom_css, title="TinySQL Demo | Martian", theme=gr.themes.Soft()) as demo:
# Header with Martian branding
gr.HTML("""
<div class="header-section">
<div class="logo-container">
<h1 style="margin: 0; font-size: 2.5rem;">🔮 TinySQL Interactive Demo</h1>
</div>
<div class="martian-badge">
⚡ Powered by Martian
</div>
<p style="font-size: 1.1rem; margin-top: 1rem; opacity: 0.9;">
Transform natural language into SQL queries using mechanistically interpretable models
</p>
</div>
""")
# Info box
gr.HTML("""
<div class="info-box">
<strong>🎯 How it works:</strong> Select a model from our collection of 11 fine-tuned transformers,
describe what you want in plain English, and watch as the model generates precise SQL queries.
Each model is trained on progressively complex SQL operations—from basic SELECT statements to
advanced JOINs and aggregations.
</div>
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 🎛️ Configuration")
model_dropdown = gr.Dropdown(
choices=list(MODELS.keys()),
value="BM2_CS2_Syn (0.5B)",
label="🤖 Model Selection",
info="Larger models = better accuracy, slower inference"
)
gr.Markdown("""
**Model Guide:**
- 🟢 **BM1 (33M)**: Lightning fast, great for simple queries
- 🟡 **BM2 (0.5B)**: Balanced performance and speed
- 🔴 **BM3 (1B)**: Most accurate, handles complex queries
**Dataset Complexity:**
- **CS1**: Basic SELECT-FROM queries
- **CS2**: Adds ORDER BY clauses
- **CS3**: Aggregations (COUNT, SUM, AVG)
- **CS4**: Adds WHERE filters
- **CS5**: Multi-table JOINs
""")
with gr.Column(scale=2):
gr.Markdown("### 💬 Your Query")
instruction = gr.Textbox(
label="What do you want to know?",
placeholder="e.g., Find all employees earning more than $50,000 sorted by name",
lines=2
)
schema = gr.Textbox(
label="📋 Database Schema",
placeholder="CREATE TABLE employees (name VARCHAR, salary INT, department VARCHAR)",
lines=3,
value="CREATE TABLE employees (name VARCHAR(100), salary INT, department VARCHAR(100))"
)
with gr.Row():
max_length = gr.Slider(
minimum=64,
maximum=512,
value=256,
step=32,
label="Max Length",
info="Longer = more complex queries"
)
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.1,
step=0.1,
label="Temperature",
info="Higher = more creative (use 0.1 for accuracy)"
)
generate_btn = gr.Button("✨ Generate SQL", variant="primary", size="lg")
output = gr.Code(
label="🎉 Generated SQL Query",
language="sql",
lines=8,
)
gr.Markdown("### 💡 Try These Examples")
gr.Examples(
examples=examples,
inputs=[model_dropdown, instruction, schema],
)
# Resource links
gr.HTML("""
<div class="resource-links">
<a href="https://arxiv.org/abs/2503.12730" class="resource-link" target="_blank">
📄 Read the Paper
</a>
<a href="https://github.com/withmartian/TinySQL" class="resource-link" target="_blank">
💻 View Code
</a>
<a href="https://huggingface.co/collections/withmartian/tinysql-6760e92748b63fa56a6ffc9f" class="resource-link" target="_blank">
🤗 Get Dataset & Models
</a>
<a href="https://withmartian.com" class="resource-link" target="_blank">
🚀 Visit Martian
</a>
</div>
""")
# Citation box
gr.HTML("""
<div class="citation-box">
<div class="citation-header">
📚 Citation
</div>
<pre style="margin: 0; overflow-x: auto;">@misc{harrasse2025tinysqlprogressivetexttosqldataset,
title={TinySQL: A Progressive Text-to-SQL Dataset for Mechanistic Interpretability Research},
author={Abir Harrasse and Philip Quirke and Clement Neo and Dhruv Nathawani and Luke Marks and Amir Abdullah},
year={2025},
eprint={2503.12730},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2503.12730}
}</pre>
</div>
""")
# Footer
gr.HTML("""
<footer>
<p style="margin: 0.5rem 0;">
Built with ❤️ by the Martian team
</p>
<p style="margin: 0; font-size: 0.9rem;">
Bridging the gap between toy tasks and real-world interpretability
</p>
</footer>
""")
generate_btn.click(
fn=generate_sql,
inputs=[model_dropdown, instruction, schema, max_length, temperature],
outputs=output
)
if __name__ == "__main__":
demo.launch()