QSBench commited on
Commit
943ad56
Β·
verified Β·
1 Parent(s): 95d311a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +496 -0
app.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import logging
3
+ import re
4
+ from typing import Dict, List, Optional, Tuple
5
+
6
+ import gradio as gr
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import pandas as pd
10
+ from datasets import load_dataset
11
+ from sklearn.ensemble import RandomForestClassifier
12
+ from sklearn.impute import SimpleImputer
13
+ from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, f1_score
14
+ from sklearn.model_selection import train_test_split
15
+ from sklearn.pipeline import Pipeline
16
+ from sklearn.preprocessing import StandardScaler
17
+
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
+
21
+ APP_TITLE = "Noise Detection"
22
+ APP_SUBTITLE = "Classify quantum circuits into clean, depolarizing, amplitude_damping, or hardware-aware noise conditions."
23
+
24
+ REPO_CONFIG = {
25
+ "clean": {
26
+ "label": "clean",
27
+ "repo": "QSBench/QSBench-Core-v1.0.0-demo",
28
+ },
29
+ "depolarizing": {
30
+ "label": "depolarizing",
31
+ "repo": "QSBench/QSBench-Depolarizing-Demo-v1.0.0",
32
+ },
33
+ "amplitude_damping": {
34
+ "label": "amplitude_damping",
35
+ "repo": "QSBench/QSBench-Amplitude-v1.0.0-demo",
36
+ },
37
+ "hardware_aware": {
38
+ "label": "hardware_aware",
39
+ "repo": "QSBench/QSBench-Transpilation-v1.0.0-demo",
40
+ },
41
+ }
42
+
43
+ CLASS_ORDER = ["clean", "depolarizing", "amplitude_damping", "hardware_aware"]
44
+
45
+ NON_FEATURE_COLS = {
46
+ "sample_id",
47
+ "sample_seed",
48
+ "circuit_hash",
49
+ "split",
50
+ "circuit_qasm",
51
+ "qasm_raw",
52
+ "qasm_transpiled",
53
+ "circuit_type_resolved",
54
+ "circuit_type_requested",
55
+ "noise_type",
56
+ "noise_prob",
57
+ "observable_bases",
58
+ "observable_mode",
59
+ "backend_device",
60
+ "precision_mode",
61
+ "circuit_signature",
62
+ "entanglement",
63
+ "meyer_wallach",
64
+ "cx_count",
65
+ "noise_label",
66
+ }
67
+
68
+ SOFT_EXCLUDE_PATTERNS = ["ideal_", "noisy_", "error_", "sign_ideal_", "sign_noisy_"]
69
+ _ASSET_CACHE: Dict[str, pd.DataFrame] = {}
70
+ _COMBINED_CACHE: Optional[pd.DataFrame] = None
71
+
72
+
73
+ def safe_parse(value):
74
+ """Safely parse stringified Python literals."""
75
+ if isinstance(value, str):
76
+ try:
77
+ return ast.literal_eval(value)
78
+ except Exception:
79
+ return value
80
+ return value
81
+
82
+
83
+ def adjacency_features(adj_value) -> Dict[str, float]:
84
+ """Derive compact graph features from an adjacency matrix."""
85
+ parsed = safe_parse(adj_value)
86
+ if not isinstance(parsed, list) or len(parsed) == 0:
87
+ return {
88
+ "adj_edge_count": np.nan,
89
+ "adj_density": np.nan,
90
+ "adj_degree_mean": np.nan,
91
+ "adj_degree_std": np.nan,
92
+ }
93
+
94
+ try:
95
+ arr = np.array(parsed, dtype=float)
96
+ n = arr.shape[0]
97
+ edge_count = float(np.triu(arr, k=1).sum())
98
+ possible_edges = float(n * (n - 1) / 2)
99
+ density = edge_count / possible_edges if possible_edges > 0 else np.nan
100
+ degrees = arr.sum(axis=1)
101
+ return {
102
+ "adj_edge_count": edge_count,
103
+ "adj_density": density,
104
+ "adj_degree_mean": float(np.mean(degrees)),
105
+ "adj_degree_std": float(np.std(degrees)),
106
+ }
107
+ except Exception:
108
+ return {
109
+ "adj_edge_count": np.nan,
110
+ "adj_density": np.nan,
111
+ "adj_degree_mean": np.nan,
112
+ "adj_degree_std": np.nan,
113
+ }
114
+
115
+
116
+ def qasm_features(qasm_value) -> Dict[str, float]:
117
+ """Extract lightweight text statistics from QASM."""
118
+ if not isinstance(qasm_value, str) or not qasm_value.strip():
119
+ return {
120
+ "qasm_length": np.nan,
121
+ "qasm_line_count": np.nan,
122
+ "qasm_gate_keyword_count": np.nan,
123
+ "qasm_measure_count": np.nan,
124
+ "qasm_comment_count": np.nan,
125
+ }
126
+
127
+ text = qasm_value
128
+ lines = [line for line in text.splitlines() if line.strip()]
129
+ gate_keywords = re.findall(
130
+ r"\b(cx|h|x|y|z|rx|ry|rz|u1|u2|u3|u|swap|cz|ccx|rxx|ryy|rzz)\b",
131
+ text,
132
+ flags=re.IGNORECASE,
133
+ )
134
+ measure_count = len(re.findall(r"\bmeasure\b", text, flags=re.IGNORECASE))
135
+ comment_count = sum(1 for line in lines if line.strip().startswith("//"))
136
+
137
+ return {
138
+ "qasm_length": float(len(text)),
139
+ "qasm_line_count": float(len(lines)),
140
+ "qasm_gate_keyword_count": float(len(gate_keywords)),
141
+ "qasm_measure_count": float(measure_count),
142
+ "qasm_comment_count": float(comment_count),
143
+ }
144
+
145
+
146
+ def enrich_dataframe(df: pd.DataFrame) -> pd.DataFrame:
147
+ """Add derived numeric features for classification."""
148
+ df = df.copy()
149
+
150
+ if "adjacency" in df.columns:
151
+ adj_df = df["adjacency"].apply(adjacency_features).apply(pd.Series)
152
+ df = pd.concat([df, adj_df], axis=1)
153
+
154
+ qasm_source = "qasm_transpiled" if "qasm_transpiled" in df.columns else "qasm_raw"
155
+ if qasm_source in df.columns:
156
+ qasm_df = df[qasm_source].apply(qasm_features).apply(pd.Series)
157
+ df = pd.concat([df, qasm_df], axis=1)
158
+
159
+ return df
160
+
161
+
162
+ def load_single_dataset(dataset_key: str) -> pd.DataFrame:
163
+ """Load a single dataset shard from Hugging Face and cache it."""
164
+ if dataset_key not in _ASSET_CACHE:
165
+ logger.info("Loading dataset: %s", dataset_key)
166
+ ds = load_dataset(REPO_CONFIG[dataset_key]["repo"])
167
+ df = pd.DataFrame(ds["train"])
168
+ df = enrich_dataframe(df)
169
+ df["noise_label"] = REPO_CONFIG[dataset_key]["label"]
170
+ _ASSET_CACHE[dataset_key] = df
171
+ return _ASSET_CACHE[dataset_key]
172
+
173
+
174
+ def load_combined_dataset() -> pd.DataFrame:
175
+ """Load and merge all noise-condition datasets."""
176
+ global _COMBINED_CACHE
177
+ if _COMBINED_CACHE is None:
178
+ frames = [load_single_dataset(key) for key in REPO_CONFIG.keys()]
179
+ combined = pd.concat(frames, ignore_index=True)
180
+ combined = combined[combined["noise_label"].isin(CLASS_ORDER)].copy()
181
+ _COMBINED_CACHE = combined
182
+ return _COMBINED_CACHE
183
+
184
+
185
+ def load_guide_content() -> str:
186
+ """Load the markdown guide if it exists."""
187
+ try:
188
+ with open("GUIDE.md", "r", encoding="utf-8") as f:
189
+ return f.read()
190
+ except FileNotFoundError:
191
+ return "# Guide\n\nGuide file not found."
192
+
193
+
194
+ def get_available_feature_columns(df: pd.DataFrame) -> List[str]:
195
+ """Return numeric feature columns excluding metadata and the target."""
196
+ numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
197
+ features = []
198
+ for col in numeric_cols:
199
+ if col in NON_FEATURE_COLS:
200
+ continue
201
+ if any(pattern in col for pattern in SOFT_EXCLUDE_PATTERNS):
202
+ continue
203
+ features.append(col)
204
+ return sorted(features)
205
+
206
+
207
+ def default_feature_selection(features: List[str]) -> List[str]:
208
+ """Pick a stable set of default features."""
209
+ preferred = [
210
+ "gate_entropy",
211
+ "adj_density",
212
+ "adj_degree_mean",
213
+ "adj_degree_std",
214
+ "depth",
215
+ "total_gates",
216
+ "single_qubit_gates",
217
+ "two_qubit_gates",
218
+ "cx_count",
219
+ "qasm_length",
220
+ "qasm_line_count",
221
+ "qasm_gate_keyword_count",
222
+ ]
223
+ selected = [feature for feature in preferred if feature in features]
224
+ return selected[:8] if selected else features[:8]
225
+
226
+
227
+ def make_classification_figure(
228
+ y_true: np.ndarray,
229
+ y_pred: np.ndarray,
230
+ class_names: List[str],
231
+ feature_names: Optional[List[str]] = None,
232
+ importances: Optional[np.ndarray] = None,
233
+ ) -> plt.Figure:
234
+ """Create a compact classification summary figure."""
235
+ fig = plt.figure(figsize=(20, 6))
236
+ gs = fig.add_gridspec(1, 3)
237
+
238
+ ax1 = fig.add_subplot(gs[0, 0])
239
+ ax2 = fig.add_subplot(gs[0, 1])
240
+ ax3 = fig.add_subplot(gs[0, 2])
241
+
242
+ cm = confusion_matrix(y_true, y_pred, labels=class_names)
243
+ im = ax1.imshow(cm, interpolation="nearest")
244
+ ax1.set_title("Confusion Matrix")
245
+ ax1.set_xlabel("Predicted")
246
+ ax1.set_ylabel("Actual")
247
+ ax1.set_xticks(np.arange(len(class_names)))
248
+ ax1.set_yticks(np.arange(len(class_names)))
249
+ ax1.set_xticklabels(class_names, rotation=45, ha="right")
250
+ ax1.set_yticklabels(class_names)
251
+ for i in range(cm.shape[0]):
252
+ for j in range(cm.shape[1]):
253
+ ax1.text(j, i, cm[i, j], ha="center", va="center")
254
+ fig.colorbar(im, ax=ax1, fraction=0.046, pad=0.04)
255
+
256
+ residual_like = (y_true != y_pred).astype(int)
257
+ ax2.hist(residual_like, bins=[-0.5, 0.5, 1.5])
258
+ ax2.set_title("Correct vs Incorrect")
259
+ ax2.set_xlabel("0 = Correct, 1 = Incorrect")
260
+ ax2.set_ylabel("Count")
261
+
262
+ if importances is not None and feature_names is not None and len(importances) == len(feature_names):
263
+ idx = np.argsort(importances)[-10:]
264
+ ax3.barh([feature_names[i] for i in idx], importances[idx])
265
+ ax3.set_title("Top-10 Feature Importances")
266
+ ax3.set_xlabel("Importance")
267
+ else:
268
+ ax3.text(0.5, 0.5, "Feature importances are unavailable.", ha="center", va="center")
269
+ ax3.set_axis_off()
270
+
271
+ fig.tight_layout()
272
+ return fig
273
+
274
+
275
+ def build_dataset_profile(df: pd.DataFrame) -> str:
276
+ """Build a dataset summary for the explorer tab."""
277
+ return (
278
+ f"### Dataset profile\n\n"
279
+ f"**Rows:** {len(df):,} \n"
280
+ f"**Columns:** {len(df.columns):,} \n"
281
+ f"**Classes:** {', '.join(CLASS_ORDER)}"
282
+ )
283
+
284
+
285
+ def refresh_explorer(dataset_key: str, split_name: str) -> Tuple[gr.update, pd.DataFrame, str, str, str, str]:
286
+ """Refresh the explorer view for the selected source dataset."""
287
+ df = load_single_dataset(dataset_key)
288
+ splits = df["split"].dropna().unique().tolist() if "split" in df.columns else ["train"]
289
+ if not splits:
290
+ splits = ["train"]
291
+
292
+ if split_name not in splits:
293
+ split_name = splits[0]
294
+
295
+ filtered = df[df["split"] == split_name] if "split" in df.columns else df
296
+ display_df = filtered.head(12).copy()
297
+
298
+ raw_qasm = display_df["qasm_raw"].iloc[0] if "qasm_raw" in display_df.columns and not display_df.empty else "// N/A"
299
+ transpiled_qasm = display_df["qasm_transpiled"].iloc[0] if "qasm_transpiled" in display_df.columns and not display_df.empty else "// N/A"
300
+
301
+ profile_box = build_dataset_profile(df)
302
+ summary_box = (
303
+ f"### Split summary\n\n"
304
+ f"**Dataset:** `{dataset_key}` \n"
305
+ f"**Label:** `{REPO_CONFIG[dataset_key]['label']}` \n"
306
+ f"**Available splits:** {', '.join(splits)} \n"
307
+ f"**Preview rows:** {len(display_df)}"
308
+ )
309
+
310
+ return (
311
+ gr.update(choices=splits, value=split_name),
312
+ display_df,
313
+ raw_qasm,
314
+ transpiled_qasm,
315
+ profile_box,
316
+ summary_box,
317
+ )
318
+
319
+
320
+ def sync_feature_picker(_dataset_key: str) -> gr.update:
321
+ """Refresh the feature list from the combined dataset."""
322
+ df = load_combined_dataset()
323
+ features = get_available_feature_columns(df)
324
+ defaults = default_feature_selection(features)
325
+ return gr.update(choices=features, value=defaults)
326
+
327
+
328
+ def train_classifier(
329
+ feature_columns: List[str],
330
+ test_size: float,
331
+ n_estimators: int,
332
+ max_depth: float,
333
+ random_state: float,
334
+ ) -> Tuple[Optional[plt.Figure], str]:
335
+ """Train a four-class classifier and return metrics plus a plot."""
336
+ if not feature_columns:
337
+ return None, "### ❌ Please select at least one feature."
338
+
339
+ df = load_combined_dataset()
340
+ required_cols = feature_columns + ["noise_label"]
341
+ train_df = df.dropna(subset=required_cols).copy()
342
+ train_df = train_df[train_df["noise_label"].isin(CLASS_ORDER)]
343
+
344
+ if len(train_df) < 20:
345
+ return None, "### ❌ Not enough clean rows after filtering missing values."
346
+
347
+ X = train_df[feature_columns]
348
+ y = train_df["noise_label"]
349
+
350
+ seed = int(random_state)
351
+ depth = int(max_depth) if max_depth and int(max_depth) > 0 else None
352
+ trees = int(n_estimators)
353
+
354
+ try:
355
+ X_train, X_test, y_train, y_test = train_test_split(
356
+ X,
357
+ y,
358
+ test_size=test_size,
359
+ random_state=seed,
360
+ stratify=y,
361
+ )
362
+ except ValueError:
363
+ X_train, X_test, y_train, y_test = train_test_split(
364
+ X,
365
+ y,
366
+ test_size=test_size,
367
+ random_state=seed,
368
+ )
369
+
370
+ model = Pipeline(
371
+ steps=[
372
+ ("imputer", SimpleImputer(strategy="median")),
373
+ ("scaler", StandardScaler()),
374
+ (
375
+ "classifier",
376
+ RandomForestClassifier(
377
+ n_estimators=trees,
378
+ max_depth=depth,
379
+ random_state=seed,
380
+ n_jobs=-1,
381
+ ),
382
+ ),
383
+ ]
384
+ )
385
+
386
+ model.fit(X_train, y_train)
387
+ y_pred = model.predict(X_test)
388
+
389
+ accuracy = float(accuracy_score(y_test, y_pred))
390
+ macro_f1 = float(f1_score(y_test, y_pred, average="macro"))
391
+ weighted_f1 = float(f1_score(y_test, y_pred, average="weighted"))
392
+
393
+ classifier = model.named_steps["classifier"]
394
+ importances = getattr(classifier, "feature_importances_", None)
395
+ fig = make_classification_figure(y_test.to_numpy(), y_pred, CLASS_ORDER, list(feature_columns), importances)
396
+
397
+ report = classification_report(y_test, y_pred, labels=CLASS_ORDER, output_dict=False, zero_division=0)
398
+ results = (
399
+ "### Classification results\n\n"
400
+ f"**Rows used:** {len(train_df):,} \n"
401
+ f"**Test size:** {test_size:.0%} \n"
402
+ f"**Accuracy:** {accuracy:.4f} \n"
403
+ f"**Macro F1:** {macro_f1:.4f} \n"
404
+ f"**Weighted F1:** {weighted_f1:.4f}\n\n"
405
+ "```text\n"
406
+ f"{report}"
407
+ "```"
408
+ )
409
+ return fig, results
410
+
411
+
412
+ CUSTOM_CSS = """
413
+ .gradio-container {
414
+ max-width: 1400px !important;
415
+ }
416
+ footer {
417
+ margin-top: 1rem;
418
+ }
419
+ """
420
+
421
+ with gr.Blocks(title=APP_TITLE) as demo:
422
+ gr.Markdown(f"# 🌌 {APP_TITLE}")
423
+ gr.Markdown(APP_SUBTITLE)
424
+
425
+ with gr.Tabs():
426
+ with gr.TabItem("πŸ”Ž Explorer"):
427
+ dataset_dropdown = gr.Dropdown(
428
+ list(REPO_CONFIG.keys()),
429
+ value="clean",
430
+ label="Dataset",
431
+ )
432
+ split_dropdown = gr.Dropdown(
433
+ ["train"],
434
+ value="train",
435
+ label="Split",
436
+ )
437
+
438
+ profile_box = gr.Markdown(value="### Loading dataset...")
439
+ summary_box = gr.Markdown(value="### Loading split summary...")
440
+ explorer_df = gr.Dataframe(label="Preview", interactive=False)
441
+
442
+ with gr.Row():
443
+ raw_qasm = gr.Code(label="Raw QASM", language=None)
444
+ transpiled_qasm = gr.Code(label="Transpiled QASM", language=None)
445
+
446
+ with gr.TabItem("🧠 Classification"):
447
+ feature_picker = gr.CheckboxGroup(label="Input features", choices=[])
448
+ test_size = gr.Slider(0.1, 0.4, value=0.2, step=0.05, label="Test split")
449
+ n_estimators = gr.Slider(50, 400, value=200, step=10, label="Trees")
450
+ max_depth = gr.Slider(1, 30, value=12, step=1, label="Max depth")
451
+ seed = gr.Number(value=42, precision=0, label="Random seed")
452
+ run_btn = gr.Button("Train & Evaluate", variant="primary")
453
+ plot = gr.Plot()
454
+ metrics = gr.Markdown()
455
+
456
+ with gr.TabItem("πŸ“– Guide"):
457
+ gr.Markdown(load_guide_content())
458
+
459
+ gr.Markdown("---")
460
+ gr.Markdown(
461
+ "### πŸ”— Links\n"
462
+ "[Website](https://qsbench.github.io) | "
463
+ "[Hugging Face](https://huggingface.co/QSBench) | "
464
+ "[GitHub](https://github.com/QSBench)"
465
+ )
466
+
467
+ dataset_dropdown.change(
468
+ refresh_explorer,
469
+ [dataset_dropdown, split_dropdown],
470
+ [split_dropdown, explorer_df, raw_qasm, transpiled_qasm, profile_box, summary_box],
471
+ )
472
+
473
+ split_dropdown.change(
474
+ refresh_explorer,
475
+ [dataset_dropdown, split_dropdown],
476
+ [split_dropdown, explorer_df, raw_qasm, transpiled_qasm, profile_box, summary_box],
477
+ )
478
+
479
+ dataset_dropdown.change(sync_feature_picker, [dataset_dropdown], [feature_picker])
480
+
481
+ run_btn.click(
482
+ train_classifier,
483
+ [feature_picker, test_size, n_estimators, max_depth, seed],
484
+ [plot, metrics],
485
+ )
486
+
487
+ demo.load(
488
+ refresh_explorer,
489
+ [dataset_dropdown, split_dropdown],
490
+ [split_dropdown, explorer_df, raw_qasm, transpiled_qasm, profile_box, summary_box],
491
+ )
492
+ demo.load(sync_feature_picker, [dataset_dropdown], [feature_picker])
493
+
494
+
495
+ if __name__ == "__main__":
496
+ demo.launch(theme=gr.themes.Soft(), css=CUSTOM_CSS)