QSBench commited on
Commit
81da9d5
·
verified ·
1 Parent(s): 33e2b92

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +471 -0
app.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import logging
3
+ import os
4
+ import re
5
+ from dataclasses import dataclass
6
+ from typing import Dict, List, Optional, Tuple
7
+
8
+ import gradio as gr
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
+ import pandas as pd
12
+ from sklearn.ensemble import RandomForestRegressor
13
+ from sklearn.impute import SimpleImputer
14
+ from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
15
+ from sklearn.model_selection import train_test_split
16
+ from sklearn.pipeline import Pipeline
17
+ from sklearn.preprocessing import StandardScaler
18
+
19
+ # -----------------------------------------------------------------------------
20
+ # Logging
21
+ # -----------------------------------------------------------------------------
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+ # -----------------------------------------------------------------------------
26
+ # Configuration
27
+ # -----------------------------------------------------------------------------
28
+ APP_TITLE = "Entanglement Score Regression"
29
+ APP_SUBTITLE = "Predict the continuous Meyer-Wallach entanglement score from circuit topology and gate structure."
30
+
31
+ # Set this to the CSV file you place in the Space repository.
32
+ # You can also override it with an environment variable in Spaces.
33
+ DATA_PATH = os.getenv("QS_DATA_PATH", "QSBench-Amplitude-v1.0.0-demo_shard_00000.csv")
34
+
35
+ # Columns that should never be used as direct features.
36
+ NON_FEATURE_COLS = {
37
+ "sample_id",
38
+ "sample_seed",
39
+ "circuit_hash",
40
+ "split",
41
+ "circuit_qasm",
42
+ "qasm_raw",
43
+ "qasm_transpiled",
44
+ "circuit_type_resolved",
45
+ "circuit_type_requested",
46
+ "noise_type",
47
+ "noise_prob",
48
+ "observable_bases",
49
+ "observable_mode",
50
+ "backend_device",
51
+ "precision_mode",
52
+ "circuit_signature",
53
+ "entanglement",
54
+ "meyer_wallach", # target column
55
+ }
56
+
57
+ # Optional columns to visually hide from the feature picker because they are usually constant
58
+ # or less informative in small demo shards.
59
+ SOFT_EXCLUDE_PATTERNS = ["ideal_", "noisy_", "error_", "sign_ideal_", "sign_noisy_"]
60
+
61
+ _ASSET_CACHE: Dict[str, pd.DataFrame] = {}
62
+
63
+
64
+ # -----------------------------------------------------------------------------
65
+ # Data loading and feature engineering
66
+ # -----------------------------------------------------------------------------
67
+
68
+ def load_dataset_df() -> pd.DataFrame:
69
+ """Load the demo shard from disk and cache it in memory."""
70
+ if "df" not in _ASSET_CACHE:
71
+ if not os.path.exists(DATA_PATH):
72
+ raise FileNotFoundError(
73
+ f"Dataset file not found: {DATA_PATH}. "
74
+ "Place the CSV in the Space repository or set QS_DATA_PATH."
75
+ )
76
+
77
+ logger.info("Loading dataset from %s", DATA_PATH)
78
+ df = pd.read_csv(DATA_PATH)
79
+ df = enrich_dataframe(df)
80
+ _ASSET_CACHE["df"] = df
81
+ return _ASSET_CACHE["df"]
82
+
83
+
84
+ def safe_parse(value):
85
+ """Safely parse a string representation of a Python literal."""
86
+ if isinstance(value, str):
87
+ try:
88
+ return ast.literal_eval(value)
89
+ except Exception:
90
+ return value
91
+ return value
92
+
93
+
94
+ def adjacency_features(adj_value) -> Dict[str, float]:
95
+ """Derive compact topology features from the adjacency matrix."""
96
+ parsed = safe_parse(adj_value)
97
+ if not isinstance(parsed, list) or len(parsed) == 0:
98
+ return {
99
+ "adj_edge_count": np.nan,
100
+ "adj_density": np.nan,
101
+ "adj_degree_mean": np.nan,
102
+ "adj_degree_std": np.nan,
103
+ }
104
+
105
+ try:
106
+ arr = np.array(parsed, dtype=float)
107
+ n = arr.shape[0]
108
+ # For an undirected adjacency matrix, sum counts both directions.
109
+ edge_count = float(np.triu(arr, k=1).sum())
110
+ possible_edges = float(n * (n - 1) / 2)
111
+ density = edge_count / possible_edges if possible_edges > 0 else np.nan
112
+ degrees = arr.sum(axis=1)
113
+ return {
114
+ "adj_edge_count": edge_count,
115
+ "adj_density": density,
116
+ "adj_degree_mean": float(np.mean(degrees)),
117
+ "adj_degree_std": float(np.std(degrees)),
118
+ }
119
+ except Exception:
120
+ return {
121
+ "adj_edge_count": np.nan,
122
+ "adj_density": np.nan,
123
+ "adj_degree_mean": np.nan,
124
+ "adj_degree_std": np.nan,
125
+ }
126
+
127
+
128
+ def qasm_features(qasm_value) -> Dict[str, float]:
129
+ """Extract simple string-based statistics from QASM text."""
130
+ if not isinstance(qasm_value, str) or not qasm_value.strip():
131
+ return {
132
+ "qasm_length": np.nan,
133
+ "qasm_line_count": np.nan,
134
+ "qasm_gate_keyword_count": np.nan,
135
+ "qasm_measure_count": np.nan,
136
+ "qasm_comment_count": np.nan,
137
+ }
138
+
139
+ text = qasm_value
140
+ lines = [line for line in text.splitlines() if line.strip()]
141
+ gate_keywords = re.findall(r"\b(cx|h|x|y|z|rx|ry|rz|u1|u2|u3|u|swap|cz|ccx|rxx|ryy|rzz)\b", text, flags=re.IGNORECASE)
142
+ measure_count = len(re.findall(r"\bmeasure\b", text, flags=re.IGNORECASE))
143
+ comment_count = sum(1 for line in lines if line.strip().startswith("//"))
144
+
145
+ return {
146
+ "qasm_length": float(len(text)),
147
+ "qasm_line_count": float(len(lines)),
148
+ "qasm_gate_keyword_count": float(len(gate_keywords)),
149
+ "qasm_measure_count": float(measure_count),
150
+ "qasm_comment_count": float(comment_count),
151
+ }
152
+
153
+
154
+ def enrich_dataframe(df: pd.DataFrame) -> pd.DataFrame:
155
+ """Create extra features that are useful for regression."""
156
+ df = df.copy()
157
+
158
+ if "adjacency" in df.columns:
159
+ adj_df = df["adjacency"].apply(adjacency_features).apply(pd.Series)
160
+ df = pd.concat([df, adj_df], axis=1)
161
+
162
+ qasm_source = "qasm_transpiled" if "qasm_transpiled" in df.columns else "qasm_raw"
163
+ if qasm_source in df.columns:
164
+ qasm_df = df[qasm_source].apply(qasm_features).apply(pd.Series)
165
+ df = pd.concat([df, qasm_df], axis=1)
166
+
167
+ # Normalize obvious object columns that can be safely treated as strings.
168
+ for col in ["noise_type", "backend_device", "precision_mode", "observable_mode"]:
169
+ if col in df.columns:
170
+ df[col] = df[col].astype("string")
171
+
172
+ return df
173
+
174
+
175
+ def load_guide_content() -> str:
176
+ """Load the user guide if it exists in the repository."""
177
+ guide_path = "GUIDE.md"
178
+ if os.path.exists(guide_path):
179
+ with open(guide_path, "r", encoding="utf-8") as f:
180
+ return f.read()
181
+ return (
182
+ "# Guide\n\n"
183
+ "The guide file is not added yet. In the next step, we can build a full user manual "
184
+ "with dataset description, model interpretation, and example workflows."
185
+ )
186
+
187
+
188
+ # -----------------------------------------------------------------------------
189
+ # Feature selection helpers
190
+ # -----------------------------------------------------------------------------
191
+
192
+ def get_available_feature_columns(df: pd.DataFrame) -> List[str]:
193
+ """Return all numeric feature columns after excluding target and metadata."""
194
+ numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
195
+ features = []
196
+ for col in numeric_cols:
197
+ if col in NON_FEATURE_COLS:
198
+ continue
199
+ if any(pattern in col for pattern in SOFT_EXCLUDE_PATTERNS):
200
+ continue
201
+ features.append(col)
202
+ return sorted(features)
203
+
204
+
205
+ def default_feature_selection(features: List[str]) -> List[str]:
206
+ """Pick a stable set of high-value defaults."""
207
+ preferred = [
208
+ "gate_entropy",
209
+ "adjacency",
210
+ "adj_density",
211
+ "adj_degree_mean",
212
+ "adj_degree_std",
213
+ "depth",
214
+ "total_gates",
215
+ "single_qubit_gates",
216
+ "two_qubit_gates",
217
+ "cx_count",
218
+ "qasm_length",
219
+ "qasm_line_count",
220
+ "qasm_gate_keyword_count",
221
+ ]
222
+ return [f for f in preferred if f in features][:8]
223
+
224
+
225
+ # -----------------------------------------------------------------------------
226
+ # Visualization helpers
227
+ # -----------------------------------------------------------------------------
228
+
229
+ def make_regression_figure(y_true: np.ndarray, y_pred: np.ndarray, feature_names: Optional[List[str]] = None, importances: Optional[np.ndarray] = None) -> plt.Figure:
230
+ """Create a compact three-panel regression summary figure."""
231
+ fig = plt.figure(figsize=(20, 6))
232
+ gs = fig.add_gridspec(1, 3)
233
+
234
+ ax1 = fig.add_subplot(gs[0, 0])
235
+ ax2 = fig.add_subplot(gs[0, 1])
236
+ ax3 = fig.add_subplot(gs[0, 2])
237
+
238
+ # Actual vs predicted.
239
+ ax1.scatter(y_true, y_pred, alpha=0.75)
240
+ min_v = min(float(np.min(y_true)), float(np.min(y_pred)))
241
+ max_v = max(float(np.max(y_true)), float(np.max(y_pred)))
242
+ ax1.plot([min_v, max_v], [min_v, max_v], linestyle="--")
243
+ ax1.set_title("Actual vs Predicted")
244
+ ax1.set_xlabel("Actual Meyer-Wallach")
245
+ ax1.set_ylabel("Predicted Meyer-Wallach")
246
+
247
+ # Residual histogram.
248
+ residuals = y_true - y_pred
249
+ ax2.hist(residuals, bins=20)
250
+ ax2.set_title("Residual Distribution")
251
+ ax2.set_xlabel("Residual")
252
+ ax2.set_ylabel("Count")
253
+
254
+ # Feature importance chart.
255
+ if importances is not None and feature_names is not None and len(importances) == len(feature_names):
256
+ idx = np.argsort(importances)[-10:]
257
+ ax3.barh([feature_names[i] for i in idx], importances[idx])
258
+ ax3.set_title("Top-10 Feature Importances")
259
+ ax3.set_xlabel("Importance")
260
+ else:
261
+ ax3.text(0.5, 0.5, "Feature importances are unavailable.", ha="center", va="center")
262
+ ax3.set_axis_off()
263
+
264
+ fig.tight_layout()
265
+ return fig
266
+
267
+
268
+ # -----------------------------------------------------------------------------
269
+ # UI callbacks
270
+ # -----------------------------------------------------------------------------
271
+
272
+ def refresh_explorer(split_name: str) -> Tuple[gr.update, pd.DataFrame, str, str, str, str]:
273
+ """Refresh explorer output based on the selected split."""
274
+ df = load_dataset_df()
275
+ splits = df["split"].dropna().unique().tolist() if "split" in df.columns else ["train"]
276
+ if not splits:
277
+ splits = ["train"]
278
+
279
+ if split_name not in splits:
280
+ split_name = splits[0]
281
+
282
+ filtered = df[df["split"] == split_name] if "split" in df.columns else df
283
+ display_df = filtered.head(12).copy()
284
+
285
+ raw_qasm = display_df["qasm_raw"].iloc[0] if "qasm_raw" in display_df.columns and not display_df.empty else "// N/A"
286
+ transpiled_qasm = display_df["qasm_transpiled"].iloc[0] if "qasm_transpiled" in display_df.columns and not display_df.empty else "// N/A"
287
+
288
+ target_info = (
289
+ f"### Dataset overview\n\n"
290
+ f"**Rows:** {len(df):,} \n"
291
+ f"**Visible split:** `{split_name}` \n"
292
+ f"**Target:** `meyer_wallach` \n"
293
+ f"**Target range:** {df['meyer_wallach'].min():.4f} → {df['meyer_wallach'].max():.4f}"
294
+ )
295
+
296
+ summary = (
297
+ f"### Split summary\n\n"
298
+ f"**Available splits:** {', '.join(splits)} \n"
299
+ f"**Preview rows:** {len(display_df)}"
300
+ )
301
+
302
+ return (
303
+ gr.update(choices=splits, value=split_name),
304
+ display_df,
305
+ raw_qasm,
306
+ transpiled_qasm,
307
+ target_info,
308
+ summary,
309
+ )
310
+
311
+
312
+ def sync_feature_picker() -> gr.update:
313
+ """Refresh the feature list from the loaded dataset."""
314
+ df = load_dataset_df()
315
+ features = get_available_feature_columns(df)
316
+ defaults = default_feature_selection(features)
317
+ return gr.update(choices=features, value=defaults)
318
+
319
+
320
+ def train_regressor(feature_columns: List[str], test_size: float, n_estimators: int, max_depth: int, random_state: int) -> Tuple[Optional[plt.Figure], str]:
321
+ """Train a regression model and return metrics plus a plot."""
322
+ if not feature_columns:
323
+ return None, "### ❌ Please select at least one feature."
324
+
325
+ df = load_dataset_df()
326
+ required_cols = feature_columns + ["meyer_wallach"]
327
+ train_df = df.dropna(subset=required_cols).copy()
328
+
329
+ if len(train_df) < 10:
330
+ return None, "### ❌ Not enough clean rows after filtering missing values."
331
+
332
+ X = train_df[feature_columns]
333
+ y = train_df["meyer_wallach"]
334
+
335
+ X_train, X_test, y_train, y_test = train_test_split(
336
+ X,
337
+ y,
338
+ test_size=test_size,
339
+ random_state=random_state,
340
+ )
341
+
342
+ # Random forest works well for small, tabular demo data and gives feature importances.
343
+ model = Pipeline(
344
+ steps=[
345
+ ("imputer", SimpleImputer(strategy="median")),
346
+ ("scaler", StandardScaler()),
347
+ (
348
+ "regressor",
349
+ RandomForestRegressor(
350
+ n_estimators=n_estimators,
351
+ max_depth=max_depth if max_depth > 0 else None,
352
+ random_state=random_state,
353
+ n_jobs=-1,
354
+ ),
355
+ ),
356
+ ]
357
+ )
358
+
359
+ model.fit(X_train, y_train)
360
+ y_pred = model.predict(X_test)
361
+
362
+ rmse = float(np.sqrt(mean_squared_error(y_test, y_pred)))
363
+ mae = float(mean_absolute_error(y_test, y_pred))
364
+ r2 = float(r2_score(y_test, y_pred))
365
+
366
+ regressor = model.named_steps["regressor"]
367
+ importances = getattr(regressor, "feature_importances_", None)
368
+ fig = make_regression_figure(y_test.to_numpy(), y_pred, list(feature_columns), importances)
369
+
370
+ results = (
371
+ "### Regression results\n\n"
372
+ f"**Rows used:** {len(train_df):,} \n"
373
+ f"**Test size:** {test_size:.0%} \n"
374
+ f"**RMSE:** {rmse:.4f} \n"
375
+ f"**MAE:** {mae:.4f} \n"
376
+ f"**R²:** {r2:.4f}\n\n"
377
+ "The closer the scatter points are to the diagonal line, the better the model."
378
+ )
379
+ return fig, results
380
+
381
+
382
+ def build_dataset_profile() -> str:
383
+ """Generate a compact dataset summary for the explorer tab."""
384
+ df = load_dataset_df()
385
+ target = df["meyer_wallach"]
386
+ return (
387
+ f"### Dataset profile\n\n"
388
+ f"**Rows:** {len(df):,} \n"
389
+ f"**Columns:** {len(df.columns):,} \n"
390
+ f"**Meyer-Wallach mean:** {target.mean():.4f} \n"
391
+ f"**Meyer-Wallach std:** {target.std():.4f} \n"
392
+ f"**Meyer-Wallach min/max:** {target.min():.4f} / {target.max():.4f}"
393
+ )
394
+
395
+
396
+ # -----------------------------------------------------------------------------
397
+ # UI
398
+ # -----------------------------------------------------------------------------
399
+ CUSTOM_CSS = """
400
+ footer {
401
+ margin-top: 1rem;
402
+ }
403
+ .gradio-container {
404
+ max-width: 1400px !important;
405
+ }
406
+ """
407
+
408
+ with gr.Blocks(theme=gr.themes.Soft(), title=APP_TITLE, css=CUSTOM_CSS) as demo:
409
+ gr.Markdown(f"# 🌌 {APP_TITLE}")
410
+ gr.Markdown(APP_SUBTITLE)
411
+
412
+ with gr.Tabs():
413
+ with gr.TabItem("🔎 Explorer"):
414
+ with gr.Row():
415
+ split_dropdown = gr.Dropdown(label="Split", choices=["train"], value="train")
416
+ profile_box = gr.Markdown(value="### Loading dataset...")
417
+
418
+ with gr.Row():
419
+ explorer_summary = gr.Markdown(value="### Loading split summary...")
420
+
421
+ explorer_df = gr.Dataframe(interactive=False, label="Preview rows")
422
+
423
+ with gr.Row():
424
+ raw_qasm_code = gr.Code(label="Raw QASM", language="text")
425
+ transpiled_qasm_code = gr.Code(label="Transpiled QASM", language="text")
426
+
427
+ with gr.TabItem("🧠 Regression"):
428
+ with gr.Row():
429
+ with gr.Column(scale=1):
430
+ feature_picker = gr.CheckboxGroup(label="Input features", choices=[])
431
+ test_size_slider = gr.Slider(0.1, 0.4, value=0.2, step=0.05, label="Test split")
432
+ n_estimators_slider = gr.Slider(50, 400, value=200, step=10, label="Number of trees")
433
+ max_depth_slider = gr.Slider(2, 30, value=12, step=1, label="Max tree depth")
434
+ random_state_number = gr.Number(value=42, precision=0, label="Random seed")
435
+ train_btn = gr.Button("Train & Evaluate", variant="primary")
436
+ with gr.Column(scale=2):
437
+ plot_output = gr.Plot()
438
+ metrics_output = gr.Markdown()
439
+
440
+ with gr.TabItem("📖 Guide"):
441
+ gr.Markdown(load_guide_content())
442
+
443
+ gr.Markdown("---")
444
+ gr.Markdown(
445
+ "### 🔗 Links \n"
446
+ "[Website](https://qsbench.github.io) | [Hugging Face](https://huggingface.co/QSBench) | [GitHub](https://github.com/QSBench)"
447
+ )
448
+
449
+ # Bind events.
450
+ split_dropdown.change(
451
+ refresh_explorer,
452
+ inputs=[split_dropdown],
453
+ outputs=[split_dropdown, explorer_df, raw_qasm_code, transpiled_qasm_code, profile_box, explorer_summary],
454
+ )
455
+
456
+ train_btn.click(
457
+ train_regressor,
458
+ inputs=[feature_picker, test_size_slider, n_estimators_slider, max_depth_slider, random_state_number],
459
+ outputs=[plot_output, metrics_output],
460
+ )
461
+
462
+ demo.load(
463
+ refresh_explorer,
464
+ inputs=[split_dropdown],
465
+ outputs=[split_dropdown, explorer_df, raw_qasm_code, transpiled_qasm_code, profile_box, explorer_summary],
466
+ )
467
+ demo.load(sync_feature_picker, outputs=[feature_picker])
468
+
469
+
470
+ if __name__ == "__main__":
471
+ demo.launch()