QSBench commited on
Commit
3ba9c8c
·
verified ·
1 Parent(s): 2cc591e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -59
app.py CHANGED
@@ -22,14 +22,12 @@ DATASET_MAP = {
22
  LOCAL_BENCHMARK_CSV = "noise_benchmark_results.csv"
23
  TARGET_COL = "ideal_expval_Z_global"
24
 
 
25
  EXCLUDE_COLS = {
26
- "sample_id", "sample_seed", "split", "circuit_qasm", "circuit_qasm_transpiled",
27
- "ideal_expval_Z_global", "ideal_expval_X_global", "ideal_expval_Y_global",
28
- "noisy_expval_Z_global", "noisy_expval_X_global", "noisy_expval_Y_global",
29
- "error_Z_global", "error_X_global", "error_Y_global",
30
- "sign_ideal_Z_global", "sign_noisy_Z_global",
31
- "sign_ideal_X_global", "sign_noisy_X_global",
32
- "sign_ideal_Y_global", "sign_noisy_Y_global",
33
  }
34
 
35
  dataset_cache = {}
@@ -46,7 +44,7 @@ def get_df(dataset_key):
46
 
47
  def get_numeric_feature_cols(df: pd.DataFrame) -> list[str]:
48
  numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
49
- return [c for c in numeric_cols if c not in EXCLUDE_COLS and not c.startswith("error_")]
50
 
51
  # =========================================================
52
  # TAB FUNCTIONS
@@ -54,103 +52,100 @@ def get_numeric_feature_cols(df: pd.DataFrame) -> list[str]:
54
  def update_explorer(dataset_name, split_name):
55
  df = get_df(dataset_name)
56
  splits = df["split"].unique().tolist() if "split" in df.columns else ["all"]
57
-
58
  filtered = df[df["split"] == split_name].head(10) if "split" in df.columns else df.head(10)
59
 
60
- qasm_col = "circuit_qasm" if "circuit_qasm" in df.columns else None
61
- qasm_sample = filtered[qasm_col].iloc[0] if qasm_col and not filtered.empty else "// QASM not available in this pack"
 
 
 
 
62
 
63
- return gr.update(choices=splits), filtered, qasm_sample
64
 
65
- def run_model_demo(dataset_name):
 
 
 
66
  df = get_df(dataset_name)
67
- feature_cols = get_numeric_feature_cols(df)
68
  target = TARGET_COL if TARGET_COL in df.columns else df.filter(like="expval").columns[0]
69
 
70
- work_df = df.dropna(subset=feature_cols + [target]).reset_index(drop=True)
71
- X, y = work_df[feature_cols], work_df[target]
72
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
73
 
74
  model = RandomForestRegressor(n_estimators=50, max_depth=10, n_jobs=-1, random_state=42)
75
  model.fit(X_train, y_train)
76
  preds = model.predict(X_test)
77
 
 
78
  fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5))
79
 
80
  # 1. Parity Plot
81
  ax1.scatter(y_test, preds, alpha=0.4, color='#636EFA')
82
  ax1.plot([y.min(), y.max()], [y.min(), y.max()], 'r--', lw=2)
 
 
83
  ax1.set_title(f"Parity Plot (R²={r2_score(y_test, preds):.3f})")
84
 
85
  # 2. Feature Importance
86
  importances = model.feature_importances_
87
- indices = np.argsort(importances)[-10:]
88
- ax2.barh(range(10), importances[indices], color='#EF553B')
89
- ax2.set_yticks(range(10))
90
- ax2.set_yticklabels([feature_cols[i] for i in indices])
91
- ax2.set_title("Top 10 Features")
92
 
93
- # 3. Residuals (Error Distribution)
94
  sns.histplot(y_test - preds, kde=True, ax=ax3, color='#00CC96')
95
- ax3.set_title("Residuals Distribution")
96
 
97
  plt.tight_layout()
98
- return fig, f"### Baseline Analysis for {dataset_name}\nMAE: {mean_absolute_error(y_test, preds):.4f}"
99
-
100
- def load_benchmark():
101
- path = Path(LOCAL_BENCHMARK_CSV)
102
- if not path.exists(): return None, None, "File noise_benchmark_results.csv not found."
103
- df = pd.read_csv(path)
104
- fig_r2, ax = plt.subplots(figsize=(8, 4))
105
- ax.bar(df["dataset"], df["r2"], color=['#636EFA', '#EF553B', '#00CC96', '#AB63FA'])
106
- plt.xticks(rotation=15)
107
- plt.tight_layout()
108
- return df, fig_r2, "Benchmark comparison completed."
109
 
110
  # =========================================================
111
  # INTERFACE
112
  # =========================================================
113
  with gr.Blocks(title="QSBench Unified Explorer") as demo:
114
- gr.Markdown("# 🌌 QSBench: Quantum Synthetic Benchmark Explorer\n**Professional-grade datasets for Noise-Aware QML and Hardware Optimization.**")
115
-
116
  with gr.Tabs():
117
  with gr.TabItem("🔎 Dataset Explorer"):
118
  with gr.Row():
119
  ds_selector = gr.Dropdown(choices=list(DATASET_MAP.keys()), value="Core (Clean)", label="Dataset Pack")
120
- split_selector = gr.Dropdown(choices=["train", "test", "validation"], value="train", label="Split")
121
 
122
- data_table = gr.Dataframe(interactive=False)
123
- qasm_view = gr.Code(label="Circuit QASM Preview (First row of selection)", language="python")
124
 
125
- ds_selector.change(update_explorer, [ds_selector, split_selector], [split_selector, data_table, qasm_view])
126
- split_selector.change(update_explorer, [ds_selector, split_selector], [split_selector, data_table, qasm_view])
 
127
 
128
  with gr.TabItem("🤖 ML Baseline Demo"):
129
- gr.Markdown("Train a Random Forest regressor to evaluate how well structural circuit features predict expectation values.")
130
- model_ds_selector = gr.Dropdown(choices=list(DATASET_MAP.keys()), value="Core (Clean)", label="Select Target Pack")
131
- train_btn = gr.Button("Train Baseline Model", variant="primary")
132
- plot_output = gr.Plot()
133
- text_output = gr.Markdown()
134
- train_btn.click(run_model_demo, [model_ds_selector], [plot_output, text_output])
135
-
136
- with gr.TabItem("📊 Cross-Dataset Benchmark"):
137
- gr.Markdown("Comparison of model performance across different noise environments and hardware transpilation stages.")
138
- bench_btn = gr.Button("Analyze Robustness Across All Packs")
139
- bench_table = gr.Dataframe()
140
- bench_plot = gr.Plot()
141
- bench_btn.click(load_benchmark, outputs=[bench_table, bench_plot, text_output])
142
 
143
  gr.Markdown("""
144
  ---
145
- ### 🔬 Research Resources
146
- This interface provides a structural overview of the QSBench dataset family. These datasets are designed to support reproducible research in quantum error mitigation and machine learning.
147
-
148
  - **GitHub**: [QSBench/QSBench-Demo](https://github.com/QSBench/QSBench-Demo)
149
  - **Website**: [qsbench.github.io](https://qsbench.github.io)
150
- - **Hugging Face**: [Explore all datasets](https://huggingface.co/QSBench)
151
  """)
152
 
153
- demo.load(update_explorer, [ds_selector, split_selector], [split_selector, data_table, qasm_view])
 
 
 
 
 
 
154
 
155
  if __name__ == "__main__":
156
  demo.launch(theme=gr.themes.Soft())
 
22
  LOCAL_BENCHMARK_CSV = "noise_benchmark_results.csv"
23
  TARGET_COL = "ideal_expval_Z_global"
24
 
25
+ # Исключаем нечисловые данные и целевые переменные из признаков
26
  EXCLUDE_COLS = {
27
+ "sample_id", "sample_seed", "circuit_hash", "split", "circuit_qasm",
28
+ "qasm_raw", "qasm_transpiled", "circuit_type_resolved", "circuit_type_requested",
29
+ "noise_type", "observable_bases", "observable_mode", "backend_device",
30
+ "precision_mode", "circuit_signature", "ideal_expval_Z_global", "noisy_expval_Z_global"
 
 
 
31
  }
32
 
33
  dataset_cache = {}
 
44
 
45
  def get_numeric_feature_cols(df: pd.DataFrame) -> list[str]:
46
  numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
47
+ return [c for c in numeric_cols if c not in EXCLUDE_COLS and not c.startswith("error_") and not c.startswith("sign_")]
48
 
49
  # =========================================================
50
  # TAB FUNCTIONS
 
52
  def update_explorer(dataset_name, split_name):
53
  df = get_df(dataset_name)
54
  splits = df["split"].unique().tolist() if "split" in df.columns else ["all"]
 
55
  filtered = df[df["split"] == split_name].head(10) if "split" in df.columns else df.head(10)
56
 
57
+ # Берем QASM из ваших реальных колонок
58
+ qasm_raw = filtered["qasm_raw"].iloc[0] if "qasm_raw" in filtered.columns else "// No raw QASM"
59
+ qasm_tr = filtered["qasm_transpiled"].iloc[0] if "qasm_transpiled" in filtered.columns else "// No transpiled QASM"
60
+
61
+ # Обновляем список фичей для вкладки ML
62
+ features = get_numeric_feature_cols(df)
63
 
64
+ return gr.update(choices=splits), filtered, qasm_raw, qasm_tr, gr.update(choices=features, value=features[:5])
65
 
66
+ def run_model_demo(dataset_name, selected_features):
67
+ if not selected_features:
68
+ return None, "### ⚠️ Please select at least one feature."
69
+
70
  df = get_df(dataset_name)
 
71
  target = TARGET_COL if TARGET_COL in df.columns else df.filter(like="expval").columns[0]
72
 
73
+ work_df = df.dropna(subset=selected_features + [target]).reset_index(drop=True)
74
+ X, y = work_df[selected_features], work_df[target]
75
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
76
 
77
  model = RandomForestRegressor(n_estimators=50, max_depth=10, n_jobs=-1, random_state=42)
78
  model.fit(X_train, y_train)
79
  preds = model.predict(X_test)
80
 
81
+ sns.set_theme(style="whitegrid")
82
  fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5))
83
 
84
  # 1. Parity Plot
85
  ax1.scatter(y_test, preds, alpha=0.4, color='#636EFA')
86
  ax1.plot([y.min(), y.max()], [y.min(), y.max()], 'r--', lw=2)
87
+ ax1.set_xlabel("Actual")
88
+ ax1.set_ylabel("Predicted")
89
  ax1.set_title(f"Parity Plot (R²={r2_score(y_test, preds):.3f})")
90
 
91
  # 2. Feature Importance
92
  importances = model.feature_importances_
93
+ indices = np.argsort(importances)
94
+ ax2.barh(range(len(indices)), importances[indices], color='#EF553B')
95
+ ax2.set_yticks(range(len(indices)))
96
+ ax2.set_yticklabels([selected_features[i] for i in indices])
97
+ ax2.set_title("Feature Importance")
98
 
99
+ # 3. Residuals
100
  sns.histplot(y_test - preds, kde=True, ax=ax3, color='#00CC96')
101
+ ax3.set_title("Residuals (Error Distribution)")
102
 
103
  plt.tight_layout()
104
+ return fig, f"### Results for {dataset_name}\n**MAE:** {mean_absolute_error(y_test, preds):.4f} | **Features used:** {len(selected_features)}"
 
 
 
 
 
 
 
 
 
 
105
 
106
  # =========================================================
107
  # INTERFACE
108
  # =========================================================
109
  with gr.Blocks(title="QSBench Unified Explorer") as demo:
110
+ gr.Markdown("# 🌌 QSBench: Quantum Synthetic Benchmark Explorer")
111
+
112
  with gr.Tabs():
113
  with gr.TabItem("🔎 Dataset Explorer"):
114
  with gr.Row():
115
  ds_selector = gr.Dropdown(choices=list(DATASET_MAP.keys()), value="Core (Clean)", label="Dataset Pack")
116
+ split_selector = gr.Dropdown(choices=["train"], value="train", label="Split")
117
 
118
+ data_table = gr.Dataframe(interactive=False, overflow_row_behaviour="paginate")
 
119
 
120
+ with gr.Row():
121
+ qasm_raw_view = gr.Code(label="Raw QASM (Source)", language="python", lines=10)
122
+ qasm_tr_view = gr.Code(label="Transpiled QASM (Hardware-ready)", language="python", lines=10)
123
 
124
  with gr.TabItem("🤖 ML Baseline Demo"):
125
+ with gr.Row():
126
+ with gr.Column(scale=1):
127
+ model_ds_selector = gr.Dropdown(choices=list(DATASET_MAP.keys()), value="Core (Clean)", label="Dataset")
128
+ feature_selector = gr.Checkboxgroup(label="Select Features for Training", choices=[])
129
+ train_btn = gr.Button("Train Model", variant="primary")
130
+ with gr.Column(scale=2):
131
+ plot_output = gr.Plot()
132
+ text_output = gr.Markdown()
 
 
 
 
 
133
 
134
  gr.Markdown("""
135
  ---
136
+ ### 🔬 Research & Data
137
+ This Space provides structural validation of the **QSBench** dataset family.
 
138
  - **GitHub**: [QSBench/QSBench-Demo](https://github.com/QSBench/QSBench-Demo)
139
  - **Website**: [qsbench.github.io](https://qsbench.github.io)
 
140
  """)
141
 
142
+ # Event Linking
143
+ ds_selector.change(update_explorer, [ds_selector, split_selector], [split_selector, data_table, qasm_raw_view, qasm_tr_view, feature_selector])
144
+ split_selector.change(update_explorer, [ds_selector, split_selector], [split_selector, data_table, qasm_raw_view, qasm_tr_view, feature_selector])
145
+ train_btn.click(run_model_demo, [model_ds_selector, feature_selector], [plot_output, text_output])
146
+
147
+ # Initial load
148
+ demo.load(update_explorer, [ds_selector, split_selector], [split_selector, data_table, qasm_raw_view, qasm_tr_view, feature_selector])
149
 
150
  if __name__ == "__main__":
151
  demo.launch(theme=gr.themes.Soft())