QSBench's picture
Update app.py
76cdd53 verified
raw
history blame
8.94 kB
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import logging
import requests
from typing import List, Tuple, Dict, Optional
from datasets import load_dataset
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_absolute_error, r2_score
from sklearn.model_selection import train_test_split
# --- CONFIG & LOGGING ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
REPO_CONFIG = {
"Core (Clean)": {
"repo": "QSBench/QSBench-Core-v1.0.0-demo",
"meta_url": "https://huggingface.co/datasets/QSBench/QSBench-Core-v1.0.0-demo/raw/metadata/meta/meta.json",
"report_url": "https://huggingface.co/datasets/QSBench/QSBench-Core-v1.0.0-demo/raw/metadata/meta/report.json"
},
"Depolarizing Noise": {
"repo": "QSBench/QSBench-Depolarizing-Demo-v1.0.0",
"meta_url": "https://huggingface.co/datasets/QSBench/QSBench-Depolarizing-Demo-v1.0.0/raw/meta/meta/meta.json",
"report_url": "https://huggingface.co/datasets/QSBench/QSBench-Depolarizing-Demo-v1.0.0/raw/meta/meta/report.json"
},
"Amplitude Damping": {
"repo": "QSBench/QSBench-Amplitude-v1.0.0-demo",
"meta_url": "https://huggingface.co/datasets/QSBench/QSBench-Amplitude-v1.0.0-demo/raw/meta/meta/meta.json",
"report_url": "https://huggingface.co/datasets/QSBench/QSBench-Amplitude-v1.0.0-demo/raw/meta/meta/report.json"
},
"Transpilation (10q)": {
"repo": "QSBench/QSBench-Transpilation-v1.0.0-demo",
"meta_url": "https://huggingface.co/datasets/QSBench/QSBench-Transpilation-v1.0.0-demo/raw/meta/meta/meta.json",
"report_url": "https://huggingface.co/datasets/QSBench/QSBench-Transpilation-v1.0.0-demo/raw/meta/meta/report.json"
}
}
# Columns that are NOT features (system, categorical, or targets)
NON_FEATURE_COLS = {
"sample_id", "sample_seed", "circuit_hash", "split", "circuit_qasm",
"qasm_raw", "qasm_transpiled", "circuit_type_resolved", "circuit_type_requested",
"noise_type", "noise_prob", "observable_bases", "observable_mode", "backend_device",
"precision_mode", "circuit_signature", "entanglement", "shots", "gpu_requested", "gpu_available"
}
_ASSET_CACHE = {}
def load_all_assets(key: str) -> Dict:
if key not in _ASSET_CACHE:
logger.info(f"Fetching {key}...")
ds = load_dataset(REPO_CONFIG[key]["repo"])
meta = requests.get(REPO_CONFIG[key]["meta_url"]).json()
report = requests.get(REPO_CONFIG[key]["report_url"]).json()
_ASSET_CACHE[key] = {"df": pd.DataFrame(ds["train"]), "meta": meta, "report": report}
return _ASSET_CACHE[key]
# --- UI LOGIC ---
def get_methodology_content(ds_name: str):
assets = load_all_assets(ds_name)
meta, report = assets["meta"], assets["report"]
params = meta.get("parameters", {})
families = report.get("families", {})
fam_table = "| Family | Samples | Description |\n|:---|:---|:---|\n"
for f, count in families.items():
fam_table += f"| {f.upper()} | {count} | Synthetic {f} circuits |\n"
return f"""
## πŸ“– Methodology: {meta.get('dataset_version')}
**Generator:** QSBench v{meta.get('generator_version')}
**Config:** {params.get('n_qubits')} Qubits | Depth {params.get('depth')} | Noise `{params.get('noise')}` (p={params.get('noise_prob')})
### Circuit Family Coverage
{fam_table}
"""
def sync_ml_metrics(ds_name: str):
"""Dynamically finds all available numerical metrics (features) from CSV/Dataset"""
assets = load_all_assets(ds_name)
df = assets["df"]
# Extract all numeric columns
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
# Filter: remove system IDs and targets (anything starting with ideal/noisy/error/sign)
valid_features = [
c for c in numeric_cols
if c not in NON_FEATURE_COLS
and not any(prefix in c for prefix in ["ideal_", "noisy_", "error_", "sign_"])
]
# Priority metrics for "default" selection
top_tier = ["gate_entropy", "meyer_wallach", "adjacency", "depth", "total_gates", "cx_count"]
defaults = [f for f in top_tier if f in valid_features]
return gr.update(choices=valid_features, value=defaults or valid_features[:5])
def train_model(ds_name: str, features: List[str]):
if not features: return None, "### ❌ Error: No metrics selected."
assets = load_all_assets(ds_name)
df = assets["df"]
# Use global Z value as target
target = "ideal_expval_Z_global"
train_df = df.dropna(subset=features + [target])
X, y = train_df[features], train_df[target]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
model = RandomForestRegressor(n_estimators=100, max_depth=10, n_jobs=-1).fit(X_train, y_train)
preds = model.predict(X_test)
sns.set_theme(style="whitegrid", context="talk")
fig, axes = plt.subplots(1, 3, figsize=(24, 8))
# 1. Prediction vs Reality
axes[0].scatter(y_test, preds, alpha=0.3, color='#2c3e50')
axes[0].plot([y.min(), y.max()], [y.min(), y.max()], 'r--', lw=2)
axes[0].set_title(f"Accuracy (RΒ²: {r2_score(y_test, preds):.3f})")
axes[0].set_xlabel("Ideal ExpVal"); axes[0].set_ylabel("Predicted")
# 2. Feature Importance
imp = model.feature_importances_
# Take top 10 if there are many, or all if few
top_n = min(len(features), 10)
idx = np.argsort(imp)[-top_n:]
axes[1].barh([features[i] for i in idx], imp[idx], color='#27ae60')
axes[1].set_title(f"Top {top_n} Metrics Importance")
# 3. Residuals
sns.histplot(y_test - preds, kde=True, ax=axes[2], color='#d35400')
axes[2].set_title("Residuals (Error Distribution)")
plt.tight_layout(pad=3.0)
return fig, f"**Mean Absolute Error (MAE):** {mean_absolute_error(y_test, preds):.4f}"
def update_explorer(ds_name: str, split_name: str):
assets = load_all_assets(ds_name)
df = assets["df"]
splits = df["split"].unique().tolist() if "split" in df.columns else ["train"]
display_df = df[df["split"] == split_name].head(10) if "split" in df.columns else df.head(10)
raw = display_df["qasm_raw"].iloc[0] if "qasm_raw" in display_df.columns else "// N/A"
tr = display_df["qasm_transpiled"].iloc[0] if "qasm_transpiled" in display_df.columns else "// N/A"
return gr.update(choices=splits), display_df, raw, tr, f"### πŸ“‹ {ds_name} Explorer"
# --- INTERFACE ---
with gr.Blocks(theme=gr.themes.Soft(), title="QSBench Hub") as demo:
gr.Markdown("# 🌌 QSBench: Quantum Analytics Hub")
with gr.Tabs():
with gr.TabItem("πŸ”Ž Explorer"):
meta_txt = gr.Markdown("### Loading...")
with gr.Row():
ds_sel = gr.Dropdown(list(REPO_CONFIG.keys()), value="Core (Clean)", label="Dataset")
sp_sel = gr.Dropdown(["train"], value="train", label="Split")
data_view = gr.Dataframe(interactive=False)
with gr.Row():
c_raw = gr.Code(label="Source QASM", language="python")
c_tr = gr.Code(label="Transpiled QASM", language="python")
with gr.TabItem("πŸ€– ML Training"):
with gr.Row():
with gr.Column(scale=1):
ml_ds_sel = gr.Dropdown(list(REPO_CONFIG.keys()), value="Core (Clean)", label="Select Dataset")
# Dynamic metrics list extracted from CSV
ml_feat_sel = gr.CheckboxGroup(label="Available Metrics (extracted from CSV)", choices=[])
train_btn = gr.Button("Execute Baseline", variant="primary")
with gr.Column(scale=2):
p_out = gr.Plot()
t_out = gr.Markdown()
with gr.TabItem("πŸ“– Methodology"):
meth_ds_sel = gr.Dropdown(list(REPO_CONFIG.keys()), value="Core (Clean)", label="Dataset Details")
meth_md = gr.Markdown()
gr.Markdown(f"""
---
### πŸ”— Project Links
[**🌐 Website**](https://qsbench.github.io) | [**πŸ€— Hugging Face**](https://huggingface.co/QSBench) | [**πŸ’» GitHub**](https://github.com/QSBench)
""")
# --- EVENTS ---
# Explorer
ds_sel.change(update_explorer, [ds_sel, sp_sel], [sp_sel, data_view, c_raw, c_tr, meta_txt])
# ML Tab: Dynamic metrics update
ml_ds_sel.change(sync_ml_metrics, [ml_ds_sel], [ml_feat_sel])
train_btn.click(train_model, [ml_ds_sel, ml_feat_sel], [p_out, t_out])
# Methodology
meth_ds_sel.change(get_methodology_content, [meth_ds_sel], [meth_md])
# Initial Load
demo.load(update_explorer, [ds_sel, sp_sel], [sp_sel, data_view, c_raw, c_tr, meta_txt])
demo.load(sync_ml_metrics, [ml_ds_sel], [ml_feat_sel])
demo.load(get_methodology_content, [meth_ds_sel], [meth_md])
if __name__ == "__main__":
demo.launch()