QSBench's picture
Update app.py
1971b4a verified
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import logging
import requests
from typing import List, Tuple, Dict, Optional
from datasets import load_dataset
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_absolute_error, r2_score
from sklearn.model_selection import train_test_split
# --- CONFIG & LOGGING ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
REPO_CONFIG = {
"Core (Clean)": {
"repo": "QSBench/QSBench-Core-v1.0.0-demo",
"meta_url": "https://huggingface.co/datasets/QSBench/QSBench-Core-v1.0.0-demo/raw/metadata/meta/meta.json",
"report_url": "https://huggingface.co/datasets/QSBench/QSBench-Core-v1.0.0-demo/raw/metadata/meta/report.json"
},
"Depolarizing Noise": {
"repo": "QSBench/QSBench-Depolarizing-Demo-v1.0.0",
"meta_url": "https://huggingface.co/datasets/QSBench/QSBench-Depolarizing-Demo-v1.0.0/raw/meta/meta/meta.json",
"report_url": "https://huggingface.co/datasets/QSBench/QSBench-Depolarizing-Demo-v1.0.0/raw/meta/meta/report.json"
},
"Amplitude Damping": {
"repo": "QSBench/QSBench-Amplitude-v1.0.0-demo",
"meta_url": "https://huggingface.co/datasets/QSBench/QSBench-Amplitude-v1.0.0-demo/raw/meta/meta/meta.json",
"report_url": "https://huggingface.co/datasets/QSBench/QSBench-Amplitude-v1.0.0-demo/raw/meta/meta/report.json"
},
"Transpilation (10q)": {
"repo": "QSBench/QSBench-Transpilation-v1.0.0-demo",
"meta_url": "https://huggingface.co/datasets/QSBench/QSBench-Transpilation-v1.0.0-demo/raw/meta/meta/meta.json",
"report_url": "https://huggingface.co/datasets/QSBench/QSBench-Transpilation-v1.0.0-demo/raw/meta/meta/report.json"
}
}
# Columns that are NOT features (system, categorical, or targets)
NON_FEATURE_COLS = {
"sample_id", "sample_seed", "circuit_hash", "split", "circuit_qasm",
"qasm_raw", "qasm_transpiled", "circuit_type_resolved", "circuit_type_requested",
"noise_type", "noise_prob", "observable_bases", "observable_mode", "backend_device",
"precision_mode", "circuit_signature", "entanglement", "shots", "gpu_requested", "gpu_available"
}
_ASSET_CACHE = {}
def load_all_assets(key: str) -> Dict:
if key not in _ASSET_CACHE:
logger.info(f"Fetching {key}...")
ds = load_dataset(REPO_CONFIG[key]["repo"])
meta = requests.get(REPO_CONFIG[key]["meta_url"]).json()
report = requests.get(REPO_CONFIG[key]["report_url"]).json()
_ASSET_CACHE[key] = {"df": pd.DataFrame(ds["train"]), "meta": meta, "report": report}
return _ASSET_CACHE[key]
# --- UI LOGIC ---
def load_guide_content():
"""Reads the content of GUIDE.md from the local directory."""
try:
with open("GUIDE.md", "r", encoding="utf-8") as f:
return f.read()
except FileNotFoundError:
return "### ⚠️ Error: GUIDE.md not found. Please ensure it is in the root directory."
def sync_ml_metrics(ds_name: str):
"""Dynamically finds all available numerical metrics (features) from CSV/Dataset"""
assets = load_all_assets(ds_name)
df = assets["df"]
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
valid_features = [
c for c in numeric_cols
if c not in NON_FEATURE_COLS
and not any(prefix in c for prefix in ["ideal_", "noisy_", "error_", "sign_"])
]
top_tier = ["gate_entropy", "meyer_wallach", "adjacency", "depth", "total_gates", "cx_count"]
defaults = [f for f in top_tier if f in valid_features]
return gr.update(choices=valid_features, value=defaults or valid_features[:5])
def train_model(ds_name: str, features: List[str]):
"""Trains a Multi-Target Regressor to predict X, Y, and Z expectation values."""
if not features: return None, "### ❌ Error: No metrics selected."
assets = load_all_assets(ds_name)
df = assets["df"]
targets = ["ideal_expval_X_global", "ideal_expval_Y_global", "ideal_expval_Z_global"]
available_targets = [t for t in targets if t in df.columns]
if not available_targets:
return None, "### ❌ Error: Target columns not found in dataset."
train_df = df.dropna(subset=features + available_targets)
X, y = train_df[features], train_df[available_targets]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
model = RandomForestRegressor(n_estimators=100, max_depth=10, n_jobs=-1).fit(X_train, y_train)
preds = model.predict(X_test)
sns.set_theme(style="whitegrid", context="talk")
fig, axes = plt.subplots(1, len(available_targets), figsize=(8 * len(available_targets), 7))
if len(available_targets) == 1: axes = [axes]
summary_text = "### πŸ“Š Multi-Target Performance Summary\n"
colors = ['#2980b9', '#8e44ad', '#2c3e50']
for i, target_col in enumerate(available_targets):
y_true_axis = y_test.iloc[:, i]
y_pred_axis = preds[:, i]
r2 = r2_score(y_true_axis, y_pred_axis)
mae = mean_absolute_error(y_true_axis, y_pred_axis)
axes[i].scatter(y_true_axis, y_pred_axis, alpha=0.3, color=colors[i % len(colors)])
axes[i].plot([-1, 1], [-1, 1], 'r--', lw=2)
axes[i].set_title(f"Target: {target_col}\n(RΒ²: {r2:.3f})")
axes[i].set_xlabel("Ground Truth"); axes[i].set_ylabel("Prediction")
axes[i].set_xlim([-1.1, 1.1]); axes[i].set_ylim([-1.1, 1.1])
axis_name = target_col.split('_')[2]
summary_text += f"- **{axis_name}-Axis:** MAE = {mae:.4f} | RΒ² = {r2:.3f}\n"
plt.tight_layout(pad=3.0)
return fig, summary_text
def update_explorer(ds_name: str, split_name: str):
"""Updates the data view based on dataset and split selection."""
assets = load_all_assets(ds_name)
df = assets["df"]
# Get unique splits for the dropdown update
unique_splits = df["split"].unique().tolist() if "split" in df.columns else ["train"]
# Filter dataframe by selected split
if "split" in df.columns:
filtered_df = df[df["split"] == split_name]
# If the split_name is not found in the new dataset, fallback to first available
if filtered_df.empty:
split_name = unique_splits[0]
filtered_df = df[df["split"] == split_name]
else:
filtered_df = df
display_df = filtered_df.head(10)
# Extract QASM samples
raw = display_df["qasm_raw"].iloc[0] if "qasm_raw" in display_df.columns and not display_df.empty else "// N/A"
tr = display_df["qasm_transpiled"].iloc[0] if "qasm_transpiled" in display_df.columns and not display_df.empty else "// N/A"
return (
gr.update(choices=unique_splits, value=split_name),
display_df,
raw,
tr,
f"### πŸ“‹ {ds_name} Explorer"
)
# --- INTERFACE ---
with gr.Blocks(theme=gr.themes.Soft(), title="QSBench Hub") as demo:
gr.Markdown("# 🌌 QSBench: Quantum Analytics Hub")
with gr.Tabs():
with gr.TabItem("πŸ”Ž Explorer"):
meta_txt = gr.Markdown("### Loading...")
with gr.Row():
ds_sel = gr.Dropdown(list(REPO_CONFIG.keys()), value="Core (Clean)", label="Dataset")
sp_sel = gr.Dropdown(["train"], value="train", label="Split")
data_view = gr.Dataframe(interactive=False)
with gr.Row():
c_raw = gr.Code(label="Source QASM", language="python")
c_tr = gr.Code(label="Transpiled QASM", language="python")
with gr.TabItem("πŸ€– ML Training"):
gr.Markdown("Multi-target regression: predicting X, Y, and Z components simultaneously.")
with gr.Row():
with gr.Column(scale=1):
ml_ds_sel = gr.Dropdown(list(REPO_CONFIG.keys()), value="Core (Clean)", label="Select Dataset")
ml_feat_sel = gr.CheckboxGroup(label="Structural Metrics", choices=[])
train_btn = gr.Button("Train Multi-Output Model", variant="primary")
with gr.Column(scale=2):
p_out = gr.Plot()
t_out = gr.Markdown()
with gr.TabItem("πŸ“– Methodology"):
meth_md = gr.Markdown(value=load_guide_content())
gr.Markdown(f"""
---
### πŸ”— Project Links
[**🌐 Website**](https://qsbench.github.io) | [**πŸ€— Hugging Face**](https://huggingface.co/QSBench) | [**πŸ’» GitHub**](https://github.com/QSBench)
""")
# --- EVENTS ---
# Explorer: Fixed by adding sp_sel.change
ds_sel.change(update_explorer, [ds_sel, sp_sel], [sp_sel, data_view, c_raw, c_tr, meta_txt])
sp_sel.change(update_explorer, [ds_sel, sp_sel], [sp_sel, data_view, c_raw, c_tr, meta_txt])
# ML Tab
ml_ds_sel.change(sync_ml_metrics, [ml_ds_sel], [ml_feat_sel])
train_btn.click(train_model, [ml_ds_sel, ml_feat_sel], [p_out, t_out])
# Initial Load
demo.load(update_explorer, [ds_sel, sp_sel], [sp_sel, data_view, c_raw, c_tr, meta_txt])
demo.load(sync_ml_metrics, [ml_ds_sel], [ml_feat_sel])
if __name__ == "__main__":
demo.launch()