QSBench commited on
Commit
1971b4a
·
verified ·
1 Parent(s): 82c3c62

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -25
app.py CHANGED
@@ -71,18 +71,14 @@ def sync_ml_metrics(ds_name: str):
71
  """Dynamically finds all available numerical metrics (features) from CSV/Dataset"""
72
  assets = load_all_assets(ds_name)
73
  df = assets["df"]
74
-
75
- # Extract all numeric columns
76
  numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
77
 
78
- # Filter: remove system IDs and ALL target components (X, Y, Z, global, local, error)
79
  valid_features = [
80
  c for c in numeric_cols
81
  if c not in NON_FEATURE_COLS
82
  and not any(prefix in c for prefix in ["ideal_", "noisy_", "error_", "sign_"])
83
  ]
84
 
85
- # Priority metrics for "default" selection
86
  top_tier = ["gate_entropy", "meyer_wallach", "adjacency", "depth", "total_gates", "cx_count"]
87
  defaults = [f for f in top_tier if f in valid_features]
88
 
@@ -94,11 +90,9 @@ def train_model(ds_name: str, features: List[str]):
94
  assets = load_all_assets(ds_name)
95
  df = assets["df"]
96
 
97
- # Multi-Target: Prediction of all global expectation values
98
  targets = ["ideal_expval_X_global", "ideal_expval_Y_global", "ideal_expval_Z_global"]
99
-
100
- # Filter targets that actually exist in the dataframe (handle cases where some might be missing)
101
  available_targets = [t for t in targets if t in df.columns]
 
102
  if not available_targets:
103
  return None, "### ❌ Error: Target columns not found in dataset."
104
 
@@ -106,50 +100,65 @@ def train_model(ds_name: str, features: List[str]):
106
  X, y = train_df[features], train_df[available_targets]
107
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
108
 
109
- # RandomForestRegressor supports multi-output regression out of the box
110
  model = RandomForestRegressor(n_estimators=100, max_depth=10, n_jobs=-1).fit(X_train, y_train)
111
  preds = model.predict(X_test)
112
 
113
  sns.set_theme(style="whitegrid", context="talk")
114
  fig, axes = plt.subplots(1, len(available_targets), figsize=(8 * len(available_targets), 7))
115
-
116
- # If only one target available, axes is not an array
117
  if len(available_targets) == 1: axes = [axes]
118
 
119
  summary_text = "### 📊 Multi-Target Performance Summary\n"
120
-
121
- colors = ['#2980b9', '#8e44ad', '#2c3e50'] # Blue for X, Purple for Y, Dark for Z
122
 
123
  for i, target_col in enumerate(available_targets):
124
  y_true_axis = y_test.iloc[:, i]
125
  y_pred_axis = preds[:, i]
126
-
127
  r2 = r2_score(y_true_axis, y_pred_axis)
128
  mae = mean_absolute_error(y_true_axis, y_pred_axis)
129
 
130
- # Parity Plot for each basis
131
  axes[i].scatter(y_true_axis, y_pred_axis, alpha=0.3, color=colors[i % len(colors)])
132
- axes[i].plot([-1, 1], [-1, 1], 'r--', lw=2) # Theoretical range of expectation values is [-1, 1]
133
  axes[i].set_title(f"Target: {target_col}\n(R²: {r2:.3f})")
134
- axes[i].set_xlabel("Ground Truth (Ideal)"); axes[i].set_ylabel("Model Prediction")
135
  axes[i].set_xlim([-1.1, 1.1]); axes[i].set_ylim([-1.1, 1.1])
136
 
137
- axis_name = target_col.split('_')[2] # Extracts X, Y, or Z
138
  summary_text += f"- **{axis_name}-Axis:** MAE = {mae:.4f} | R² = {r2:.3f}\n"
139
 
140
  plt.tight_layout(pad=3.0)
141
  return fig, summary_text
142
 
143
  def update_explorer(ds_name: str, split_name: str):
 
144
  assets = load_all_assets(ds_name)
145
  df = assets["df"]
146
- splits = df["split"].unique().tolist() if "split" in df.columns else ["train"]
147
- display_df = df[df["split"] == split_name].head(10) if "split" in df.columns else df.head(10)
148
 
149
- raw = display_df["qasm_raw"].iloc[0] if "qasm_raw" in display_df.columns else "// N/A"
150
- tr = display_df["qasm_transpiled"].iloc[0] if "qasm_transpiled" in display_df.columns else "// N/A"
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
- return gr.update(choices=splits), display_df, raw, tr, f"### 📋 {ds_name} Explorer"
 
 
 
 
 
 
 
 
 
 
153
 
154
  # --- INTERFACE ---
155
  with gr.Blocks(theme=gr.themes.Soft(), title="QSBench Hub") as demo:
@@ -166,12 +175,12 @@ with gr.Blocks(theme=gr.themes.Soft(), title="QSBench Hub") as demo:
166
  c_raw = gr.Code(label="Source QASM", language="python")
167
  c_tr = gr.Code(label="Transpiled QASM", language="python")
168
 
169
- with gr.TabItem("🤖 ML Training (Multi-Target)"):
170
- gr.Markdown("Training models to predict the full Bloch vector expectation values (X, Y, Z) simultaneously.")
171
  with gr.Row():
172
  with gr.Column(scale=1):
173
  ml_ds_sel = gr.Dropdown(list(REPO_CONFIG.keys()), value="Core (Clean)", label="Select Dataset")
174
- ml_feat_sel = gr.CheckboxGroup(label="Structural Metrics (Features)", choices=[])
175
  train_btn = gr.Button("Train Multi-Output Model", variant="primary")
176
  with gr.Column(scale=2):
177
  p_out = gr.Plot()
@@ -187,7 +196,11 @@ with gr.Blocks(theme=gr.themes.Soft(), title="QSBench Hub") as demo:
187
  """)
188
 
189
  # --- EVENTS ---
 
190
  ds_sel.change(update_explorer, [ds_sel, sp_sel], [sp_sel, data_view, c_raw, c_tr, meta_txt])
 
 
 
191
  ml_ds_sel.change(sync_ml_metrics, [ml_ds_sel], [ml_feat_sel])
192
  train_btn.click(train_model, [ml_ds_sel, ml_feat_sel], [p_out, t_out])
193
 
 
71
  """Dynamically finds all available numerical metrics (features) from CSV/Dataset"""
72
  assets = load_all_assets(ds_name)
73
  df = assets["df"]
 
 
74
  numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
75
 
 
76
  valid_features = [
77
  c for c in numeric_cols
78
  if c not in NON_FEATURE_COLS
79
  and not any(prefix in c for prefix in ["ideal_", "noisy_", "error_", "sign_"])
80
  ]
81
 
 
82
  top_tier = ["gate_entropy", "meyer_wallach", "adjacency", "depth", "total_gates", "cx_count"]
83
  defaults = [f for f in top_tier if f in valid_features]
84
 
 
90
  assets = load_all_assets(ds_name)
91
  df = assets["df"]
92
 
 
93
  targets = ["ideal_expval_X_global", "ideal_expval_Y_global", "ideal_expval_Z_global"]
 
 
94
  available_targets = [t for t in targets if t in df.columns]
95
+
96
  if not available_targets:
97
  return None, "### ❌ Error: Target columns not found in dataset."
98
 
 
100
  X, y = train_df[features], train_df[available_targets]
101
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
102
 
 
103
  model = RandomForestRegressor(n_estimators=100, max_depth=10, n_jobs=-1).fit(X_train, y_train)
104
  preds = model.predict(X_test)
105
 
106
  sns.set_theme(style="whitegrid", context="talk")
107
  fig, axes = plt.subplots(1, len(available_targets), figsize=(8 * len(available_targets), 7))
 
 
108
  if len(available_targets) == 1: axes = [axes]
109
 
110
  summary_text = "### 📊 Multi-Target Performance Summary\n"
111
+ colors = ['#2980b9', '#8e44ad', '#2c3e50']
 
112
 
113
  for i, target_col in enumerate(available_targets):
114
  y_true_axis = y_test.iloc[:, i]
115
  y_pred_axis = preds[:, i]
 
116
  r2 = r2_score(y_true_axis, y_pred_axis)
117
  mae = mean_absolute_error(y_true_axis, y_pred_axis)
118
 
 
119
  axes[i].scatter(y_true_axis, y_pred_axis, alpha=0.3, color=colors[i % len(colors)])
120
+ axes[i].plot([-1, 1], [-1, 1], 'r--', lw=2)
121
  axes[i].set_title(f"Target: {target_col}\n(R²: {r2:.3f})")
122
+ axes[i].set_xlabel("Ground Truth"); axes[i].set_ylabel("Prediction")
123
  axes[i].set_xlim([-1.1, 1.1]); axes[i].set_ylim([-1.1, 1.1])
124
 
125
+ axis_name = target_col.split('_')[2]
126
  summary_text += f"- **{axis_name}-Axis:** MAE = {mae:.4f} | R² = {r2:.3f}\n"
127
 
128
  plt.tight_layout(pad=3.0)
129
  return fig, summary_text
130
 
131
  def update_explorer(ds_name: str, split_name: str):
132
+ """Updates the data view based on dataset and split selection."""
133
  assets = load_all_assets(ds_name)
134
  df = assets["df"]
 
 
135
 
136
+ # Get unique splits for the dropdown update
137
+ unique_splits = df["split"].unique().tolist() if "split" in df.columns else ["train"]
138
+
139
+ # Filter dataframe by selected split
140
+ if "split" in df.columns:
141
+ filtered_df = df[df["split"] == split_name]
142
+ # If the split_name is not found in the new dataset, fallback to first available
143
+ if filtered_df.empty:
144
+ split_name = unique_splits[0]
145
+ filtered_df = df[df["split"] == split_name]
146
+ else:
147
+ filtered_df = df
148
+
149
+ display_df = filtered_df.head(10)
150
 
151
+ # Extract QASM samples
152
+ raw = display_df["qasm_raw"].iloc[0] if "qasm_raw" in display_df.columns and not display_df.empty else "// N/A"
153
+ tr = display_df["qasm_transpiled"].iloc[0] if "qasm_transpiled" in display_df.columns and not display_df.empty else "// N/A"
154
+
155
+ return (
156
+ gr.update(choices=unique_splits, value=split_name),
157
+ display_df,
158
+ raw,
159
+ tr,
160
+ f"### 📋 {ds_name} Explorer"
161
+ )
162
 
163
  # --- INTERFACE ---
164
  with gr.Blocks(theme=gr.themes.Soft(), title="QSBench Hub") as demo:
 
175
  c_raw = gr.Code(label="Source QASM", language="python")
176
  c_tr = gr.Code(label="Transpiled QASM", language="python")
177
 
178
+ with gr.TabItem("🤖 ML Training"):
179
+ gr.Markdown("Multi-target regression: predicting X, Y, and Z components simultaneously.")
180
  with gr.Row():
181
  with gr.Column(scale=1):
182
  ml_ds_sel = gr.Dropdown(list(REPO_CONFIG.keys()), value="Core (Clean)", label="Select Dataset")
183
+ ml_feat_sel = gr.CheckboxGroup(label="Structural Metrics", choices=[])
184
  train_btn = gr.Button("Train Multi-Output Model", variant="primary")
185
  with gr.Column(scale=2):
186
  p_out = gr.Plot()
 
196
  """)
197
 
198
  # --- EVENTS ---
199
+ # Explorer: Fixed by adding sp_sel.change
200
  ds_sel.change(update_explorer, [ds_sel, sp_sel], [sp_sel, data_view, c_raw, c_tr, meta_txt])
201
+ sp_sel.change(update_explorer, [ds_sel, sp_sel], [sp_sel, data_view, c_raw, c_tr, meta_txt])
202
+
203
+ # ML Tab
204
  ml_ds_sel.change(sync_ml_metrics, [ml_ds_sel], [ml_feat_sel])
205
  train_btn.click(train_model, [ml_ds_sel, ml_feat_sel], [p_out, t_out])
206