QSBench commited on
Commit
d840d30
·
verified ·
1 Parent(s): df2604b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +454 -0
app.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import logging
3
+ import re
4
+ from typing import Dict, List, Optional, Tuple
5
+
6
+ import gradio as gr
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import pandas as pd
10
+ from datasets import load_dataset
11
+ from sklearn.ensemble import RandomForestRegressor
12
+ from sklearn.impute import SimpleImputer
13
+ from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
14
+ from sklearn.model_selection import train_test_split
15
+ from sklearn.pipeline import Pipeline
16
+ from sklearn.preprocessing import StandardScaler
17
+
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
+
21
+ APP_TITLE = "CNOT Count Regression"
22
+ APP_SUBTITLE = "Predict the number of CNOT gates (cx_count) from circuit topology and gate structure."
23
+
24
+ REPO_CONFIG = {
25
+ "Core (Clean)": "QSBench/QSBench-Core-v1.0.0-demo",
26
+ "Depolarizing Noise": "QSBench/QSBench-Depolarizing-Demo-v1.0.0",
27
+ "Amplitude Damping": "QSBench/QSBench-Amplitude-v1.0.0-demo",
28
+ "Transpilation (10q)": "QSBench/QSBench-Transpilation-v1.0.0-demo",
29
+ }
30
+
31
+ TARGET_COL = "cx_count"
32
+
33
+ NON_FEATURE_COLS = {
34
+ "sample_id",
35
+ "sample_seed",
36
+ "circuit_hash",
37
+ "split",
38
+ "circuit_qasm",
39
+ "qasm_raw",
40
+ "qasm_transpiled",
41
+ "circuit_type_resolved",
42
+ "circuit_type_requested",
43
+ "noise_type",
44
+ "noise_prob",
45
+ "observable_bases",
46
+ "observable_mode",
47
+ "backend_device",
48
+ "precision_mode",
49
+ "circuit_signature",
50
+ "entanglement",
51
+ TARGET_COL,
52
+ }
53
+
54
+ SOFT_EXCLUDE_PATTERNS = ["ideal_", "noisy_", "error_", "sign_ideal_", "sign_noisy_"]
55
+ _ASSET_CACHE: Dict[str, pd.DataFrame] = {}
56
+
57
+
58
+ def load_dataset_df(dataset_key: str) -> pd.DataFrame:
59
+ """Load a dataset shard from Hugging Face and cache it in memory."""
60
+ if dataset_key not in _ASSET_CACHE:
61
+ logger.info("Loading dataset from Hugging Face: %s", dataset_key)
62
+ ds = load_dataset(REPO_CONFIG[dataset_key])
63
+ df = pd.DataFrame(ds["train"])
64
+ df = enrich_dataframe(df)
65
+ _ASSET_CACHE[dataset_key] = df
66
+ return _ASSET_CACHE[dataset_key]
67
+
68
+
69
+ def safe_parse(value):
70
+ """Safely parse stringified Python literals."""
71
+ if isinstance(value, str):
72
+ try:
73
+ return ast.literal_eval(value)
74
+ except Exception:
75
+ return value
76
+ return value
77
+
78
+
79
+ def adjacency_features(adj_value) -> Dict[str, float]:
80
+ """Derive basic graph features from an adjacency matrix."""
81
+ parsed = safe_parse(adj_value)
82
+ if not isinstance(parsed, list) or len(parsed) == 0:
83
+ return {
84
+ "adj_edge_count": np.nan,
85
+ "adj_density": np.nan,
86
+ "adj_degree_mean": np.nan,
87
+ "adj_degree_std": np.nan,
88
+ }
89
+
90
+ try:
91
+ arr = np.array(parsed, dtype=float)
92
+ n = arr.shape[0]
93
+ edge_count = float(np.triu(arr, k=1).sum())
94
+ possible_edges = float(n * (n - 1) / 2)
95
+ density = edge_count / possible_edges if possible_edges > 0 else np.nan
96
+ degrees = arr.sum(axis=1)
97
+ return {
98
+ "adj_edge_count": edge_count,
99
+ "adj_density": density,
100
+ "adj_degree_mean": float(np.mean(degrees)),
101
+ "adj_degree_std": float(np.std(degrees)),
102
+ }
103
+ except Exception:
104
+ return {
105
+ "adj_edge_count": np.nan,
106
+ "adj_density": np.nan,
107
+ "adj_degree_mean": np.nan,
108
+ "adj_degree_std": np.nan,
109
+ }
110
+
111
+
112
+ def qasm_features(qasm_value) -> Dict[str, float]:
113
+ """Extract lightweight statistics from QASM text."""
114
+ if not isinstance(qasm_value, str) or not qasm_value.strip():
115
+ return {
116
+ "qasm_length": np.nan,
117
+ "qasm_line_count": np.nan,
118
+ "qasm_gate_keyword_count": np.nan,
119
+ "qasm_measure_count": np.nan,
120
+ "qasm_comment_count": np.nan,
121
+ }
122
+
123
+ text = qasm_value
124
+ lines = [line for line in text.splitlines() if line.strip()]
125
+ gate_keywords = re.findall(
126
+ r"\b(cx|h|x|y|z|rx|ry|rz|u1|u2|u3|u|swap|cz|ccx|rxx|ryy|rzz)\b",
127
+ text,
128
+ flags=re.IGNORECASE,
129
+ )
130
+ measure_count = len(re.findall(r"\bmeasure\b", text, flags=re.IGNORECASE))
131
+ comment_count = sum(1 for line in lines if line.strip().startswith("//"))
132
+
133
+ return {
134
+ "qasm_length": float(len(text)),
135
+ "qasm_line_count": float(len(lines)),
136
+ "qasm_gate_keyword_count": float(len(gate_keywords)),
137
+ "qasm_measure_count": float(measure_count),
138
+ "qasm_comment_count": float(comment_count),
139
+ }
140
+
141
+
142
+ def enrich_dataframe(df: pd.DataFrame) -> pd.DataFrame:
143
+ """Add derived numeric features for regression."""
144
+ df = df.copy()
145
+
146
+ if "adjacency" in df.columns:
147
+ adj_df = df["adjacency"].apply(adjacency_features).apply(pd.Series)
148
+ df = pd.concat([df, adj_df], axis=1)
149
+
150
+ qasm_source = "qasm_transpiled" if "qasm_transpiled" in df.columns else "qasm_raw"
151
+ if qasm_source in df.columns:
152
+ qasm_df = df[qasm_source].apply(qasm_features).apply(pd.Series)
153
+ df = pd.concat([df, qasm_df], axis=1)
154
+
155
+ return df
156
+
157
+
158
+ def load_guide_content() -> str:
159
+ """Load the markdown guide if it exists."""
160
+ try:
161
+ with open("GUIDE.md", "r", encoding="utf-8") as f:
162
+ return f.read()
163
+ except FileNotFoundError:
164
+ return "# Guide\n\nGuide file not found."
165
+
166
+
167
+ def get_available_feature_columns(df: pd.DataFrame) -> List[str]:
168
+ """Collect numeric feature columns, excluding the target and metadata."""
169
+ numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
170
+ features = []
171
+ for col in numeric_cols:
172
+ if col in NON_FEATURE_COLS:
173
+ continue
174
+ if any(pattern in col for pattern in SOFT_EXCLUDE_PATTERNS):
175
+ continue
176
+ features.append(col)
177
+ return sorted(features)
178
+
179
+
180
+ def default_feature_selection(features: List[str]) -> List[str]:
181
+ """Select a stable default feature subset."""
182
+ preferred = [
183
+ "gate_entropy",
184
+ "adj_density",
185
+ "adj_degree_mean",
186
+ "adj_degree_std",
187
+ "depth",
188
+ "total_gates",
189
+ "single_qubit_gates",
190
+ "two_qubit_gates",
191
+ "qasm_length",
192
+ "qasm_line_count",
193
+ "qasm_gate_keyword_count",
194
+ ]
195
+ selected = [feature for feature in preferred if feature in features]
196
+ return selected[:8] if selected else features[:8]
197
+
198
+
199
+ def make_regression_figure(
200
+ y_true: np.ndarray,
201
+ y_pred: np.ndarray,
202
+ feature_names: Optional[List[str]] = None,
203
+ importances: Optional[np.ndarray] = None,
204
+ ) -> plt.Figure:
205
+ """Create a compact regression summary figure."""
206
+ fig = plt.figure(figsize=(20, 6))
207
+ gs = fig.add_gridspec(1, 3)
208
+
209
+ ax1 = fig.add_subplot(gs[0, 0])
210
+ ax2 = fig.add_subplot(gs[0, 1])
211
+ ax3 = fig.add_subplot(gs[0, 2])
212
+
213
+ ax1.scatter(y_true, y_pred, alpha=0.75)
214
+ min_v = min(float(np.min(y_true)), float(np.min(y_pred)))
215
+ max_v = max(float(np.max(y_true)), float(np.max(y_pred)))
216
+ ax1.plot([min_v, max_v], [min_v, max_v], linestyle="--")
217
+ ax1.set_title("Actual vs Predicted")
218
+ ax1.set_xlabel("Actual cx_count")
219
+ ax1.set_ylabel("Predicted cx_count")
220
+
221
+ residuals = y_true - y_pred
222
+ ax2.hist(residuals, bins=20)
223
+ ax2.set_title("Residual Distribution")
224
+ ax2.set_xlabel("Residual")
225
+ ax2.set_ylabel("Count")
226
+
227
+ if importances is not None and feature_names is not None and len(importances) == len(feature_names):
228
+ idx = np.argsort(importances)[-10:]
229
+ ax3.barh([feature_names[i] for i in idx], importances[idx])
230
+ ax3.set_title("Top-10 Feature Importances")
231
+ ax3.set_xlabel("Importance")
232
+ else:
233
+ ax3.text(0.5, 0.5, "Feature importances are unavailable.", ha="center", va="center")
234
+ ax3.set_axis_off()
235
+
236
+ fig.tight_layout()
237
+ return fig
238
+
239
+
240
+ def build_dataset_profile(df: pd.DataFrame) -> str:
241
+ """Build a short dataset summary for the explorer tab."""
242
+ target = df[TARGET_COL] if TARGET_COL in df.columns else None
243
+ if target is None:
244
+ return "### Dataset profile\n\nTarget column not found."
245
+
246
+ return (
247
+ f"### Dataset profile\n\n"
248
+ f"**Rows:** {len(df):,} \n"
249
+ f"**Columns:** {len(df.columns):,} \n"
250
+ f"**{TARGET_COL} mean:** {target.mean():.4f} \n"
251
+ f"**{TARGET_COL} std:** {target.std():.4f} \n"
252
+ f"**{TARGET_COL} min/max:** {target.min():.4f} / {target.max():.4f}"
253
+ )
254
+
255
+
256
+ def refresh_explorer(dataset_key: str, split_name: str) -> Tuple[gr.update, pd.DataFrame, str, str, str, str]:
257
+ """Refresh the explorer tab when the dataset or split changes."""
258
+ df = load_dataset_df(dataset_key)
259
+ splits = df["split"].dropna().unique().tolist() if "split" in df.columns else ["train"]
260
+ if not splits:
261
+ splits = ["train"]
262
+
263
+ if split_name not in splits:
264
+ split_name = splits[0]
265
+
266
+ filtered = df[df["split"] == split_name] if "split" in df.columns else df
267
+ display_df = filtered.head(12).copy()
268
+
269
+ raw_qasm = display_df["qasm_raw"].iloc[0] if "qasm_raw" in display_df.columns and not display_df.empty else "// N/A"
270
+ transpiled_qasm = display_df["qasm_transpiled"].iloc[0] if "qasm_transpiled" in display_df.columns and not display_df.empty else "// N/A"
271
+
272
+ profile_box = build_dataset_profile(df)
273
+ summary_box = (
274
+ f"### Split summary\n\n"
275
+ f"**Dataset:** `{dataset_key}` \n"
276
+ f"**Available splits:** {', '.join(splits)} \n"
277
+ f"**Preview rows:** {len(display_df)}"
278
+ )
279
+
280
+ return (
281
+ gr.update(choices=splits, value=split_name),
282
+ display_df,
283
+ raw_qasm,
284
+ transpiled_qasm,
285
+ profile_box,
286
+ summary_box,
287
+ )
288
+
289
+
290
+ def sync_feature_picker(dataset_key: str) -> gr.update:
291
+ """Refresh the feature list for the selected dataset."""
292
+ df = load_dataset_df(dataset_key)
293
+ features = get_available_feature_columns(df)
294
+ defaults = default_feature_selection(features)
295
+ return gr.update(choices=features, value=defaults)
296
+
297
+
298
+ def train_regressor(
299
+ dataset_key: str,
300
+ feature_columns: List[str],
301
+ test_size: float,
302
+ n_estimators: int,
303
+ max_depth: float,
304
+ random_state: float,
305
+ ) -> Tuple[Optional[plt.Figure], str]:
306
+ """Train a regression model and report evaluation metrics."""
307
+ if not feature_columns:
308
+ return None, "### ❌ Please select at least one feature."
309
+
310
+ df = load_dataset_df(dataset_key)
311
+ required_cols = feature_columns + [TARGET_COL]
312
+ train_df = df.dropna(subset=required_cols).copy()
313
+
314
+ if len(train_df) < 10:
315
+ return None, "### ❌ Not enough clean rows after filtering missing values."
316
+
317
+ X = train_df[feature_columns]
318
+ y = train_df[TARGET_COL]
319
+
320
+ seed = int(random_state)
321
+ depth = int(max_depth) if max_depth and int(max_depth) > 0 else None
322
+ trees = int(n_estimators)
323
+
324
+ X_train, X_test, y_train, y_test = train_test_split(
325
+ X,
326
+ y,
327
+ test_size=test_size,
328
+ random_state=seed,
329
+ )
330
+
331
+ model = Pipeline(
332
+ steps=[
333
+ ("imputer", SimpleImputer(strategy="median")),
334
+ ("scaler", StandardScaler()),
335
+ (
336
+ "regressor",
337
+ RandomForestRegressor(
338
+ n_estimators=trees,
339
+ max_depth=depth,
340
+ random_state=seed,
341
+ n_jobs=-1,
342
+ ),
343
+ ),
344
+ ]
345
+ )
346
+
347
+ model.fit(X_train, y_train)
348
+ y_pred = model.predict(X_test)
349
+
350
+ rmse = float(np.sqrt(mean_squared_error(y_test, y_pred)))
351
+ mae = float(mean_absolute_error(y_test, y_pred))
352
+ r2 = float(r2_score(y_test, y_pred))
353
+
354
+ regressor = model.named_steps["regressor"]
355
+ importances = getattr(regressor, "feature_importances_", None)
356
+ fig = make_regression_figure(y_test.to_numpy(), y_pred, list(feature_columns), importances)
357
+
358
+ results = (
359
+ "### Regression results\n\n"
360
+ f"**Rows used:** {len(train_df):,} \n"
361
+ f"**Test size:** {test_size:.0%} \n"
362
+ f"**RMSE:** {rmse:.4f} \n"
363
+ f"**MAE:** {mae:.4f} \n"
364
+ f"**R²:** {r2:.4f}\n\n"
365
+ "The closer the scatter points are to the diagonal line, the better the model."
366
+ )
367
+ return fig, results
368
+
369
+
370
+ CUSTOM_CSS = """
371
+ .gradio-container {
372
+ max-width: 1400px !important;
373
+ }
374
+ footer {
375
+ margin-top: 1rem;
376
+ }
377
+ """
378
+
379
+ with gr.Blocks(title=APP_TITLE) as demo:
380
+ gr.Markdown(f"# 🌌 {APP_TITLE}")
381
+ gr.Markdown(APP_SUBTITLE)
382
+
383
+ with gr.Tabs():
384
+ with gr.TabItem("🔎 Explorer"):
385
+ dataset_dropdown = gr.Dropdown(
386
+ list(REPO_CONFIG.keys()),
387
+ value="Amplitude Damping",
388
+ label="Dataset",
389
+ )
390
+ split_dropdown = gr.Dropdown(
391
+ ["train"],
392
+ value="train",
393
+ label="Split",
394
+ )
395
+
396
+ profile_box = gr.Markdown(value="### Loading dataset...")
397
+ summary_box = gr.Markdown(value="### Loading split summary...")
398
+ explorer_df = gr.Dataframe(label="Preview", interactive=False)
399
+
400
+ with gr.Row():
401
+ raw_qasm = gr.Code(label="Raw QASM", language=None)
402
+ transpiled_qasm = gr.Code(label="Transpiled QASM", language=None)
403
+
404
+ with gr.TabItem("🧠 Regression"):
405
+ feature_picker = gr.CheckboxGroup(label="Input features", choices=[])
406
+ test_size = gr.Slider(0.1, 0.4, value=0.2, step=0.05, label="Test split")
407
+ n_estimators = gr.Slider(50, 400, value=200, step=10, label="Trees")
408
+ max_depth = gr.Slider(1, 30, value=12, step=1, label="Max depth")
409
+ seed = gr.Number(value=42, precision=0, label="Random seed")
410
+ run_btn = gr.Button("Train & Evaluate", variant="primary")
411
+ plot = gr.Plot()
412
+ metrics = gr.Markdown()
413
+
414
+ with gr.TabItem("📖 Guide"):
415
+ gr.Markdown(load_guide_content())
416
+
417
+ gr.Markdown("---")
418
+ gr.Markdown(
419
+ "### 🔗 Links\n"
420
+ "[Website](https://qsbench.github.io) | "
421
+ "[Hugging Face](https://huggingface.co/QSBench) | "
422
+ "[GitHub](https://github.com/QSBench)"
423
+ )
424
+
425
+ dataset_dropdown.change(
426
+ refresh_explorer,
427
+ [dataset_dropdown, split_dropdown],
428
+ [split_dropdown, explorer_df, raw_qasm, transpiled_qasm, profile_box, summary_box],
429
+ )
430
+
431
+ split_dropdown.change(
432
+ refresh_explorer,
433
+ [dataset_dropdown, split_dropdown],
434
+ [split_dropdown, explorer_df, raw_qasm, transpiled_qasm, profile_box, summary_box],
435
+ )
436
+
437
+ dataset_dropdown.change(sync_feature_picker, [dataset_dropdown], [feature_picker])
438
+
439
+ run_btn.click(
440
+ train_regressor,
441
+ [dataset_dropdown, feature_picker, test_size, n_estimators, max_depth, seed],
442
+ [plot, metrics],
443
+ )
444
+
445
+ demo.load(
446
+ refresh_explorer,
447
+ [dataset_dropdown, split_dropdown],
448
+ [split_dropdown, explorer_df, raw_qasm, transpiled_qasm, profile_box, summary_box],
449
+ )
450
+ demo.load(sync_feature_picker, [dataset_dropdown], [feature_picker])
451
+
452
+
453
+ if __name__ == "__main__":
454
+ demo.launch(theme=gr.themes.Soft(), css=CUSTOM_CSS)