tinysql-demo / tinysql_dataset_viewer.py
abir-hr196's picture
updates
f35a40c
raw
history blame
6.09 kB
import gradio as gr
from datasets import load_dataset
import pandas as pd
# Datasets to include
DATASETS = {
"CS1": "withmartian/cs1_dataset",
"CS2": "withmartian/cs2_dataset",
"CS3": "withmartian/cs3_dataset",
"CS2 Synonyms": "withmartian/cs2_dataset_synonyms",
"CS3 Synonyms": "withmartian/cs3_dataset_synonyms",
"CS4 Synonyms": "withmartian/cs4_dataset_synonyms",
}
COLUMNS = ["create_statement", "english_prompt", "sql_statement"]
def load_preview(dataset_name):
"""Load first 500 rows of selected dataset"""
try:
ds = load_dataset(DATASETS[dataset_name], split="train")
df = pd.DataFrame(ds).head(500)
# Filter to only the columns we want
if all(col in df.columns for col in COLUMNS):
df = df[COLUMNS]
return df
except Exception as e:
return pd.DataFrame({"Error": [str(e)]})
def filter_dataframe(df, search_query):
"""Filter dataframe by search query across all columns"""
if not search_query or df.empty or "Error" in df.columns:
return df
mask = df.astype(str).apply(
lambda row: row.str.contains(search_query, case=False, na=False).any(),
axis=1
)
return df[mask]
def dataset_viewer(shared_instruction, shared_schema):
"""Dataset viewer component with ability to send examples to model demo"""
gr.HTML("""
<div class="header-section" style="text-align: center; padding: 2.5rem 1.5rem; background: linear-gradient(135deg, #1A1A1A 0%, #2A2A2A 100%); border-radius: 16px; margin-bottom: 2rem; color: white;">
<h1 style="font-size: 2.2rem; font-weight: 700; margin-bottom: 0.75rem;">TinySQL Dataset Viewer</h1>
<p style="font-size: 1.1rem; opacity: 0.9; line-height: 1.6;">
Browse dataset previews, search, and filter queries with <span style="color: #FF6B4A; font-weight: 600;">ease</span>
</p>
</div>
""")
gr.HTML("""
<div class="info-box" style="background: #3A3A3A; border-radius: 12px; padding: 1.5rem; margin: 1.5rem 0; border-left: 4px solid #FF6B4A; color: #E0E0E0;">
<strong>Preview Mode:</strong> Showing first 500 rows of each dataset. Use search to filter results in real-time.
</div>
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Dataset Selection")
dataset_dropdown = gr.Dropdown(
choices=list(DATASETS.keys()),
value="CS1",
label="Choose Dataset",
info="Select a dataset to preview"
)
gr.HTML("""
<div style="background: #3A3A3A; border-radius: 8px; padding: 1rem; margin-top: 1rem; font-size: 0.9rem; color: #D0D0D0;">
<strong>Complexity Levels:</strong><br><br>
<strong>CS1:</strong> Basic SELECT-FROM<br>
<strong>CS2:</strong> Adds ORDER BY<br>
<strong>CS3:</strong> Aggregations<br>
<strong>CS4:</strong> Adds WHERE filters<br><br>
<strong>Synonyms:</strong> Natural language variations
</div>
""")
load_btn = gr.Button("Load Dataset", variant="primary", size="lg")
row_selector = gr.Number(
label="Select Row to Test",
value=0,
minimum=0,
precision=0,
info="Enter row number to send to Model Demo"
)
send_to_model_btn = gr.Button("πŸš€ Run This Example in Model Demo", variant="primary")
with gr.Column(scale=3):
gr.Markdown("### Dataset Preview (First 500 Rows)")
search_box = gr.Textbox(
label="Search",
placeholder="Search across all columns...",
lines=1
)
df_display = gr.Dataframe(
headers=COLUMNS,
datatype=["str", "str", "str"],
interactive=False,
wrap=True,
label="Results"
)
stats_display = gr.Markdown("Click 'Load Dataset' to begin")
# Store the loaded dataframe
df_state = gr.State(value=pd.DataFrame())
# Load dataset
def load_and_display(dataset_name):
df = load_preview(dataset_name)
if "Error" in df.columns:
return df, df, "❌ Error loading dataset"
stats = f"**Loaded:** {len(df)} rows | **Columns:** {', '.join(COLUMNS)}"
return df, df, stats
load_btn.click(
fn=load_and_display,
inputs=dataset_dropdown,
outputs=[df_state, df_display, stats_display]
)
# Search functionality
def search_and_display(df, query):
if df.empty:
return df, "Load a dataset first"
filtered_df = filter_dataframe(df, query)
stats = f"**Showing:** {len(filtered_df)} of {len(df)} rows"
if query:
stats += f" | **Search:** '{query}'"
return filtered_df, stats
search_box.change(
fn=search_and_display,
inputs=[df_state, search_box],
outputs=[df_display, stats_display]
)
# Send example to model demo
def send_to_model(df, row_num):
if df.empty or row_num >= len(df):
return "", "", "⚠️ Invalid row number or no data loaded"
row = df.iloc[int(row_num)]
instruction = row['english_prompt'] if 'english_prompt' in row else ""
schema = row['create_statement'] if 'create_statement' in row else ""
return instruction, schema, f"βœ… Row {row_num} loaded! Switch to Model Demo tab."
send_to_model_btn.click(
fn=send_to_model,
inputs=[df_state, row_selector],
outputs=[shared_instruction, shared_schema, stats_display]
)
return {
'df_state': df_state,
'df_display': df_display
}