QSBench commited on
Commit
9c99c87
·
verified ·
1 Parent(s): a63cf6b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -70
app.py CHANGED
@@ -39,7 +39,6 @@ REPO_CONFIG = {
39
  }
40
  }
41
 
42
- TARGET_FAMILIES = ['QFT', 'HEA', 'RANDOM', 'EFFICIENT', 'REAL_AMPLITUDES']
43
  NON_FEATURE_COLS = {
44
  "sample_id", "sample_seed", "circuit_hash", "split", "circuit_qasm",
45
  "qasm_raw", "qasm_transpiled", "circuit_type_resolved", "circuit_type_requested",
@@ -50,6 +49,7 @@ NON_FEATURE_COLS = {
50
  _ASSET_CACHE = {}
51
 
52
  def load_all_assets(key: str) -> Dict:
 
53
  if key not in _ASSET_CACHE:
54
  logger.info(f"Fetching {key}...")
55
  ds = load_dataset(REPO_CONFIG[key]["repo"])
@@ -58,16 +58,16 @@ def load_all_assets(key: str) -> Dict:
58
  _ASSET_CACHE[key] = {"df": pd.DataFrame(ds["train"]), "meta": meta, "report": report}
59
  return _ASSET_CACHE[key]
60
 
61
- # --- UI LOGIC ---
62
-
63
  def load_guide_content():
 
64
  try:
65
  with open("GUIDE.md", "r", encoding="utf-8") as f:
66
  return f.read()
67
  except:
68
- return "### ⚠️ GUIDE.md not found."
69
 
70
  def sync_ml_metrics(ds_name: str):
 
71
  assets = load_all_assets(ds_name)
72
  df = assets["df"]
73
  numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
@@ -75,99 +75,62 @@ def sync_ml_metrics(ds_name: str):
75
  defaults = [f for f in ["gate_entropy", "meyer_wallach", "adjacency", "depth", "cx_count"] if f in valid_features]
76
  return gr.update(choices=valid_features, value=defaults)
77
 
78
- Судя по ошибке Found only: ['mixed'], в вашем столбце circuit_type_requested вместо конкретных названий семейств (QFT, HEA и т.д.) записано значение 'mixed'. Это часто случается в демонстрационных подмножествах, где данные уже перемешаны и помечены общим тегом.
79
-
80
- Для классификации нам нужны исходные метки. В датасетах QSBench они обычно находятся в столбце circuit_type_resolved.
81
-
82
- Вот обновленный код функции train_classifier с исправленной логикой выбора столбца и более надежной обработкой ошибок.
83
- Исправленный код (App Code)
84
- Python
85
-
86
  def train_classifier(ds_name: str, features: List[str]):
 
87
  if not features:
88
- return None, "### ❌ Error: No features selected. Please pick structural metrics."
89
 
90
  assets = load_all_assets(ds_name)
91
  df = assets["df"]
92
 
93
- # Try 'resolved' column first as 'requested' might contain 'mixed' in demo shards
94
  target_col = 'circuit_type_resolved' if 'circuit_type_resolved' in df.columns else 'circuit_type_requested'
95
 
96
- # Clean data: remove NaNs and ensure we have valid target strings
97
  train_df = df.dropna(subset=features + [target_col])
98
-
99
- # Filter out rows where the target might be 'mixed' or generic if others are available
100
- unique_types = train_df[target_col].unique()
101
- if 'mixed' in unique_types and len(unique_types) > 1:
102
  train_df = train_df[train_df[target_col] != 'mixed']
103
-
104
  X = train_df[features]
105
  y = train_df[target_col]
106
 
107
- # Verification: Do we have at least 2 distinct classes to perform classification?
108
- current_classes = y.unique()
109
- if len(current_classes) < 2:
110
- return None, f"### ❌ Classification Error\nFound only one class: `{current_classes}` in column `{target_col}`. " \
111
- "Try a different dataset or check if the source file has labels."
112
 
113
- # Encode labels to integers
114
  le = LabelEncoder()
115
  y_encoded = le.fit_transform(y)
116
- class_names = le.classes_
117
-
118
- # Split dataset
119
  try:
120
- X_train, X_test, y_train, y_test = train_test_split(
121
- X, y_encoded, test_size=0.2, random_state=42, stratify=y_encoded
122
- )
123
- except ValueError:
124
- # Fallback if stratify fails due to very small class sizes
125
- X_train, X_test, y_train, y_test = train_test_split(
126
- X, y_encoded, test_size=0.2, random_state=42
127
- )
128
-
129
- # Train Random Forest Classifier
130
- clf = RandomForestClassifier(n_estimators=100, max_depth=12, n_jobs=-1, random_state=42)
131
- clf.fit(X_train, y_train)
132
  preds = clf.predict(X_test)
133
 
134
- # Visuals
135
  sns.set_theme(style="whitegrid")
136
  fig, axes = plt.subplots(1, 2, figsize=(20, 8))
137
 
138
- # Plot 1: Confusion Matrix
139
  cm = confusion_matrix(y_test, preds)
140
- sns.heatmap(cm, annot=True, fmt='d', cmap='viridis',
141
- xticklabels=class_names, yticklabels=class_names, ax=axes[0], cbar=False)
142
  axes[0].set_title(f"Confusion Matrix (Acc: {accuracy_score(y_test, preds):.2%})")
143
- axes[0].set_xlabel("Predicted Label")
144
- axes[0].set_ylabel("True Label")
145
 
146
- # Plot 2: Feature Importance
147
  importances = clf.feature_importances_
148
- indices = np.argsort(importances)[-10:]
149
- axes[1].barh([features[i] for i in indices], importances[indices], color='#2ecc71')
150
  axes[1].set_title("Top-10 Discriminative Features")
151
 
152
  plt.tight_layout()
153
-
154
- # Generate text report
155
- report_dict = classification_report(y_test, preds, target_names=class_names)
156
- summary = f"### 🏆 Classifier Results: {ds_name}\n" \
157
- f"**Target Column used:** `{target_col}`\n" \
158
- f"**Accuracy:** {accuracy_score(y_test, preds):.2%}\n\n" \
159
- f"**Report:**\n```\n{report_dict}\n```"
160
-
161
- return fig, summary
162
 
163
  def update_explorer(ds_name: str, split_name: str):
 
164
  assets = load_all_assets(ds_name)
165
  df = assets["df"]
166
-
167
- # Identify splits
168
  splits = df["split"].unique().tolist() if "split" in df.columns else ["train"]
169
 
170
- # Ensure current split_name exists in this dataset
171
  if split_name not in splits:
172
  split_name = splits[0]
173
 
@@ -185,27 +148,27 @@ def update_explorer(ds_name: str, split_name: str):
185
  f"### 📋 {ds_name} Explorer"
186
  )
187
 
188
- # --- INTERFACE ---
189
  with gr.Blocks(theme=gr.themes.Soft(), title="QSBench Classifier") as demo:
190
  gr.Markdown("# 🌌 QSBench: Circuit Family Classifier")
191
 
192
  with gr.Tabs():
193
  with gr.TabItem("🔎 Explorer"):
194
- meta_txt = gr.Markdown("### Loading...")
195
  with gr.Row():
196
  ds_sel = gr.Dropdown(list(REPO_CONFIG.keys()), value="Core (Clean)", label="Dataset")
197
  sp_sel = gr.Dropdown(["train"], value="train", label="Split")
198
  data_view = gr.Dataframe(interactive=False)
199
  with gr.Row():
200
- c_raw = gr.Code(label="Logic QASM", language="python")
201
  c_tr = gr.Code(label="Transpiled QASM", language="python")
202
 
203
  with gr.TabItem("🧠 Classification"):
204
  with gr.Row():
205
  with gr.Column(scale=1):
206
  ml_ds_sel = gr.Dropdown(list(REPO_CONFIG.keys()), value="Core (Clean)", label="Environment")
207
- ml_feat_sel = gr.CheckboxGroup(label="Features", choices=[])
208
- train_btn = gr.Button("Run Analysis", variant="primary")
209
  with gr.Column(scale=2):
210
  p_out = gr.Plot()
211
  t_out = gr.Markdown()
@@ -215,14 +178,13 @@ with gr.Blocks(theme=gr.themes.Soft(), title="QSBench Classifier") as demo:
215
 
216
  gr.Markdown("--- \n ### 🔗 [Website](https://qsbench.github.io) | [Hugging Face](https://huggingface.co/QSBench) | [GitHub](https://github.com/QSBench)")
217
 
218
- # --- UPDATED EVENT LOGIC ---
219
- # Triggering the same function for both dropdowns
220
  ds_sel.change(update_explorer, [ds_sel, sp_sel], [sp_sel, data_view, c_raw, c_tr, meta_txt])
221
  sp_sel.change(update_explorer, [ds_sel, sp_sel], [sp_sel, data_view, c_raw, c_tr, meta_txt])
222
-
223
  ml_ds_sel.change(sync_ml_metrics, [ml_ds_sel], [ml_feat_sel])
224
  train_btn.click(train_classifier, [ml_ds_sel, ml_feat_sel], [p_out, t_out])
225
 
 
226
  demo.load(update_explorer, [ds_sel, sp_sel], [sp_sel, data_view, c_raw, c_tr, meta_txt])
227
  demo.load(sync_ml_metrics, [ml_ds_sel], [ml_feat_sel])
228
 
 
39
  }
40
  }
41
 
 
42
  NON_FEATURE_COLS = {
43
  "sample_id", "sample_seed", "circuit_hash", "split", "circuit_qasm",
44
  "qasm_raw", "qasm_transpiled", "circuit_type_resolved", "circuit_type_requested",
 
49
  _ASSET_CACHE = {}
50
 
51
  def load_all_assets(key: str) -> Dict:
52
+ """Fetch dataset and metadata from Hugging Face."""
53
  if key not in _ASSET_CACHE:
54
  logger.info(f"Fetching {key}...")
55
  ds = load_dataset(REPO_CONFIG[key]["repo"])
 
58
  _ASSET_CACHE[key] = {"df": pd.DataFrame(ds["train"]), "meta": meta, "report": report}
59
  return _ASSET_CACHE[key]
60
 
 
 
61
  def load_guide_content():
62
+ """Load content for the methodology tab."""
63
  try:
64
  with open("GUIDE.md", "r", encoding="utf-8") as f:
65
  return f.read()
66
  except:
67
+ return "### ⚠️ GUIDE.md not found. Please upload it to the root directory."
68
 
69
  def sync_ml_metrics(ds_name: str):
70
+ """Identify numerical features for classification."""
71
  assets = load_all_assets(ds_name)
72
  df = assets["df"]
73
  numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
 
75
  defaults = [f for f in ["gate_entropy", "meyer_wallach", "adjacency", "depth", "cx_count"] if f in valid_features]
76
  return gr.update(choices=valid_features, value=defaults)
77
 
 
 
 
 
 
 
 
 
78
  def train_classifier(ds_name: str, features: List[str]):
79
+ """Train a classifier to detect circuit families."""
80
  if not features:
81
+ return None, "### ❌ Error: Select features first."
82
 
83
  assets = load_all_assets(ds_name)
84
  df = assets["df"]
85
 
86
+ # Logic: use 'resolved' if 'requested' contains 'mixed' tags
87
  target_col = 'circuit_type_resolved' if 'circuit_type_resolved' in df.columns else 'circuit_type_requested'
88
 
89
+ # Filter 'mixed' out if other classes exist
90
  train_df = df.dropna(subset=features + [target_col])
91
+ if 'mixed' in train_df[target_col].unique() and len(train_df[target_col].unique()) > 1:
 
 
 
92
  train_df = train_df[train_df[target_col] != 'mixed']
93
+
94
  X = train_df[features]
95
  y = train_df[target_col]
96
 
97
+ if len(y.unique()) < 2:
98
+ return None, f"### ❌ Error: At least 2 classes needed. Found only: {y.unique()}"
 
 
 
99
 
 
100
  le = LabelEncoder()
101
  y_encoded = le.fit_transform(y)
102
+
 
 
103
  try:
104
+ X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=0.2, random_state=42, stratify=y_encoded)
105
+ except:
106
+ X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=0.2, random_state=42)
107
+
108
+ clf = RandomForestClassifier(n_estimators=100, max_depth=12, n_jobs=-1).fit(X_train, y_train)
 
 
 
 
 
 
 
109
  preds = clf.predict(X_test)
110
 
111
+ # Plotting
112
  sns.set_theme(style="whitegrid")
113
  fig, axes = plt.subplots(1, 2, figsize=(20, 8))
114
 
 
115
  cm = confusion_matrix(y_test, preds)
116
+ sns.heatmap(cm, annot=True, fmt='d', cmap='viridis', xticklabels=le.classes_, yticklabels=le.classes_, ax=axes[0], cbar=False)
 
117
  axes[0].set_title(f"Confusion Matrix (Acc: {accuracy_score(y_test, preds):.2%})")
 
 
118
 
 
119
  importances = clf.feature_importances_
120
+ idx = np.argsort(importances)[-10:]
121
+ axes[1].barh([features[i] for i in idx], importances[idx], color='#2ecc71')
122
  axes[1].set_title("Top-10 Discriminative Features")
123
 
124
  plt.tight_layout()
125
+ report = classification_report(y_test, preds, target_names=le.classes_)
126
+ return fig, f"### 🏆 Results\n**Target Column:** `{target_col}`\n```\n{report}\n```"
 
 
 
 
 
 
 
127
 
128
  def update_explorer(ds_name: str, split_name: str):
129
+ """Manage the Explorer tab data view."""
130
  assets = load_all_assets(ds_name)
131
  df = assets["df"]
 
 
132
  splits = df["split"].unique().tolist() if "split" in df.columns else ["train"]
133
 
 
134
  if split_name not in splits:
135
  split_name = splits[0]
136
 
 
148
  f"### 📋 {ds_name} Explorer"
149
  )
150
 
151
+ # --- GRADIO INTERFACE ---
152
  with gr.Blocks(theme=gr.themes.Soft(), title="QSBench Classifier") as demo:
153
  gr.Markdown("# 🌌 QSBench: Circuit Family Classifier")
154
 
155
  with gr.Tabs():
156
  with gr.TabItem("🔎 Explorer"):
157
+ meta_txt = gr.Markdown("### Initializing...")
158
  with gr.Row():
159
  ds_sel = gr.Dropdown(list(REPO_CONFIG.keys()), value="Core (Clean)", label="Dataset")
160
  sp_sel = gr.Dropdown(["train"], value="train", label="Split")
161
  data_view = gr.Dataframe(interactive=False)
162
  with gr.Row():
163
+ c_raw = gr.Code(label="Source QASM", language="python")
164
  c_tr = gr.Code(label="Transpiled QASM", language="python")
165
 
166
  with gr.TabItem("🧠 Classification"):
167
  with gr.Row():
168
  with gr.Column(scale=1):
169
  ml_ds_sel = gr.Dropdown(list(REPO_CONFIG.keys()), value="Core (Clean)", label="Environment")
170
+ ml_feat_sel = gr.CheckboxGroup(label="Structural Metrics", choices=[])
171
+ train_btn = gr.Button("Train Classifier", variant="primary")
172
  with gr.Column(scale=2):
173
  p_out = gr.Plot()
174
  t_out = gr.Markdown()
 
178
 
179
  gr.Markdown("--- \n ### 🔗 [Website](https://qsbench.github.io) | [Hugging Face](https://huggingface.co/QSBench) | [GitHub](https://github.com/QSBench)")
180
 
181
+ # Event Mapping
 
182
  ds_sel.change(update_explorer, [ds_sel, sp_sel], [sp_sel, data_view, c_raw, c_tr, meta_txt])
183
  sp_sel.change(update_explorer, [ds_sel, sp_sel], [sp_sel, data_view, c_raw, c_tr, meta_txt])
 
184
  ml_ds_sel.change(sync_ml_metrics, [ml_ds_sel], [ml_feat_sel])
185
  train_btn.click(train_classifier, [ml_ds_sel, ml_feat_sel], [p_out, t_out])
186
 
187
+ # Startup Load
188
  demo.load(update_explorer, [ds_sel, sp_sel], [sp_sel, data_view, c_raw, c_tr, meta_txt])
189
  demo.load(sync_ml_metrics, [ml_ds_sel], [ml_feat_sel])
190