QSBench commited on
Commit
05a1756
Β·
verified Β·
1 Parent(s): 7aaa54b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +349 -0
app.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import logging
3
+ import re
4
+ from typing import Dict, List, Optional, Tuple
5
+
6
+ import gradio as gr
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import pandas as pd
10
+ from datasets import load_dataset
11
+
12
+ from sklearn.cluster import KMeans
13
+ from sklearn.decomposition import PCA
14
+ from sklearn.impute import SimpleImputer
15
+ from sklearn.metrics import silhouette_score
16
+ from sklearn.pipeline import Pipeline
17
+ from sklearn.preprocessing import StandardScaler
18
+
19
+ logging.basicConfig(level=logging.INFO)
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # ========================= CONFIG =========================
23
+ APP_TITLE = "Circuit Complexity Clustering"
24
+ APP_SUBTITLE = (
25
+ "Unsupervised grouping of quantum circuits by structural complexity "
26
+ "using only topology and gate features β€” no labels required."
27
+ )
28
+
29
+ REPO_CONFIG = {
30
+ "Core (Clean)": "QSBench/QSBench-Core-v1.0.0-demo",
31
+ "Depolarizing Noise": "QSBench/QSBench-Depolarizing-Demo-v1.0.0",
32
+ "Amplitude Damping": "QSBench/QSBench-Amplitude-v1.0.0-demo",
33
+ "Transpilation (10q)": "QSBench/QSBench-Transpilation-v1.0.0-demo",
34
+ }
35
+
36
+ NON_FEATURE_COLS = {
37
+ "sample_id", "sample_seed", "circuit_hash", "split", "circuit_qasm",
38
+ "qasm_raw", "qasm_transpiled", "circuit_type_resolved", "circuit_type_requested",
39
+ "noise_type", "noise_prob", "observable_bases", "observable_mode",
40
+ "backend_device", "precision_mode", "circuit_signature",
41
+ "entanglement", "meyer_wallach", "noise_label",
42
+ }
43
+
44
+ SOFT_EXCLUDE_PATTERNS = ["ideal_", "noisy_", "error_", "sign_ideal_", "sign_noisy_"]
45
+
46
+ _ASSET_CACHE: Dict[str, pd.DataFrame] = {}
47
+
48
+
49
+ def safe_parse(value):
50
+ """Safely parse stringified Python literals."""
51
+ if isinstance(value, str):
52
+ try:
53
+ return ast.literal_eval(value)
54
+ except Exception:
55
+ return value
56
+ return value
57
+
58
+
59
+ def adjacency_features(adj_value) -> Dict[str, float]:
60
+ """Derive basic graph features from an adjacency matrix."""
61
+ parsed = safe_parse(adj_value)
62
+ if not isinstance(parsed, list) or len(parsed) == 0:
63
+ return {
64
+ "adj_edge_count": np.nan,
65
+ "adj_density": np.nan,
66
+ "adj_degree_mean": np.nan,
67
+ "adj_degree_std": np.nan,
68
+ }
69
+ try:
70
+ arr = np.array(parsed, dtype=float)
71
+ n = arr.shape[0]
72
+ edge_count = float(np.triu(arr, k=1).sum())
73
+ possible_edges = float(n * (n - 1) / 2)
74
+ density = edge_count / possible_edges if possible_edges > 0 else np.nan
75
+ degrees = arr.sum(axis=1)
76
+ return {
77
+ "adj_edge_count": edge_count,
78
+ "adj_density": density,
79
+ "adj_degree_mean": float(np.mean(degrees)),
80
+ "adj_degree_std": float(np.std(degrees)),
81
+ }
82
+ except Exception:
83
+ return {
84
+ "adj_edge_count": np.nan,
85
+ "adj_density": np.nan,
86
+ "adj_degree_mean": np.nan,
87
+ "adj_degree_std": np.nan,
88
+ }
89
+
90
+
91
+ def qasm_features(qasm_value) -> Dict[str, float]:
92
+ """Extract lightweight statistics from QASM text."""
93
+ if not isinstance(qasm_value, str) or not qasm_value.strip():
94
+ return {
95
+ "qasm_length": np.nan,
96
+ "qasm_line_count": np.nan,
97
+ "qasm_gate_keyword_count": np.nan,
98
+ "qasm_measure_count": np.nan,
99
+ "qasm_comment_count": np.nan,
100
+ }
101
+ text = qasm_value
102
+ lines = [line for line in text.splitlines() if line.strip()]
103
+ gate_keywords = re.findall(
104
+ r"\b(cx|h|x|y|z|rx|ry|rz|u1|u2|u3|u|swap|cz|ccx|rxx|ryy|rzz)\b",
105
+ text,
106
+ flags=re.IGNORECASE,
107
+ )
108
+ measure_count = len(re.findall(r"\bmeasure\b", text, flags=re.IGNORECASE))
109
+ comment_count = sum(1 for line in lines if line.strip().startswith("//"))
110
+ return {
111
+ "qasm_length": float(len(text)),
112
+ "qasm_line_count": float(len(lines)),
113
+ "qasm_gate_keyword_count": float(len(gate_keywords)),
114
+ "qasm_measure_count": float(measure_count),
115
+ "qasm_comment_count": float(comment_count),
116
+ }
117
+
118
+
119
+ def enrich_dataframe(df: pd.DataFrame) -> pd.DataFrame:
120
+ """Add derived numeric features for clustering."""
121
+ df = df.copy()
122
+ if "adjacency" in df.columns:
123
+ adj_df = df["adjacency"].apply(adjacency_features).apply(pd.Series)
124
+ df = pd.concat([df, adj_df], axis=1)
125
+ qasm_source = "qasm_transpiled" if "qasm_transpiled" in df.columns else "qasm_raw"
126
+ if qasm_source in df.columns:
127
+ qasm_df = df[qasm_source].apply(qasm_features).apply(pd.Series)
128
+ df = pd.concat([df, qasm_df], axis=1)
129
+ return df
130
+
131
+
132
+ def load_dataset_df(dataset_key: str) -> pd.DataFrame:
133
+ """Load a dataset shard from Hugging Face and cache it in memory."""
134
+ if dataset_key not in _ASSET_CACHE:
135
+ logger.info("Loading dataset from Hugging Face: %s", dataset_key)
136
+ ds = load_dataset(REPO_CONFIG[dataset_key])
137
+ df = pd.DataFrame(ds["train"])
138
+ df = enrich_dataframe(df)
139
+ _ASSET_CACHE[dataset_key] = df
140
+ return _ASSET_CACHE[dataset_key]
141
+
142
+
143
+ def load_guide_content() -> str:
144
+ """Load the markdown guide if it exists."""
145
+ try:
146
+ with open("GUIDE.md", "r", encoding="utf-8") as f:
147
+ return f.read()
148
+ except FileNotFoundError:
149
+ return "# Guide\n\nGuide file not found."
150
+
151
+
152
+ def get_available_feature_columns(df: pd.DataFrame) -> List[str]:
153
+ """Collect numeric feature columns, excluding metadata."""
154
+ numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
155
+ features = []
156
+ for col in numeric_cols:
157
+ if col in NON_FEATURE_COLS:
158
+ continue
159
+ if any(pattern in col for pattern in SOFT_EXCLUDE_PATTERNS):
160
+ continue
161
+ features.append(col)
162
+ return sorted(features)
163
+
164
+
165
+ def default_feature_selection(features: List[str]) -> List[str]:
166
+ """Select a stable default feature subset."""
167
+ preferred = [
168
+ "gate_entropy", "adj_density", "adj_degree_mean", "adj_degree_std",
169
+ "depth", "total_gates", "single_qubit_gates", "two_qubit_gates",
170
+ "cx_count", "qasm_length", "qasm_line_count", "qasm_gate_keyword_count",
171
+ ]
172
+ selected = [feature for feature in preferred if feature in features]
173
+ return selected[:10] if selected else features[:10]
174
+
175
+
176
+ def run_clustering(
177
+ dataset_key: str,
178
+ feature_columns: List[str],
179
+ n_clusters: int,
180
+ random_state: float,
181
+ ) -> Tuple[Optional[plt.Figure], str, pd.DataFrame]:
182
+ """Run K-Means clustering and return PCA plot + metrics."""
183
+ if not feature_columns:
184
+ return None, "### ❌ Please select at least one feature.", None
185
+
186
+ df = load_dataset_df(dataset_key)
187
+ train_df = df.dropna(subset=feature_columns).copy()
188
+
189
+ if len(train_df) < 30:
190
+ return None, "### ❌ Not enough rows after filtering missing values.", None
191
+
192
+ X = train_df[feature_columns]
193
+
194
+ pipeline = Pipeline([
195
+ ("imputer", SimpleImputer(strategy="median")),
196
+ ("scaler", StandardScaler()),
197
+ ("pca", PCA(n_components=2, random_state=int(random_state))),
198
+ ("kmeans", KMeans(n_clusters=n_clusters, random_state=int(random_state), n_init=10))
199
+ ])
200
+
201
+ pipeline.fit(X)
202
+ labels = pipeline.named_steps["kmeans"].labels_
203
+ pca_coords = pipeline.named_steps["pca"].transform(
204
+ pipeline.named_steps["scaler"].transform(
205
+ pipeline.named_steps["imputer"].transform(X)
206
+ )
207
+ )
208
+
209
+ sil_score = silhouette_score(X, labels)
210
+
211
+ # Plot
212
+ fig, ax = plt.subplots(figsize=(10, 8))
213
+ scatter = ax.scatter(pca_coords[:, 0], pca_coords[:, 1], c=labels, cmap="tab10", s=30, alpha=0.8)
214
+ ax.set_title(f"Circuit Complexity Clusters (K={n_clusters})")
215
+ ax.set_xlabel("PCA Component 1")
216
+ ax.set_ylabel("PCA Component 2")
217
+ ax.grid(True, alpha=0.3)
218
+ plt.colorbar(scatter, ax=ax, label="Cluster")
219
+ plt.tight_layout()
220
+
221
+ # Cluster summary
222
+ summary = train_df.copy()
223
+ summary["cluster"] = labels
224
+ cluster_summary = summary.groupby("cluster").size().reset_index()
225
+ cluster_summary.columns = ["Cluster", "Number of Circuits"]
226
+
227
+ metrics_text = (
228
+ f"### Clustering Results\n\n"
229
+ f"**Number of circuits clustered:** {len(train_df):,}\n"
230
+ f"**Number of clusters:** {n_clusters}\n"
231
+ f"**Silhouette Score:** {sil_score:.4f} (closer to 1 = better separation)\n\n"
232
+ f"**Cluster sizes:**\n"
233
+ f"{cluster_summary.to_markdown(index=False)}"
234
+ )
235
+
236
+ return fig, metrics_text, cluster_summary
237
+
238
+
239
+ CUSTOM_CSS = """
240
+ .gradio-container {
241
+ max-width: 1400px !important;
242
+ }
243
+ footer {
244
+ margin-top: 1rem;
245
+ }
246
+ """
247
+
248
+ with gr.Blocks(title=APP_TITLE) as demo:
249
+ gr.Markdown(f"# 🌌 {APP_TITLE}")
250
+ gr.Markdown(APP_SUBTITLE)
251
+
252
+ with gr.Tabs():
253
+ with gr.TabItem("πŸ”Ž Explorer"):
254
+ dataset_dropdown = gr.Dropdown(
255
+ list(REPO_CONFIG.keys()),
256
+ value="Amplitude Damping",
257
+ label="Dataset",
258
+ )
259
+ split_dropdown = gr.Dropdown(
260
+ ["train"],
261
+ value="train",
262
+ label="Split",
263
+ )
264
+ profile_box = gr.Markdown(value="### Loading dataset...")
265
+ summary_box = gr.Markdown(value="### Loading split summary...")
266
+ explorer_df = gr.Dataframe(label="Preview", interactive=False)
267
+
268
+ with gr.Row():
269
+ raw_qasm = gr.Code(label="Raw QASM", language=None)
270
+ transpiled_qasm = gr.Code(label="Transpiled QASM", language=None)
271
+
272
+ with gr.TabItem("🧠 Clustering"):
273
+ feature_picker = gr.CheckboxGroup(label="Input features", choices=[])
274
+ n_clusters = gr.Slider(2, 10, value=4, step=1, label="Number of Clusters")
275
+ seed = gr.Number(value=42, precision=0, label="Random Seed")
276
+ run_btn = gr.Button("πŸš€ Run K-Means Clustering", variant="primary")
277
+
278
+ plot = gr.Plot()
279
+ metrics = gr.Markdown()
280
+ cluster_table = gr.Dataframe(label="Cluster Sizes")
281
+
282
+ with gr.TabItem("πŸ“– Guide"):
283
+ gr.Markdown(load_guide_content())
284
+
285
+ gr.Markdown("---")
286
+ gr.Markdown(
287
+ "### πŸ”— Links\n"
288
+ "[Website](https://qsbench.github.io) | "
289
+ "[Hugging Face](https://huggingface.co/QSBench) | "
290
+ "[GitHub](https://github.com/QSBench)"
291
+ )
292
+
293
+ # Callbacks
294
+ def refresh_explorer(dataset_key: str, split_name: str):
295
+ df = load_dataset_df(dataset_key)
296
+ splits = df["split"].dropna().unique().tolist() if "split" in df.columns else ["train"]
297
+ if not splits:
298
+ splits = ["train"]
299
+ if split_name not in splits:
300
+ split_name = splits[0]
301
+ filtered = df[df["split"] == split_name] if "split" in df.columns else df
302
+ display_df = filtered.head(12).copy()
303
+ raw = display_df["qasm_raw"].iloc[0] if not display_df.empty else "// N/A"
304
+ transpiled = display_df["qasm_transpiled"].iloc[0] if not display_df.empty else "// N/A"
305
+ profile = f"### Dataset profile\n\n**Rows:** {len(df):,}\n**Columns:** {len(df.columns):,}"
306
+ summary = f"### Split summary\n\n**Dataset:** `{dataset_key}`\n**Available splits:** {', '.join(splits)}\n**Preview rows:** {len(display_df)}"
307
+ return (
308
+ gr.update(choices=splits, value=split_name),
309
+ display_df,
310
+ raw,
311
+ transpiled,
312
+ profile,
313
+ summary,
314
+ )
315
+
316
+ def sync_feature_picker(dataset_key: str):
317
+ df = load_dataset_df(dataset_key)
318
+ features = get_available_feature_columns(df)
319
+ defaults = default_feature_selection(features)
320
+ return gr.update(choices=features, value=defaults)
321
+
322
+ dataset_dropdown.change(
323
+ refresh_explorer,
324
+ [dataset_dropdown, split_dropdown],
325
+ [split_dropdown, explorer_df, raw_qasm, transpiled_qasm, profile_box, summary_box],
326
+ )
327
+ split_dropdown.change(
328
+ refresh_explorer,
329
+ [dataset_dropdown, split_dropdown],
330
+ [split_dropdown, explorer_df, raw_qasm, transpiled_qasm, profile_box, summary_box],
331
+ )
332
+ dataset_dropdown.change(sync_feature_picker, [dataset_dropdown], [feature_picker])
333
+
334
+ run_btn.click(
335
+ run_clustering,
336
+ [dataset_dropdown, feature_picker, n_clusters, seed],
337
+ [plot, metrics, cluster_table],
338
+ )
339
+
340
+ demo.load(
341
+ refresh_explorer,
342
+ [dataset_dropdown, split_dropdown],
343
+ [split_dropdown, explorer_df, raw_qasm, transpiled_qasm, profile_box, summary_box],
344
+ )
345
+ demo.load(sync_feature_picker, [dataset_dropdown], [feature_picker])
346
+
347
+
348
+ if __name__ == "__main__":
349
+ demo.launch(theme=gr.themes.Soft(), css=CUSTOM_CSS)