QSBench commited on
Commit
30d5809
·
verified ·
1 Parent(s): 09506ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +174 -118
app.py CHANGED
@@ -3,166 +3,222 @@ import matplotlib.pyplot as plt
3
  import numpy as np
4
  import pandas as pd
5
  import seaborn as sns
 
 
 
6
  from datasets import load_dataset
7
  from sklearn.ensemble import RandomForestRegressor
8
  from sklearn.metrics import mean_absolute_error, r2_score
9
  from sklearn.model_selection import train_test_split
10
 
11
- # =========================================================
12
- # CONFIG
13
- # =========================================================
14
- DATASET_MAP = {
15
- "Core (Clean)": "QSBench/QSBench-Core-v1.0.0-demo",
16
- "Depolarizing Noise": "QSBench/QSBench-Depolarizing-Demo-v1.0.0",
17
- "Amplitude Damping": "QSBench/QSBench-Amplitude-v1.0.0-demo",
18
- "Transpilation (10q)": "QSBench/QSBench-Transpilation-v1.0.0-demo"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  }
20
 
21
- TARGET_COL = "ideal_expval_Z_global"
22
-
23
- # Колонки, которые никогда не должны быть признаками (фичами)
24
- EXCLUDE_COLS = {
25
  "sample_id", "sample_seed", "circuit_hash", "split", "circuit_qasm",
26
  "qasm_raw", "qasm_transpiled", "circuit_type_resolved", "circuit_type_requested",
27
- "noise_type", "observable_bases", "observable_mode", "backend_device",
28
  "precision_mode", "circuit_signature", "ideal_expval_Z_global", "noisy_expval_Z_global"
29
  }
30
 
31
- dataset_cache = {}
32
-
33
- # =========================================================
34
- # UTILS
35
- # =========================================================
36
- def get_df(dataset_key):
37
- if dataset_key not in dataset_cache:
38
- repo_id = DATASET_MAP[dataset_key]
39
- ds = load_dataset(repo_id)
40
- dataset_cache[dataset_key] = pd.DataFrame(ds["train"])
41
- return dataset_cache[dataset_key]
42
-
43
- def get_numeric_feature_cols(df: pd.DataFrame) -> list[str]:
44
- numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
45
- # Оставляем только структурные метрики, убираем таргеты и ошибки
46
- return [c for c in numeric_cols if c not in EXCLUDE_COLS and not c.startswith("error_") and "expval" not in c]
47
-
48
- # =========================================================
49
- # LOGIC
50
- # =========================================================
51
-
52
- # Функция для обновления первой вкладки (Explorer)
53
- def update_explorer_tab(dataset_name, split_name):
54
- df = get_df(dataset_name)
55
- splits = df["split"].unique().tolist() if "split" in df.columns else ["train"]
56
- filtered = df[df["split"] == split_name].head(10) if "split" in df.columns else df.head(10)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- qasm_raw = filtered["qasm_raw"].iloc[0] if "qasm_raw" in filtered.columns else "// N/A"
59
- qasm_tr = filtered["qasm_transpiled"].iloc[0] if "qasm_transpiled" in filtered.columns else "// N/A"
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- return gr.update(choices=splits), filtered, qasm_raw, qasm_tr
62
-
63
- # Функция для обновления списка фичей во второй вкладке (ML Demo)
64
- def update_ml_features(dataset_name):
65
- df = get_df(dataset_name)
66
- features = get_numeric_feature_cols(df)
67
- # По умолчанию выбираем первые несколько важных метрик
68
- default_selection = [f for f in ["n_qubits", "depth", "total_gates", "gate_entropy", "meyer_wallach"] if f in features]
69
- if not default_selection: default_selection = features[:5]
70
 
71
- return gr.update(choices=features, value=default_selection)
72
 
73
- def run_model_demo(dataset_name, selected_features):
74
- df = get_df(dataset_name)
 
 
75
 
76
- # Защита от несуществующих колонок (KeyError)
77
- valid_features = [f for f in selected_features if f in df.columns]
78
 
79
- if not valid_features:
80
- return None, "### ⚠️ Ошибка: Выбранные признаки не найдены в этом датасете."
81
 
82
- target = TARGET_COL if TARGET_COL in df.columns else df.filter(like="expval").columns[0]
 
 
 
 
 
83
 
84
- work_df = df.dropna(subset=valid_features + [target]).reset_index(drop=True)
85
- X, y = work_df[valid_features], work_df[target]
 
86
 
87
- if len(work_df) < 20:
88
- return None, "### ⚠️ Недостаточно данных для обучения."
 
 
89
 
 
90
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
91
 
92
- model = RandomForestRegressor(n_estimators=50, max_depth=10, n_jobs=-1, random_state=42)
 
93
  model.fit(X_train, y_train)
94
  preds = model.predict(X_test)
95
 
96
- sns.set_theme(style="whitegrid")
97
- fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
- # График предсказаний
100
- ax1.scatter(y_test, preds, alpha=0.4, color='#636EFA')
101
- ax1.plot([y.min(), y.max()], [y.min(), y.max()], 'r--', lw=2)
102
- ax1.set_title(f"R² Score: {r2_score(y_test, preds):.3f}")
103
- ax1.set_xlabel("Actual")
104
- ax1.set_ylabel("Predicted")
105
 
106
- # Важность признаков (топ-10)
107
- importances = model.feature_importances_
108
- indices = np.argsort(importances)[-10:]
109
- ax2.barh(range(len(indices)), importances[indices], color='#EF553B')
110
- ax2.set_yticks(range(len(indices)))
111
- ax2.set_yticklabels([valid_features[i] for i in indices])
112
- ax2.set_title("Feature Importance")
113
-
114
- # Распределение ошибок
115
- sns.histplot(y_test - preds, kde=True, ax=ax3, color='#00CC96')
116
- ax3.set_title("Residuals")
117
-
118
  plt.tight_layout()
119
- return fig, f"### Отчет по датасету: {dataset_name}\n**MAE:** {mean_absolute_error(y_test, preds):.4f}"
 
 
120
 
121
- # =========================================================
122
- # UI
123
- # =========================================================
124
- with gr.Blocks(title="QSBench Explorer") as demo:
125
- gr.Markdown("# 🌌 QSBench: Quantum Synthetic Benchmark")
126
 
127
  with gr.Tabs():
128
- # ВКЛАДКА 1: ПРОСМОТР ДАННЫХ
129
- with gr.TabItem("🔎 Explorer"):
 
 
130
  with gr.Row():
131
- ds_selector = gr.Dropdown(choices=list(DATASET_MAP.keys()), value="Core (Clean)", label="Dataset")
132
- split_selector = gr.Dropdown(choices=["train"], value="train", label="Split")
133
 
134
  data_table = gr.Dataframe(interactive=False)
135
 
136
  with gr.Row():
137
- qasm_raw_view = gr.Code(label="Raw QASM", language="python", lines=10)
138
- qasm_tr_view = gr.Code(label="Transpiled QASM", language="python", lines=10)
139
 
140
- # ВКЛАДКА 2: МАШИННОЕ ОБУЧЕНИЕ
141
- with gr.TabItem("🤖 ML Demo"):
142
  with gr.Row():
143
  with gr.Column(scale=1):
144
- gr.Markdown("### Настройка обучения")
145
- m_ds_selector = gr.Dropdown(choices=list(DATASET_MAP.keys()), value="Core (Clean)", label="Dataset for ML")
146
- f_selector = gr.CheckboxGroup(label="Признаки (Features)", choices=[])
147
- train_btn = gr.Button("Запустить обучение", variant="primary")
148
  with gr.Column(scale=2):
149
- plot_out = gr.Plot()
150
- text_out = gr.Markdown()
151
-
152
- # --- ЛОГИКА СОБЫТИЙ ---
153
-
154
- # При измене��ии датасета в Explorer — обновляем таблицу и QASM
155
- ds_selector.change(update_explorer_tab, [ds_selector, split_selector], [split_selector, data_table, qasm_raw_view, qasm_tr_view])
156
-
157
- # ПРИНЦИПИАЛЬНО: При изменении датасета в ML Demo обновляем список чекбоксов
158
- m_ds_selector.change(update_ml_features, inputs=[m_ds_selector], outputs=[f_selector])
159
-
160
- # Кнопка обучения
161
- train_btn.click(run_model_demo, [m_ds_selector, f_selector], [plot_out, text_out])
 
 
162
 
163
- # Инициализация при старте
164
- demo.load(update_explorer_tab, [ds_selector, split_selector], [split_selector, data_table, qasm_raw_view, qasm_tr_view])
165
- demo.load(update_ml_features, [m_ds_selector], [f_selector])
166
 
167
  if __name__ == "__main__":
168
- demo.launch(theme=gr.themes.Soft())
 
3
  import numpy as np
4
  import pandas as pd
5
  import seaborn as sns
6
+ import logging
7
+ import requests
8
+ from typing import List, Tuple, Dict, Optional
9
  from datasets import load_dataset
10
  from sklearn.ensemble import RandomForestRegressor
11
  from sklearn.metrics import mean_absolute_error, r2_score
12
  from sklearn.model_selection import train_test_split
13
 
14
+ # Setup production-style logging
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # Configuration for datasets and their specific metadata branches
19
+ REPO_CONFIG = {
20
+ "Core (Clean)": {
21
+ "repo": "QSBench/QSBench-Core-v1.0.0-demo",
22
+ "meta_url": "https://huggingface.co/datasets/QSBench/QSBench-Core-v1.0.0-demo/raw/metadata/meta/meta.json",
23
+ "report_url": "https://huggingface.co/datasets/QSBench/QSBench-Core-v1.0.0-demo/raw/metadata/meta/report.json"
24
+ },
25
+ "Depolarizing Noise": {
26
+ "repo": "QSBench/QSBench-Depolarizing-Demo-v1.0.0",
27
+ "meta_url": "https://huggingface.co/datasets/QSBench/QSBench-Depolarizing-Demo-v1.0.0/raw/meta/meta/meta.json",
28
+ "report_url": "https://huggingface.co/datasets/QSBench/QSBench-Depolarizing-Demo-v1.0.0/raw/meta/meta/report.json"
29
+ },
30
+ "Amplitude Damping": {
31
+ "repo": "QSBench/QSBench-Amplitude-v1.0.0-demo",
32
+ "meta_url": "https://huggingface.co/datasets/QSBench/QSBench-Amplitude-v1.0.0-demo/raw/meta/meta/meta.json",
33
+ "report_url": "https://huggingface.co/datasets/QSBench/QSBench-Amplitude-v1.0.0-demo/raw/meta/meta/report.json"
34
+ },
35
+ "Transpilation (10q)": {
36
+ "repo": "QSBench/QSBench-Transpilation-v1.0.0-demo",
37
+ "meta_url": "https://huggingface.co/datasets/QSBench/QSBench-Transpilation-v1.0.0-demo/raw/meta/meta/meta.json",
38
+ "report_url": "https://huggingface.co/datasets/QSBench/QSBench-Transpilation-v1.0.0-demo/raw/meta/meta/report.json"
39
+ }
40
  }
41
 
42
+ # Features that should never be used as training inputs
43
+ NON_FEATURE_COLS = {
 
 
44
  "sample_id", "sample_seed", "circuit_hash", "split", "circuit_qasm",
45
  "qasm_raw", "qasm_transpiled", "circuit_type_resolved", "circuit_type_requested",
46
+ "noise_type", "noise_prob", "observable_bases", "observable_mode", "backend_device",
47
  "precision_mode", "circuit_signature", "ideal_expval_Z_global", "noisy_expval_Z_global"
48
  }
49
 
50
+ _ASSET_CACHE = {}
51
+
52
+ def fetch_remote_json(url: str) -> Optional[dict]:
53
+ """Helper to fetch JSON files from raw Hugging Face branches."""
54
+ try:
55
+ response = requests.get(url, timeout=5)
56
+ return response.json() if response.status_code == 200 else None
57
+ except Exception as e:
58
+ logger.error(f"Error fetching metadata from {url}: {e}")
59
+ return None
60
+
61
+ def load_all_assets(key: str) -> Dict:
62
+ """Fetch dataframe and metadata with memory caching."""
63
+ if key not in _ASSET_CACHE:
64
+ logger.info(f"Loading assets for dataset: {key}")
65
+ # Load main parquet/csv data
66
+ ds = load_dataset(REPO_CONFIG[key]["repo"])
67
+ # Fetch metadata from dedicated branches
68
+ meta = fetch_remote_json(REPO_CONFIG[key]["meta_url"])
69
+ report = fetch_remote_json(REPO_CONFIG[key]["report_url"])
70
+
71
+ _ASSET_CACHE[key] = {
72
+ "df": pd.DataFrame(ds["train"]),
73
+ "meta": meta,
74
+ "report": report
75
+ }
76
+ return _ASSET_CACHE[key]
77
+
78
+ def generate_meta_markdown(assets: Dict) -> str:
79
+ """Parse JSON metadata into a human-readable research summary."""
80
+ meta = assets.get("meta", {})
81
+ params = meta.get("parameters", {})
82
+ report = assets.get("report", {})
83
+
84
+ if not meta:
85
+ return "⚠️ *Metadata currently unavailable for this dataset branch.*"
86
+
87
+ # Format family distribution from report.json
88
+ families = report.get("families", {})
89
+ fam_info = ", ".join([f"{k.upper()}: {v}" for k, v in families.items()])
90
 
91
+ md = (
92
+ f"### 📋 Dataset Release: {meta.get('dataset_version', '1.0.0')}\n"
93
+ f"**Hardware Config:** {params.get('n_qubits')} Qubits | Depth: {params.get('depth')} | "
94
+ f"Shots: {params.get('shots')} | Device: {meta.get('backend_device', 'GPU')}\n\n"
95
+ f"**Noise Model:** `{params.get('noise', 'Clean')}` (p={params.get('noise_prob', 0.0)}) | "
96
+ f"**Circuit Coverage:** {fam_info}"
97
+ )
98
+ return md
99
+
100
+ def update_explorer_view(ds_name: str, split_name: str):
101
+ """Main callback for the Explorer tab."""
102
+ assets = load_all_assets(ds_name)
103
+ df = assets["df"]
104
 
105
+ splits = df["split"].unique().tolist() if "split" in df.columns else ["train"]
106
+ display_df = df[df["split"] == split_name].head(10) if "split" in df.columns else df.head(10)
107
+
108
+ # QASM Sample Extraction
109
+ raw_qasm = display_df["qasm_raw"].iloc[0] if "qasm_raw" in display_df.columns else "// No data"
110
+ tr_qasm = display_df["qasm_transpiled"].iloc[0] if "qasm_transpiled" in display_df.columns else "// No data"
 
 
 
111
 
112
+ return gr.update(choices=splits), display_df, raw_qasm, tr_qasm, generate_meta_markdown(assets)
113
 
114
+ def sync_ml_inputs(ds_name: str):
115
+ """Callback to update feature checkboxes when dataset changes."""
116
+ assets = load_all_assets(ds_name)
117
+ df = assets["df"]
118
 
119
+ numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
120
+ valid_features = [c for c in numeric_cols if c not in NON_FEATURE_COLS and not c.startswith(("error_", "sign_", "ideal_", "noisy_"))]
121
 
122
+ # Default selection of core structural metrics
123
+ top_picks = [f for f in ["gate_entropy", "meyer_wallach", "n_qubits", "depth", "total_gates"] if f in valid_features]
124
 
125
+ return gr.update(choices=valid_features, value=top_picks or valid_features[:5])
126
+
127
+ def train_baseline_model(ds_name: str, selected_features: List[str]):
128
+ """Train a Random Forest regressor and generate analytics plots."""
129
+ if not selected_features:
130
+ return None, "### ❌ Error: Please select at least one feature."
131
 
132
+ assets = load_all_assets(ds_name)
133
+ df = assets["df"]
134
+ target = "ideal_expval_Z_global" if "ideal_expval_Z_global" in df.columns else df.filter(like="expval").columns[0]
135
 
136
+ # Data cleaning
137
+ train_df = df.dropna(subset=selected_features + [target])
138
+ if len(train_df) < 50:
139
+ return None, "### ⚠️ Warning: Dataset too small for reliable training."
140
 
141
+ X, y = train_df[selected_features], train_df[target]
142
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
143
 
144
+ # Model Pipeline
145
+ model = RandomForestRegressor(n_estimators=100, max_depth=12, n_jobs=-1, random_state=42)
146
  model.fit(X_train, y_train)
147
  preds = model.predict(X_test)
148
 
149
+ # Plotting
150
+ sns.set_theme(style="whitegrid", context="notebook")
151
+ fig, axes = plt.subplots(1, 3, figsize=(20, 6))
152
+
153
+ # 1. Parity Plot
154
+ axes[0].scatter(y_test, preds, alpha=0.4, color='#34495e')
155
+ axes[0].plot([y.min(), y.max()], [y.min(), y.max()], 'r--', lw=2)
156
+ axes[0].set_title(f"Regression Accuracy (R²: {r2_score(y_test, preds):.3f})")
157
+ axes[0].set_xlabel("Actual")
158
+ axes[0].set_ylabel("Predicted")
159
+
160
+ # 2. Importance
161
+ importances = model.feature_importances_
162
+ indices = np.argsort(importances)[-12:]
163
+ axes[1].barh([selected_features[i] for i in indices], importances[indices], color='#1abc9c')
164
+ axes[1].set_title("Top Structural Predictors")
165
 
166
+ # 3. Error Analysis
167
+ sns.histplot(y_test - preds, kde=True, ax=axes[2], color='#e67e22')
168
+ axes[2].set_title("Residuals Distribution")
 
 
 
169
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  plt.tight_layout()
171
+ result_text = f"**Model Performance on {ds_name}**\n**MAE:** {mean_absolute_error(y_test, preds):.4f}"
172
+
173
+ return fig, result_text
174
 
175
+ # --- GRADIO INTERFACE ---
176
+
177
+ with gr.Blocks(theme=gr.themes.Soft(), title="QSBench Analytics") as demo:
178
+ gr.Markdown("# 🌌 QSBench: Quantum Synthetic Benchmark Suite")
 
179
 
180
  with gr.Tabs():
181
+ with gr.TabItem("🔎 Dataset Explorer"):
182
+ # Header with parsed metadata from JSON
183
+ metadata_box = gr.Markdown("### Synchronizing with Hugging Face...")
184
+
185
  with gr.Row():
186
+ ds_select = gr.Dropdown(choices=list(REPO_CONFIG.keys()), value="Core (Clean)", label="Dataset Pack")
187
+ split_select = gr.Dropdown(choices=["train"], value="train", label="Subset")
188
 
189
  data_table = gr.Dataframe(interactive=False)
190
 
191
  with gr.Row():
192
+ code_raw = gr.Code(label="Source Circuit (QASM)", language="python")
193
+ code_tr = gr.Code(label="Transpiled (Hardware-Ready)", language="python")
194
 
195
+ with gr.TabItem("🤖 ML Training Baseline"):
 
196
  with gr.Row():
197
  with gr.Column(scale=1):
198
+ gr.Markdown("### Training Configuration")
199
+ ml_ds_select = gr.Dropdown(choices=list(REPO_CONFIG.keys()), value="Core (Clean)", label="Source Dataset")
200
+ ml_features = gr.CheckboxGroup(label="Structural Metrics", choices=[])
201
+ btn_train = gr.Button("Execute Baseline Training", variant="primary")
202
  with gr.Column(scale=2):
203
+ plot_output = gr.Plot()
204
+ txt_output = gr.Markdown()
205
+
206
+ # Footer/Resources
207
+ gr.Markdown("""
208
+ ---
209
+ ### 🔬 Research Credits
210
+ **QSBench** is an open-source framework for noise-aware Quantum Machine Learning benchmarking.
211
+ - [GitHub Repository](https://github.com/QSBench/QSBench-Demo) | [Official Website](https://qsbench.github.io)
212
+ """)
213
+
214
+ # Event Handlers
215
+ ds_select.change(update_explorer_view, [ds_select, split_select], [split_select, data_table, code_raw, code_tr, metadata_box])
216
+ ml_ds_select.change(sync_ml_inputs, [ml_ds_select], [ml_features])
217
+ btn_train.click(train_baseline_model, [ml_ds_select, ml_features], [plot_output, txt_output])
218
 
219
+ # Initial Load
220
+ demo.load(update_explorer_view, [ds_select, split_select], [split_select, data_table, code_raw, code_tr, metadata_box])
221
+ demo.load(sync_ml_inputs, [ml_ds_select], [ml_features])
222
 
223
  if __name__ == "__main__":
224
+ demo.launch()