QSBench commited on
Commit
5569d74
Β·
verified Β·
1 Parent(s): 60fb523

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -108
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  import matplotlib.pyplot as plt
3
  import numpy as np
4
  import pandas as pd
 
5
  from datasets import load_dataset
6
  from sklearn.ensemble import RandomForestRegressor
7
  from sklearn.metrics import mean_absolute_error, r2_score
@@ -22,7 +23,7 @@ LOCAL_BENCHMARK_CSV = "noise_benchmark_results.csv"
22
  TARGET_COL = "ideal_expval_Z_global"
23
 
24
  EXCLUDE_COLS = {
25
- "sample_id", "sample_seed", "split",
26
  "ideal_expval_Z_global", "ideal_expval_X_global", "ideal_expval_Y_global",
27
  "noisy_expval_Z_global", "noisy_expval_X_global", "noisy_expval_Y_global",
28
  "error_Z_global", "error_X_global", "error_Y_global",
@@ -31,15 +32,6 @@ EXCLUDE_COLS = {
31
  "sign_ideal_Y_global", "sign_noisy_Y_global",
32
  }
33
 
34
- MODEL_PARAMS = dict(
35
- n_estimators=80,
36
- max_depth=10,
37
- min_samples_leaf=2,
38
- random_state=42,
39
- n_jobs=-1,
40
- )
41
-
42
- # Global cache to avoid redundant downloads
43
  dataset_cache = {}
44
 
45
  # =========================================================
@@ -48,7 +40,6 @@ dataset_cache = {}
48
  def get_df(dataset_key):
49
  if dataset_key not in dataset_cache:
50
  repo_id = DATASET_MAP[dataset_key]
51
- print(f"Downloading {repo_id}...")
52
  ds = load_dataset(repo_id)
53
  dataset_cache[dataset_key] = pd.DataFrame(ds["train"])
54
  return dataset_cache[dataset_key]
@@ -60,152 +51,105 @@ def get_numeric_feature_cols(df: pd.DataFrame) -> list[str]:
60
  # =========================================================
61
  # TAB FUNCTIONS
62
  # =========================================================
63
- def update_explorer(dataset_name):
64
  df = get_df(dataset_name)
65
  splits = df["split"].unique().tolist() if "split" in df.columns else ["all"]
66
- return gr.update(choices=splits, value=splits[0]), df.head(10)
67
-
68
- def filter_explorer_by_split(dataset_name, split_name):
69
- df = get_df(dataset_name)
70
- if "split" in df.columns:
71
- return df[df["split"] == split_name].head(10)
72
- return df.head(10)
73
 
74
  def run_model_demo(dataset_name):
75
  df = get_df(dataset_name)
76
  feature_cols = get_numeric_feature_cols(df)
77
-
78
- # Ensure target exists, fallback to noisy if clean is missing (though unlikely in your schema)
79
  target = TARGET_COL if TARGET_COL in df.columns else df.filter(like="expval").columns[0]
80
 
81
  work_df = df.dropna(subset=feature_cols + [target]).reset_index(drop=True)
82
- X = work_df[feature_cols]
83
- y = work_df[target]
84
-
85
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
86
 
87
- model = RandomForestRegressor(**MODEL_PARAMS)
88
  model.fit(X_train, y_train)
89
  preds = model.predict(X_test)
90
 
91
- r2 = r2_score(y_test, preds)
92
- mae = mean_absolute_error(y_test, preds)
93
-
94
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
95
 
96
- # Parity Plot
97
- ax1.scatter(y_test, preds, alpha=0.5, color='#636EFA')
98
- lims = [min(y_test.min(), preds.min()), max(y_test.max(), preds.max())]
99
- ax1.plot(lims, lims, 'r--', alpha=0.75, zorder=3)
100
- ax1.set_xlabel("Ground Truth")
101
- ax1.set_ylabel("Predictions")
102
- ax1.set_title(f"Prediction Accuracy\nRΒ² = {r2:.4f}")
103
 
104
- # Feature Importance
105
  importances = model.feature_importances_
106
  indices = np.argsort(importances)[-10:]
107
- ax2.barh(range(len(indices)), importances[indices], color='#EF553B')
108
- ax2.set_yticks(range(len(indices)))
109
  ax2.set_yticklabels([feature_cols[i] for i in indices])
110
- ax2.set_title("Top 10 Structural Features")
 
 
 
 
111
 
112
  plt.tight_layout()
113
-
114
- summary = f"""
115
- ### Model Performance: {dataset_name}
116
- - **RΒ² Score:** {r2:.4f}
117
- - **Mean Absolute Error (MAE):** {mae:.4f}
118
-
119
- *This baseline demonstrates that structural circuit metrics (entropy, gate counts, etc.) hold predictive power for quantum expectation values.*
120
- """
121
- return fig, summary
122
 
123
  def load_benchmark():
124
  path = Path(LOCAL_BENCHMARK_CSV)
125
- if not path.exists():
126
- return pd.DataFrame([{"info": "Benchmark file not found"}]), None, None
127
-
128
  df = pd.read_csv(path)
129
-
130
- # R2 Plot
131
  fig_r2, ax = plt.subplots(figsize=(8, 4))
132
- ax.bar(df["dataset"], df["r2"], color='skyblue')
133
- ax.set_title("Cross-Dataset Robustness (RΒ² Score)")
134
- ax.set_ylabel("RΒ²")
135
- plt.xticks(rotation=15)
136
- plt.tight_layout()
137
-
138
- # MAE Plot
139
- fig_mae, ax = plt.subplots(figsize=(8, 4))
140
- ax.bar(df["dataset"], df["mae"], color='salmon')
141
- ax.set_title("Cross-Dataset Error (MAE)")
142
- ax.set_ylabel("MAE")
143
  plt.xticks(rotation=15)
144
  plt.tight_layout()
145
-
146
- return df, fig_r2, fig_mae
147
 
148
  # =========================================================
149
  # INTERFACE
150
  # =========================================================
151
  with gr.Blocks(title="QSBench Unified Explorer", theme=gr.themes.Soft()) as demo:
152
- gr.Markdown(
153
- """
154
- # 🌌 QSBench: Quantum Synthetic Benchmark Explorer
155
- **Unified interface for Core, Noise-Affected, and Hardware-Transpiled Quantum Datasets.**
156
-
157
- Browse the demo datasets from the QSBench family, run baseline ML models, and analyze noise robustness across different distributions.
158
- """
159
- )
160
 
161
  with gr.Tabs():
162
- # TAB 1: DATA EXPLORER
163
  with gr.TabItem("πŸ”Ž Dataset Explorer"):
164
  with gr.Row():
165
- ds_selector = gr.Dropdown(choices=list(DATASET_MAP.keys()), value="Core (Clean)", label="Select Dataset Pack")
166
  split_selector = gr.Dropdown(choices=["train", "test", "validation"], value="train", label="Split")
167
 
168
- data_table = gr.Dataframe(label="Sample Data (First 10 rows)", interactive=False)
 
169
 
170
- ds_selector.change(update_explorer, inputs=[ds_selector], outputs=[split_selector, data_table])
171
- split_selector.change(filter_explorer_by_split, inputs=[ds_selector, split_selector], outputs=[data_table])
172
 
173
- # TAB 2: ML BASELINE
174
  with gr.TabItem("πŸ€– ML Baseline Demo"):
175
- gr.Markdown("Select a dataset and train a Random Forest regressor to predict expectation values from circuit metadata.")
176
- model_ds_selector = gr.Dropdown(choices=list(DATASET_MAP.keys()), value="Core (Clean)", label="Target Dataset")
177
  train_btn = gr.Button("Train Baseline Model", variant="primary")
178
-
179
- with gr.Row():
180
- plot_output = gr.Plot(label="Model Metrics")
181
- text_output = gr.Markdown(label="Stats")
182
-
183
- train_btn.click(run_model_demo, inputs=[model_ds_selector], outputs=[plot_output, text_output])
184
-
185
- # TAB 3: BENCHMARKING
186
- with gr.TabItem("πŸ“Š Noise Robustness Benchmark"):
187
- gr.Markdown("Analysis of model performance degradation under distribution shifts (Clean β†’ Noisy β†’ Hardware).")
188
- bench_btn = gr.Button("Load Precomputed Benchmark Results")
189
- bench_table = gr.Dataframe(interactive=False)
190
- with gr.Row():
191
- r2_plot = gr.Plot()
192
- mae_plot = gr.Plot()
193
-
194
- bench_btn.click(load_benchmark, outputs=[bench_table, r2_plot, mae_plot])
195
-
196
- gr.Markdown(
197
- """
198
- ---
199
- ### About QSBench
200
- QSBench is a collection of high-quality synthetic datasets designed for **Quantum Machine Learning** research.
201
- It provides paired ideal/noisy data, structural circuit metrics, and transpilation metadata.
202
 
203
  πŸ”— [Website](https://qsbench.github.io) | πŸ€— [Hugging Face](https://huggingface.co/QSBench) | πŸ› οΈ [GitHub](https://github.com/QSBench/QSBench-Demo)
204
  """
205
  )
206
 
207
  # Initial load
208
- demo.load(update_explorer, inputs=[ds_selector], outputs=[split_selector, data_table])
209
 
210
  if __name__ == "__main__":
211
  demo.launch()
 
2
  import matplotlib.pyplot as plt
3
  import numpy as np
4
  import pandas as pd
5
+ import seaborn as sns
6
  from datasets import load_dataset
7
  from sklearn.ensemble import RandomForestRegressor
8
  from sklearn.metrics import mean_absolute_error, r2_score
 
23
  TARGET_COL = "ideal_expval_Z_global"
24
 
25
  EXCLUDE_COLS = {
26
+ "sample_id", "sample_seed", "split", "circuit_qasm", "circuit_qasm_transpiled",
27
  "ideal_expval_Z_global", "ideal_expval_X_global", "ideal_expval_Y_global",
28
  "noisy_expval_Z_global", "noisy_expval_X_global", "noisy_expval_Y_global",
29
  "error_Z_global", "error_X_global", "error_Y_global",
 
32
  "sign_ideal_Y_global", "sign_noisy_Y_global",
33
  }
34
 
 
 
 
 
 
 
 
 
 
35
  dataset_cache = {}
36
 
37
  # =========================================================
 
40
  def get_df(dataset_key):
41
  if dataset_key not in dataset_cache:
42
  repo_id = DATASET_MAP[dataset_key]
 
43
  ds = load_dataset(repo_id)
44
  dataset_cache[dataset_key] = pd.DataFrame(ds["train"])
45
  return dataset_cache[dataset_key]
 
51
  # =========================================================
52
  # TAB FUNCTIONS
53
  # =========================================================
54
+ def update_explorer(dataset_name, split_name):
55
  df = get_df(dataset_name)
56
  splits = df["split"].unique().tolist() if "split" in df.columns else ["all"]
57
+
58
+ filtered = df[df["split"] == split_name].head(10) if "split" in df.columns else df.head(10)
59
+ qasm_sample = filtered["circuit_qasm"].iloc[0] if "circuit_qasm" in filtered.columns else "// QASM not found"
60
+
61
+ return gr.update(choices=splits), filtered, qasm_sample
 
 
62
 
63
  def run_model_demo(dataset_name):
64
  df = get_df(dataset_name)
65
  feature_cols = get_numeric_feature_cols(df)
 
 
66
  target = TARGET_COL if TARGET_COL in df.columns else df.filter(like="expval").columns[0]
67
 
68
  work_df = df.dropna(subset=feature_cols + [target]).reset_index(drop=True)
69
+ X, y = work_df[feature_cols], work_df[target]
 
 
70
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
71
 
72
+ model = RandomForestRegressor(n_estimators=50, max_depth=10, n_jobs=-1, random_state=42)
73
  model.fit(X_train, y_train)
74
  preds = model.predict(X_test)
75
 
76
+ fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5))
 
 
 
77
 
78
+ # 1. Parity Plot
79
+ ax1.scatter(y_test, preds, alpha=0.4, color='#636EFA')
80
+ ax1.plot([y.min(), y.max()], [y.min(), y.max()], 'r--', lw=2)
81
+ ax1.set_title(f"Parity Plot (RΒ²={r2_score(y_test, preds):.3f})")
 
 
 
82
 
83
+ # 2. Feature Importance
84
  importances = model.feature_importances_
85
  indices = np.argsort(importances)[-10:]
86
+ ax2.barh(range(10), importances[indices], color='#EF553B')
 
87
  ax2.set_yticklabels([feature_cols[i] for i in indices])
88
+ ax2.set_title("Top 10 Features")
89
+
90
+ # 3. Residuals (Error Distribution)
91
+ sns.histplot(y_test - preds, kde=True, ax=ax3, color='#00CC96')
92
+ ax3.set_title("Residuals Distribution")
93
 
94
  plt.tight_layout()
95
+ return fig, f"### Baseline Analysis for {dataset_name}\nMAE: {mean_absolute_error(y_test, preds):.4f}"
 
 
 
 
 
 
 
 
96
 
97
  def load_benchmark():
98
  path = Path(LOCAL_BENCHMARK_CSV)
99
+ if not path.exists(): return None, None, None
 
 
100
  df = pd.read_csv(path)
 
 
101
  fig_r2, ax = plt.subplots(figsize=(8, 4))
102
+ ax.bar(df["dataset"], df["r2"], color=['#636EFA', '#EF553B', '#00CC96', '#AB63FA'])
 
 
 
 
 
 
 
 
 
 
103
  plt.xticks(rotation=15)
104
  plt.tight_layout()
105
+ return df, fig_r2, "Benchmark comparison completed."
 
106
 
107
  # =========================================================
108
  # INTERFACE
109
  # =========================================================
110
  with gr.Blocks(title="QSBench Unified Explorer", theme=gr.themes.Soft()) as demo:
111
+ gr.Markdown("# 🌌 QSBench: Quantum Synthetic Benchmark Explorer\n**Professional-grade datasets for Noise-Aware QML and Hardware Optimization.**")
 
 
 
 
 
 
 
112
 
113
  with gr.Tabs():
 
114
  with gr.TabItem("πŸ”Ž Dataset Explorer"):
115
  with gr.Row():
116
+ ds_selector = gr.Dropdown(choices=list(DATASET_MAP.keys()), value="Core (Clean)", label="Dataset Pack")
117
  split_selector = gr.Dropdown(choices=["train", "test", "validation"], value="train", label="Split")
118
 
119
+ data_table = gr.Dataframe(interactive=False)
120
+ qasm_view = gr.Code(label="Circuit QASM Preview (First row of selection)", language="wasm")
121
 
122
+ ds_selector.change(update_explorer, [ds_selector, split_selector], [split_selector, data_table, qasm_view])
123
+ split_selector.change(update_explorer, [ds_selector, split_selector], [split_selector, data_table, qasm_view])
124
 
 
125
  with gr.TabItem("πŸ€– ML Baseline Demo"):
126
+ model_ds_selector = gr.Dropdown(choices=list(DATASET_MAP.keys()), value="Core (Clean)", label="Select Target Pack")
 
127
  train_btn = gr.Button("Train Baseline Model", variant="primary")
128
+ plot_output = gr.Plot()
129
+ text_output = gr.Markdown()
130
+ train_btn.click(run_model_demo, [model_ds_selector], [plot_output, text_output])
131
+
132
+ with gr.TabItem("πŸ“Š Cross-Dataset Benchmark"):
133
+ bench_btn = gr.Button("Analyze Robustness Across All Packs")
134
+ bench_table = gr.Dataframe()
135
+ bench_plot = gr.Plot()
136
+ bench_btn.click(load_benchmark, outputs=[bench_table, bench_plot, text_output])
137
+
138
+ gr.Markdown("""
139
+ ---
140
+ ### πŸ”¬ Data Integrity & Research Value
141
+ The demo files provided here serve as a **structural validation** for researchers.
142
+ - **Demographic**: 8-10 Qubits, Depth 6-8.
143
+ - **Features**: Includes gate entropy, Meyer-Wallach entanglement, and transpilation metrics.
144
+
145
+ To achieve state-of-the-art results in error mitigation or noise modeling, access to the full dataset family (up to 200,000 samples) is recommended to ensure statistical significance and model generalization.
 
 
 
 
 
 
146
 
147
  πŸ”— [Website](https://qsbench.github.io) | πŸ€— [Hugging Face](https://huggingface.co/QSBench) | πŸ› οΈ [GitHub](https://github.com/QSBench/QSBench-Demo)
148
  """
149
  )
150
 
151
  # Initial load
152
+ demo.load(update_explorer, [ds_selector, split_selector], [split_selector, data_table, qasm_view])
153
 
154
  if __name__ == "__main__":
155
  demo.launch()