QSBench commited on
Commit
6ad662b
Β·
verified Β·
1 Parent(s): 3844728

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -253
app.py CHANGED
@@ -1,6 +1,3 @@
1
- import json
2
- from pathlib import Path
3
-
4
  import gradio as gr
5
  import matplotlib.pyplot as plt
6
  import numpy as np
@@ -9,34 +6,29 @@ from datasets import load_dataset
9
  from sklearn.ensemble import RandomForestRegressor
10
  from sklearn.metrics import mean_absolute_error, r2_score
11
  from sklearn.model_selection import train_test_split
12
-
13
 
14
  # =========================================================
15
- # CONFIG
16
  # =========================================================
17
- HF_DATASET_NAME = "QSBench/QSBench-Core-v1.0.0-demo"
18
- LOCAL_BENCHMARK_CSV = "noise_benchmark_results.csv"
 
 
 
 
19
 
 
20
  TARGET_COL = "ideal_expval_Z_global"
21
 
22
  EXCLUDE_COLS = {
23
- "sample_id",
24
- "sample_seed",
25
- "ideal_expval_Z_global",
26
- "ideal_expval_X_global",
27
- "ideal_expval_Y_global",
28
- "noisy_expval_Z_global",
29
- "noisy_expval_X_global",
30
- "noisy_expval_Y_global",
31
- "error_Z_global",
32
- "error_X_global",
33
- "error_Y_global",
34
- "sign_ideal_Z_global",
35
- "sign_noisy_Z_global",
36
- "sign_ideal_X_global",
37
- "sign_noisy_X_global",
38
- "sign_ideal_Y_global",
39
- "sign_noisy_Y_global",
40
  }
41
 
42
  MODEL_PARAMS = dict(
@@ -47,89 +39,50 @@ MODEL_PARAMS = dict(
47
  n_jobs=-1,
48
  )
49
 
 
 
50
 
51
  # =========================================================
52
- # DATA LOADING
53
  # =========================================================
54
- def load_demo_dataset() -> pd.DataFrame:
55
- ds_all = load_dataset(HF_DATASET_NAME)
56
- df_all = pd.DataFrame(ds_all["train"])
57
- return df_all
58
-
59
-
60
- def split_by_split_column(df: pd.DataFrame) -> dict:
61
- if "split" not in df.columns:
62
- return {"all": df.reset_index(drop=True)}
63
-
64
- splits = {}
65
- for split_name in df["split"].dropna().astype(str).unique():
66
- splits[split_name] = df[df["split"].astype(str) == split_name].reset_index(drop=True)
67
- return splits
68
-
69
 
70
  def get_numeric_feature_cols(df: pd.DataFrame) -> list[str]:
71
  numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
72
- feature_cols = [c for c in numeric_cols if c not in EXCLUDE_COLS and not c.startswith("error_")]
73
- return feature_cols
74
-
75
-
76
- def load_benchmark_results() -> pd.DataFrame:
77
- path = Path(LOCAL_BENCHMARK_CSV)
78
- if not path.exists():
79
- return pd.DataFrame(
80
- [
81
- {
82
- "dataset": "noise_benchmark_results.csv not found",
83
- "split_used": "",
84
- "n_samples": 0,
85
- "r2": np.nan,
86
- "mae": np.nan,
87
- "avg_noise_prob": np.nan,
88
- "status": "missing_file",
89
- }
90
- ]
91
- )
92
-
93
- df = pd.read_csv(path)
94
- return df
95
-
96
-
97
- # =========================================================
98
- # DATA EXPLORER TAB
99
- # =========================================================
100
- def show_data(split_name, splits_cache):
101
- if not splits_cache:
102
- return pd.DataFrame([{"message": "Dataset not loaded"}])
103
-
104
- if split_name in splits_cache:
105
- return splits_cache[split_name].head(10)
106
-
107
- first_key = next(iter(splits_cache.keys()))
108
- return splits_cache[first_key].head(10)
109
-
110
 
111
  # =========================================================
112
- # MODEL DEMO TAB
113
  # =========================================================
114
- def train_model_demo(df: pd.DataFrame):
115
- if TARGET_COL not in df.columns:
116
- return None, "Target column not found."
117
-
 
 
 
 
 
 
 
 
 
118
  feature_cols = get_numeric_feature_cols(df)
119
- if not feature_cols:
120
- return None, "No numeric feature columns found."
121
-
122
- work_df = df.dropna(subset=feature_cols + [TARGET_COL]).reset_index(drop=True)
123
-
124
  X = work_df[feature_cols]
125
- y = work_df[TARGET_COL]
126
-
127
- if len(work_df) < 20:
128
- return None, "Not enough rows for a stable demo."
129
 
130
- X_train, X_test, y_train, y_test = train_test_split(
131
- X, y, test_size=0.2, random_state=42
132
- )
133
 
134
  model = RandomForestRegressor(**MODEL_PARAMS)
135
  model.fit(X_train, y_train)
@@ -138,175 +91,121 @@ def train_model_demo(df: pd.DataFrame):
138
  r2 = r2_score(y_test, preds)
139
  mae = mean_absolute_error(y_test, preds)
140
 
141
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13, 5))
142
 
143
- ax1.scatter(y_test, preds, alpha=0.6)
144
- min_v = min(float(y_test.min()), float(np.min(preds)))
145
- max_v = max(float(y_test.max()), float(np.max(preds)))
146
- ax1.plot([min_v, max_v], [min_v, max_v], linestyle="--")
147
- ax1.set_xlabel("True value")
148
- ax1.set_ylabel("Predicted value")
149
- ax1.set_title(f"Predictions vs Truth\nRΒ² = {r2:.4f}, MAE = {mae:.4f}")
150
 
 
151
  importances = model.feature_importances_
152
- top_idx = np.argsort(importances)[-10:]
153
- ax2.barh(range(len(top_idx)), importances[top_idx])
154
- ax2.set_yticks(range(len(top_idx)))
155
- ax2.set_yticklabels([feature_cols[i] for i in top_idx])
156
- ax2.set_xlabel("Feature importance")
157
- ax2.set_title("Top 10 features")
158
-
159
  plt.tight_layout()
160
-
161
- explanation = f"""
162
- **RΒ² score:** {r2:.4f}
163
- **MAE:** {mae:.4f}
164
-
165
- This is a lightweight baseline on the demo dataset. The point is not to get a perfect score, but to show that the dataset contains real structure and can support quantum ML experiments.
166
- """
167
-
168
- return fig, explanation
169
-
170
-
171
- # =========================================================
172
- # BENCHMARK TAB
173
- # =========================================================
174
- def make_bar_plot(df: pd.DataFrame, value_col: str, title: str, ylabel: str):
175
- fig, ax = plt.subplots(figsize=(9, 4.8))
176
- if df.empty or value_col not in df.columns or "dataset" not in df.columns:
177
- ax.text(0.5, 0.5, "No benchmark data available", ha="center", va="center")
178
- ax.axis("off")
179
- return fig
180
-
181
- plot_df = df.copy()
182
- plot_df = plot_df.dropna(subset=[value_col])
183
-
184
- ax.bar(plot_df["dataset"].astype(str), plot_df[value_col].astype(float))
185
- ax.set_title(title)
186
- ax.set_xlabel("Dataset")
187
- ax.set_ylabel(ylabel)
188
- ax.tick_params(axis="x", rotation=20)
189
- ax.axhline(0, linewidth=1)
190
  plt.tight_layout()
191
- return fig
192
-
193
-
194
- def build_benchmark_dashboard():
195
- df = load_benchmark_results()
196
-
197
- explanation = """
198
- ### Noise robustness benchmark
199
-
200
- This dashboard shows how a model trained on clean circuits behaves on:
201
- - **core_clean**
202
- - **depolarizing**
203
- - **amplitude_damping**
204
- - **transpilation**
205
-
206
- A sharp drop in RΒ² indicates strong distribution shift. That is exactly the value of the larger QSBench packs.
207
- """
208
-
209
- r2_fig = make_bar_plot(df, "r2", "Noise Robustness Benchmark β€” RΒ²", "RΒ²")
210
- mae_fig = make_bar_plot(df, "mae", "Noise Robustness Benchmark β€” MAE", "MAE")
211
 
212
- return df, r2_fig, mae_fig, explanation
 
 
 
 
 
 
213
 
 
214
 
215
  # =========================================================
216
- # APP
217
  # =========================================================
218
- def main():
219
- print("Loading demo dataset...")
220
- df_all = load_demo_dataset()
221
- splits_cache = split_by_split_column(df_all)
222
- split_choices = list(splits_cache.keys())
223
-
224
- default_split = split_choices[0] if split_choices else None
225
-
226
- with gr.Blocks(title="QSBench Demo Explorer") as demo:
227
- gr.Markdown(
228
- """
229
- # QSBench Demo Explorer
230
-
231
- Interactive demo for the QSBench Core demo dataset and precomputed noise robustness benchmark.
232
- """
233
- )
234
-
235
- with gr.Tabs():
236
- with gr.TabItem("Data Explorer"):
237
- gr.Markdown("Inspect the demo dataset split by split.")
238
- split_selector = gr.Dropdown(
239
- choices=split_choices,
240
- value=default_split,
241
- label="Choose a split",
242
- )
243
- data_table = gr.Dataframe(label="First 10 rows", interactive=False)
244
-
245
- split_selector.change(
246
- fn=lambda s: show_data(s, splits_cache),
247
- inputs=split_selector,
248
- outputs=data_table,
249
- )
250
-
251
- demo.load(
252
- fn=lambda: show_data(default_split, splits_cache),
253
- inputs=[],
254
- outputs=data_table,
255
- )
256
-
257
- with gr.TabItem("Model Demo"):
258
- gr.Markdown(
259
- """
260
- Train a lightweight Random Forest baseline on the demo data and inspect predictions.
261
- """
262
- )
263
- train_button = gr.Button("Train model")
264
- plot_output = gr.Plot()
265
- text_output = gr.Markdown()
266
-
267
- train_button.click(
268
- fn=lambda: train_model_demo(df_all),
269
- inputs=[],
270
- outputs=[plot_output, text_output],
271
- )
272
-
273
- with gr.TabItem("Noise Robustness Benchmark"):
274
- gr.Markdown(
275
- """
276
- This tab loads the precomputed local benchmark results from `noise_benchmark_results.csv`.
277
- """
278
- )
279
- refresh_button = gr.Button("Load benchmark results")
280
- benchmark_table = gr.Dataframe(label="Benchmark results", interactive=False)
281
- r2_plot = gr.Plot(label="RΒ² plot")
282
- mae_plot = gr.Plot(label="MAE plot")
283
- benchmark_text = gr.Markdown()
284
-
285
- refresh_button.click(
286
- fn=build_benchmark_dashboard,
287
- inputs=[],
288
- outputs=[benchmark_table, r2_plot, mae_plot, benchmark_text],
289
- )
290
-
291
- demo.load(
292
- fn=build_benchmark_dashboard,
293
- inputs=[],
294
- outputs=[benchmark_table, r2_plot, mae_plot, benchmark_text],
295
- )
296
-
297
- gr.Markdown("---")
298
- gr.Markdown(
299
- """
300
- ### What this demo shows
301
-
302
- - Data Explorer: inspect the dataset splits
303
- - Model Demo: quick baseline on the demo data
304
- - Noise Robustness Benchmark: precomputed results that show how performance changes across clean, noisy, and transpiled datasets
305
- """
306
- )
307
 
308
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
 
 
 
310
 
311
  if __name__ == "__main__":
312
- main()
 
 
 
 
1
  import gradio as gr
2
  import matplotlib.pyplot as plt
3
  import numpy as np
 
6
  from sklearn.ensemble import RandomForestRegressor
7
  from sklearn.metrics import mean_absolute_error, r2_score
8
  from sklearn.model_selection import train_test_split
9
+ from pathlib import Path
10
 
11
  # =========================================================
12
+ # CONFIG & REPOSITORIES
13
  # =========================================================
14
+ DATASET_MAP = {
15
+ "Core (Clean)": "QSBench/QSBench-Core-v1.0.0-demo",
16
+ "Depolarizing Noise": "QSBench/QSBench-Depolarizing-v1.0.0-demo",
17
+ "Amplitude Damping": "QSBench/QSBench-Amplitude-v1.0.0-demo",
18
+ "Transpilation (10q)": "QSBench/QSBench-Transpilation-v1.0.0-demo"
19
+ }
20
 
21
+ LOCAL_BENCHMARK_CSV = "noise_benchmark_results.csv"
22
  TARGET_COL = "ideal_expval_Z_global"
23
 
24
  EXCLUDE_COLS = {
25
+ "sample_id", "sample_seed", "split",
26
+ "ideal_expval_Z_global", "ideal_expval_X_global", "ideal_expval_Y_global",
27
+ "noisy_expval_Z_global", "noisy_expval_X_global", "noisy_expval_Y_global",
28
+ "error_Z_global", "error_X_global", "error_Y_global",
29
+ "sign_ideal_Z_global", "sign_noisy_Z_global",
30
+ "sign_ideal_X_global", "sign_noisy_X_global",
31
+ "sign_ideal_Y_global", "sign_noisy_Y_global",
 
 
 
 
 
 
 
 
 
 
32
  }
33
 
34
  MODEL_PARAMS = dict(
 
39
  n_jobs=-1,
40
  )
41
 
42
+ # Global cache to avoid redundant downloads
43
+ dataset_cache = {}
44
 
45
  # =========================================================
46
+ # DATA UTILS
47
  # =========================================================
48
+ def get_df(dataset_key):
49
+ if dataset_key not in dataset_cache:
50
+ repo_id = DATASET_MAP[dataset_key]
51
+ print(f"Downloading {repo_id}...")
52
+ ds = load_dataset(repo_id)
53
+ dataset_cache[dataset_key] = pd.DataFrame(ds["train"])
54
+ return dataset_cache[dataset_key]
 
 
 
 
 
 
 
 
55
 
56
  def get_numeric_feature_cols(df: pd.DataFrame) -> list[str]:
57
  numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
58
+ return [c for c in numeric_cols if c not in EXCLUDE_COLS and not c.startswith("error_")]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  # =========================================================
61
+ # TAB FUNCTIONS
62
  # =========================================================
63
+ def update_explorer(dataset_name):
64
+ df = get_df(dataset_name)
65
+ splits = df["split"].unique().tolist() if "split" in df.columns else ["all"]
66
+ return gr.update(choices=splits, value=splits[0]), df.head(10)
67
+
68
+ def filter_explorer_by_split(dataset_name, split_name):
69
+ df = get_df(dataset_name)
70
+ if "split" in df.columns:
71
+ return df[df["split"] == split_name].head(10)
72
+ return df.head(10)
73
+
74
+ def run_model_demo(dataset_name):
75
+ df = get_df(dataset_name)
76
  feature_cols = get_numeric_feature_cols(df)
77
+
78
+ # Ensure target exists, fallback to noisy if clean is missing (though unlikely in your schema)
79
+ target = TARGET_COL if TARGET_COL in df.columns else df.filter(like="expval").columns[0]
80
+
81
+ work_df = df.dropna(subset=feature_cols + [target]).reset_index(drop=True)
82
  X = work_df[feature_cols]
83
+ y = work_df[target]
 
 
 
84
 
85
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
 
 
86
 
87
  model = RandomForestRegressor(**MODEL_PARAMS)
88
  model.fit(X_train, y_train)
 
91
  r2 = r2_score(y_test, preds)
92
  mae = mean_absolute_error(y_test, preds)
93
 
94
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
95
 
96
+ # Parity Plot
97
+ ax1.scatter(y_test, preds, alpha=0.5, color='#636EFA')
98
+ lims = [min(y_test.min(), preds.min()), max(y_test.max(), preds.max())]
99
+ ax1.plot(lims, lims, 'r--', alpha=0.75, zorder=3)
100
+ ax1.set_xlabel("Ground Truth")
101
+ ax1.set_ylabel("Predictions")
102
+ ax1.set_title(f"Prediction Accuracy\nRΒ² = {r2:.4f}")
103
 
104
+ # Feature Importance
105
  importances = model.feature_importances_
106
+ indices = np.argsort(importances)[-10:]
107
+ ax2.barh(range(len(indices)), importances[indices], color='#EF553B')
108
+ ax2.set_yticks(range(len(indices)))
109
+ ax2.set_yticklabels([feature_cols[i] for i in indices])
110
+ ax2.set_title("Top 10 Structural Features")
111
+
 
112
  plt.tight_layout()
113
+
114
+ summary = f"""
115
+ ### Model Performance: {dataset_name}
116
+ - **RΒ² Score:** {r2:.4f}
117
+ - **Mean Absolute Error (MAE):** {mae:.4f}
118
+
119
+ *This baseline demonstrates that structural circuit metrics (entropy, gate counts, etc.) hold predictive power for quantum expectation values.*
120
+ """
121
+ return fig, summary
122
+
123
+ def load_benchmark():
124
+ path = Path(LOCAL_BENCHMARK_CSV)
125
+ if not path.exists():
126
+ return pd.DataFrame([{"info": "Benchmark file not found"}]), None, None
127
+
128
+ df = pd.read_csv(path)
129
+
130
+ # R2 Plot
131
+ fig_r2, ax = plt.subplots(figsize=(8, 4))
132
+ ax.bar(df["dataset"], df["r2"], color='skyblue')
133
+ ax.set_title("Cross-Dataset Robustness (RΒ² Score)")
134
+ ax.set_ylabel("RΒ²")
135
+ plt.xticks(rotation=15)
 
 
 
 
 
 
 
136
  plt.tight_layout()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
+ # MAE Plot
139
+ fig_mae, ax = plt.subplots(figsize=(8, 4))
140
+ ax.bar(df["dataset"], df["mae"], color='salmon')
141
+ ax.set_title("Cross-Dataset Error (MAE)")
142
+ ax.set_ylabel("MAE")
143
+ plt.xticks(rotation=15)
144
+ plt.tight_layout()
145
 
146
+ return df, fig_r2, fig_mae
147
 
148
  # =========================================================
149
+ # INTERFACE
150
  # =========================================================
151
+ with gr.Blocks(title="QSBench Unified Explorer", theme=gr.themes.Soft()) as demo:
152
+ gr.Markdown(
153
+ """
154
+ # 🌌 QSBench: Quantum Synthetic Benchmark Explorer
155
+ **Unified interface for Core, Noise-Affected, and Hardware-Transpiled Quantum Datasets.**
156
+
157
+ Browse the demo datasets from the QSBench family, run baseline ML models, and analyze noise robustness across different distributions.
158
+ """
159
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
+ with gr.Tabs():
162
+ # TAB 1: DATA EXPLORER
163
+ with gr.TabItem("πŸ”Ž Dataset Explorer"):
164
+ with gr.Row():
165
+ ds_selector = gr.Dropdown(choices=list(DATASET_MAP.keys()), value="Core (Clean)", label="Select Dataset Pack")
166
+ split_selector = gr.Dropdown(choices=["train", "test", "validation"], value="train", label="Split")
167
+
168
+ data_table = gr.Dataframe(label="Sample Data (First 10 rows)", interactive=False)
169
+
170
+ ds_selector.change(update_explorer, inputs=[ds_selector], outputs=[split_selector, data_table])
171
+ split_selector.change(filter_explorer_by_split, inputs=[ds_selector, split_selector], outputs=[data_table])
172
+
173
+ # TAB 2: ML BASELINE
174
+ with gr.TabItem("πŸ€– ML Baseline Demo"):
175
+ gr.Markdown("Select a dataset and train a Random Forest regressor to predict expectation values from circuit metadata.")
176
+ model_ds_selector = gr.Dropdown(choices=list(DATASET_MAP.keys()), value="Core (Clean)", label="Target Dataset")
177
+ train_btn = gr.Button("Train Baseline Model", variant="primary")
178
+
179
+ with gr.Row():
180
+ plot_output = gr.Plot(label="Model Metrics")
181
+ text_output = gr.Markdown(label="Stats")
182
+
183
+ train_btn.click(run_model_demo, inputs=[model_ds_selector], outputs=[plot_output, text_output])
184
+
185
+ # TAB 3: BENCHMARKING
186
+ with gr.TabItem("πŸ“Š Noise Robustness Benchmark"):
187
+ gr.Markdown("Analysis of model performance degradation under distribution shifts (Clean β†’ Noisy β†’ Hardware).")
188
+ bench_btn = gr.Button("Load Precomputed Benchmark Results")
189
+ bench_table = gr.Dataframe(interactive=False)
190
+ with gr.Row():
191
+ r2_plot = gr.Plot()
192
+ mae_plot = gr.Plot()
193
+
194
+ bench_btn.click(load_benchmark, outputs=[bench_table, r2_plot, mae_plot])
195
+
196
+ gr.Markdown(
197
+ """
198
+ ---
199
+ ### About QSBench
200
+ QSBench is a collection of high-quality synthetic datasets designed for **Quantum Machine Learning** research.
201
+ It provides paired ideal/noisy data, structural circuit metrics, and transpilation metadata.
202
+
203
+ πŸ”— [Website](https://qsbench.github.io) | πŸ€— [Hugging Face](https://huggingface.co/QSBench) | πŸ› οΈ [GitHub](https://github.com/QSBench)
204
+ """
205
+ )
206
 
207
+ # Initial load
208
+ demo.load(update_explorer, inputs=[ds_selector], outputs=[split_selector, data_table])
209
 
210
  if __name__ == "__main__":
211
+ demo.launch()