Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| # Model configurations | |
| MODELS = { | |
| "BM1_CS1_Syn (33M)": "withmartian/sql_interp_bm1_cs1_experiment_1.10", | |
| "BM1_CS2_Syn (33M)": "withmartian/sql_interp_bm1_cs2_experiment_2.10", | |
| "BM1_CS3_Syn (33M)": "withmartian/sql_interp_bm1_cs3_experiment_3.10", | |
| "BM1_CS4_Syn (33M)": "withmartian/sql_interp_bm1_cs4_dataset_synonyms_experiment_1.1", | |
| "BM1_CS5_Syn (33M)": "withmartian/sql_interp_bm1_cs5_dataset_synonyms_experiment_1.2", | |
| "BM2_CS1_Syn (0.5B)": "withmartian/sql_interp_bm2_cs1_experiment_4.3", | |
| "BM2_CS2_Syn (0.5B)": "withmartian/sql_interp_bm2_cs2_experiment_5.3", | |
| "BM2_CS3_Syn (0.5B)": "withmartian/sql_interp_bm2_cs3_experiment_6.3", | |
| "BM3_CS1_Syn (1B)": "withmartian/sql_interp_bm3_cs1_experiment_7.3", | |
| "BM3_CS2_Syn (1B)": "withmartian/sql_interp_bm3_cs2_experiment_8.3", | |
| "BM3_CS3_Syn (1B)": "withmartian/sql_interp_bm3_cs3_experiment_9.3", | |
| } | |
| model_cache = {} | |
| def load_model(model_name): | |
| if model_name not in model_cache: | |
| model_id = MODELS[model_name] | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| model_cache[model_name] = (tokenizer, model) | |
| return model_cache[model_name] | |
| def generate_sql(model_name, instruction, schema, max_length=256, temperature=0.7): | |
| if not model_name or not instruction or not schema: | |
| return "⚠️ Please fill in all fields and select a model" | |
| try: | |
| tokenizer, model = load_model(model_name) | |
| prompt = f"""### Instruction: {instruction} | |
| ### Context: {schema} | |
| ### Response:""" | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| outputs = model.generate( | |
| **inputs, | |
| max_length=max_length, | |
| temperature=temperature, | |
| do_sample=temperature > 0, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| generated = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| if "### Response:" in generated: | |
| sql = generated.split("### Response:")[-1].strip() | |
| else: | |
| sql = generated.strip() | |
| return sql | |
| except Exception as e: | |
| return f"❌ Error: {str(e)}" | |
| # Example queries | |
| examples = [ | |
| [ | |
| "BM1_CS1_Syn (33M)", | |
| "Show me the name and salary from employees", | |
| "CREATE TABLE employees (name VARCHAR(100), salary INT, department VARCHAR(100))" | |
| ], | |
| [ | |
| "BM2_CS2_Syn (0.5B)", | |
| "List worker earnings from highest to lowest", | |
| "CREATE TABLE employees (name VARCHAR(100), salary INT, department VARCHAR(100))" | |
| ], | |
| [ | |
| "BM3_CS3_Syn (1B)", | |
| "Count how many employees in each department", | |
| "CREATE TABLE employees (name VARCHAR(100), salary INT, department VARCHAR(100))" | |
| ], | |
| ] | |
| # Custom CSS for beautiful styling | |
| custom_css = """ | |
| .gradio-container { | |
| font-family: 'Inter', sans-serif; | |
| } | |
| .header-section { | |
| text-align: center; | |
| padding: 2rem 0; | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| border-radius: 12px; | |
| margin-bottom: 2rem; | |
| color: white; | |
| } | |
| .logo-container { | |
| display: flex; | |
| justify-content: center; | |
| align-items: center; | |
| gap: 1rem; | |
| margin-bottom: 1rem; | |
| } | |
| .martian-badge { | |
| background: rgba(255, 255, 255, 0.2); | |
| padding: 0.5rem 1rem; | |
| border-radius: 20px; | |
| font-size: 0.9rem; | |
| backdrop-filter: blur(10px); | |
| } | |
| .info-box { | |
| background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); | |
| border-radius: 12px; | |
| padding: 1.5rem; | |
| margin: 1rem 0; | |
| border-left: 4px solid #667eea; | |
| } | |
| .citation-box { | |
| background: #f8f9fa; | |
| border: 1px solid #dee2e6; | |
| border-radius: 8px; | |
| padding: 1.5rem; | |
| margin: 2rem 0; | |
| font-family: 'Monaco', 'Courier New', monospace; | |
| font-size: 0.85rem; | |
| } | |
| .citation-header { | |
| font-weight: bold; | |
| color: #495057; | |
| margin-bottom: 0.5rem; | |
| display: flex; | |
| align-items: center; | |
| gap: 0.5rem; | |
| } | |
| .resource-links { | |
| display: flex; | |
| gap: 1rem; | |
| justify-content: center; | |
| margin: 1.5rem 0; | |
| flex-wrap: wrap; | |
| } | |
| .resource-link { | |
| background: white; | |
| padding: 0.75rem 1.5rem; | |
| border-radius: 8px; | |
| text-decoration: none; | |
| color: #667eea; | |
| border: 2px solid #667eea; | |
| font-weight: 500; | |
| transition: all 0.3s ease; | |
| } | |
| .resource-link:hover { | |
| background: #667eea; | |
| color: white; | |
| } | |
| footer { | |
| text-align: center; | |
| padding: 2rem 0; | |
| color: #6c757d; | |
| border-top: 1px solid #dee2e6; | |
| margin-top: 3rem; | |
| } | |
| """ | |
| # Create Gradio interface | |
| with gr.Blocks(css=custom_css, title="TinySQL Demo | Martian", theme=gr.themes.Soft()) as demo: | |
| # Header with Martian branding | |
| gr.HTML(""" | |
| <div class="header-section"> | |
| <div class="logo-container"> | |
| <h1 style="margin: 0; font-size: 2.5rem;">🔮 TinySQL Interactive Demo</h1> | |
| </div> | |
| <div class="martian-badge"> | |
| ⚡ Powered by Martian | |
| </div> | |
| <p style="font-size: 1.1rem; margin-top: 1rem; opacity: 0.9;"> | |
| Transform natural language into SQL queries using mechanistically interpretable models | |
| </p> | |
| </div> | |
| """) | |
| # Info box | |
| gr.HTML(""" | |
| <div class="info-box"> | |
| <strong>🎯 How it works:</strong> Select a model from our collection of 11 fine-tuned transformers, | |
| describe what you want in plain English, and watch as the model generates precise SQL queries. | |
| Each model is trained on progressively complex SQL operations—from basic SELECT statements to | |
| advanced JOINs and aggregations. | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 🎛️ Configuration") | |
| model_dropdown = gr.Dropdown( | |
| choices=list(MODELS.keys()), | |
| value="BM2_CS2_Syn (0.5B)", | |
| label="🤖 Model Selection", | |
| info="Larger models = better accuracy, slower inference" | |
| ) | |
| gr.Markdown(""" | |
| **Model Guide:** | |
| - 🟢 **BM1 (33M)**: Lightning fast, great for simple queries | |
| - 🟡 **BM2 (0.5B)**: Balanced performance and speed | |
| - 🔴 **BM3 (1B)**: Most accurate, handles complex queries | |
| **Dataset Complexity:** | |
| - **CS1**: Basic SELECT-FROM queries | |
| - **CS2**: Adds ORDER BY clauses | |
| - **CS3**: Aggregations (COUNT, SUM, AVG) | |
| - **CS4**: Adds WHERE filters | |
| - **CS5**: Multi-table JOINs | |
| """) | |
| with gr.Column(scale=2): | |
| gr.Markdown("### 💬 Your Query") | |
| instruction = gr.Textbox( | |
| label="What do you want to know?", | |
| placeholder="e.g., Find all employees earning more than $50,000 sorted by name", | |
| lines=2 | |
| ) | |
| schema = gr.Textbox( | |
| label="📋 Database Schema", | |
| placeholder="CREATE TABLE employees (name VARCHAR, salary INT, department VARCHAR)", | |
| lines=3, | |
| value="CREATE TABLE employees (name VARCHAR(100), salary INT, department VARCHAR(100))" | |
| ) | |
| with gr.Row(): | |
| max_length = gr.Slider( | |
| minimum=64, | |
| maximum=512, | |
| value=256, | |
| step=32, | |
| label="Max Length", | |
| info="Longer = more complex queries" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.1, | |
| step=0.1, | |
| label="Temperature", | |
| info="Higher = more creative (use 0.1 for accuracy)" | |
| ) | |
| generate_btn = gr.Button("✨ Generate SQL", variant="primary", size="lg") | |
| output = gr.Code( | |
| label="🎉 Generated SQL Query", | |
| language="sql", | |
| lines=8, | |
| ) | |
| gr.Markdown("### 💡 Try These Examples") | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[model_dropdown, instruction, schema], | |
| ) | |
| # Resource links | |
| gr.HTML(""" | |
| <div class="resource-links"> | |
| <a href="https://arxiv.org/abs/2503.12730" class="resource-link" target="_blank"> | |
| 📄 Read the Paper | |
| </a> | |
| <a href="https://github.com/withmartian/TinySQL" class="resource-link" target="_blank"> | |
| 💻 View Code | |
| </a> | |
| <a href="https://huggingface.co/collections/withmartian/tinysql-6760e92748b63fa56a6ffc9f" class="resource-link" target="_blank"> | |
| 🤗 Get Dataset & Models | |
| </a> | |
| <a href="https://withmartian.com" class="resource-link" target="_blank"> | |
| 🚀 Visit Martian | |
| </a> | |
| </div> | |
| """) | |
| # Citation box | |
| gr.HTML(""" | |
| <div class="citation-box"> | |
| <div class="citation-header"> | |
| 📚 Citation | |
| </div> | |
| <pre style="margin: 0; overflow-x: auto;">@misc{harrasse2025tinysqlprogressivetexttosqldataset, | |
| title={TinySQL: A Progressive Text-to-SQL Dataset for Mechanistic Interpretability Research}, | |
| author={Abir Harrasse and Philip Quirke and Clement Neo and Dhruv Nathawani and Luke Marks and Amir Abdullah}, | |
| year={2025}, | |
| eprint={2503.12730}, | |
| archivePrefix={arXiv}, | |
| primaryClass={cs.LG}, | |
| url={https://arxiv.org/abs/2503.12730} | |
| }</pre> | |
| </div> | |
| """) | |
| # Footer | |
| gr.HTML(""" | |
| <footer> | |
| <p style="margin: 0.5rem 0;"> | |
| Built with ❤️ by the Martian team | |
| </p> | |
| <p style="margin: 0; font-size: 0.9rem;"> | |
| Bridging the gap between toy tasks and real-world interpretability | |
| </p> | |
| </footer> | |
| """) | |
| generate_btn.click( | |
| fn=generate_sql, | |
| inputs=[model_dropdown, instruction, schema, max_length, temperature], | |
| outputs=output | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |