abir-hr196 commited on
Commit
e07e1fe
·
1 Parent(s): e24804d
Files changed (3) hide show
  1. app.py +8 -441
  2. tinysql_dataset_viewer.py +153 -0
  3. tinysql_model_demo.py +199 -0
app.py CHANGED
@@ -1,443 +1,10 @@
1
- import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
- import torch
4
 
5
- # Model configurations
6
- MODELS = {
7
- "BM1_CS1_Syn (33M)": "withmartian/sql_interp_bm1_cs1_experiment_1.10",
8
- "BM1_CS2_Syn (33M)": "withmartian/sql_interp_bm1_cs2_experiment_2.10",
9
- "BM1_CS3_Syn (33M)": "withmartian/sql_interp_bm1_cs3_experiment_3.10",
10
- "BM1_CS4_Syn (33M)": "withmartian/sql_interp_bm1_cs4_dataset_synonyms_experiment_1.1",
11
- "BM1_CS5_Syn (33M)": "withmartian/sql_interp_bm1_cs5_dataset_synonyms_experiment_1.2",
12
- "BM2_CS1_Syn (0.5B)": "withmartian/sql_interp_bm2_cs1_experiment_4.3",
13
- "BM2_CS2_Syn (0.5B)": "withmartian/sql_interp_bm2_cs2_experiment_5.3",
14
- "BM2_CS3_Syn (0.5B)": "withmartian/sql_interp_bm2_cs3_experiment_6.3",
15
- "BM3_CS1_Syn (1B)": "withmartian/sql_interp_bm3_cs1_experiment_7.3",
16
- "BM3_CS2_Syn (1B)": "withmartian/sql_interp_bm3_cs2_experiment_8.3",
17
- "BM3_CS3_Syn (1B)": "withmartian/sql_interp_bm3_cs3_experiment_9.3",
18
- }
19
 
20
- model_cache = {}
21
-
22
- def load_model(model_name):
23
- if model_name not in model_cache:
24
- model_id = MODELS[model_name]
25
- tokenizer = AutoTokenizer.from_pretrained(model_id)
26
- model = AutoModelForCausalLM.from_pretrained(
27
- model_id,
28
- torch_dtype=torch.float16,
29
- device_map="auto"
30
- )
31
- model_cache[model_name] = (tokenizer, model)
32
- return model_cache[model_name]
33
-
34
- def generate_sql(model_name, instruction, schema, max_length=256, temperature=0.7):
35
- if not model_name or not instruction or not schema:
36
- return "Please fill in all fields and select a model"
37
-
38
- try:
39
- tokenizer, model = load_model(model_name)
40
-
41
- prompt = f"""### Instruction: {instruction}
42
- ### Context: {schema}
43
- ### Response:"""
44
-
45
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
46
-
47
- outputs = model.generate(
48
- **inputs,
49
- max_length=max_length,
50
- temperature=temperature,
51
- do_sample=temperature > 0,
52
- pad_token_id=tokenizer.eos_token_id
53
- )
54
-
55
- generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
56
-
57
- if "### Response:" in generated:
58
- sql = generated.split("### Response:")[-1].strip()
59
- else:
60
- sql = generated.strip()
61
-
62
- return sql
63
-
64
- except Exception as e:
65
- return f"Error: {str(e)}"
66
-
67
- # Example queries
68
- examples = [
69
- [
70
- "BM1_CS1_Syn (33M)",
71
- "Show me the name and salary from employees",
72
- "CREATE TABLE employees (name VARCHAR(100), salary INT, department VARCHAR(100))"
73
- ],
74
- [
75
- "BM2_CS2_Syn (0.5B)",
76
- "List worker earnings from highest to lowest",
77
- "CREATE TABLE employees (name VARCHAR(100), salary INT, department VARCHAR(100))"
78
- ],
79
- [
80
- "BM3_CS3_Syn (1B)",
81
- "Count how many employees in each department",
82
- "CREATE TABLE employees (name VARCHAR(100), salary INT, department VARCHAR(100))"
83
- ],
84
- ]
85
-
86
- # Custom CSS with Martian colors (Orange/Black/Dark Gray only)
87
- custom_css = """
88
- :root {
89
- --martian-orange: #FF6B4A;
90
- --martian-dark: #1A1A1A;
91
- --martian-gray-dark: #3A3A3A;
92
- --martian-gray-medium: #4A4A4A;
93
- --martian-gray-light: #5A5A5A;
94
- --martian-bg: #2A2A2A;
95
- }
96
-
97
- .gradio-container {
98
- font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif;
99
- background-color: var(--martian-bg) !important;
100
- }
101
-
102
- .header-section {
103
- text-align: center;
104
- padding: 3rem 2rem;
105
- background: linear-gradient(135deg, var(--martian-dark) 0%, var(--martian-gray-dark) 100%);
106
- border-radius: 16px;
107
- margin-bottom: 2rem;
108
- color: white;
109
- box-shadow: 0 4px 6px rgba(0, 0, 0, 0.3);
110
- }
111
-
112
- .header-section h1 {
113
- font-size: 2.5rem;
114
- font-weight: 700;
115
- margin-bottom: 1rem;
116
- color: white;
117
- }
118
-
119
- .header-section .subtitle {
120
- font-size: 1.2rem;
121
- opacity: 0.9;
122
- line-height: 1.6;
123
- color: white;
124
- }
125
-
126
- .orange-accent {
127
- color: var(--martian-orange);
128
- font-weight: 600;
129
- }
130
-
131
- .info-box {
132
- background: var(--martian-gray-dark);
133
- border-radius: 12px;
134
- padding: 1.5rem;
135
- margin: 1.5rem 0;
136
- border-left: 4px solid var(--martian-orange);
137
- color: #E0E0E0;
138
- }
139
-
140
- .model-guide {
141
- background: var(--martian-gray-dark);
142
- border-radius: 8px;
143
- padding: 1rem;
144
- margin-top: 1rem;
145
- font-size: 0.9rem;
146
- color: #D0D0D0;
147
- }
148
-
149
- /* Remove all purple/blue colors from Gradio components */
150
- .primary.svelte-cmf5ev {
151
- background: var(--martian-orange) !important;
152
- border-color: var(--martian-orange) !important;
153
- }
154
-
155
- button.primary {
156
- background: var(--martian-orange) !important;
157
- border: none !important;
158
- color: white !important;
159
- }
160
-
161
- button.primary:hover {
162
- background: #FF5733 !important;
163
- }
164
-
165
- /* Fix label colors */
166
- label {
167
- color: #D0D0D0 !important;
168
- }
169
-
170
- .label-wrap span {
171
- color: var(--martian-orange) !important;
172
- }
173
-
174
- /* Input fields - dark theme */
175
- .input-text, textarea, select, input {
176
- background: var(--martian-gray-medium) !important;
177
- border-color: var(--martian-gray-light) !important;
178
- color: #E0E0E0 !important;
179
- }
180
-
181
- textarea::placeholder, input::placeholder {
182
- color: #888 !important;
183
- }
184
-
185
- /* Slider colors */
186
- input[type="range"] {
187
- background: var(--martian-gray-medium) !important;
188
- }
189
-
190
- input[type="range"]::-webkit-slider-thumb {
191
- background: var(--martian-orange) !important;
192
- }
193
-
194
- input[type="range"]::-moz-range-thumb {
195
- background: var(--martian-orange) !important;
196
- }
197
-
198
- input[type="range"]::-webkit-slider-runnable-track {
199
- background: var(--martian-gray-light) !important;
200
- }
201
-
202
- /* Citation box - medium gray with light text */
203
- .citation-box {
204
- background: var(--martian-gray-medium);
205
- border: 1px solid var(--martian-gray-light);
206
- border-radius: 12px;
207
- padding: 1.5rem;
208
- margin: 2rem 0;
209
- font-family: 'Monaco', 'Courier New', monospace;
210
- font-size: 0.85rem;
211
- }
212
-
213
- .citation-header {
214
- font-weight: 700;
215
- color: #E0E0E0;
216
- margin-bottom: 1rem;
217
- font-size: 1.1rem;
218
- }
219
-
220
- .citation-box pre {
221
- color: #D0D0D0;
222
- background: transparent;
223
- }
224
-
225
- .resource-links {
226
- display: flex;
227
- gap: 1rem;
228
- justify-content: center;
229
- margin: 2rem 0;
230
- flex-wrap: wrap;
231
- }
232
-
233
- .resource-link {
234
- background: var(--martian-gray-dark);
235
- color: white;
236
- padding: 0.75rem 1.5rem;
237
- border-radius: 8px;
238
- text-decoration: none;
239
- font-weight: 500;
240
- transition: all 0.3s ease;
241
- border: 2px solid var(--martian-gray-dark);
242
- }
243
-
244
- .resource-link:hover {
245
- background: var(--martian-orange);
246
- border-color: var(--martian-orange);
247
- transform: translateY(-2px);
248
- box-shadow: 0 4px 8px rgba(255, 107, 74, 0.3);
249
- }
250
-
251
- footer {
252
- text-align: center;
253
- padding: 2rem 0;
254
- color: #999;
255
- border-top: 1px solid var(--martian-gray-dark);
256
- margin-top: 3rem;
257
- font-size: 0.9rem;
258
- background: var(--martian-bg);
259
- }
260
-
261
- /* Remove light backgrounds everywhere */
262
- .block, .panel {
263
- background: var(--martian-gray-dark) !important;
264
- }
265
-
266
- .form {
267
- background: var(--martian-gray-medium) !important;
268
- }
269
-
270
- /* Dropdown styling */
271
- .dropdown {
272
- background: var(--martian-gray-medium) !important;
273
- color: #E0E0E0 !important;
274
- }
275
-
276
- /* Code output styling */
277
- .code {
278
- background: var(--martian-gray-dark) !important;
279
- color: #E0E0E0 !important;
280
- }
281
-
282
- /* Examples section */
283
- .example {
284
- background: var(--martian-gray-medium) !important;
285
- border-color: var(--martian-gray-light) !important;
286
- }
287
-
288
- /* Markdown sections */
289
- .markdown {
290
- color: #D0D0D0 !important;
291
- }
292
-
293
- h1, h2, h3, h4, h5, h6 {
294
- color: #E0E0E0 !important;
295
- }
296
- """
297
- # Create Gradio interface
298
- with gr.Blocks(css=custom_css, title="TinySQL Demo", theme=gr.themes.Soft()) as demo:
299
-
300
- # Header
301
- gr.HTML("""
302
- <div class="header-section">
303
- <h1>TinySQL Interactive Demo</h1>
304
- <p class="subtitle">
305
- Transform natural language into SQL queries using <span class="orange-accent">mechanistically interpretable</span> models
306
- </p>
307
- </div>
308
- """)
309
-
310
- # Info box
311
- gr.HTML("""
312
- <div class="info-box">
313
- <strong>How it works:</strong> Select a model from our collection of 11 fine-tuned transformers,
314
- describe what you want in plain English, and watch as the model generates precise SQL queries.
315
- Each model is trained on progressively complex SQL operations—from basic SELECT statements to
316
- advanced JOINs and aggregations.
317
- </div>
318
- """)
319
-
320
- with gr.Row():
321
- with gr.Column(scale=1):
322
- gr.Markdown("### Configuration")
323
-
324
- model_dropdown = gr.Dropdown(
325
- choices=list(MODELS.keys()),
326
- value="BM2_CS2_Syn (0.5B)",
327
- label="Model Selection",
328
- info="Larger models = better accuracy, slower inference"
329
- )
330
-
331
- gr.HTML("""
332
- <div class="model-guide">
333
- <strong>Model Guide:</strong><br><br>
334
- <strong>BM1 (33M)</strong> - Lightning fast, great for simple queries<br>
335
- <strong>BM2 (0.5B)</strong> - Balanced performance and speed<br>
336
- <strong>BM3 (1B)</strong> - Most accurate, handles complex queries<br><br>
337
- <strong>Dataset Complexity:</strong><br>
338
- CS1: Basic SELECT-FROM queries<br>
339
- CS2: Adds ORDER BY clauses<br>
340
- CS3: Aggregations (COUNT, SUM, AVG)<br>
341
- CS4: Adds WHERE filters<br>
342
- CS5: Multi-table JOINs
343
- </div>
344
- """)
345
-
346
- with gr.Column(scale=2):
347
- gr.Markdown("### Your Query")
348
-
349
- instruction = gr.Textbox(
350
- label="What do you want to know?",
351
- placeholder="e.g., Find all employees earning more than $50,000 sorted by name",
352
- lines=2
353
- )
354
-
355
- schema = gr.Textbox(
356
- label="Database Schema",
357
- placeholder="CREATE TABLE employees (name VARCHAR, salary INT, department VARCHAR)",
358
- lines=3,
359
- value="CREATE TABLE employees (name VARCHAR(100), salary INT, department VARCHAR(100))"
360
- )
361
-
362
- with gr.Row():
363
- max_length = gr.Slider(
364
- minimum=64,
365
- maximum=512,
366
- value=256,
367
- step=32,
368
- label="Max Length",
369
- info="Longer = more complex queries"
370
- )
371
- temperature = gr.Slider(
372
- minimum=0.0,
373
- maximum=1.0,
374
- value=0.1,
375
- step=0.1,
376
- label="Temperature",
377
- info="Higher = more creative (use 0.1 for accuracy)"
378
- )
379
-
380
- generate_btn = gr.Button("Generate SQL", variant="primary", size="lg", elem_classes="primary-button")
381
-
382
- output = gr.Code(
383
- label="Generated SQL Query",
384
- language="sql",
385
- lines=8,
386
- )
387
-
388
- gr.Markdown("### Example Queries")
389
- gr.Examples(
390
- examples=examples,
391
- inputs=[model_dropdown, instruction, schema],
392
- )
393
-
394
- # Resource links
395
- gr.HTML("""
396
- <div class="resource-links">
397
- <a href="https://arxiv.org/abs/2503.12730" class="resource-link" target="_blank">
398
- Read the Paper
399
- </a>
400
- <a href="https://github.com/withmartian/TinySQL" class="resource-link" target="_blank">
401
- View Code
402
- </a>
403
- <a href="https://huggingface.co/collections/withmartian/tinysql-6760e92748b63fa56a6ffc9f" class="resource-link" target="_blank">
404
- Dataset & Models
405
- </a>
406
- <a href="https://withmartian.com" class="resource-link" target="_blank">
407
- Martian
408
- </a>
409
- </div>
410
- """)
411
-
412
- # Citation box
413
- gr.HTML("""
414
- <div class="citation-box">
415
- <div class="citation-header">Citation</div>
416
- <pre style="margin: 0; overflow-x: auto; background: transparent;">@misc{harrasse2025tinysqlprogressivetexttosqldataset,
417
- title={TinySQL: A Progressive Text-to-SQL Dataset for Mechanistic Interpretability Research},
418
- author={Abir Harrasse and Philip Quirke and Clement Neo and Dhruv Nathawani and Luke Marks and Amir Abdullah},
419
- year={2025},
420
- eprint={2503.12730},
421
- archivePrefix={arXiv},
422
- primaryClass={cs.LG},
423
- url={https://arxiv.org/abs/2503.12730}
424
- }</pre>
425
- </div>
426
- """)
427
-
428
- # Footer
429
- gr.HTML("""
430
- <footer>
431
- <p>Brought to you with ❤️ from the Martian science team</p>
432
- <p style="margin-top: 0.5rem;">Bridging the gap between toy tasks and real-world interpretability</p>
433
- </footer>
434
- """)
435
-
436
- generate_btn.click(
437
- fn=generate_sql,
438
- inputs=[model_dropdown, instruction, schema, max_length, temperature],
439
- outputs=output
440
- )
441
-
442
- if __name__ == "__main__":
443
- demo.launch()
 
1
+ from tinysql_model_demo import model_demo
2
+ from tinysql_dataset_viewer import dataset_viewer # your dataset viewer function
 
3
 
4
+ with gr.Blocks() as app:
5
+ with gr.Tab("Model Demo"):
6
+ model_demo()
7
+ with gr.Tab("Dataset Viewer"):
8
+ dataset_viewer()
 
 
 
 
 
 
 
 
 
9
 
10
+ app.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tinysql_dataset_viewer.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # tinysql_dataset_viewer.py
2
+ import gradio as gr
3
+ from datasets import load_dataset
4
+ import pandas as pd
5
+ import urllib.parse
6
+ import html
7
+ import traceback
8
+
9
+ HF_DATASETS = {
10
+ "CS1": "withmartian/cs1_dataset",
11
+ "CS2": "withmartian/cs2_dataset",
12
+ "CS3": "withmartian/cs3_dataset",
13
+ "CS2_synonyms": "withmartian/cs2_dataset_synonyms",
14
+ "CS3_synonyms": "withmartian/cs3_dataset_synonyms",
15
+ "CS4_synonyms": "withmartian/cs4_dataset_synonyms",
16
+ }
17
+
18
+ DEMO_URL = "https://huggingface.co/spaces/abir-hr196/tinysql-demo"
19
+ PREVIEW_LIMIT = 500
20
+ FIELDS = ["english_prompt", "create_statement", "sql_statement"]
21
+ dataset_cache = {}
22
+
23
+ # ---------------- Helpers ----------------
24
+ def load_preview(dataset_id, limit=PREVIEW_LIMIT):
25
+ try:
26
+ ds = load_dataset(dataset_id, split=f"train[:{limit}]")
27
+ except Exception:
28
+ full = load_dataset(dataset_id)
29
+ first_split = list(full.keys())[0] if isinstance(full, dict) else None
30
+ if first_split:
31
+ ds = full[first_split].select(range(min(len(full[first_split]), limit)))
32
+ else:
33
+ ds = full.select(range(min(len(full), limit)))
34
+ df = pd.DataFrame(ds)
35
+ for f in FIELDS:
36
+ if f not in df.columns:
37
+ df[f] = ""
38
+ df = df[FIELDS].copy()
39
+ df.reset_index(inplace=True)
40
+ df.rename(columns={"index": "example_index"}, inplace=True)
41
+ return df
42
+
43
+ def get_dataset_preview(name):
44
+ if name in dataset_cache:
45
+ return dataset_cache[name]
46
+ df = load_preview(HF_DATASETS[name])
47
+ dataset_cache[name] = df
48
+ return df
49
+
50
+ def make_dropdown_options(df):
51
+ opts = []
52
+ for _, row in df.iterrows():
53
+ idx = int(row["example_index"])
54
+ prompt = (row["english_prompt"] or "")
55
+ short = " ".join(prompt.split())[:120] + ("…" if len(prompt) > 120 else "")
56
+ opts.append((f"{idx} — {short}", idx))
57
+ return opts
58
+
59
+ def filter_dataframe(df, query):
60
+ if not query:
61
+ return df
62
+ q = str(query).lower()
63
+ mask = df["english_prompt"].fillna("").str.lower().str.contains(q) | df["sql_statement"].fillna("").str.lower().str.contains(q)
64
+ return df[mask].reset_index(drop=True)
65
+
66
+ # ---------------- Gradio callbacks ----------------
67
+ def on_dataset_change(dataset_name):
68
+ try:
69
+ df = get_dataset_preview(dataset_name)
70
+ displayed = df[["example_index", "english_prompt", "sql_statement", "create_statement"]]
71
+ opts = make_dropdown_options(displayed)
72
+ return displayed, gr.Dropdown.update(choices=opts, value=opts[0][1] if opts else None), ""
73
+ except Exception as e:
74
+ tb = traceback.format_exc()
75
+ return pd.DataFrame([], columns=["id", "english_prompt", "sql_statement", "create_statement"]), gr.Dropdown.update(choices=[], value=None), f"Error loading dataset: {e}\n{tb}"
76
+
77
+ def on_search(dataset_name, query):
78
+ try:
79
+ df = get_dataset_preview(dataset_name)
80
+ filtered = filter_dataframe(df, query)
81
+ displayed = filtered[["example_index", "english_prompt", "sql_statement", "create_statement"]]
82
+ opts = make_dropdown_options(displayed)
83
+ return displayed, gr.Dropdown.update(choices=opts, value=opts[0][1] if opts else None)
84
+ except Exception as e:
85
+ return pd.DataFrame([], columns=["id", "english_prompt", "sql_statement", "create_statement"]), gr.Dropdown.update(choices=[], value=None)
86
+
87
+ def send_to_demo(dataset_name, selected_index):
88
+ try:
89
+ df = get_dataset_preview(dataset_name)
90
+ row = df[df["example_index"] == int(selected_index)]
91
+ if row.empty:
92
+ return html.escape("Selected example not found.")
93
+ instr = str(row.iloc[0]["english_prompt"] or "")
94
+ schema = str(row.iloc[0]["create_statement"] or "")
95
+ q_instr = urllib.parse.quote_plus(instr)
96
+ q_schema = urllib.parse.quote_plus(schema)
97
+ url = f"{DEMO_URL}?instruction={q_instr}&schema={q_schema}"
98
+ safe_url = html.escape(url, quote=True)
99
+ html_out = f"""
100
+ <script>
101
+ window.open("{safe_url}", "_blank");
102
+ </script>
103
+ <div style="color: #E0E0E0; font-family: Inter, sans-serif;">
104
+ Opened the demo in a new tab. If your browser blocked the popup, <a href="{safe_url}" target="_blank" rel="noreferrer">click here</a>.
105
+ </div>
106
+ """
107
+ return gr.HTML.update(value=html_out)
108
+ except Exception as e:
109
+ tb = traceback.format_exc()
110
+ return gr.HTML.update(value=f"<pre style='color:#ffb3a7'>Error: {html.escape(str(e))}\n{html.escape(tb)}</pre>")
111
+
112
+ # ---------------- Dataset viewer function ----------------
113
+ def dataset_viewer():
114
+ custom_css = """
115
+ :root {
116
+ --martian-orange: #FF6B4A;
117
+ --martian-dark: #0E0E0E;
118
+ --martian-gray-dark: #1A1A1A;
119
+ --martian-gray-medium: #2A2A2A;
120
+ --martian-gray-light: #3A3A3A;
121
+ --martian-bg: #0E0E0E;
122
+ }
123
+ .gradio-container { background-color: var(--martian-bg) !important; font-family: 'Inter', sans-serif; }
124
+ .header-section { text-align: center; padding: 2rem; background: linear-gradient(135deg, var(--martian-dark) 0%, var(--martian-gray-dark) 100%); border-radius: 12px; margin-bottom: 1rem; color:white;}
125
+ .header-section h1 { font-size: 2rem; margin-bottom: 0.5rem; }
126
+ .info-box { background: var(--martian-gray-dark); border-left: 4px solid var(--martian-orange); border-radius: 10px; padding:1rem; margin:1rem 0; color:#E0E0E0;}
127
+ button, .gr-button { background: var(--martian-orange) !important; color:white !important; border:none !important;}
128
+ .input-text, textarea, select, input, .gradio-dataframe { background: var(--martian-gray-medium) !important; border-color: var(--martian-gray-light) !important; color: #E0E0E0 !important; }
129
+ a { color: var(--martian-orange) !important; }
130
+ """
131
+
132
+ with gr.Blocks(css=custom_css) as viewer:
133
+ gr.HTML("""<div class="header-section"><h1>TinySQL — Dataset Viewer</h1><div class="subtitle">Browse dataset variants, filter examples, and send a selected example to the TinySQL model demo.</div></div>""")
134
+ gr.HTML("""<div class="info-box"><strong>Note:</strong> Previews load the first 500 examples for fast exploration. Use the search box to filter prompts or SQL statements.</div>""")
135
+ with gr.Row():
136
+ with gr.Column(scale=1):
137
+ dataset_dropdown = gr.Dropdown(choices=list(HF_DATASETS.keys()), value=list(HF_DATASETS.keys())[0], label="Dataset Variant")
138
+ search_box = gr.Textbox(label="Search (prompt or SQL)", placeholder="Type keywords to filter prompts or SQL...")
139
+ select_dropdown = gr.Dropdown(choices=[], label="Select example to try")
140
+ try_button = gr.Button("Try in Model Demo", variant="primary")
141
+ status_html = gr.HTML("")
142
+ with gr.Column(scale=3):
143
+ df_display = gr.Dataframe(
144
+ headers=["id", "english_prompt", "sql_statement", "create_statement"],
145
+ value=pd.DataFrame(columns=["id", "english_prompt", "sql_statement", "create_statement"]),
146
+ label="Preview (first 500 rows)"
147
+ )
148
+
149
+ dataset_dropdown.change(fn=on_dataset_change, inputs=[dataset_dropdown], outputs=[df_display, select_dropdown, status_html])
150
+ search_box.change(fn=on_search, inputs=[dataset_dropdown, search_box], outputs=[df_display, select_dropdown])
151
+ try_button.click(fn=send_to_demo, inputs=[dataset_dropdown, select_dropdown], outputs=[status_html])
152
+
153
+ return viewer
tinysql_model_demo.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
+
5
+ # ---------------- Model Setup ----------------
6
+ MODELS = {
7
+ "BM1_CS1_Syn (33M)": "withmartian/sql_interp_bm1_cs1_experiment_1.10",
8
+ "BM1_CS2_Syn (33M)": "withmartian/sql_interp_bm1_cs2_experiment_2.10",
9
+ "BM1_CS3_Syn (33M)": "withmartian/sql_interp_bm1_cs3_experiment_3.10",
10
+ "BM1_CS4_Syn (33M)": "withmartian/sql_interp_bm1_cs4_dataset_synonyms_experiment_1.1",
11
+ "BM1_CS5_Syn (33M)": "withmartian/sql_interp_bm1_cs5_dataset_synonyms_experiment_1.2",
12
+ "BM2_CS1_Syn (0.5B)": "withmartian/sql_interp_bm2_cs1_experiment_4.3",
13
+ "BM2_CS2_Syn (0.5B)": "withmartian/sql_interp_bm2_cs2_experiment_5.3",
14
+ "BM2_CS3_Syn (0.5B)": "withmartian/sql_interp_bm2_cs3_experiment_6.3",
15
+ "BM3_CS1_Syn (1B)": "withmartian/sql_interp_bm3_cs1_experiment_7.3",
16
+ "BM3_CS2_Syn (1B)": "withmartian/sql_interp_bm3_cs2_experiment_8.3",
17
+ "BM3_CS3_Syn (1B)": "withmartian/sql_interp_bm3_cs3_experiment_9.3",
18
+ }
19
+
20
+ model_cache = {}
21
+
22
+ def load_model(model_name):
23
+ if model_name not in model_cache:
24
+ model_id = MODELS[model_name]
25
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
26
+ model = AutoModelForCausalLM.from_pretrained(
27
+ model_id,
28
+ torch_dtype=torch.float16,
29
+ device_map="auto"
30
+ )
31
+ model_cache[model_name] = (tokenizer, model)
32
+ return model_cache[model_name]
33
+
34
+ def generate_sql(model_name, instruction, schema, max_length=256, temperature=0.7):
35
+ if not model_name or not instruction or not schema:
36
+ return "Please fill in all fields and select a model"
37
+
38
+ try:
39
+ tokenizer, model = load_model(model_name)
40
+
41
+ prompt = f"""### Instruction: {instruction}
42
+ ### Context: {schema}
43
+ ### Response:"""
44
+
45
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
46
+
47
+ outputs = model.generate(
48
+ **inputs,
49
+ max_length=max_length,
50
+ temperature=temperature,
51
+ do_sample=temperature > 0,
52
+ pad_token_id=tokenizer.eos_token_id
53
+ )
54
+
55
+ generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
56
+ if "### Response:" in generated:
57
+ sql = generated.split("### Response:")[-1].strip()
58
+ else:
59
+ sql = generated.strip()
60
+ return sql
61
+
62
+ except Exception as e:
63
+ return f"Error: {str(e)}"
64
+
65
+ # ---------------- Example Queries ----------------
66
+ examples = [
67
+ [
68
+ "BM1_CS1_Syn (33M)",
69
+ "Show me the name and salary from employees",
70
+ "CREATE TABLE employees (name VARCHAR(100), salary INT, department VARCHAR(100))"
71
+ ],
72
+ [
73
+ "BM2_CS2_Syn (0.5B)",
74
+ "List worker earnings from highest to lowest",
75
+ "CREATE TABLE employees (name VARCHAR(100), salary INT, department VARCHAR(100))"
76
+ ],
77
+ [
78
+ "BM3_CS3_Syn (1B)",
79
+ "Count how many employees in each department",
80
+ "CREATE TABLE employees (name VARCHAR(100), salary INT, department VARCHAR(100))"
81
+ ],
82
+ ]
83
+
84
+ # ---------------- Model Demo Function ----------------
85
+ def model_demo():
86
+ custom_css = """
87
+ :root {
88
+ --martian-orange: #FF6B4A;
89
+ --martian-bg: #0E0E0E; /* deep black background */
90
+ --martian-gray-dark: #3A3A3A;
91
+ --martian-gray-medium: #4A4A4A;
92
+ --martian-gray-light: #5A5A5A;
93
+ }
94
+
95
+ .gradio-container {
96
+ font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif;
97
+ background-color: var(--martian-bg) !important;
98
+ }
99
+
100
+ .header-section {
101
+ text-align: center;
102
+ padding: 3rem 2rem;
103
+ background: linear-gradient(135deg, var(--martian-gray-dark) 0%, var(--martian-gray-medium) 100%);
104
+ border-radius: 16px;
105
+ margin-bottom: 2rem;
106
+ color: white;
107
+ box-shadow: 0 4px 6px rgba(0,0,0,0.3);
108
+ }
109
+
110
+ .header-section h1 { font-size: 2.5rem; font-weight: 700; margin-bottom: 1rem; color: white; }
111
+ .header-section .subtitle { font-size: 1.2rem; opacity: 0.9; line-height: 1.6; color: white; }
112
+ .orange-accent { color: var(--martian-orange); font-weight: 600; }
113
+
114
+ .info-box { background: var(--martian-gray-dark); border-radius: 12px; padding: 1.5rem; margin: 1.5rem 0; border-left: 4px solid var(--martian-orange); color: #E0E0E0; }
115
+ .model-guide { background: var(--martian-gray-dark); border-radius: 8px; padding: 1rem; margin-top: 1rem; font-size: 0.9rem; color: #D0D0D0; }
116
+
117
+ button.primary { background: var(--martian-orange) !important; border: none !important; color: white !important; }
118
+ button.primary:hover { background: #FF5733 !important; }
119
+
120
+ label { color: #D0D0D0 !important; }
121
+ .label-wrap span { color: var(--martian-orange) !important; }
122
+
123
+ input, textarea, select { background: var(--martian-gray-medium) !important; border-color: var(--martian-gray-light) !important; color: #E0E0E0 !important; }
124
+ textarea::placeholder, input::placeholder { color: #888 !important; }
125
+
126
+ .code { background: var(--martian-gray-dark) !important; color: #E0E0E0 !important; }
127
+ """
128
+
129
+ with gr.Blocks(css=custom_css, title="TinySQL Model Demo") as demo:
130
+
131
+ # Header
132
+ gr.HTML("""
133
+ <div class="header-section">
134
+ <h1>TinySQL Interactive Demo</h1>
135
+ <p class="subtitle">
136
+ Transform natural language into SQL queries using <span class="orange-accent">mechanistically interpretable</span> models
137
+ </p>
138
+ </div>
139
+ """)
140
+
141
+ # Info box
142
+ gr.HTML("""
143
+ <div class="info-box">
144
+ <strong>How it works:</strong> Select a model, describe your query in plain English, and watch the model generate SQL.
145
+ </div>
146
+ """)
147
+
148
+ with gr.Row():
149
+ with gr.Column(scale=1):
150
+ gr.Markdown("### Configuration")
151
+ model_dropdown = gr.Dropdown(
152
+ choices=list(MODELS.keys()),
153
+ value="BM2_CS2_Syn (0.5B)",
154
+ label="Model Selection",
155
+ info="Larger models = better accuracy, slower inference"
156
+ )
157
+ gr.HTML("""
158
+ <div class="model-guide">
159
+ <strong>BM1 (33M)</strong> - Lightning fast, simple queries<br>
160
+ <strong>BM2 (0.5B)</strong> - Balanced performance<br>
161
+ <strong>BM3 (1B)</strong> - Most accurate, complex queries<br><br>
162
+ <strong>Dataset Complexity:</strong><br>
163
+ CS1: Basic SELECT-FROM<br>
164
+ CS2: Adds ORDER BY<br>
165
+ CS3: Aggregations<br>
166
+ CS4: Adds WHERE filters<br>
167
+ CS5: Multi-table JOINs
168
+ </div>
169
+ """)
170
+
171
+ with gr.Column(scale=2):
172
+ gr.Markdown("### Your Query")
173
+ instruction = gr.Textbox(
174
+ label="What do you want to know?",
175
+ placeholder="e.g., Find all employees earning more than $50,000 sorted by name",
176
+ lines=2
177
+ )
178
+ schema = gr.Textbox(
179
+ label="Database Schema",
180
+ placeholder="CREATE TABLE employees (name VARCHAR, salary INT, department VARCHAR)",
181
+ lines=3,
182
+ value="CREATE TABLE employees (name VARCHAR(100), salary INT, department VARCHAR(100))"
183
+ )
184
+ with gr.Row():
185
+ max_length = gr.Slider(64, 512, value=256, step=32, label="Max Length")
186
+ temperature = gr.Slider(0.0, 1.0, value=0.1, step=0.1, label="Temperature")
187
+
188
+ generate_btn = gr.Button("Generate SQL", variant="primary", size="lg")
189
+ output = gr.Code(label="Generated SQL Query", language="sql", lines=8)
190
+
191
+ gr.Examples(examples=examples, inputs=[model_dropdown, instruction, schema])
192
+
193
+ generate_btn.click(
194
+ fn=generate_sql,
195
+ inputs=[model_dropdown, instruction, schema, max_length, temperature],
196
+ outputs=output
197
+ )
198
+
199
+ return demo