Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -75,7 +75,7 @@ def sync_ml_metrics(ds_name: str):
|
|
| 75 |
# Extract all numeric columns
|
| 76 |
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
|
| 77 |
|
| 78 |
-
# Filter: remove system IDs and
|
| 79 |
valid_features = [
|
| 80 |
c for c in numeric_cols
|
| 81 |
if c not in NON_FEATURE_COLS
|
|
@@ -89,43 +89,56 @@ def sync_ml_metrics(ds_name: str):
|
|
| 89 |
return gr.update(choices=valid_features, value=defaults or valid_features[:5])
|
| 90 |
|
| 91 |
def train_model(ds_name: str, features: List[str]):
|
|
|
|
| 92 |
if not features: return None, "### ❌ Error: No metrics selected."
|
| 93 |
assets = load_all_assets(ds_name)
|
| 94 |
df = assets["df"]
|
| 95 |
|
| 96 |
-
#
|
| 97 |
-
|
| 98 |
|
| 99 |
-
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
| 102 |
|
|
|
|
| 103 |
model = RandomForestRegressor(n_estimators=100, max_depth=10, n_jobs=-1).fit(X_train, y_train)
|
| 104 |
preds = model.predict(X_test)
|
| 105 |
|
| 106 |
sns.set_theme(style="whitegrid", context="talk")
|
| 107 |
-
fig, axes = plt.subplots(1,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
-
#
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
|
|
|
| 126 |
|
| 127 |
plt.tight_layout(pad=3.0)
|
| 128 |
-
return fig,
|
| 129 |
|
| 130 |
def update_explorer(ds_name: str, split_name: str):
|
| 131 |
assets = load_all_assets(ds_name)
|
|
@@ -153,19 +166,18 @@ with gr.Blocks(theme=gr.themes.Soft(), title="QSBench Hub") as demo:
|
|
| 153 |
c_raw = gr.Code(label="Source QASM", language="python")
|
| 154 |
c_tr = gr.Code(label="Transpiled QASM", language="python")
|
| 155 |
|
| 156 |
-
with gr.TabItem("🤖 ML Training"):
|
|
|
|
| 157 |
with gr.Row():
|
| 158 |
with gr.Column(scale=1):
|
| 159 |
ml_ds_sel = gr.Dropdown(list(REPO_CONFIG.keys()), value="Core (Clean)", label="Select Dataset")
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
train_btn = gr.Button("Execute Baseline", variant="primary")
|
| 163 |
with gr.Column(scale=2):
|
| 164 |
p_out = gr.Plot()
|
| 165 |
t_out = gr.Markdown()
|
| 166 |
|
| 167 |
with gr.TabItem("📖 Methodology"):
|
| 168 |
-
# Automatically loads content from GUIDE.md
|
| 169 |
meth_md = gr.Markdown(value=load_guide_content())
|
| 170 |
|
| 171 |
gr.Markdown(f"""
|
|
@@ -175,10 +187,7 @@ with gr.Blocks(theme=gr.themes.Soft(), title="QSBench Hub") as demo:
|
|
| 175 |
""")
|
| 176 |
|
| 177 |
# --- EVENTS ---
|
| 178 |
-
# Explorer
|
| 179 |
ds_sel.change(update_explorer, [ds_sel, sp_sel], [sp_sel, data_view, c_raw, c_tr, meta_txt])
|
| 180 |
-
|
| 181 |
-
# ML Tab: Dynamic metrics update
|
| 182 |
ml_ds_sel.change(sync_ml_metrics, [ml_ds_sel], [ml_feat_sel])
|
| 183 |
train_btn.click(train_model, [ml_ds_sel, ml_feat_sel], [p_out, t_out])
|
| 184 |
|
|
|
|
| 75 |
# Extract all numeric columns
|
| 76 |
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
|
| 77 |
|
| 78 |
+
# Filter: remove system IDs and ALL target components (X, Y, Z, global, local, error)
|
| 79 |
valid_features = [
|
| 80 |
c for c in numeric_cols
|
| 81 |
if c not in NON_FEATURE_COLS
|
|
|
|
| 89 |
return gr.update(choices=valid_features, value=defaults or valid_features[:5])
|
| 90 |
|
| 91 |
def train_model(ds_name: str, features: List[str]):
|
| 92 |
+
"""Trains a Multi-Target Regressor to predict X, Y, and Z expectation values."""
|
| 93 |
if not features: return None, "### ❌ Error: No metrics selected."
|
| 94 |
assets = load_all_assets(ds_name)
|
| 95 |
df = assets["df"]
|
| 96 |
|
| 97 |
+
# Multi-Target: Prediction of all global expectation values
|
| 98 |
+
targets = ["ideal_expval_X_global", "ideal_expval_Y_global", "ideal_expval_Z_global"]
|
| 99 |
|
| 100 |
+
# Filter targets that actually exist in the dataframe (handle cases where some might be missing)
|
| 101 |
+
available_targets = [t for t in targets if t in df.columns]
|
| 102 |
+
if not available_targets:
|
| 103 |
+
return None, "### ❌ Error: Target columns not found in dataset."
|
| 104 |
+
|
| 105 |
+
train_df = df.dropna(subset=features + available_targets)
|
| 106 |
+
X, y = train_df[features], train_df[available_targets]
|
| 107 |
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
| 108 |
|
| 109 |
+
# RandomForestRegressor supports multi-output regression out of the box
|
| 110 |
model = RandomForestRegressor(n_estimators=100, max_depth=10, n_jobs=-1).fit(X_train, y_train)
|
| 111 |
preds = model.predict(X_test)
|
| 112 |
|
| 113 |
sns.set_theme(style="whitegrid", context="talk")
|
| 114 |
+
fig, axes = plt.subplots(1, len(available_targets), figsize=(8 * len(available_targets), 7))
|
| 115 |
+
|
| 116 |
+
# If only one target available, axes is not an array
|
| 117 |
+
if len(available_targets) == 1: axes = [axes]
|
| 118 |
+
|
| 119 |
+
summary_text = "### 📊 Multi-Target Performance Summary\n"
|
| 120 |
|
| 121 |
+
colors = ['#2980b9', '#8e44ad', '#2c3e50'] # Blue for X, Purple for Y, Dark for Z
|
| 122 |
+
|
| 123 |
+
for i, target_col in enumerate(available_targets):
|
| 124 |
+
y_true_axis = y_test.iloc[:, i]
|
| 125 |
+
y_pred_axis = preds[:, i]
|
| 126 |
+
|
| 127 |
+
r2 = r2_score(y_true_axis, y_pred_axis)
|
| 128 |
+
mae = mean_absolute_error(y_true_axis, y_pred_axis)
|
| 129 |
+
|
| 130 |
+
# Parity Plot for each basis
|
| 131 |
+
axes[i].scatter(y_true_axis, y_pred_axis, alpha=0.3, color=colors[i % len(colors)])
|
| 132 |
+
axes[i].plot([-1, 1], [-1, 1], 'r--', lw=2) # Theoretical range of expectation values is [-1, 1]
|
| 133 |
+
axes[i].set_title(f"Target: {target_col}\n(R²: {r2:.3f})")
|
| 134 |
+
axes[i].set_xlabel("Ground Truth (Ideal)"); axes[i].set_ylabel("Model Prediction")
|
| 135 |
+
axes[i].set_xlim([-1.1, 1.1]); axes[i].set_ylim([-1.1, 1.1])
|
| 136 |
+
|
| 137 |
+
axis_name = target_col.split('_')[2] # Extracts X, Y, or Z
|
| 138 |
+
summary_text += f"- **{axis_name}-Axis:** MAE = {mae:.4f} | R² = {r2:.3f}\n"
|
| 139 |
|
| 140 |
plt.tight_layout(pad=3.0)
|
| 141 |
+
return fig, summary_text
|
| 142 |
|
| 143 |
def update_explorer(ds_name: str, split_name: str):
|
| 144 |
assets = load_all_assets(ds_name)
|
|
|
|
| 166 |
c_raw = gr.Code(label="Source QASM", language="python")
|
| 167 |
c_tr = gr.Code(label="Transpiled QASM", language="python")
|
| 168 |
|
| 169 |
+
with gr.TabItem("🤖 ML Training (Multi-Target)"):
|
| 170 |
+
gr.Markdown("Training models to predict the full Bloch vector expectation values (X, Y, Z) simultaneously.")
|
| 171 |
with gr.Row():
|
| 172 |
with gr.Column(scale=1):
|
| 173 |
ml_ds_sel = gr.Dropdown(list(REPO_CONFIG.keys()), value="Core (Clean)", label="Select Dataset")
|
| 174 |
+
ml_feat_sel = gr.CheckboxGroup(label="Structural Metrics (Features)", choices=[])
|
| 175 |
+
train_btn = gr.Button("Train Multi-Output Model", variant="primary")
|
|
|
|
| 176 |
with gr.Column(scale=2):
|
| 177 |
p_out = gr.Plot()
|
| 178 |
t_out = gr.Markdown()
|
| 179 |
|
| 180 |
with gr.TabItem("📖 Methodology"):
|
|
|
|
| 181 |
meth_md = gr.Markdown(value=load_guide_content())
|
| 182 |
|
| 183 |
gr.Markdown(f"""
|
|
|
|
| 187 |
""")
|
| 188 |
|
| 189 |
# --- EVENTS ---
|
|
|
|
| 190 |
ds_sel.change(update_explorer, [ds_sel, sp_sel], [sp_sel, data_view, c_raw, c_tr, meta_txt])
|
|
|
|
|
|
|
| 191 |
ml_ds_sel.change(sync_ml_metrics, [ml_ds_sel], [ml_feat_sel])
|
| 192 |
train_btn.click(train_model, [ml_ds_sel, ml_feat_sel], [p_out, t_out])
|
| 193 |
|