Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from datasets import load_dataset | |
| import pandas as pd | |
| DATASETS = { | |
| "CS1": "withmartian/cs1_dataset", | |
| "CS2": "withmartian/cs2_dataset", | |
| "CS3": "withmartian/cs3_dataset", | |
| "CS2 Synonyms": "withmartian/cs2_dataset_synonyms", | |
| "CS3 Synonyms": "withmartian/cs3_dataset_synonyms", | |
| "CS4 Synonyms": "withmartian/cs4_dataset_synonyms", | |
| } | |
| COLUMNS = ["create_statement", "english_prompt", "sql_statement"] | |
| dataset_cache = {} | |
| def preload_datasets(): | |
| for name, path in DATASETS.items(): | |
| try: | |
| ds = load_dataset(path, split="train") | |
| df = pd.DataFrame(ds).head(500) | |
| if all(col in df.columns for col in COLUMNS): | |
| df = df[COLUMNS] | |
| df.insert(0, 'index', range(len(df))) | |
| dataset_cache[name] = df | |
| print(f"✓ Cached {name}") | |
| except Exception as e: | |
| print(f"✗ Failed to cache {name}: {e}") | |
| preload_datasets() | |
| def load_preview(dataset_name): | |
| if dataset_name in dataset_cache: | |
| return dataset_cache[dataset_name] | |
| return pd.DataFrame({"Error": ["Dataset not found in cache"]}) | |
| def filter_dataframe(df, search_query, search_column): | |
| if not search_query or df.empty or "Error" in df.columns: | |
| return df | |
| if search_column == "All Columns": | |
| mask = df.astype(str).apply( | |
| lambda row: row.str.contains(search_query, case=False, na=False).any(), | |
| axis=1 | |
| ) | |
| else: | |
| mask = df[search_column].astype(str).str.contains(search_query, case=False, na=False) | |
| return df[mask] | |
| def dataset_viewer(shared_instruction, shared_schema): | |
| gr.HTML(""" | |
| <div style="text-align: center; padding: 2rem; background: linear-gradient(135deg, #2A2A2A 0%, #3A3A3A 100%); border-radius: 16px; margin-bottom: 2rem;"> | |
| <h2 style="font-size: 2rem; font-weight: 700; margin-bottom: 0.5rem; color: #FF6B4A;">Dataset Explorer</h2> | |
| <p style="font-size: 1rem; color: #D0D0D0; margin: 0;">Browse, search, and explore TinySQL datasets</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Controls") | |
| dataset_dropdown = gr.Dropdown( | |
| choices=list(DATASETS.keys()), | |
| value="CS1", | |
| label="Choose Dataset" | |
| ) | |
| load_btn = gr.Button("Load Dataset", variant="primary", size="lg") | |
| gr.HTML(""" | |
| <div style="background: #2A2A2A; border-radius: 12px; padding: 1.5rem; margin: 1.5rem 0;"> | |
| <h4 style="color: #FF6B4A; font-size: 1rem; margin: 0 0 1rem 0; font-weight: 700;">Dataset Levels</h4> | |
| <div style="font-size: 0.9rem; line-height: 2;"> | |
| <div style="color: #FFFFFF;"><strong style="color: #FFFFFF;">CS1:</strong> <span style="color: #FFFFFF;">Basic SELECT-FROM</span></div> | |
| <div style="color: #FFFFFF;"><strong style="color: #FFFFFF;">CS2:</strong> <span style="color: #FFFFFF;">Adds ORDER BY</span></div> | |
| <div style="color: #FFFFFF;"><strong style="color: #FFFFFF;">CS3:</strong> <span style="color: #FFFFFF;">Aggregations</span></div> | |
| <div style="color: #FFFFFF;"><strong style="color: #FFFFFF;">CS4:</strong> <span style="color: #FFFFFF;">WHERE filters</span></div> | |
| <div style="color: #FFFFFF;"><strong style="color: #FFFFFF;">CS5:</strong> <span style="color: #FFFFFF;">Multi-table JOINs</span></div> | |
| </div> | |
| </div> | |
| """) | |
| gr.Markdown("### Test Example") | |
| row_selector = gr.Number( | |
| label="Row Number", | |
| value=0, | |
| minimum=0, | |
| precision=0 | |
| ) | |
| send_to_model_btn = gr.Button("Run in Model Demo", variant="primary") | |
| with gr.Column(scale=3): | |
| gr.Markdown("### Dataset Preview") | |
| with gr.Row(): | |
| search_box = gr.Textbox( | |
| label="Search", | |
| placeholder="Enter search term...", | |
| lines=1, | |
| scale=3 | |
| ) | |
| search_column = gr.Dropdown( | |
| choices=["All Columns", "create_statement", "english_prompt", "sql_statement"], | |
| value="All Columns", | |
| label="Search In", | |
| scale=1 | |
| ) | |
| df_display = gr.Dataframe( | |
| headers=["index"] + COLUMNS, | |
| datatype=["number", "str", "str", "str"], | |
| interactive=False, | |
| wrap=True | |
| ) | |
| stats_display = gr.Markdown("Click **Load Dataset** to begin") | |
| df_state = gr.State(value=pd.DataFrame()) | |
| def load_and_display(dataset_name): | |
| df = load_preview(dataset_name) | |
| if "Error" in df.columns: | |
| return df, df, "Error loading dataset" | |
| stats = f"**Loaded {len(df)} rows** • Columns: {', '.join(COLUMNS)}" | |
| return df, df, stats | |
| load_btn.click( | |
| fn=load_and_display, | |
| inputs=dataset_dropdown, | |
| outputs=[df_state, df_display, stats_display] | |
| ) | |
| def search_and_display(df, query, column): | |
| if df.empty: | |
| return df, "Load a dataset first" | |
| filtered_df = filter_dataframe(df, query, column) | |
| stats = f"**Showing {len(filtered_df)} of {len(df)} rows**" | |
| if query: | |
| stats += f" • Search: '{query}' in {column}" | |
| return filtered_df, stats | |
| search_box.change( | |
| fn=search_and_display, | |
| inputs=[df_state, search_box, search_column], | |
| outputs=[df_display, stats_display] | |
| ) | |
| search_column.change( | |
| fn=search_and_display, | |
| inputs=[df_state, search_box, search_column], | |
| outputs=[df_display, stats_display] | |
| ) | |
| def send_to_model(df, row_num): | |
| if df.empty or row_num >= len(df): | |
| return "", "", "Invalid row or no data loaded" | |
| row = df.iloc[int(row_num)] | |
| instruction = row['english_prompt'] if 'english_prompt' in row else "" | |
| schema = row['create_statement'] if 'create_statement' in row else "" | |
| return instruction, schema, f"**Row {row_num} loaded!** Switch to Model Demo tab" | |
| send_to_model_btn.click( | |
| fn=send_to_model, | |
| inputs=[df_state, row_selector], | |
| outputs=[shared_instruction, shared_schema, stats_display] | |
| ) | |
| return {'df_state': df_state, 'df_display': df_display} |