QSBench commited on
Commit
c8aba73
·
verified ·
1 Parent(s): 382832c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -32
app.py CHANGED
@@ -75,7 +75,7 @@ def sync_ml_metrics(ds_name: str):
75
  # Extract all numeric columns
76
  numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
77
 
78
- # Filter: remove system IDs and targets (anything starting with ideal/noisy/error/sign)
79
  valid_features = [
80
  c for c in numeric_cols
81
  if c not in NON_FEATURE_COLS
@@ -89,43 +89,56 @@ def sync_ml_metrics(ds_name: str):
89
  return gr.update(choices=valid_features, value=defaults or valid_features[:5])
90
 
91
  def train_model(ds_name: str, features: List[str]):
 
92
  if not features: return None, "### ❌ Error: No metrics selected."
93
  assets = load_all_assets(ds_name)
94
  df = assets["df"]
95
 
96
- # Use global Z value as target
97
- target = "ideal_expval_Z_global"
98
 
99
- train_df = df.dropna(subset=features + [target])
100
- X, y = train_df[features], train_df[target]
 
 
 
 
 
101
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
102
 
 
103
  model = RandomForestRegressor(n_estimators=100, max_depth=10, n_jobs=-1).fit(X_train, y_train)
104
  preds = model.predict(X_test)
105
 
106
  sns.set_theme(style="whitegrid", context="talk")
107
- fig, axes = plt.subplots(1, 3, figsize=(24, 8))
 
 
 
 
 
108
 
109
- # 1. Prediction vs Reality
110
- axes[0].scatter(y_test, preds, alpha=0.3, color='#2c3e50')
111
- axes[0].plot([y.min(), y.max()], [y.min(), y.max()], 'r--', lw=2)
112
- axes[0].set_title(f"Accuracy (R²: {r2_score(y_test, preds):.3f})")
113
- axes[0].set_xlabel("Ideal ExpVal"); axes[0].set_ylabel("Predicted")
114
-
115
- # 2. Feature Importance
116
- imp = model.feature_importances_
117
- # Take top 10 if there are many, or all if few
118
- top_n = min(len(features), 10)
119
- idx = np.argsort(imp)[-top_n:]
120
- axes[1].barh([features[i] for i in idx], imp[idx], color='#27ae60')
121
- axes[1].set_title(f"Top {top_n} Metrics Importance")
122
-
123
- # 3. Residuals
124
- sns.histplot(y_test - preds, kde=True, ax=axes[2], color='#d35400')
125
- axes[2].set_title("Residuals (Error Distribution)")
 
126
 
127
  plt.tight_layout(pad=3.0)
128
- return fig, f"**Mean Absolute Error (MAE):** {mean_absolute_error(y_test, preds):.4f}"
129
 
130
  def update_explorer(ds_name: str, split_name: str):
131
  assets = load_all_assets(ds_name)
@@ -153,19 +166,18 @@ with gr.Blocks(theme=gr.themes.Soft(), title="QSBench Hub") as demo:
153
  c_raw = gr.Code(label="Source QASM", language="python")
154
  c_tr = gr.Code(label="Transpiled QASM", language="python")
155
 
156
- with gr.TabItem("🤖 ML Training"):
 
157
  with gr.Row():
158
  with gr.Column(scale=1):
159
  ml_ds_sel = gr.Dropdown(list(REPO_CONFIG.keys()), value="Core (Clean)", label="Select Dataset")
160
- # Dynamic metrics list extracted from CSV
161
- ml_feat_sel = gr.CheckboxGroup(label="Available Metrics (extracted from CSV)", choices=[])
162
- train_btn = gr.Button("Execute Baseline", variant="primary")
163
  with gr.Column(scale=2):
164
  p_out = gr.Plot()
165
  t_out = gr.Markdown()
166
 
167
  with gr.TabItem("📖 Methodology"):
168
- # Automatically loads content from GUIDE.md
169
  meth_md = gr.Markdown(value=load_guide_content())
170
 
171
  gr.Markdown(f"""
@@ -175,10 +187,7 @@ with gr.Blocks(theme=gr.themes.Soft(), title="QSBench Hub") as demo:
175
  """)
176
 
177
  # --- EVENTS ---
178
- # Explorer
179
  ds_sel.change(update_explorer, [ds_sel, sp_sel], [sp_sel, data_view, c_raw, c_tr, meta_txt])
180
-
181
- # ML Tab: Dynamic metrics update
182
  ml_ds_sel.change(sync_ml_metrics, [ml_ds_sel], [ml_feat_sel])
183
  train_btn.click(train_model, [ml_ds_sel, ml_feat_sel], [p_out, t_out])
184
 
 
75
  # Extract all numeric columns
76
  numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
77
 
78
+ # Filter: remove system IDs and ALL target components (X, Y, Z, global, local, error)
79
  valid_features = [
80
  c for c in numeric_cols
81
  if c not in NON_FEATURE_COLS
 
89
  return gr.update(choices=valid_features, value=defaults or valid_features[:5])
90
 
91
  def train_model(ds_name: str, features: List[str]):
92
+ """Trains a Multi-Target Regressor to predict X, Y, and Z expectation values."""
93
  if not features: return None, "### ❌ Error: No metrics selected."
94
  assets = load_all_assets(ds_name)
95
  df = assets["df"]
96
 
97
+ # Multi-Target: Prediction of all global expectation values
98
+ targets = ["ideal_expval_X_global", "ideal_expval_Y_global", "ideal_expval_Z_global"]
99
 
100
+ # Filter targets that actually exist in the dataframe (handle cases where some might be missing)
101
+ available_targets = [t for t in targets if t in df.columns]
102
+ if not available_targets:
103
+ return None, "### ❌ Error: Target columns not found in dataset."
104
+
105
+ train_df = df.dropna(subset=features + available_targets)
106
+ X, y = train_df[features], train_df[available_targets]
107
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
108
 
109
+ # RandomForestRegressor supports multi-output regression out of the box
110
  model = RandomForestRegressor(n_estimators=100, max_depth=10, n_jobs=-1).fit(X_train, y_train)
111
  preds = model.predict(X_test)
112
 
113
  sns.set_theme(style="whitegrid", context="talk")
114
+ fig, axes = plt.subplots(1, len(available_targets), figsize=(8 * len(available_targets), 7))
115
+
116
+ # If only one target available, axes is not an array
117
+ if len(available_targets) == 1: axes = [axes]
118
+
119
+ summary_text = "### 📊 Multi-Target Performance Summary\n"
120
 
121
+ colors = ['#2980b9', '#8e44ad', '#2c3e50'] # Blue for X, Purple for Y, Dark for Z
122
+
123
+ for i, target_col in enumerate(available_targets):
124
+ y_true_axis = y_test.iloc[:, i]
125
+ y_pred_axis = preds[:, i]
126
+
127
+ r2 = r2_score(y_true_axis, y_pred_axis)
128
+ mae = mean_absolute_error(y_true_axis, y_pred_axis)
129
+
130
+ # Parity Plot for each basis
131
+ axes[i].scatter(y_true_axis, y_pred_axis, alpha=0.3, color=colors[i % len(colors)])
132
+ axes[i].plot([-1, 1], [-1, 1], 'r--', lw=2) # Theoretical range of expectation values is [-1, 1]
133
+ axes[i].set_title(f"Target: {target_col}\n(R²: {r2:.3f})")
134
+ axes[i].set_xlabel("Ground Truth (Ideal)"); axes[i].set_ylabel("Model Prediction")
135
+ axes[i].set_xlim([-1.1, 1.1]); axes[i].set_ylim([-1.1, 1.1])
136
+
137
+ axis_name = target_col.split('_')[2] # Extracts X, Y, or Z
138
+ summary_text += f"- **{axis_name}-Axis:** MAE = {mae:.4f} | R² = {r2:.3f}\n"
139
 
140
  plt.tight_layout(pad=3.0)
141
+ return fig, summary_text
142
 
143
  def update_explorer(ds_name: str, split_name: str):
144
  assets = load_all_assets(ds_name)
 
166
  c_raw = gr.Code(label="Source QASM", language="python")
167
  c_tr = gr.Code(label="Transpiled QASM", language="python")
168
 
169
+ with gr.TabItem("🤖 ML Training (Multi-Target)"):
170
+ gr.Markdown("Training models to predict the full Bloch vector expectation values (X, Y, Z) simultaneously.")
171
  with gr.Row():
172
  with gr.Column(scale=1):
173
  ml_ds_sel = gr.Dropdown(list(REPO_CONFIG.keys()), value="Core (Clean)", label="Select Dataset")
174
+ ml_feat_sel = gr.CheckboxGroup(label="Structural Metrics (Features)", choices=[])
175
+ train_btn = gr.Button("Train Multi-Output Model", variant="primary")
 
176
  with gr.Column(scale=2):
177
  p_out = gr.Plot()
178
  t_out = gr.Markdown()
179
 
180
  with gr.TabItem("📖 Methodology"):
 
181
  meth_md = gr.Markdown(value=load_guide_content())
182
 
183
  gr.Markdown(f"""
 
187
  """)
188
 
189
  # --- EVENTS ---
 
190
  ds_sel.change(update_explorer, [ds_sel, sp_sel], [sp_sel, data_view, c_raw, c_tr, meta_txt])
 
 
191
  ml_ds_sel.change(sync_ml_metrics, [ml_ds_sel], [ml_feat_sel])
192
  train_btn.click(train_model, [ml_ds_sel, ml_feat_sel], [p_out, t_out])
193