tinysql-demo / tinysql_dataset_viewer.py
abir-hr196's picture
upd
b4215f7
import gradio as gr
from datasets import load_dataset
import pandas as pd
DATASETS = {
"CS1": "withmartian/cs1_dataset",
"CS2": "withmartian/cs2_dataset",
"CS3": "withmartian/cs3_dataset",
"CS2 Synonyms": "withmartian/cs2_dataset_synonyms",
"CS3 Synonyms": "withmartian/cs3_dataset_synonyms",
"CS4 Synonyms": "withmartian/cs4_dataset_synonyms",
}
COLUMNS = ["create_statement", "english_prompt", "sql_statement"]
dataset_cache = {}
def preload_datasets():
for name, path in DATASETS.items():
try:
ds = load_dataset(path, split="train")
df = pd.DataFrame(ds).head(500)
if all(col in df.columns for col in COLUMNS):
df = df[COLUMNS]
df.insert(0, 'index', range(len(df)))
dataset_cache[name] = df
print(f"✓ Cached {name}")
except Exception as e:
print(f"✗ Failed to cache {name}: {e}")
preload_datasets()
def load_preview(dataset_name):
if dataset_name in dataset_cache:
return dataset_cache[dataset_name]
return pd.DataFrame({"Error": ["Dataset not found in cache"]})
def filter_dataframe(df, search_query, search_column):
if not search_query or df.empty or "Error" in df.columns:
return df
if search_column == "All Columns":
mask = df.astype(str).apply(
lambda row: row.str.contains(search_query, case=False, na=False).any(),
axis=1
)
else:
mask = df[search_column].astype(str).str.contains(search_query, case=False, na=False)
return df[mask]
def dataset_viewer(shared_instruction, shared_schema):
gr.HTML("""
<div style="text-align: center; padding: 2rem; background: linear-gradient(135deg, #2A2A2A 0%, #3A3A3A 100%); border-radius: 16px; margin-bottom: 2rem;">
<h2 style="font-size: 2rem; font-weight: 700; margin-bottom: 0.5rem; color: #FF6B4A;">Dataset Explorer</h2>
<p style="font-size: 1rem; color: #D0D0D0; margin: 0;">Browse, search, and explore TinySQL datasets</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Controls")
dataset_dropdown = gr.Dropdown(
choices=list(DATASETS.keys()),
value="CS1",
label="Choose Dataset"
)
load_btn = gr.Button("Load Dataset", variant="primary", size="lg")
gr.HTML("""
<div style="background: #2A2A2A; border-radius: 12px; padding: 1.5rem; margin: 1.5rem 0;">
<h4 style="color: #FF6B4A; font-size: 1rem; margin: 0 0 1rem 0; font-weight: 700;">Dataset Levels</h4>
<div style="font-size: 0.9rem; line-height: 2;">
<div style="color: #FFFFFF;"><strong style="color: #FFFFFF;">CS1:</strong> <span style="color: #FFFFFF;">Basic SELECT-FROM</span></div>
<div style="color: #FFFFFF;"><strong style="color: #FFFFFF;">CS2:</strong> <span style="color: #FFFFFF;">Adds ORDER BY</span></div>
<div style="color: #FFFFFF;"><strong style="color: #FFFFFF;">CS3:</strong> <span style="color: #FFFFFF;">Aggregations</span></div>
<div style="color: #FFFFFF;"><strong style="color: #FFFFFF;">CS4:</strong> <span style="color: #FFFFFF;">WHERE filters</span></div>
<div style="color: #FFFFFF;"><strong style="color: #FFFFFF;">CS5:</strong> <span style="color: #FFFFFF;">Multi-table JOINs</span></div>
</div>
</div>
""")
gr.Markdown("### Test Example")
row_selector = gr.Number(
label="Row Number",
value=0,
minimum=0,
precision=0
)
send_to_model_btn = gr.Button("Run in Model Demo", variant="primary")
with gr.Column(scale=3):
gr.Markdown("### Dataset Preview")
with gr.Row():
search_box = gr.Textbox(
label="Search",
placeholder="Enter search term...",
lines=1,
scale=3
)
search_column = gr.Dropdown(
choices=["All Columns", "create_statement", "english_prompt", "sql_statement"],
value="All Columns",
label="Search In",
scale=1
)
df_display = gr.Dataframe(
headers=["index"] + COLUMNS,
datatype=["number", "str", "str", "str"],
interactive=False,
wrap=True
)
stats_display = gr.Markdown("Click **Load Dataset** to begin")
df_state = gr.State(value=pd.DataFrame())
def load_and_display(dataset_name):
df = load_preview(dataset_name)
if "Error" in df.columns:
return df, df, "Error loading dataset"
stats = f"**Loaded {len(df)} rows** • Columns: {', '.join(COLUMNS)}"
return df, df, stats
load_btn.click(
fn=load_and_display,
inputs=dataset_dropdown,
outputs=[df_state, df_display, stats_display]
)
def search_and_display(df, query, column):
if df.empty:
return df, "Load a dataset first"
filtered_df = filter_dataframe(df, query, column)
stats = f"**Showing {len(filtered_df)} of {len(df)} rows**"
if query:
stats += f" • Search: '{query}' in {column}"
return filtered_df, stats
search_box.change(
fn=search_and_display,
inputs=[df_state, search_box, search_column],
outputs=[df_display, stats_display]
)
search_column.change(
fn=search_and_display,
inputs=[df_state, search_box, search_column],
outputs=[df_display, stats_display]
)
def send_to_model(df, row_num):
if df.empty or row_num >= len(df):
return "", "", "Invalid row or no data loaded"
row = df.iloc[int(row_num)]
instruction = row['english_prompt'] if 'english_prompt' in row else ""
schema = row['create_statement'] if 'create_statement' in row else ""
return instruction, schema, f"**Row {row_num} loaded!** Switch to Model Demo tab"
send_to_model_btn.click(
fn=send_to_model,
inputs=[df_state, row_selector],
outputs=[shared_instruction, shared_schema, stats_display]
)
return {'df_state': df_state, 'df_display': df_display}