abir-hr196 commited on
Commit
febdf85
ยท
1 Parent(s): 38a8f52

new version

Browse files
Files changed (1) hide show
  1. app.py +202 -43
app.py CHANGED
@@ -17,11 +17,9 @@ MODELS = {
17
  "BM3_CS3_Syn (1B)": "withmartian/sql_interp_bm3_cs3_experiment_9.3",
18
  }
19
 
20
- # Cache loaded models
21
  model_cache = {}
22
 
23
  def load_model(model_name):
24
- """Load model and tokenizer with caching"""
25
  if model_name not in model_cache:
26
  model_id = MODELS[model_name]
27
  tokenizer = AutoTokenizer.from_pretrained(model_id)
@@ -34,19 +32,18 @@ def load_model(model_name):
34
  return model_cache[model_name]
35
 
36
  def generate_sql(model_name, instruction, schema, max_length=256, temperature=0.7):
37
- """Generate SQL query from natural language"""
 
 
38
  try:
39
  tokenizer, model = load_model(model_name)
40
 
41
- # Format prompt
42
  prompt = f"""### Instruction: {instruction}
43
  ### Context: {schema}
44
  ### Response:"""
45
 
46
- # Tokenize
47
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
48
 
49
- # Generate
50
  outputs = model.generate(
51
  **inputs,
52
  max_length=max_length,
@@ -55,10 +52,8 @@ def generate_sql(model_name, instruction, schema, max_length=256, temperature=0.
55
  pad_token_id=tokenizer.eos_token_id
56
  )
57
 
58
- # Decode
59
  generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
60
 
61
- # Extract only the SQL response
62
  if "### Response:" in generated:
63
  sql = generated.split("### Response:")[-1].strip()
64
  else:
@@ -67,12 +62,12 @@ def generate_sql(model_name, instruction, schema, max_length=256, temperature=0.
67
  return sql
68
 
69
  except Exception as e:
70
- return f"Error: {str(e)}"
71
 
72
  # Example queries
73
  examples = [
74
  [
75
- "BM1_CS1 (33M)",
76
  "Show me the name and salary from employees",
77
  "CREATE TABLE employees (name VARCHAR(100), salary INT, department VARCHAR(100))"
78
  ],
@@ -82,44 +77,167 @@ examples = [
82
  "CREATE TABLE employees (name VARCHAR(100), salary INT, department VARCHAR(100))"
83
  ],
84
  [
85
- "BM3_CS3 (1B)",
86
  "Count how many employees in each department",
87
  "CREATE TABLE employees (name VARCHAR(100), salary INT, department VARCHAR(100))"
88
  ],
89
  ]
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  # Create Gradio interface
92
- with gr.Blocks(title="TinySQL Demo") as demo:
93
- gr.Markdown("""
94
- # ๐Ÿ” TinySQL: Text-to-SQL Generation Demo
95
 
96
- Generate SQL queries from natural language using models trained on TinySQL.
97
- Select a model, provide a natural language instruction and database schema, then click **Generate**.
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
- **Model Types:**
100
- - **BM1** (33M params): TinyStories-based, fastest
101
- - **BM2** (0.5B params): Qwen2.5-based, balanced
102
- - **BM3** (1B params): Llama-3.2-based, most accurate
103
- - **Syn** variants: Trained on synonym dataset (handles semantic mappings)
 
 
 
104
  """)
105
 
106
  with gr.Row():
107
- with gr.Column(scale=2):
 
 
108
  model_dropdown = gr.Dropdown(
109
  choices=list(MODELS.keys()),
110
- value="BM2_CS1_Syn (0.5B)",
111
- label="Select Model",
112
- info="Choose model size and training dataset"
113
  )
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  instruction = gr.Textbox(
116
- label="Natural Language Query",
117
- placeholder="e.g., Show me all employees with salary greater than 50000",
118
  lines=2
119
  )
120
 
121
  schema = gr.Textbox(
122
- label="Database Schema",
123
  placeholder="CREATE TABLE employees (name VARCHAR, salary INT, department VARCHAR)",
124
  lines=3,
125
  value="CREATE TABLE employees (name VARCHAR(100), salary INT, department VARCHAR(100))"
@@ -131,39 +249,80 @@ with gr.Blocks(title="TinySQL Demo") as demo:
131
  maximum=512,
132
  value=256,
133
  step=32,
134
- label="Max Length"
 
135
  )
136
  temperature = gr.Slider(
137
  minimum=0.0,
138
  maximum=1.0,
139
  value=0.1,
140
  step=0.1,
141
- label="Temperature"
 
142
  )
143
 
144
- generate_btn = gr.Button("Generate SQL", variant="primary")
145
-
146
- with gr.Column(scale=1):
147
- output = gr.Textbox(
148
- label="Generated SQL",
149
- lines=10,
150
- placeholder="SQL query will appear here..."
151
  )
152
 
153
- gr.Markdown("### Example Queries")
154
  gr.Examples(
155
  examples=examples,
156
  inputs=[model_dropdown, instruction, schema],
157
  )
158
 
159
- gr.Markdown("""
160
- ---
161
- **Paper:** [TinySQL: A Progressive Text-to-SQL Dataset for Mechanistic Interpretability Research](https://arxiv.org/abs/2503.12730)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
- **Resources:** [GitHub](https://github.com/withmartian/TinySQL) | [Dataset](https://huggingface.co/collections/withmartian/tinysql-6760e92748b63fa56a6ffc9f)
 
 
 
 
 
 
 
 
 
164
  """)
165
 
166
- # Connect button
167
  generate_btn.click(
168
  fn=generate_sql,
169
  inputs=[model_dropdown, instruction, schema, max_length, temperature],
 
17
  "BM3_CS3_Syn (1B)": "withmartian/sql_interp_bm3_cs3_experiment_9.3",
18
  }
19
 
 
20
  model_cache = {}
21
 
22
  def load_model(model_name):
 
23
  if model_name not in model_cache:
24
  model_id = MODELS[model_name]
25
  tokenizer = AutoTokenizer.from_pretrained(model_id)
 
32
  return model_cache[model_name]
33
 
34
  def generate_sql(model_name, instruction, schema, max_length=256, temperature=0.7):
35
+ if not model_name or not instruction or not schema:
36
+ return "โš ๏ธ Please fill in all fields and select a model"
37
+
38
  try:
39
  tokenizer, model = load_model(model_name)
40
 
 
41
  prompt = f"""### Instruction: {instruction}
42
  ### Context: {schema}
43
  ### Response:"""
44
 
 
45
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
46
 
 
47
  outputs = model.generate(
48
  **inputs,
49
  max_length=max_length,
 
52
  pad_token_id=tokenizer.eos_token_id
53
  )
54
 
 
55
  generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
56
 
 
57
  if "### Response:" in generated:
58
  sql = generated.split("### Response:")[-1].strip()
59
  else:
 
62
  return sql
63
 
64
  except Exception as e:
65
+ return f"โŒ Error: {str(e)}"
66
 
67
  # Example queries
68
  examples = [
69
  [
70
+ "BM1_CS1_Syn (33M)",
71
  "Show me the name and salary from employees",
72
  "CREATE TABLE employees (name VARCHAR(100), salary INT, department VARCHAR(100))"
73
  ],
 
77
  "CREATE TABLE employees (name VARCHAR(100), salary INT, department VARCHAR(100))"
78
  ],
79
  [
80
+ "BM3_CS3_Syn (1B)",
81
  "Count how many employees in each department",
82
  "CREATE TABLE employees (name VARCHAR(100), salary INT, department VARCHAR(100))"
83
  ],
84
  ]
85
 
86
+ # Custom CSS for beautiful styling
87
+ custom_css = """
88
+ .gradio-container {
89
+ font-family: 'Inter', sans-serif;
90
+ }
91
+
92
+ .header-section {
93
+ text-align: center;
94
+ padding: 2rem 0;
95
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
96
+ border-radius: 12px;
97
+ margin-bottom: 2rem;
98
+ color: white;
99
+ }
100
+
101
+ .logo-container {
102
+ display: flex;
103
+ justify-content: center;
104
+ align-items: center;
105
+ gap: 1rem;
106
+ margin-bottom: 1rem;
107
+ }
108
+
109
+ .martian-badge {
110
+ background: rgba(255, 255, 255, 0.2);
111
+ padding: 0.5rem 1rem;
112
+ border-radius: 20px;
113
+ font-size: 0.9rem;
114
+ backdrop-filter: blur(10px);
115
+ }
116
+
117
+ .info-box {
118
+ background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
119
+ border-radius: 12px;
120
+ padding: 1.5rem;
121
+ margin: 1rem 0;
122
+ border-left: 4px solid #667eea;
123
+ }
124
+
125
+ .citation-box {
126
+ background: #f8f9fa;
127
+ border: 1px solid #dee2e6;
128
+ border-radius: 8px;
129
+ padding: 1.5rem;
130
+ margin: 2rem 0;
131
+ font-family: 'Monaco', 'Courier New', monospace;
132
+ font-size: 0.85rem;
133
+ }
134
+
135
+ .citation-header {
136
+ font-weight: bold;
137
+ color: #495057;
138
+ margin-bottom: 0.5rem;
139
+ display: flex;
140
+ align-items: center;
141
+ gap: 0.5rem;
142
+ }
143
+
144
+ .resource-links {
145
+ display: flex;
146
+ gap: 1rem;
147
+ justify-content: center;
148
+ margin: 1.5rem 0;
149
+ flex-wrap: wrap;
150
+ }
151
+
152
+ .resource-link {
153
+ background: white;
154
+ padding: 0.75rem 1.5rem;
155
+ border-radius: 8px;
156
+ text-decoration: none;
157
+ color: #667eea;
158
+ border: 2px solid #667eea;
159
+ font-weight: 500;
160
+ transition: all 0.3s ease;
161
+ }
162
+
163
+ .resource-link:hover {
164
+ background: #667eea;
165
+ color: white;
166
+ }
167
+
168
+ footer {
169
+ text-align: center;
170
+ padding: 2rem 0;
171
+ color: #6c757d;
172
+ border-top: 1px solid #dee2e6;
173
+ margin-top: 3rem;
174
+ }
175
+ """
176
+
177
  # Create Gradio interface
178
+ with gr.Blocks(css=custom_css, title="TinySQL Demo | Martian", theme=gr.themes.Soft()) as demo:
 
 
179
 
180
+ # Header with Martian branding
181
+ gr.HTML("""
182
+ <div class="header-section">
183
+ <div class="logo-container">
184
+ <h1 style="margin: 0; font-size: 2.5rem;">๐Ÿ”ฎ TinySQL Interactive Demo</h1>
185
+ </div>
186
+ <div class="martian-badge">
187
+ โšก Powered by Martian
188
+ </div>
189
+ <p style="font-size: 1.1rem; margin-top: 1rem; opacity: 0.9;">
190
+ Transform natural language into SQL queries using mechanistically interpretable models
191
+ </p>
192
+ </div>
193
+ """)
194
 
195
+ # Info box
196
+ gr.HTML("""
197
+ <div class="info-box">
198
+ <strong>๐ŸŽฏ How it works:</strong> Select a model from our collection of 11 fine-tuned transformers,
199
+ describe what you want in plain English, and watch as the model generates precise SQL queries.
200
+ Each model is trained on progressively complex SQL operationsโ€”from basic SELECT statements to
201
+ advanced JOINs and aggregations.
202
+ </div>
203
  """)
204
 
205
  with gr.Row():
206
+ with gr.Column(scale=1):
207
+ gr.Markdown("### ๐ŸŽ›๏ธ Configuration")
208
+
209
  model_dropdown = gr.Dropdown(
210
  choices=list(MODELS.keys()),
211
+ value="BM2_CS2_Syn (0.5B)",
212
+ label="๐Ÿค– Model Selection",
213
+ info="Larger models = better accuracy, slower inference"
214
  )
215
 
216
+ gr.Markdown("""
217
+ **Model Guide:**
218
+ - ๐ŸŸข **BM1 (33M)**: Lightning fast, great for simple queries
219
+ - ๐ŸŸก **BM2 (0.5B)**: Balanced performance and speed
220
+ - ๐Ÿ”ด **BM3 (1B)**: Most accurate, handles complex queries
221
+
222
+ **Dataset Complexity:**
223
+ - **CS1**: Basic SELECT-FROM queries
224
+ - **CS2**: Adds ORDER BY clauses
225
+ - **CS3**: Aggregations (COUNT, SUM, AVG)
226
+ - **CS4**: Adds WHERE filters
227
+ - **CS5**: Multi-table JOINs
228
+ """)
229
+
230
+ with gr.Column(scale=2):
231
+ gr.Markdown("### ๐Ÿ’ฌ Your Query")
232
+
233
  instruction = gr.Textbox(
234
+ label="What do you want to know?",
235
+ placeholder="e.g., Find all employees earning more than $50,000 sorted by name",
236
  lines=2
237
  )
238
 
239
  schema = gr.Textbox(
240
+ label="๐Ÿ“‹ Database Schema",
241
  placeholder="CREATE TABLE employees (name VARCHAR, salary INT, department VARCHAR)",
242
  lines=3,
243
  value="CREATE TABLE employees (name VARCHAR(100), salary INT, department VARCHAR(100))"
 
249
  maximum=512,
250
  value=256,
251
  step=32,
252
+ label="Max Length",
253
+ info="Longer = more complex queries"
254
  )
255
  temperature = gr.Slider(
256
  minimum=0.0,
257
  maximum=1.0,
258
  value=0.1,
259
  step=0.1,
260
+ label="Temperature",
261
+ info="Higher = more creative (use 0.1 for accuracy)"
262
  )
263
 
264
+ generate_btn = gr.Button("โœจ Generate SQL", variant="primary", size="lg")
265
+
266
+ output = gr.Code(
267
+ label="๐ŸŽ‰ Generated SQL Query",
268
+ language="sql",
269
+ lines=8,
 
270
  )
271
 
272
+ gr.Markdown("### ๐Ÿ’ก Try These Examples")
273
  gr.Examples(
274
  examples=examples,
275
  inputs=[model_dropdown, instruction, schema],
276
  )
277
 
278
+ # Resource links
279
+ gr.HTML("""
280
+ <div class="resource-links">
281
+ <a href="https://arxiv.org/abs/2503.12730" class="resource-link" target="_blank">
282
+ ๐Ÿ“„ Read the Paper
283
+ </a>
284
+ <a href="https://github.com/withmartian/TinySQL" class="resource-link" target="_blank">
285
+ ๐Ÿ’ป View Code
286
+ </a>
287
+ <a href="https://huggingface.co/collections/withmartian/tinysql-6760e92748b63fa56a6ffc9f" class="resource-link" target="_blank">
288
+ ๐Ÿค— Get Dataset & Models
289
+ </a>
290
+ <a href="https://withmartian.com" class="resource-link" target="_blank">
291
+ ๐Ÿš€ Visit Martian
292
+ </a>
293
+ </div>
294
+ """)
295
+
296
+ # Citation box
297
+ gr.HTML("""
298
+ <div class="citation-box">
299
+ <div class="citation-header">
300
+ ๐Ÿ“š Citation
301
+ </div>
302
+ <pre style="margin: 0; overflow-x: auto;">@misc{harrasse2025tinysqlprogressivetexttosqldataset,
303
+ title={TinySQL: A Progressive Text-to-SQL Dataset for Mechanistic Interpretability Research},
304
+ author={Abir Harrasse and Philip Quirke and Clement Neo and Dhruv Nathawani and Luke Marks and Amir Abdullah},
305
+ year={2025},
306
+ eprint={2503.12730},
307
+ archivePrefix={arXiv},
308
+ primaryClass={cs.LG},
309
+ url={https://arxiv.org/abs/2503.12730}
310
+ }</pre>
311
+ </div>
312
+ """)
313
 
314
+ # Footer
315
+ gr.HTML("""
316
+ <footer>
317
+ <p style="margin: 0.5rem 0;">
318
+ Built with โค๏ธ by the Martian team
319
+ </p>
320
+ <p style="margin: 0; font-size: 0.9rem;">
321
+ Bridging the gap between toy tasks and real-world interpretability
322
+ </p>
323
+ </footer>
324
  """)
325
 
 
326
  generate_btn.click(
327
  fn=generate_sql,
328
  inputs=[model_dropdown, instruction, schema, max_length, temperature],