QSBench commited on
Commit
980efa8
Β·
verified Β·
1 Parent(s): 9c99c87

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -48
app.py CHANGED
@@ -12,10 +12,11 @@ from sklearn.metrics import accuracy_score, confusion_matrix, classification_rep
12
  from sklearn.model_selection import train_test_split
13
  from sklearn.preprocessing import LabelEncoder
14
 
15
- # --- CONFIG & LOGGING ---
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
 
 
19
  REPO_CONFIG = {
20
  "Core (Clean)": {
21
  "repo": "QSBench/QSBench-Core-v1.0.0-demo",
@@ -39,6 +40,7 @@ REPO_CONFIG = {
39
  }
40
  }
41
 
 
42
  NON_FEATURE_COLS = {
43
  "sample_id", "sample_seed", "circuit_hash", "split", "circuit_qasm",
44
  "qasm_raw", "qasm_transpiled", "circuit_type_resolved", "circuit_type_requested",
@@ -49,25 +51,31 @@ NON_FEATURE_COLS = {
49
  _ASSET_CACHE = {}
50
 
51
  def load_all_assets(key: str) -> Dict:
52
- """Fetch dataset and metadata from Hugging Face."""
 
 
53
  if key not in _ASSET_CACHE:
54
- logger.info(f"Fetching {key}...")
55
  ds = load_dataset(REPO_CONFIG[key]["repo"])
56
  meta = requests.get(REPO_CONFIG[key]["meta_url"]).json()
57
  report = requests.get(REPO_CONFIG[key]["report_url"]).json()
58
  _ASSET_CACHE[key] = {"df": pd.DataFrame(ds["train"]), "meta": meta, "report": report}
59
  return _ASSET_CACHE[key]
60
 
61
- def load_guide_content():
62
- """Load content for the methodology tab."""
 
 
63
  try:
64
  with open("GUIDE.md", "r", encoding="utf-8") as f:
65
  return f.read()
66
- except:
67
- return "### ⚠️ GUIDE.md not found. Please upload it to the root directory."
68
 
69
- def sync_ml_metrics(ds_name: str):
70
- """Identify numerical features for classification."""
 
 
71
  assets = load_all_assets(ds_name)
72
  df = assets["df"]
73
  numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
@@ -75,18 +83,20 @@ def sync_ml_metrics(ds_name: str):
75
  defaults = [f for f in ["gate_entropy", "meyer_wallach", "adjacency", "depth", "cx_count"] if f in valid_features]
76
  return gr.update(choices=valid_features, value=defaults)
77
 
78
- def train_classifier(ds_name: str, features: List[str]):
79
- """Train a classifier to detect circuit families."""
 
 
80
  if not features:
81
- return None, "### ❌ Error: Select features first."
82
 
83
  assets = load_all_assets(ds_name)
84
  df = assets["df"]
85
 
86
- # Logic: use 'resolved' if 'requested' contains 'mixed' tags
87
  target_col = 'circuit_type_resolved' if 'circuit_type_resolved' in df.columns else 'circuit_type_requested'
88
 
89
- # Filter 'mixed' out if other classes exist
90
  train_df = df.dropna(subset=features + [target_col])
91
  if 'mixed' in train_df[target_col].unique() and len(train_df[target_col].unique()) > 1:
92
  train_df = train_df[train_df[target_col] != 'mixed']
@@ -95,38 +105,49 @@ def train_classifier(ds_name: str, features: List[str]):
95
  y = train_df[target_col]
96
 
97
  if len(y.unique()) < 2:
98
- return None, f"### ❌ Error: At least 2 classes needed. Found only: {y.unique()}"
99
 
 
100
  le = LabelEncoder()
101
  y_encoded = le.fit_transform(y)
102
 
103
  try:
104
  X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=0.2, random_state=42, stratify=y_encoded)
105
- except:
106
  X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=0.2, random_state=42)
107
 
108
- clf = RandomForestClassifier(n_estimators=100, max_depth=12, n_jobs=-1).fit(X_train, y_train)
 
 
109
  preds = clf.predict(X_test)
110
 
111
- # Plotting
112
  sns.set_theme(style="whitegrid")
113
  fig, axes = plt.subplots(1, 2, figsize=(20, 8))
114
 
 
115
  cm = confusion_matrix(y_test, preds)
116
  sns.heatmap(cm, annot=True, fmt='d', cmap='viridis', xticklabels=le.classes_, yticklabels=le.classes_, ax=axes[0], cbar=False)
117
- axes[0].set_title(f"Confusion Matrix (Acc: {accuracy_score(y_test, preds):.2%})")
118
 
 
119
  importances = clf.feature_importances_
120
  idx = np.argsort(importances)[-10:]
121
  axes[1].barh([features[i] for i in idx], importances[idx], color='#2ecc71')
122
- axes[1].set_title("Top-10 Discriminative Features")
123
 
124
  plt.tight_layout()
125
- report = classification_report(y_test, preds, target_names=le.classes_)
126
- return fig, f"### πŸ† Results\n**Target Column:** `{target_col}`\n```\n{report}\n```"
 
 
 
 
127
 
128
- def update_explorer(ds_name: str, split_name: str):
129
- """Manage the Explorer tab data view."""
 
 
130
  assets = load_all_assets(ds_name)
131
  df = assets["df"]
132
  splits = df["split"].unique().tolist() if "split" in df.columns else ["train"]
@@ -137,56 +158,56 @@ def update_explorer(ds_name: str, split_name: str):
137
  filtered = df[df["split"] == split_name] if "split" in df.columns else df
138
  display_df = filtered.head(10)
139
 
140
- raw = display_df["qasm_raw"].iloc[0] if "qasm_raw" in display_df.columns and not display_df.empty else "// N/A"
141
- tr = display_df["qasm_transpiled"].iloc[0] if "qasm_transpiled" in display_df.columns and not display_df.empty else "// N/A"
142
 
143
  return (
144
  gr.update(choices=splits, value=split_name),
145
  display_df,
146
- raw,
147
- tr,
148
  f"### πŸ“‹ {ds_name} Explorer"
149
  )
150
 
151
- # --- GRADIO INTERFACE ---
152
  with gr.Blocks(theme=gr.themes.Soft(), title="QSBench Classifier") as demo:
153
  gr.Markdown("# 🌌 QSBench: Circuit Family Classifier")
154
 
155
  with gr.Tabs():
156
  with gr.TabItem("πŸ”Ž Explorer"):
157
- meta_txt = gr.Markdown("### Initializing...")
158
  with gr.Row():
159
- ds_sel = gr.Dropdown(list(REPO_CONFIG.keys()), value="Core (Clean)", label="Dataset")
160
- sp_sel = gr.Dropdown(["train"], value="train", label="Split")
161
- data_view = gr.Dataframe(interactive=False)
162
  with gr.Row():
163
- c_raw = gr.Code(label="Source QASM", language="python")
164
- c_tr = gr.Code(label="Transpiled QASM", language="python")
165
 
166
  with gr.TabItem("🧠 Classification"):
167
  with gr.Row():
168
  with gr.Column(scale=1):
169
- ml_ds_sel = gr.Dropdown(list(REPO_CONFIG.keys()), value="Core (Clean)", label="Environment")
170
- ml_feat_sel = gr.CheckboxGroup(label="Structural Metrics", choices=[])
171
- train_btn = gr.Button("Train Classifier", variant="primary")
172
  with gr.Column(scale=2):
173
- p_out = gr.Plot()
174
- t_out = gr.Markdown()
175
 
176
  with gr.TabItem("πŸ“– Guide"):
177
  gr.Markdown(load_guide_content())
178
 
179
  gr.Markdown("--- \n ### πŸ”— [Website](https://qsbench.github.io) | [Hugging Face](https://huggingface.co/QSBench) | [GitHub](https://github.com/QSBench)")
180
 
181
- # Event Mapping
182
- ds_sel.change(update_explorer, [ds_sel, sp_sel], [sp_sel, data_view, c_raw, c_tr, meta_txt])
183
- sp_sel.change(update_explorer, [ds_sel, sp_sel], [sp_sel, data_view, c_raw, c_tr, meta_txt])
184
- ml_ds_sel.change(sync_ml_metrics, [ml_ds_sel], [ml_feat_sel])
185
- train_btn.click(train_classifier, [ml_ds_sel, ml_feat_sel], [p_out, t_out])
186
 
187
- # Startup Load
188
- demo.load(update_explorer, [ds_sel, sp_sel], [sp_sel, data_view, c_raw, c_tr, meta_txt])
189
- demo.load(sync_ml_metrics, [ml_ds_sel], [ml_feat_sel])
190
 
191
  if __name__ == "__main__":
192
  demo.launch()
 
12
  from sklearn.model_selection import train_test_split
13
  from sklearn.preprocessing import LabelEncoder
14
 
15
+ # Logging configuration
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
 
19
+ # Dataset repository configuration
20
  REPO_CONFIG = {
21
  "Core (Clean)": {
22
  "repo": "QSBench/QSBench-Core-v1.0.0-demo",
 
40
  }
41
  }
42
 
43
+ # Define non-feature columns to exclude from training
44
  NON_FEATURE_COLS = {
45
  "sample_id", "sample_seed", "circuit_hash", "split", "circuit_qasm",
46
  "qasm_raw", "qasm_transpiled", "circuit_type_resolved", "circuit_type_requested",
 
51
  _ASSET_CACHE = {}
52
 
53
  def load_all_assets(key: str) -> Dict:
54
+ """
55
+ Fetch and cache dataset and metadata from Hugging Face.
56
+ """
57
  if key not in _ASSET_CACHE:
58
+ logger.info(f"Fetching {key} assets...")
59
  ds = load_dataset(REPO_CONFIG[key]["repo"])
60
  meta = requests.get(REPO_CONFIG[key]["meta_url"]).json()
61
  report = requests.get(REPO_CONFIG[key]["report_url"]).json()
62
  _ASSET_CACHE[key] = {"df": pd.DataFrame(ds["train"]), "meta": meta, "report": report}
63
  return _ASSET_CACHE[key]
64
 
65
+ def load_guide_content() -> str:
66
+ """
67
+ Load Markdown content for the Methodology/Guide tab.
68
+ """
69
  try:
70
  with open("GUIDE.md", "r", encoding="utf-8") as f:
71
  return f.read()
72
+ except FileNotFoundError:
73
+ return "### ⚠️ GUIDE.md not found."
74
 
75
+ def sync_ml_metrics(ds_name: str) -> gr.update:
76
+ """
77
+ Filter and return available numerical features for the selected dataset.
78
+ """
79
  assets = load_all_assets(ds_name)
80
  df = assets["df"]
81
  numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
 
83
  defaults = [f for f in ["gate_entropy", "meyer_wallach", "adjacency", "depth", "cx_count"] if f in valid_features]
84
  return gr.update(choices=valid_features, value=defaults)
85
 
86
+ def train_classifier(ds_name: str, features: List[str]) -> Tuple[Optional[plt.Figure], str]:
87
+ """
88
+ Perform multi-class classification on circuit families and return metrics/plots.
89
+ """
90
  if not features:
91
+ return None, "### ❌ Error: No features selected."
92
 
93
  assets = load_all_assets(ds_name)
94
  df = assets["df"]
95
 
96
+ # Target column selection fallback logic
97
  target_col = 'circuit_type_resolved' if 'circuit_type_resolved' in df.columns else 'circuit_type_requested'
98
 
99
+ # Data preprocessing and cleaning
100
  train_df = df.dropna(subset=features + [target_col])
101
  if 'mixed' in train_df[target_col].unique() and len(train_df[target_col].unique()) > 1:
102
  train_df = train_df[train_df[target_col] != 'mixed']
 
105
  y = train_df[target_col]
106
 
107
  if len(y.unique()) < 2:
108
+ return None, f"### ❌ Error: Dataset contains insufficient classes for training ({y.unique()})."
109
 
110
+ # Label encoding and dataset splitting
111
  le = LabelEncoder()
112
  y_encoded = le.fit_transform(y)
113
 
114
  try:
115
  X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=0.2, random_state=42, stratify=y_encoded)
116
+ except (ValueError, TypeError):
117
  X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=0.2, random_state=42)
118
 
119
+ # Model initialization and training
120
+ clf = RandomForestClassifier(n_estimators=100, max_depth=12, n_jobs=-1, random_state=42)
121
+ clf.fit(X_train, y_train)
122
  preds = clf.predict(X_test)
123
 
124
+ # Visualization generation
125
  sns.set_theme(style="whitegrid")
126
  fig, axes = plt.subplots(1, 2, figsize=(20, 8))
127
 
128
+ # Confusion Matrix Plot
129
  cm = confusion_matrix(y_test, preds)
130
  sns.heatmap(cm, annot=True, fmt='d', cmap='viridis', xticklabels=le.classes_, yticklabels=le.classes_, ax=axes[0], cbar=False)
131
+ axes[0].set_title(f"Confusion Matrix (Accuracy: {accuracy_score(y_test, preds):.2%})")
132
 
133
+ # Feature Importance Plot
134
  importances = clf.feature_importances_
135
  idx = np.argsort(importances)[-10:]
136
  axes[1].barh([features[i] for i in idx], importances[idx], color='#2ecc71')
137
+ axes[1].set_title("Top-10 Predictive Features")
138
 
139
  plt.tight_layout()
140
+
141
+ # Performance metrics string generation
142
+ cls_report = classification_report(y_test, preds, target_names=le.classes_, output_dict=False)
143
+ results_md = f"### πŸ† Classification Results\n**Target:** `{target_col}`\n**Accuracy:** {accuracy_score(y_test, preds):.2%}\n\n**Metrics:**\n```text\n{cls_report}\n```"
144
+
145
+ return fig, results_md
146
 
147
+ def update_explorer(ds_name: str, split_name: str) -> Tuple[gr.update, pd.DataFrame, str, str, str]:
148
+ """
149
+ Refresh the Explorer view based on dataset and split selection.
150
+ """
151
  assets = load_all_assets(ds_name)
152
  df = assets["df"]
153
  splits = df["split"].unique().tolist() if "split" in df.columns else ["train"]
 
158
  filtered = df[df["split"] == split_name] if "split" in df.columns else df
159
  display_df = filtered.head(10)
160
 
161
+ raw_qasm = display_df["qasm_raw"].iloc[0] if "qasm_raw" in display_df.columns and not display_df.empty else "// N/A"
162
+ transpiled_qasm = display_df["qasm_transpiled"].iloc[0] if "qasm_transpiled" in display_df.columns and not display_df.empty else "// N/A"
163
 
164
  return (
165
  gr.update(choices=splits, value=split_name),
166
  display_df,
167
+ raw_qasm,
168
+ transpiled_qasm,
169
  f"### πŸ“‹ {ds_name} Explorer"
170
  )
171
 
172
+ # Gradio interface definition
173
  with gr.Blocks(theme=gr.themes.Soft(), title="QSBench Classifier") as demo:
174
  gr.Markdown("# 🌌 QSBench: Circuit Family Classifier")
175
 
176
  with gr.Tabs():
177
  with gr.TabItem("πŸ”Ž Explorer"):
178
+ meta_label = gr.Markdown("### Initializing...")
179
  with gr.Row():
180
+ ds_dropdown = gr.Dropdown(list(REPO_CONFIG.keys()), value="Core (Clean)", label="Dataset Type")
181
+ split_dropdown = gr.Dropdown(["train"], value="train", label="Split")
182
+ explorer_df = gr.Dataframe(interactive=False)
183
  with gr.Row():
184
+ raw_qasm_code = gr.Code(label="Logical QASM", language="python")
185
+ tr_qasm_code = gr.Code(label="Transpiled QASM", language="python")
186
 
187
  with gr.TabItem("🧠 Classification"):
188
  with gr.Row():
189
  with gr.Column(scale=1):
190
+ ml_ds_dropdown = gr.Dropdown(list(REPO_CONFIG.keys()), value="Core (Clean)", label="Noise Environment")
191
+ ml_feature_checks = gr.CheckboxGroup(label="Input Metrics", choices=[])
192
+ run_btn = gr.Button("Train & Evaluate", variant="primary")
193
  with gr.Column(scale=2):
194
+ plot_output = gr.Plot()
195
+ results_output = gr.Markdown()
196
 
197
  with gr.TabItem("πŸ“– Guide"):
198
  gr.Markdown(load_guide_content())
199
 
200
  gr.Markdown("--- \n ### πŸ”— [Website](https://qsbench.github.io) | [Hugging Face](https://huggingface.co/QSBench) | [GitHub](https://github.com/QSBench)")
201
 
202
+ # UI Event bindings
203
+ ds_dropdown.change(update_explorer, [ds_dropdown, split_dropdown], [split_dropdown, explorer_df, raw_qasm_code, tr_qasm_code, meta_label])
204
+ split_dropdown.change(update_explorer, [ds_dropdown, split_dropdown], [split_dropdown, explorer_df, raw_qasm_code, tr_qasm_code, meta_label])
205
+ ml_ds_dropdown.change(sync_ml_metrics, [ml_ds_dropdown], [ml_feature_checks])
206
+ run_btn.click(train_classifier, [ml_ds_dropdown, ml_feature_checks], [plot_output, results_output])
207
 
208
+ # Application startup triggers
209
+ demo.load(update_explorer, [ds_dropdown, split_dropdown], [split_dropdown, explorer_df, raw_qasm_code, tr_qasm_code, meta_label])
210
+ demo.load(sync_ml_metrics, [ml_ds_dropdown], [ml_feature_checks])
211
 
212
  if __name__ == "__main__":
213
  demo.launch()