Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pandas as pd | |
| from datasets import load_dataset | |
| from sklearn.ensemble import RandomForestRegressor | |
| from sklearn.metrics import mean_absolute_error, r2_score | |
| from sklearn.model_selection import train_test_split | |
| from pathlib import Path | |
| # ========================================================= | |
| # CONFIG & REPOSITORIES | |
| # ========================================================= | |
| DATASET_MAP = { | |
| "Core (Clean)": "QSBench/QSBench-Core-v1.0.0-demo", | |
| "Depolarizing Noise": "QSBench/QSBench-Depolarizing-v1.0.0-demo", | |
| "Amplitude Damping": "QSBench/QSBench-Amplitude-v1.0.0-demo", | |
| "Transpilation (10q)": "QSBench/QSBench-Transpilation-v1.0.0-demo" | |
| } | |
| LOCAL_BENCHMARK_CSV = "noise_benchmark_results.csv" | |
| TARGET_COL = "ideal_expval_Z_global" | |
| EXCLUDE_COLS = { | |
| "sample_id", "sample_seed", "split", | |
| "ideal_expval_Z_global", "ideal_expval_X_global", "ideal_expval_Y_global", | |
| "noisy_expval_Z_global", "noisy_expval_X_global", "noisy_expval_Y_global", | |
| "error_Z_global", "error_X_global", "error_Y_global", | |
| "sign_ideal_Z_global", "sign_noisy_Z_global", | |
| "sign_ideal_X_global", "sign_noisy_X_global", | |
| "sign_ideal_Y_global", "sign_noisy_Y_global", | |
| } | |
| MODEL_PARAMS = dict( | |
| n_estimators=80, | |
| max_depth=10, | |
| min_samples_leaf=2, | |
| random_state=42, | |
| n_jobs=-1, | |
| ) | |
| # Global cache to avoid redundant downloads | |
| dataset_cache = {} | |
| # ========================================================= | |
| # DATA UTILS | |
| # ========================================================= | |
| def get_df(dataset_key): | |
| if dataset_key not in dataset_cache: | |
| repo_id = DATASET_MAP[dataset_key] | |
| print(f"Downloading {repo_id}...") | |
| ds = load_dataset(repo_id) | |
| dataset_cache[dataset_key] = pd.DataFrame(ds["train"]) | |
| return dataset_cache[dataset_key] | |
| def get_numeric_feature_cols(df: pd.DataFrame) -> list[str]: | |
| numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist() | |
| return [c for c in numeric_cols if c not in EXCLUDE_COLS and not c.startswith("error_")] | |
| # ========================================================= | |
| # TAB FUNCTIONS | |
| # ========================================================= | |
| def update_explorer(dataset_name): | |
| df = get_df(dataset_name) | |
| splits = df["split"].unique().tolist() if "split" in df.columns else ["all"] | |
| return gr.update(choices=splits, value=splits[0]), df.head(10) | |
| def filter_explorer_by_split(dataset_name, split_name): | |
| df = get_df(dataset_name) | |
| if "split" in df.columns: | |
| return df[df["split"] == split_name].head(10) | |
| return df.head(10) | |
| def run_model_demo(dataset_name): | |
| df = get_df(dataset_name) | |
| feature_cols = get_numeric_feature_cols(df) | |
| # Ensure target exists, fallback to noisy if clean is missing (though unlikely in your schema) | |
| target = TARGET_COL if TARGET_COL in df.columns else df.filter(like="expval").columns[0] | |
| work_df = df.dropna(subset=feature_cols + [target]).reset_index(drop=True) | |
| X = work_df[feature_cols] | |
| y = work_df[target] | |
| X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) | |
| model = RandomForestRegressor(**MODEL_PARAMS) | |
| model.fit(X_train, y_train) | |
| preds = model.predict(X_test) | |
| r2 = r2_score(y_test, preds) | |
| mae = mean_absolute_error(y_test, preds) | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) | |
| # Parity Plot | |
| ax1.scatter(y_test, preds, alpha=0.5, color='#636EFA') | |
| lims = [min(y_test.min(), preds.min()), max(y_test.max(), preds.max())] | |
| ax1.plot(lims, lims, 'r--', alpha=0.75, zorder=3) | |
| ax1.set_xlabel("Ground Truth") | |
| ax1.set_ylabel("Predictions") | |
| ax1.set_title(f"Prediction Accuracy\nR² = {r2:.4f}") | |
| # Feature Importance | |
| importances = model.feature_importances_ | |
| indices = np.argsort(importances)[-10:] | |
| ax2.barh(range(len(indices)), importances[indices], color='#EF553B') | |
| ax2.set_yticks(range(len(indices))) | |
| ax2.set_yticklabels([feature_cols[i] for i in indices]) | |
| ax2.set_title("Top 10 Structural Features") | |
| plt.tight_layout() | |
| summary = f""" | |
| ### Model Performance: {dataset_name} | |
| - **R² Score:** {r2:.4f} | |
| - **Mean Absolute Error (MAE):** {mae:.4f} | |
| *This baseline demonstrates that structural circuit metrics (entropy, gate counts, etc.) hold predictive power for quantum expectation values.* | |
| """ | |
| return fig, summary | |
| def load_benchmark(): | |
| path = Path(LOCAL_BENCHMARK_CSV) | |
| if not path.exists(): | |
| return pd.DataFrame([{"info": "Benchmark file not found"}]), None, None | |
| df = pd.read_csv(path) | |
| # R2 Plot | |
| fig_r2, ax = plt.subplots(figsize=(8, 4)) | |
| ax.bar(df["dataset"], df["r2"], color='skyblue') | |
| ax.set_title("Cross-Dataset Robustness (R² Score)") | |
| ax.set_ylabel("R²") | |
| plt.xticks(rotation=15) | |
| plt.tight_layout() | |
| # MAE Plot | |
| fig_mae, ax = plt.subplots(figsize=(8, 4)) | |
| ax.bar(df["dataset"], df["mae"], color='salmon') | |
| ax.set_title("Cross-Dataset Error (MAE)") | |
| ax.set_ylabel("MAE") | |
| plt.xticks(rotation=15) | |
| plt.tight_layout() | |
| return df, fig_r2, fig_mae | |
| # ========================================================= | |
| # INTERFACE | |
| # ========================================================= | |
| with gr.Blocks(title="QSBench Unified Explorer", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # 🌌 QSBench: Quantum Synthetic Benchmark Explorer | |
| **Unified interface for Core, Noise-Affected, and Hardware-Transpiled Quantum Datasets.** | |
| Browse the demo datasets from the QSBench family, run baseline ML models, and analyze noise robustness across different distributions. | |
| """ | |
| ) | |
| with gr.Tabs(): | |
| # TAB 1: DATA EXPLORER | |
| with gr.TabItem("🔎 Dataset Explorer"): | |
| with gr.Row(): | |
| ds_selector = gr.Dropdown(choices=list(DATASET_MAP.keys()), value="Core (Clean)", label="Select Dataset Pack") | |
| split_selector = gr.Dropdown(choices=["train", "test", "validation"], value="train", label="Split") | |
| data_table = gr.Dataframe(label="Sample Data (First 10 rows)", interactive=False) | |
| ds_selector.change(update_explorer, inputs=[ds_selector], outputs=[split_selector, data_table]) | |
| split_selector.change(filter_explorer_by_split, inputs=[ds_selector, split_selector], outputs=[data_table]) | |
| # TAB 2: ML BASELINE | |
| with gr.TabItem("🤖 ML Baseline Demo"): | |
| gr.Markdown("Select a dataset and train a Random Forest regressor to predict expectation values from circuit metadata.") | |
| model_ds_selector = gr.Dropdown(choices=list(DATASET_MAP.keys()), value="Core (Clean)", label="Target Dataset") | |
| train_btn = gr.Button("Train Baseline Model", variant="primary") | |
| with gr.Row(): | |
| plot_output = gr.Plot(label="Model Metrics") | |
| text_output = gr.Markdown(label="Stats") | |
| train_btn.click(run_model_demo, inputs=[model_ds_selector], outputs=[plot_output, text_output]) | |
| # TAB 3: BENCHMARKING | |
| with gr.TabItem("📊 Noise Robustness Benchmark"): | |
| gr.Markdown("Analysis of model performance degradation under distribution shifts (Clean → Noisy → Hardware).") | |
| bench_btn = gr.Button("Load Precomputed Benchmark Results") | |
| bench_table = gr.Dataframe(interactive=False) | |
| with gr.Row(): | |
| r2_plot = gr.Plot() | |
| mae_plot = gr.Plot() | |
| bench_btn.click(load_benchmark, outputs=[bench_table, r2_plot, mae_plot]) | |
| gr.Markdown( | |
| """ | |
| --- | |
| ### About QSBench | |
| QSBench is a collection of high-quality synthetic datasets designed for **Quantum Machine Learning** research. | |
| It provides paired ideal/noisy data, structural circuit metrics, and transpilation metadata. | |
| 🔗 [Website](https://qsbench.github.io) | 🤗 [Hugging Face](https://huggingface.co/QSBench) | 🛠️ [GitHub](https://github.com/QSBench) | |
| """ | |
| ) | |
| # Initial load | |
| demo.load(update_explorer, inputs=[ds_selector], outputs=[split_selector, data_table]) | |
| if __name__ == "__main__": | |
| demo.launch() |