import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM import torch MODELS = { "BM1_CS1_Syn (33M)": "withmartian/sql_interp_bm1_cs1_experiment_1.10", "BM1_CS2_Syn (33M)": "withmartian/sql_interp_bm1_cs2_experiment_2.10", "BM1_CS3_Syn (33M)": "withmartian/sql_interp_bm1_cs3_experiment_3.10", "BM1_CS4_Syn (33M)": "withmartian/sql_interp_bm1_cs4_dataset_synonyms_experiment_1.1", "BM1_CS5_Syn (33M)": "withmartian/sql_interp_bm1_cs5_dataset_synonyms_experiment_1.2", "BM2_CS1_Syn (0.5B)": "withmartian/sql_interp_bm2_cs1_experiment_4.3", "BM2_CS2_Syn (0.5B)": "withmartian/sql_interp_bm2_cs2_experiment_5.3", "BM2_CS3_Syn (0.5B)": "withmartian/sql_interp_bm2_cs3_experiment_6.3", "BM3_CS1_Syn (1B)": "withmartian/sql_interp_bm3_cs1_experiment_7.3", "BM3_CS2_Syn (1B)": "withmartian/sql_interp_bm3_cs2_experiment_8.3", "BM3_CS3_Syn (1B)": "withmartian/sql_interp_bm3_cs3_experiment_9.3", } model_cache = {} def load_model(model_name): if model_name not in model_cache: model_id = MODELS[model_name] tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float16, device_map="auto" ) model_cache[model_name] = (tokenizer, model) return model_cache[model_name] def generate_sql(model_name, instruction, schema, max_length=256, temperature=0.7): if not model_name or not instruction or not schema: return "Please fill in all fields and select a model" try: tokenizer, model = load_model(model_name) prompt = f"""### Instruction: {instruction} ### Context: {schema} ### Response:""" inputs = tokenizer(prompt, return_tensors="pt").to(model.device) outputs = model.generate( **inputs, max_length=max_length, temperature=temperature, do_sample=temperature > 0, pad_token_id=tokenizer.eos_token_id ) generated = tokenizer.decode(outputs[0], skip_special_tokens=True) if "### Response:" in generated: sql = generated.split("### Response:")[-1].strip() else: sql = generated.strip() return sql except Exception as e: return f"Error: {str(e)}" # Only ONE example examples = [ [ "BM2_CS2_Syn (0.5B)", "List worker earnings from highest to lowest", "CREATE TABLE employees (name VARCHAR(100), salary INT, department VARCHAR(100))" ], ] def model_demo(shared_instruction, shared_schema): gr.HTML("""

Interactive SQL Generation

Transform natural language into SQL using mechanistically interpretable models

""") gr.HTML("""

How it works: Select a model, describe your query in plain English, and watch the model generate SQL.

""") with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Configuration") model_dropdown = gr.Dropdown( choices=list(MODELS.keys()), value="BM2_CS2_Syn (0.5B)", label="Model Selection", info="Larger models = better accuracy" ) # Better formatted guide gr.HTML("""

Model Guide

BM1 (33M parameters)
Lightning fast, ideal for simple queries
BM2 (0.5B parameters)
Balanced performance and speed
BM3 (1B parameters)
Most accurate for complex queries
Dataset Complexity
CS1: Basic SELECT-FROM
CS2: Adds ORDER BY
CS3: Aggregations
CS4: WHERE filters
CS5: Multi-table JOINs
""") with gr.Column(scale=2): gr.Markdown("### Your Query") instruction = gr.Textbox( label="Natural Language Query", placeholder="e.g., Find all employees earning more than 50000 sorted by name", lines=2 ) # Database schema as Code (SQL formatted) schema = gr.Code( label="Database Schema", language="sql", value="CREATE TABLE employees (name VARCHAR(100), salary INT, department VARCHAR(100))", lines=3 ) with gr.Row(): max_length = gr.Slider(64, 512, value=256, step=32, label="Max Length", info="Token limit") temperature = gr.Slider(0.0, 1.0, value=0.1, step=0.1, label="Temperature", info="Creativity level") generate_btn = gr.Button("Generate SQL", variant="primary", size="lg") # Output as Code (SQL formatted) output = gr.Code(label="Generated SQL Query", language="sql", lines=6) gr.Markdown("### Example Query") # Custom examples with 2-line schema display gr.HTML("""
Model: BM2_CS2_Syn (0.5B)
Query: List worker earnings from highest to lowest
Schema:
CREATE TABLE employees (
    name VARCHAR(100), salary INT, department VARCHAR(100)
)
""") shared_instruction.change( fn=lambda x: x, inputs=shared_instruction, outputs=instruction ) shared_schema.change( fn=lambda x: x, inputs=shared_schema, outputs=schema ) generate_btn.click( fn=generate_sql, inputs=[model_dropdown, instruction, schema, max_length, temperature], outputs=output ) return {'instruction': instruction, 'schema': schema, 'output': output}