QSBench commited on
Commit
e77fa31
Β·
verified Β·
1 Parent(s): 0ea2619

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +212 -0
app.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import pandas as pd
5
+ import seaborn as sns
6
+ import logging
7
+ import requests
8
+ from typing import List, Tuple, Dict, Optional
9
+ from datasets import load_dataset
10
+ from sklearn.ensemble import RandomForestClassifier
11
+ from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
12
+ from sklearn.model_selection import train_test_split
13
+ from sklearn.preprocessing import LabelEncoder
14
+
15
+ # --- CONFIG & LOGGING ---
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ REPO_CONFIG = {
20
+ "Core (Clean)": {
21
+ "repo": "QSBench/QSBench-Core-v1.0.0-demo",
22
+ "meta_url": "https://huggingface.co/datasets/QSBench/QSBench-Core-v1.0.0-demo/raw/metadata/meta/meta.json",
23
+ "report_url": "https://huggingface.co/datasets/QSBench/QSBench-Core-v1.0.0-demo/raw/metadata/meta/report.json"
24
+ },
25
+ "Depolarizing Noise": {
26
+ "repo": "QSBench/QSBench-Depolarizing-Demo-v1.0.0",
27
+ "meta_url": "https://huggingface.co/datasets/QSBench/QSBench-Depolarizing-Demo-v1.0.0/raw/meta/meta/meta.json",
28
+ "report_url": "https://huggingface.co/datasets/QSBench/QSBench-Depolarizing-Demo-v1.0.0/raw/meta/meta/report.json"
29
+ },
30
+ "Amplitude Damping": {
31
+ "repo": "QSBench/QSBench-Amplitude-v1.0.0-demo",
32
+ "meta_url": "https://huggingface.co/datasets/QSBench/QSBench-Amplitude-v1.0.0-demo/raw/meta/meta/meta.json",
33
+ "report_url": "https://huggingface.co/datasets/QSBench/QSBench-Amplitude-v1.0.0-demo/raw/meta/meta/report.json"
34
+ },
35
+ "Transpilation (10q)": {
36
+ "repo": "QSBench/QSBench-Transpilation-v1.0.0-demo",
37
+ "meta_url": "https://huggingface.co/datasets/QSBench/QSBench-Transpilation-v1.0.0-demo/raw/meta/meta/meta.json",
38
+ "report_url": "https://huggingface.co/datasets/QSBench/QSBench-Transpilation-v1.0.0-demo/raw/meta/meta/report.json"
39
+ }
40
+ }
41
+
42
+ # Columns that are NOT features
43
+ NON_FEATURE_COLS = {
44
+ "sample_id", "sample_seed", "circuit_hash", "split", "circuit_qasm",
45
+ "qasm_raw", "qasm_transpiled", "circuit_type_resolved", "circuit_type_requested",
46
+ "noise_type", "noise_prob", "observable_bases", "observable_mode", "backend_device",
47
+ "precision_mode", "circuit_signature", "entanglement", "shots", "gpu_requested", "gpu_available"
48
+ }
49
+
50
+ _ASSET_CACHE = {}
51
+
52
+ def load_all_assets(key: str) -> Dict:
53
+ if key not in _ASSET_CACHE:
54
+ logger.info(f"Fetching {key}...")
55
+ ds = load_dataset(REPO_CONFIG[key]["repo"])
56
+ meta = requests.get(REPO_CONFIG[key]["meta_url"]).json()
57
+ report = requests.get(REPO_CONFIG[key]["report_url"]).json()
58
+ _ASSET_CACHE[key] = {"df": pd.DataFrame(ds["train"]), "meta": meta, "report": report}
59
+ return _ASSET_CACHE[key]
60
+
61
+ # --- UI LOGIC ---
62
+
63
+ def load_guide_content():
64
+ """Reads the content of GUIDE.md from the local directory."""
65
+ try:
66
+ with open("GUIDE.md", "r", encoding="utf-8") as f:
67
+ return f.read()
68
+ except FileNotFoundError:
69
+ return "### ⚠️ Error: GUIDE.md not found. Please ensure it is in the root directory."
70
+
71
+ def sync_ml_metrics(ds_name: str):
72
+ """Extracts numerical features available for classification."""
73
+ assets = load_all_assets(ds_name)
74
+ df = assets["df"]
75
+ numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
76
+
77
+ valid_features = [
78
+ c for c in numeric_cols
79
+ if c not in NON_FEATURE_COLS
80
+ and not any(prefix in c for prefix in ["ideal_", "noisy_", "error_", "sign_"])
81
+ ]
82
+
83
+ # Pre-select logical structural indicators
84
+ defaults = [f for f in ["gate_entropy", "meyer_wallach", "adjacency", "depth", "cx_count"] if f in valid_features]
85
+ return gr.update(choices=valid_features, value=defaults or valid_features[:5])
86
+
87
+ def train_classifier(ds_name: str, features: List[str]):
88
+ """Trains a Classifier to identify the Circuit Family based on topology."""
89
+ if not features: return None, "### ❌ Error: No features selected."
90
+ assets = load_all_assets(ds_name)
91
+ df = assets["df"]
92
+
93
+ target_col = "circuit_type_requested"
94
+ if target_col not in df.columns:
95
+ return None, f"### ❌ Error: Target column '{target_col}' not found."
96
+
97
+ # Data Cleaning
98
+ train_df = df.dropna(subset=features + [target_col])
99
+ X = train_df[features]
100
+ y = train_df[target_col]
101
+
102
+ # Encoding targets
103
+ le = LabelEncoder()
104
+ y_encoded = le.fit_transform(y)
105
+ class_names = le.classes_
106
+
107
+ X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=0.2, random_state=42)
108
+
109
+ # Classification Model
110
+ clf = RandomForestClassifier(n_estimators=100, max_depth=12, n_jobs=-1, random_state=42)
111
+ clf.fit(X_train, y_train)
112
+ preds = clf.predict(X_test)
113
+
114
+ # Metrics
115
+ acc = accuracy_score(y_test, preds)
116
+
117
+ # Visualization
118
+ sns.set_theme(style="whitegrid", context="talk")
119
+ fig, axes = plt.subplots(1, 2, figsize=(20, 8))
120
+
121
+ # 1. Confusion Matrix
122
+ cm = confusion_matrix(y_test, preds)
123
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
124
+ xticklabels=class_names, yticklabels=class_names, ax=axes[0], cbar=False)
125
+ axes[0].set_title(f"Confusion Matrix (Accuracy: {acc:.2%})")
126
+ axes[0].set_xlabel("Predicted Family")
127
+ axes[0].set_ylabel("Actual Family")
128
+
129
+ # 2. Feature Importance
130
+ importances = clf.feature_importances_
131
+ indices = np.argsort(importances)[-10:] # Top 10
132
+ axes[1].barh([features[i] for i in indices], importances[indices], color='#16a085')
133
+ axes[1].set_title("Top Structural Discriminators")
134
+
135
+ plt.tight_layout()
136
+
137
+ report_dict = classification_report(y_test, preds, target_names=class_names)
138
+ summary = f"### πŸ† Classification Results\n**Overall Accuracy:** {acc:.2%}\n\n**Detailed Report:**\n```\n{report_dict}\n```"
139
+
140
+ return fig, summary
141
+
142
+ def update_explorer(ds_name: str, split_name: str):
143
+ """Updates the data view for the Explorer tab."""
144
+ assets = load_all_assets(ds_name)
145
+ df = assets["df"]
146
+ unique_splits = df["split"].unique().tolist() if "split" in df.columns else ["train"]
147
+
148
+ if "split" in df.columns:
149
+ filtered_df = df[df["split"] == split_name]
150
+ if filtered_df.empty:
151
+ split_name = unique_splits[0]
152
+ filtered_df = df[df["split"] == split_name]
153
+ else:
154
+ filtered_df = df
155
+
156
+ display_df = filtered_df.head(10)
157
+ raw = display_df["qasm_raw"].iloc[0] if "qasm_raw" in display_df.columns and not display_df.empty else "// N/A"
158
+ tr = display_df["qasm_transpiled"].iloc[0] if "qasm_transpiled" in display_df.columns and not display_df.empty else "// N/A"
159
+
160
+ return gr.update(choices=unique_splits, value=split_name), display_df, raw, tr, f"### πŸ“‹ {ds_name} Explorer"
161
+
162
+ # --- INTERFACE ---
163
+ with gr.Blocks(theme=gr.themes.Soft(), title="QSBench Classifier") as demo:
164
+ gr.Markdown("# 🌌 QSBench: Circuit Family Classifier")
165
+ gr.Markdown("Identify circuit types (QFT, HEA, RANDOM, etc.) using high-level structural complexity metrics.")
166
+
167
+ with gr.Tabs():
168
+ with gr.TabItem("πŸ”Ž Dataset Explorer"):
169
+ meta_txt = gr.Markdown("### Loading...")
170
+ with gr.Row():
171
+ ds_sel = gr.Dropdown(list(REPO_CONFIG.keys()), value="Core (Clean)", label="Dataset Type")
172
+ sp_sel = gr.Dropdown(["train"], value="train", label="Subset (Split)")
173
+ data_view = gr.Dataframe(interactive=False)
174
+ with gr.Row():
175
+ c_raw = gr.Code(label="Original QASM (Logic)", language="python")
176
+ c_tr = gr.Code(label="Transpiled QASM (Hardware-ready)", language="python")
177
+
178
+ with gr.TabItem("🧠 Classification Model"):
179
+ gr.Markdown("Predict the **Circuit Family** by analyzing topology signatures.")
180
+ with gr.Row():
181
+ with gr.Column(scale=1):
182
+ ml_ds_sel = gr.Dropdown(list(REPO_CONFIG.keys()), value="Core (Clean)", label="Environment")
183
+ ml_feat_sel = gr.CheckboxGroup(label="Structural Features", choices=[])
184
+ train_btn = gr.Button("Run Classification", variant="primary")
185
+ with gr.Column(scale=2):
186
+ p_out = gr.Plot()
187
+ t_out = gr.Markdown()
188
+
189
+ with gr.TabItem("πŸ“– User Guide"):
190
+ meth_md = gr.Markdown(value=load_guide_content())
191
+
192
+ gr.Markdown(f"""
193
+ ---
194
+ ### πŸ”— Project Resources
195
+ [**🌐 Website**](https://qsbench.github.io) | [**πŸ€— Hugging Face**](https://huggingface.co/QSBench) | [**πŸ’» GitHub**](https://github.com/QSBench)
196
+ """)
197
+
198
+ # --- EVENTS ---
199
+ # Explorer events
200
+ ds_sel.change(update_explorer, [ds_sel, sp_sel], [sp_sel, data_view, c_raw, c_tr, meta_txt])
201
+ sp_sel.change(update_explorer, [ds_sel, sp_sel], [sp_sel, data_view, c_raw, c_tr, meta_txt])
202
+
203
+ # ML events
204
+ ml_ds_sel.change(sync_ml_metrics, [ml_ds_sel], [ml_feat_sel])
205
+ train_btn.click(train_classifier, [ml_ds_sel, ml_feat_sel], [p_out, t_out])
206
+
207
+ # Initial Load
208
+ demo.load(update_explorer, [ds_sel, sp_sel], [sp_sel, data_view, c_raw, c_tr, meta_txt])
209
+ demo.load(sync_ml_metrics, [ml_ds_sel], [ml_feat_sel])
210
+
211
+ if __name__ == "__main__":
212
+ demo.launch()