QSBench commited on
Commit
30add96
Β·
verified Β·
1 Parent(s): a0ad65f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -83
app.py CHANGED
@@ -39,7 +39,7 @@ REPO_CONFIG = {
39
  }
40
  }
41
 
42
- # Columns that are NOT features
43
  NON_FEATURE_COLS = {
44
  "sample_id", "sample_seed", "circuit_hash", "split", "circuit_qasm",
45
  "qasm_raw", "qasm_transpiled", "circuit_type_resolved", "circuit_type_requested",
@@ -61,150 +61,115 @@ def load_all_assets(key: str) -> Dict:
61
  # --- UI LOGIC ---
62
 
63
  def load_guide_content():
64
- """Reads the content of GUIDE.md from the local directory."""
65
  try:
66
  with open("GUIDE.md", "r", encoding="utf-8") as f:
67
  return f.read()
68
- except FileNotFoundError:
69
- return "### ⚠️ Error: GUIDE.md not found. Please ensure it is in the root directory."
70
 
71
  def sync_ml_metrics(ds_name: str):
72
- """Extracts numerical features available for classification."""
73
  assets = load_all_assets(ds_name)
74
  df = assets["df"]
75
  numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
76
-
77
- valid_features = [
78
- c for c in numeric_cols
79
- if c not in NON_FEATURE_COLS
80
- and not any(prefix in c for prefix in ["ideal_", "noisy_", "error_", "sign_"])
81
- ]
82
-
83
- # Pre-select logical structural indicators
84
  defaults = [f for f in ["gate_entropy", "meyer_wallach", "adjacency", "depth", "cx_count"] if f in valid_features]
85
- return gr.update(choices=valid_features, value=defaults or valid_features[:5])
86
 
87
  def train_classifier(ds_name: str, features: List[str]):
88
- """Trains a Classifier to identify the Circuit Family based on topology."""
89
- if not features: return None, "### ❌ Error: No features selected."
90
  assets = load_all_assets(ds_name)
91
  df = assets["df"]
92
 
93
- target_col = "circuit_type_requested"
94
- if target_col not in df.columns:
95
- return None, f"### ❌ Error: Target column '{target_col}' not found."
96
-
97
- # Data Cleaning
98
- train_df = df.dropna(subset=features + [target_col])
99
- X = train_df[features]
100
- y = train_df[target_col]
101
 
102
- # Encoding targets
103
  le = LabelEncoder()
104
  y_encoded = le.fit_transform(y)
105
- class_names = le.classes_
106
-
107
  X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=0.2, random_state=42)
108
 
109
- # Classification Model
110
- clf = RandomForestClassifier(n_estimators=100, max_depth=12, n_jobs=-1, random_state=42)
111
- clf.fit(X_train, y_train)
112
  preds = clf.predict(X_test)
113
 
114
- # Metrics
115
- acc = accuracy_score(y_test, preds)
116
-
117
- # Visualization
118
- sns.set_theme(style="whitegrid", context="talk")
119
  fig, axes = plt.subplots(1, 2, figsize=(20, 8))
120
 
121
- # 1. Confusion Matrix
122
  cm = confusion_matrix(y_test, preds)
123
- sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
124
- xticklabels=class_names, yticklabels=class_names, ax=axes[0], cbar=False)
125
- axes[0].set_title(f"Confusion Matrix (Accuracy: {acc:.2%})")
126
- axes[0].set_xlabel("Predicted Family")
127
- axes[0].set_ylabel("Actual Family")
128
 
129
- # 2. Feature Importance
130
  importances = clf.feature_importances_
131
- indices = np.argsort(importances)[-10:] # Top 10
132
- axes[1].barh([features[i] for i in indices], importances[indices], color='#16a085')
133
- axes[1].set_title("Top Structural Discriminators")
134
 
135
  plt.tight_layout()
136
-
137
- report_dict = classification_report(y_test, preds, target_names=class_names)
138
- summary = f"### πŸ† Classification Results\n**Overall Accuracy:** {acc:.2%}\n\n**Detailed Report:**\n```\n{report_dict}\n```"
139
-
140
- return fig, summary
141
 
142
  def update_explorer(ds_name: str, split_name: str):
143
- """Updates the data view for the Explorer tab."""
144
  assets = load_all_assets(ds_name)
145
  df = assets["df"]
146
- unique_splits = df["split"].unique().tolist() if "split" in df.columns else ["train"]
147
 
148
- if "split" in df.columns:
149
- filtered_df = df[df["split"] == split_name]
150
- if filtered_df.empty:
151
- split_name = unique_splits[0]
152
- filtered_df = df[df["split"] == split_name]
153
- else:
154
- filtered_df = df
155
-
156
- display_df = filtered_df.head(10)
 
157
  raw = display_df["qasm_raw"].iloc[0] if "qasm_raw" in display_df.columns and not display_df.empty else "// N/A"
158
  tr = display_df["qasm_transpiled"].iloc[0] if "qasm_transpiled" in display_df.columns and not display_df.empty else "// N/A"
159
 
160
- return gr.update(choices=unique_splits, value=split_name), display_df, raw, tr, f"### πŸ“‹ {ds_name} Explorer"
 
 
 
 
 
 
161
 
162
  # --- INTERFACE ---
163
  with gr.Blocks(theme=gr.themes.Soft(), title="QSBench Classifier") as demo:
164
  gr.Markdown("# 🌌 QSBench: Circuit Family Classifier")
165
- gr.Markdown("Identify circuit types (QFT, HEA, RANDOM, etc.) using high-level structural complexity metrics.")
166
 
167
  with gr.Tabs():
168
- with gr.TabItem("πŸ”Ž Dataset Explorer"):
169
  meta_txt = gr.Markdown("### Loading...")
170
  with gr.Row():
171
- ds_sel = gr.Dropdown(list(REPO_CONFIG.keys()), value="Core (Clean)", label="Dataset Type")
172
- sp_sel = gr.Dropdown(["train"], value="train", label="Subset (Split)")
173
  data_view = gr.Dataframe(interactive=False)
174
  with gr.Row():
175
- c_raw = gr.Code(label="Original QASM (Logic)", language="python")
176
- c_tr = gr.Code(label="Transpiled QASM (Hardware-ready)", language="python")
177
 
178
- with gr.TabItem("🧠 Classification Model"):
179
- gr.Markdown("Predict the **Circuit Family** by analyzing topology signatures.")
180
  with gr.Row():
181
  with gr.Column(scale=1):
182
  ml_ds_sel = gr.Dropdown(list(REPO_CONFIG.keys()), value="Core (Clean)", label="Environment")
183
- ml_feat_sel = gr.CheckboxGroup(label="Structural Features", choices=[])
184
- train_btn = gr.Button("Run Classification", variant="primary")
185
  with gr.Column(scale=2):
186
  p_out = gr.Plot()
187
  t_out = gr.Markdown()
188
 
189
- with gr.TabItem("πŸ“– User Guide"):
190
- meth_md = gr.Markdown(value=load_guide_content())
191
 
192
- gr.Markdown(f"""
193
- ---
194
- ### πŸ”— Project Resources
195
- [**🌐 Website**](https://qsbench.github.io) | [**πŸ€— Hugging Face**](https://huggingface.co/QSBench) | [**πŸ’» GitHub**](https://github.com/QSBench)
196
- """)
197
 
198
- # --- EVENTS ---
199
- # Explorer events
200
  ds_sel.change(update_explorer, [ds_sel, sp_sel], [sp_sel, data_view, c_raw, c_tr, meta_txt])
201
  sp_sel.change(update_explorer, [ds_sel, sp_sel], [sp_sel, data_view, c_raw, c_tr, meta_txt])
202
 
203
- # ML events
204
  ml_ds_sel.change(sync_ml_metrics, [ml_ds_sel], [ml_feat_sel])
205
  train_btn.click(train_classifier, [ml_ds_sel, ml_feat_sel], [p_out, t_out])
206
 
207
- # Initial Load
208
  demo.load(update_explorer, [ds_sel, sp_sel], [sp_sel, data_view, c_raw, c_tr, meta_txt])
209
  demo.load(sync_ml_metrics, [ml_ds_sel], [ml_feat_sel])
210
 
 
39
  }
40
  }
41
 
42
+ TARGET_FAMILIES = ['QFT', 'HEA', 'RANDOM', 'EFFICIENT', 'REAL_AMPLITUDES']
43
  NON_FEATURE_COLS = {
44
  "sample_id", "sample_seed", "circuit_hash", "split", "circuit_qasm",
45
  "qasm_raw", "qasm_transpiled", "circuit_type_resolved", "circuit_type_requested",
 
61
  # --- UI LOGIC ---
62
 
63
  def load_guide_content():
 
64
  try:
65
  with open("GUIDE.md", "r", encoding="utf-8") as f:
66
  return f.read()
67
+ except:
68
+ return "### ⚠️ GUIDE.md not found."
69
 
70
  def sync_ml_metrics(ds_name: str):
 
71
  assets = load_all_assets(ds_name)
72
  df = assets["df"]
73
  numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
74
+ valid_features = [c for c in numeric_cols if c not in NON_FEATURE_COLS and not any(p in c for p in ["ideal_", "noisy_", "error_"])]
 
 
 
 
 
 
 
75
  defaults = [f for f in ["gate_entropy", "meyer_wallach", "adjacency", "depth", "cx_count"] if f in valid_features]
76
+ return gr.update(choices=valid_features, value=defaults)
77
 
78
  def train_classifier(ds_name: str, features: List[str]):
79
+ if not features: return None, "### ❌ Select features first."
 
80
  assets = load_all_assets(ds_name)
81
  df = assets["df"]
82
 
83
+ # Filter for the 5 target families only
84
+ train_df = df[df['circuit_type_requested'].isin(TARGET_FAMILIES)].dropna(subset=features)
85
+ X, y = train_df[features], train_df['circuit_type_requested']
 
 
 
 
 
86
 
 
87
  le = LabelEncoder()
88
  y_encoded = le.fit_transform(y)
 
 
89
  X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=0.2, random_state=42)
90
 
91
+ clf = RandomForestClassifier(n_estimators=100, max_depth=12, n_jobs=-1).fit(X_train, y_train)
 
 
92
  preds = clf.predict(X_test)
93
 
94
+ sns.set_theme(style="whitegrid")
 
 
 
 
95
  fig, axes = plt.subplots(1, 2, figsize=(20, 8))
96
 
 
97
  cm = confusion_matrix(y_test, preds)
98
+ sns.heatmap(cm, annot=True, fmt='d', cmap='magma', xticklabels=le.classes_, yticklabels=le.classes_, ax=axes[0], cbar=False)
99
+ axes[0].set_title(f"Confusion Matrix (Acc: {accuracy_score(y_test, preds):.2%})")
 
 
 
100
 
 
101
  importances = clf.feature_importances_
102
+ idx = np.argsort(importances)[-10:]
103
+ axes[1].barh([features[i] for i in idx], importances[idx], color='#3498db')
104
+ axes[1].set_title("Feature Importance")
105
 
106
  plt.tight_layout()
107
+ report = classification_report(y_test, preds, target_names=le.classes_)
108
+ return fig, f"### πŸ† Results\n```\n{report}\n```"
 
 
 
109
 
110
  def update_explorer(ds_name: str, split_name: str):
 
111
  assets = load_all_assets(ds_name)
112
  df = assets["df"]
 
113
 
114
+ # Identify splits
115
+ splits = df["split"].unique().tolist() if "split" in df.columns else ["train"]
116
+
117
+ # Ensure current split_name exists in this dataset
118
+ if split_name not in splits:
119
+ split_name = splits[0]
120
+
121
+ filtered = df[df["split"] == split_name] if "split" in df.columns else df
122
+ display_df = filtered.head(10)
123
+
124
  raw = display_df["qasm_raw"].iloc[0] if "qasm_raw" in display_df.columns and not display_df.empty else "// N/A"
125
  tr = display_df["qasm_transpiled"].iloc[0] if "qasm_transpiled" in display_df.columns and not display_df.empty else "// N/A"
126
 
127
+ return (
128
+ gr.update(choices=splits, value=split_name),
129
+ display_df,
130
+ raw,
131
+ tr,
132
+ f"### πŸ“‹ {ds_name} Explorer"
133
+ )
134
 
135
  # --- INTERFACE ---
136
  with gr.Blocks(theme=gr.themes.Soft(), title="QSBench Classifier") as demo:
137
  gr.Markdown("# 🌌 QSBench: Circuit Family Classifier")
 
138
 
139
  with gr.Tabs():
140
+ with gr.TabItem("πŸ”Ž Explorer"):
141
  meta_txt = gr.Markdown("### Loading...")
142
  with gr.Row():
143
+ ds_sel = gr.Dropdown(list(REPO_CONFIG.keys()), value="Core (Clean)", label="Dataset")
144
+ sp_sel = gr.Dropdown(["train"], value="train", label="Split")
145
  data_view = gr.Dataframe(interactive=False)
146
  with gr.Row():
147
+ c_raw = gr.Code(label="Logic QASM", language="python")
148
+ c_tr = gr.Code(label="Transpiled QASM", language="python")
149
 
150
+ with gr.TabItem("🧠 Classification"):
 
151
  with gr.Row():
152
  with gr.Column(scale=1):
153
  ml_ds_sel = gr.Dropdown(list(REPO_CONFIG.keys()), value="Core (Clean)", label="Environment")
154
+ ml_feat_sel = gr.CheckboxGroup(label="Features", choices=[])
155
+ train_btn = gr.Button("Run Analysis", variant="primary")
156
  with gr.Column(scale=2):
157
  p_out = gr.Plot()
158
  t_out = gr.Markdown()
159
 
160
+ with gr.TabItem("πŸ“– Guide"):
161
+ gr.Markdown(load_guide_content())
162
 
163
+ gr.Markdown("--- \n ### πŸ”— [Website](https://qsbench.github.io) | [Hugging Face](https://huggingface.co/QSBench) | [GitHub](https://github.com/QSBench)")
 
 
 
 
164
 
165
+ # --- UPDATED EVENT LOGIC ---
166
+ # Triggering the same function for both dropdowns
167
  ds_sel.change(update_explorer, [ds_sel, sp_sel], [sp_sel, data_view, c_raw, c_tr, meta_txt])
168
  sp_sel.change(update_explorer, [ds_sel, sp_sel], [sp_sel, data_view, c_raw, c_tr, meta_txt])
169
 
 
170
  ml_ds_sel.change(sync_ml_metrics, [ml_ds_sel], [ml_feat_sel])
171
  train_btn.click(train_classifier, [ml_ds_sel, ml_feat_sel], [p_out, t_out])
172
 
 
173
  demo.load(update_explorer, [ds_sel, sp_sel], [sp_sel, data_view, c_raw, c_tr, meta_txt])
174
  demo.load(sync_ml_metrics, [ml_ds_sel], [ml_feat_sel])
175