QSBench's picture
Update app.py
6ad662b verified
raw
history blame
8.34 kB
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from datasets import load_dataset
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_absolute_error, r2_score
from sklearn.model_selection import train_test_split
from pathlib import Path
# =========================================================
# CONFIG & REPOSITORIES
# =========================================================
DATASET_MAP = {
"Core (Clean)": "QSBench/QSBench-Core-v1.0.0-demo",
"Depolarizing Noise": "QSBench/QSBench-Depolarizing-v1.0.0-demo",
"Amplitude Damping": "QSBench/QSBench-Amplitude-v1.0.0-demo",
"Transpilation (10q)": "QSBench/QSBench-Transpilation-v1.0.0-demo"
}
LOCAL_BENCHMARK_CSV = "noise_benchmark_results.csv"
TARGET_COL = "ideal_expval_Z_global"
EXCLUDE_COLS = {
"sample_id", "sample_seed", "split",
"ideal_expval_Z_global", "ideal_expval_X_global", "ideal_expval_Y_global",
"noisy_expval_Z_global", "noisy_expval_X_global", "noisy_expval_Y_global",
"error_Z_global", "error_X_global", "error_Y_global",
"sign_ideal_Z_global", "sign_noisy_Z_global",
"sign_ideal_X_global", "sign_noisy_X_global",
"sign_ideal_Y_global", "sign_noisy_Y_global",
}
MODEL_PARAMS = dict(
n_estimators=80,
max_depth=10,
min_samples_leaf=2,
random_state=42,
n_jobs=-1,
)
# Global cache to avoid redundant downloads
dataset_cache = {}
# =========================================================
# DATA UTILS
# =========================================================
def get_df(dataset_key):
if dataset_key not in dataset_cache:
repo_id = DATASET_MAP[dataset_key]
print(f"Downloading {repo_id}...")
ds = load_dataset(repo_id)
dataset_cache[dataset_key] = pd.DataFrame(ds["train"])
return dataset_cache[dataset_key]
def get_numeric_feature_cols(df: pd.DataFrame) -> list[str]:
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
return [c for c in numeric_cols if c not in EXCLUDE_COLS and not c.startswith("error_")]
# =========================================================
# TAB FUNCTIONS
# =========================================================
def update_explorer(dataset_name):
df = get_df(dataset_name)
splits = df["split"].unique().tolist() if "split" in df.columns else ["all"]
return gr.update(choices=splits, value=splits[0]), df.head(10)
def filter_explorer_by_split(dataset_name, split_name):
df = get_df(dataset_name)
if "split" in df.columns:
return df[df["split"] == split_name].head(10)
return df.head(10)
def run_model_demo(dataset_name):
df = get_df(dataset_name)
feature_cols = get_numeric_feature_cols(df)
# Ensure target exists, fallback to noisy if clean is missing (though unlikely in your schema)
target = TARGET_COL if TARGET_COL in df.columns else df.filter(like="expval").columns[0]
work_df = df.dropna(subset=feature_cols + [target]).reset_index(drop=True)
X = work_df[feature_cols]
y = work_df[target]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
model = RandomForestRegressor(**MODEL_PARAMS)
model.fit(X_train, y_train)
preds = model.predict(X_test)
r2 = r2_score(y_test, preds)
mae = mean_absolute_error(y_test, preds)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
# Parity Plot
ax1.scatter(y_test, preds, alpha=0.5, color='#636EFA')
lims = [min(y_test.min(), preds.min()), max(y_test.max(), preds.max())]
ax1.plot(lims, lims, 'r--', alpha=0.75, zorder=3)
ax1.set_xlabel("Ground Truth")
ax1.set_ylabel("Predictions")
ax1.set_title(f"Prediction Accuracy\nR² = {r2:.4f}")
# Feature Importance
importances = model.feature_importances_
indices = np.argsort(importances)[-10:]
ax2.barh(range(len(indices)), importances[indices], color='#EF553B')
ax2.set_yticks(range(len(indices)))
ax2.set_yticklabels([feature_cols[i] for i in indices])
ax2.set_title("Top 10 Structural Features")
plt.tight_layout()
summary = f"""
### Model Performance: {dataset_name}
- **R² Score:** {r2:.4f}
- **Mean Absolute Error (MAE):** {mae:.4f}
*This baseline demonstrates that structural circuit metrics (entropy, gate counts, etc.) hold predictive power for quantum expectation values.*
"""
return fig, summary
def load_benchmark():
path = Path(LOCAL_BENCHMARK_CSV)
if not path.exists():
return pd.DataFrame([{"info": "Benchmark file not found"}]), None, None
df = pd.read_csv(path)
# R2 Plot
fig_r2, ax = plt.subplots(figsize=(8, 4))
ax.bar(df["dataset"], df["r2"], color='skyblue')
ax.set_title("Cross-Dataset Robustness (R² Score)")
ax.set_ylabel("R²")
plt.xticks(rotation=15)
plt.tight_layout()
# MAE Plot
fig_mae, ax = plt.subplots(figsize=(8, 4))
ax.bar(df["dataset"], df["mae"], color='salmon')
ax.set_title("Cross-Dataset Error (MAE)")
ax.set_ylabel("MAE")
plt.xticks(rotation=15)
plt.tight_layout()
return df, fig_r2, fig_mae
# =========================================================
# INTERFACE
# =========================================================
with gr.Blocks(title="QSBench Unified Explorer", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 🌌 QSBench: Quantum Synthetic Benchmark Explorer
**Unified interface for Core, Noise-Affected, and Hardware-Transpiled Quantum Datasets.**
Browse the demo datasets from the QSBench family, run baseline ML models, and analyze noise robustness across different distributions.
"""
)
with gr.Tabs():
# TAB 1: DATA EXPLORER
with gr.TabItem("🔎 Dataset Explorer"):
with gr.Row():
ds_selector = gr.Dropdown(choices=list(DATASET_MAP.keys()), value="Core (Clean)", label="Select Dataset Pack")
split_selector = gr.Dropdown(choices=["train", "test", "validation"], value="train", label="Split")
data_table = gr.Dataframe(label="Sample Data (First 10 rows)", interactive=False)
ds_selector.change(update_explorer, inputs=[ds_selector], outputs=[split_selector, data_table])
split_selector.change(filter_explorer_by_split, inputs=[ds_selector, split_selector], outputs=[data_table])
# TAB 2: ML BASELINE
with gr.TabItem("🤖 ML Baseline Demo"):
gr.Markdown("Select a dataset and train a Random Forest regressor to predict expectation values from circuit metadata.")
model_ds_selector = gr.Dropdown(choices=list(DATASET_MAP.keys()), value="Core (Clean)", label="Target Dataset")
train_btn = gr.Button("Train Baseline Model", variant="primary")
with gr.Row():
plot_output = gr.Plot(label="Model Metrics")
text_output = gr.Markdown(label="Stats")
train_btn.click(run_model_demo, inputs=[model_ds_selector], outputs=[plot_output, text_output])
# TAB 3: BENCHMARKING
with gr.TabItem("📊 Noise Robustness Benchmark"):
gr.Markdown("Analysis of model performance degradation under distribution shifts (Clean → Noisy → Hardware).")
bench_btn = gr.Button("Load Precomputed Benchmark Results")
bench_table = gr.Dataframe(interactive=False)
with gr.Row():
r2_plot = gr.Plot()
mae_plot = gr.Plot()
bench_btn.click(load_benchmark, outputs=[bench_table, r2_plot, mae_plot])
gr.Markdown(
"""
---
### About QSBench
QSBench is a collection of high-quality synthetic datasets designed for **Quantum Machine Learning** research.
It provides paired ideal/noisy data, structural circuit metrics, and transpilation metadata.
🔗 [Website](https://qsbench.github.io) | 🤗 [Hugging Face](https://huggingface.co/QSBench) | 🛠️ [GitHub](https://github.com/QSBench)
"""
)
# Initial load
demo.load(update_explorer, inputs=[ds_selector], outputs=[split_selector, data_table])
if __name__ == "__main__":
demo.launch()