Update app.py
Browse files
app.py
CHANGED
|
@@ -39,7 +39,6 @@ REPO_CONFIG = {
|
|
| 39 |
}
|
| 40 |
}
|
| 41 |
|
| 42 |
-
TARGET_FAMILIES = ['QFT', 'HEA', 'RANDOM', 'EFFICIENT', 'REAL_AMPLITUDES']
|
| 43 |
NON_FEATURE_COLS = {
|
| 44 |
"sample_id", "sample_seed", "circuit_hash", "split", "circuit_qasm",
|
| 45 |
"qasm_raw", "qasm_transpiled", "circuit_type_resolved", "circuit_type_requested",
|
|
@@ -50,6 +49,7 @@ NON_FEATURE_COLS = {
|
|
| 50 |
_ASSET_CACHE = {}
|
| 51 |
|
| 52 |
def load_all_assets(key: str) -> Dict:
|
|
|
|
| 53 |
if key not in _ASSET_CACHE:
|
| 54 |
logger.info(f"Fetching {key}...")
|
| 55 |
ds = load_dataset(REPO_CONFIG[key]["repo"])
|
|
@@ -58,16 +58,16 @@ def load_all_assets(key: str) -> Dict:
|
|
| 58 |
_ASSET_CACHE[key] = {"df": pd.DataFrame(ds["train"]), "meta": meta, "report": report}
|
| 59 |
return _ASSET_CACHE[key]
|
| 60 |
|
| 61 |
-
# --- UI LOGIC ---
|
| 62 |
-
|
| 63 |
def load_guide_content():
|
|
|
|
| 64 |
try:
|
| 65 |
with open("GUIDE.md", "r", encoding="utf-8") as f:
|
| 66 |
return f.read()
|
| 67 |
except:
|
| 68 |
-
return "### ⚠️ GUIDE.md not found."
|
| 69 |
|
| 70 |
def sync_ml_metrics(ds_name: str):
|
|
|
|
| 71 |
assets = load_all_assets(ds_name)
|
| 72 |
df = assets["df"]
|
| 73 |
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
|
|
@@ -75,99 +75,62 @@ def sync_ml_metrics(ds_name: str):
|
|
| 75 |
defaults = [f for f in ["gate_entropy", "meyer_wallach", "adjacency", "depth", "cx_count"] if f in valid_features]
|
| 76 |
return gr.update(choices=valid_features, value=defaults)
|
| 77 |
|
| 78 |
-
Судя по ошибке Found only: ['mixed'], в вашем столбце circuit_type_requested вместо конкретных названий семейств (QFT, HEA и т.д.) записано значение 'mixed'. Это часто случается в демонстрационных подмножествах, где данные уже перемешаны и помечены общим тегом.
|
| 79 |
-
|
| 80 |
-
Для классификации нам нужны исходные метки. В датасетах QSBench они обычно находятся в столбце circuit_type_resolved.
|
| 81 |
-
|
| 82 |
-
Вот обновленный код функции train_classifier с исправленной логикой выбора столбца и более надежной обработкой ошибок.
|
| 83 |
-
Исправленный код (App Code)
|
| 84 |
-
Python
|
| 85 |
-
|
| 86 |
def train_classifier(ds_name: str, features: List[str]):
|
|
|
|
| 87 |
if not features:
|
| 88 |
-
return None, "### ❌ Error:
|
| 89 |
|
| 90 |
assets = load_all_assets(ds_name)
|
| 91 |
df = assets["df"]
|
| 92 |
|
| 93 |
-
#
|
| 94 |
target_col = 'circuit_type_resolved' if 'circuit_type_resolved' in df.columns else 'circuit_type_requested'
|
| 95 |
|
| 96 |
-
#
|
| 97 |
train_df = df.dropna(subset=features + [target_col])
|
| 98 |
-
|
| 99 |
-
# Filter out rows where the target might be 'mixed' or generic if others are available
|
| 100 |
-
unique_types = train_df[target_col].unique()
|
| 101 |
-
if 'mixed' in unique_types and len(unique_types) > 1:
|
| 102 |
train_df = train_df[train_df[target_col] != 'mixed']
|
| 103 |
-
|
| 104 |
X = train_df[features]
|
| 105 |
y = train_df[target_col]
|
| 106 |
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
if len(current_classes) < 2:
|
| 110 |
-
return None, f"### ❌ Classification Error\nFound only one class: `{current_classes}` in column `{target_col}`. " \
|
| 111 |
-
"Try a different dataset or check if the source file has labels."
|
| 112 |
|
| 113 |
-
# Encode labels to integers
|
| 114 |
le = LabelEncoder()
|
| 115 |
y_encoded = le.fit_transform(y)
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
# Split dataset
|
| 119 |
try:
|
| 120 |
-
X_train, X_test, y_train, y_test = train_test_split(
|
| 121 |
-
|
| 122 |
-
)
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
X_train, X_test, y_train, y_test = train_test_split(
|
| 126 |
-
X, y_encoded, test_size=0.2, random_state=42
|
| 127 |
-
)
|
| 128 |
-
|
| 129 |
-
# Train Random Forest Classifier
|
| 130 |
-
clf = RandomForestClassifier(n_estimators=100, max_depth=12, n_jobs=-1, random_state=42)
|
| 131 |
-
clf.fit(X_train, y_train)
|
| 132 |
preds = clf.predict(X_test)
|
| 133 |
|
| 134 |
-
#
|
| 135 |
sns.set_theme(style="whitegrid")
|
| 136 |
fig, axes = plt.subplots(1, 2, figsize=(20, 8))
|
| 137 |
|
| 138 |
-
# Plot 1: Confusion Matrix
|
| 139 |
cm = confusion_matrix(y_test, preds)
|
| 140 |
-
sns.heatmap(cm, annot=True, fmt='d', cmap='viridis',
|
| 141 |
-
xticklabels=class_names, yticklabels=class_names, ax=axes[0], cbar=False)
|
| 142 |
axes[0].set_title(f"Confusion Matrix (Acc: {accuracy_score(y_test, preds):.2%})")
|
| 143 |
-
axes[0].set_xlabel("Predicted Label")
|
| 144 |
-
axes[0].set_ylabel("True Label")
|
| 145 |
|
| 146 |
-
# Plot 2: Feature Importance
|
| 147 |
importances = clf.feature_importances_
|
| 148 |
-
|
| 149 |
-
axes[1].barh([features[i] for i in
|
| 150 |
axes[1].set_title("Top-10 Discriminative Features")
|
| 151 |
|
| 152 |
plt.tight_layout()
|
| 153 |
-
|
| 154 |
-
#
|
| 155 |
-
report_dict = classification_report(y_test, preds, target_names=class_names)
|
| 156 |
-
summary = f"### 🏆 Classifier Results: {ds_name}\n" \
|
| 157 |
-
f"**Target Column used:** `{target_col}`\n" \
|
| 158 |
-
f"**Accuracy:** {accuracy_score(y_test, preds):.2%}\n\n" \
|
| 159 |
-
f"**Report:**\n```\n{report_dict}\n```"
|
| 160 |
-
|
| 161 |
-
return fig, summary
|
| 162 |
|
| 163 |
def update_explorer(ds_name: str, split_name: str):
|
|
|
|
| 164 |
assets = load_all_assets(ds_name)
|
| 165 |
df = assets["df"]
|
| 166 |
-
|
| 167 |
-
# Identify splits
|
| 168 |
splits = df["split"].unique().tolist() if "split" in df.columns else ["train"]
|
| 169 |
|
| 170 |
-
# Ensure current split_name exists in this dataset
|
| 171 |
if split_name not in splits:
|
| 172 |
split_name = splits[0]
|
| 173 |
|
|
@@ -185,27 +148,27 @@ def update_explorer(ds_name: str, split_name: str):
|
|
| 185 |
f"### 📋 {ds_name} Explorer"
|
| 186 |
)
|
| 187 |
|
| 188 |
-
# --- INTERFACE ---
|
| 189 |
with gr.Blocks(theme=gr.themes.Soft(), title="QSBench Classifier") as demo:
|
| 190 |
gr.Markdown("# 🌌 QSBench: Circuit Family Classifier")
|
| 191 |
|
| 192 |
with gr.Tabs():
|
| 193 |
with gr.TabItem("🔎 Explorer"):
|
| 194 |
-
meta_txt = gr.Markdown("###
|
| 195 |
with gr.Row():
|
| 196 |
ds_sel = gr.Dropdown(list(REPO_CONFIG.keys()), value="Core (Clean)", label="Dataset")
|
| 197 |
sp_sel = gr.Dropdown(["train"], value="train", label="Split")
|
| 198 |
data_view = gr.Dataframe(interactive=False)
|
| 199 |
with gr.Row():
|
| 200 |
-
c_raw = gr.Code(label="
|
| 201 |
c_tr = gr.Code(label="Transpiled QASM", language="python")
|
| 202 |
|
| 203 |
with gr.TabItem("🧠 Classification"):
|
| 204 |
with gr.Row():
|
| 205 |
with gr.Column(scale=1):
|
| 206 |
ml_ds_sel = gr.Dropdown(list(REPO_CONFIG.keys()), value="Core (Clean)", label="Environment")
|
| 207 |
-
ml_feat_sel = gr.CheckboxGroup(label="
|
| 208 |
-
train_btn = gr.Button("
|
| 209 |
with gr.Column(scale=2):
|
| 210 |
p_out = gr.Plot()
|
| 211 |
t_out = gr.Markdown()
|
|
@@ -215,14 +178,13 @@ with gr.Blocks(theme=gr.themes.Soft(), title="QSBench Classifier") as demo:
|
|
| 215 |
|
| 216 |
gr.Markdown("--- \n ### 🔗 [Website](https://qsbench.github.io) | [Hugging Face](https://huggingface.co/QSBench) | [GitHub](https://github.com/QSBench)")
|
| 217 |
|
| 218 |
-
#
|
| 219 |
-
# Triggering the same function for both dropdowns
|
| 220 |
ds_sel.change(update_explorer, [ds_sel, sp_sel], [sp_sel, data_view, c_raw, c_tr, meta_txt])
|
| 221 |
sp_sel.change(update_explorer, [ds_sel, sp_sel], [sp_sel, data_view, c_raw, c_tr, meta_txt])
|
| 222 |
-
|
| 223 |
ml_ds_sel.change(sync_ml_metrics, [ml_ds_sel], [ml_feat_sel])
|
| 224 |
train_btn.click(train_classifier, [ml_ds_sel, ml_feat_sel], [p_out, t_out])
|
| 225 |
|
|
|
|
| 226 |
demo.load(update_explorer, [ds_sel, sp_sel], [sp_sel, data_view, c_raw, c_tr, meta_txt])
|
| 227 |
demo.load(sync_ml_metrics, [ml_ds_sel], [ml_feat_sel])
|
| 228 |
|
|
|
|
| 39 |
}
|
| 40 |
}
|
| 41 |
|
|
|
|
| 42 |
NON_FEATURE_COLS = {
|
| 43 |
"sample_id", "sample_seed", "circuit_hash", "split", "circuit_qasm",
|
| 44 |
"qasm_raw", "qasm_transpiled", "circuit_type_resolved", "circuit_type_requested",
|
|
|
|
| 49 |
_ASSET_CACHE = {}
|
| 50 |
|
| 51 |
def load_all_assets(key: str) -> Dict:
|
| 52 |
+
"""Fetch dataset and metadata from Hugging Face."""
|
| 53 |
if key not in _ASSET_CACHE:
|
| 54 |
logger.info(f"Fetching {key}...")
|
| 55 |
ds = load_dataset(REPO_CONFIG[key]["repo"])
|
|
|
|
| 58 |
_ASSET_CACHE[key] = {"df": pd.DataFrame(ds["train"]), "meta": meta, "report": report}
|
| 59 |
return _ASSET_CACHE[key]
|
| 60 |
|
|
|
|
|
|
|
| 61 |
def load_guide_content():
|
| 62 |
+
"""Load content for the methodology tab."""
|
| 63 |
try:
|
| 64 |
with open("GUIDE.md", "r", encoding="utf-8") as f:
|
| 65 |
return f.read()
|
| 66 |
except:
|
| 67 |
+
return "### ⚠️ GUIDE.md not found. Please upload it to the root directory."
|
| 68 |
|
| 69 |
def sync_ml_metrics(ds_name: str):
|
| 70 |
+
"""Identify numerical features for classification."""
|
| 71 |
assets = load_all_assets(ds_name)
|
| 72 |
df = assets["df"]
|
| 73 |
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
|
|
|
|
| 75 |
defaults = [f for f in ["gate_entropy", "meyer_wallach", "adjacency", "depth", "cx_count"] if f in valid_features]
|
| 76 |
return gr.update(choices=valid_features, value=defaults)
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
def train_classifier(ds_name: str, features: List[str]):
|
| 79 |
+
"""Train a classifier to detect circuit families."""
|
| 80 |
if not features:
|
| 81 |
+
return None, "### ❌ Error: Select features first."
|
| 82 |
|
| 83 |
assets = load_all_assets(ds_name)
|
| 84 |
df = assets["df"]
|
| 85 |
|
| 86 |
+
# Logic: use 'resolved' if 'requested' contains 'mixed' tags
|
| 87 |
target_col = 'circuit_type_resolved' if 'circuit_type_resolved' in df.columns else 'circuit_type_requested'
|
| 88 |
|
| 89 |
+
# Filter 'mixed' out if other classes exist
|
| 90 |
train_df = df.dropna(subset=features + [target_col])
|
| 91 |
+
if 'mixed' in train_df[target_col].unique() and len(train_df[target_col].unique()) > 1:
|
|
|
|
|
|
|
|
|
|
| 92 |
train_df = train_df[train_df[target_col] != 'mixed']
|
| 93 |
+
|
| 94 |
X = train_df[features]
|
| 95 |
y = train_df[target_col]
|
| 96 |
|
| 97 |
+
if len(y.unique()) < 2:
|
| 98 |
+
return None, f"### ❌ Error: At least 2 classes needed. Found only: {y.unique()}"
|
|
|
|
|
|
|
|
|
|
| 99 |
|
|
|
|
| 100 |
le = LabelEncoder()
|
| 101 |
y_encoded = le.fit_transform(y)
|
| 102 |
+
|
|
|
|
|
|
|
| 103 |
try:
|
| 104 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=0.2, random_state=42, stratify=y_encoded)
|
| 105 |
+
except:
|
| 106 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=0.2, random_state=42)
|
| 107 |
+
|
| 108 |
+
clf = RandomForestClassifier(n_estimators=100, max_depth=12, n_jobs=-1).fit(X_train, y_train)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
preds = clf.predict(X_test)
|
| 110 |
|
| 111 |
+
# Plotting
|
| 112 |
sns.set_theme(style="whitegrid")
|
| 113 |
fig, axes = plt.subplots(1, 2, figsize=(20, 8))
|
| 114 |
|
|
|
|
| 115 |
cm = confusion_matrix(y_test, preds)
|
| 116 |
+
sns.heatmap(cm, annot=True, fmt='d', cmap='viridis', xticklabels=le.classes_, yticklabels=le.classes_, ax=axes[0], cbar=False)
|
|
|
|
| 117 |
axes[0].set_title(f"Confusion Matrix (Acc: {accuracy_score(y_test, preds):.2%})")
|
|
|
|
|
|
|
| 118 |
|
|
|
|
| 119 |
importances = clf.feature_importances_
|
| 120 |
+
idx = np.argsort(importances)[-10:]
|
| 121 |
+
axes[1].barh([features[i] for i in idx], importances[idx], color='#2ecc71')
|
| 122 |
axes[1].set_title("Top-10 Discriminative Features")
|
| 123 |
|
| 124 |
plt.tight_layout()
|
| 125 |
+
report = classification_report(y_test, preds, target_names=le.classes_)
|
| 126 |
+
return fig, f"### 🏆 Results\n**Target Column:** `{target_col}`\n```\n{report}\n```"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
def update_explorer(ds_name: str, split_name: str):
|
| 129 |
+
"""Manage the Explorer tab data view."""
|
| 130 |
assets = load_all_assets(ds_name)
|
| 131 |
df = assets["df"]
|
|
|
|
|
|
|
| 132 |
splits = df["split"].unique().tolist() if "split" in df.columns else ["train"]
|
| 133 |
|
|
|
|
| 134 |
if split_name not in splits:
|
| 135 |
split_name = splits[0]
|
| 136 |
|
|
|
|
| 148 |
f"### 📋 {ds_name} Explorer"
|
| 149 |
)
|
| 150 |
|
| 151 |
+
# --- GRADIO INTERFACE ---
|
| 152 |
with gr.Blocks(theme=gr.themes.Soft(), title="QSBench Classifier") as demo:
|
| 153 |
gr.Markdown("# 🌌 QSBench: Circuit Family Classifier")
|
| 154 |
|
| 155 |
with gr.Tabs():
|
| 156 |
with gr.TabItem("🔎 Explorer"):
|
| 157 |
+
meta_txt = gr.Markdown("### Initializing...")
|
| 158 |
with gr.Row():
|
| 159 |
ds_sel = gr.Dropdown(list(REPO_CONFIG.keys()), value="Core (Clean)", label="Dataset")
|
| 160 |
sp_sel = gr.Dropdown(["train"], value="train", label="Split")
|
| 161 |
data_view = gr.Dataframe(interactive=False)
|
| 162 |
with gr.Row():
|
| 163 |
+
c_raw = gr.Code(label="Source QASM", language="python")
|
| 164 |
c_tr = gr.Code(label="Transpiled QASM", language="python")
|
| 165 |
|
| 166 |
with gr.TabItem("🧠 Classification"):
|
| 167 |
with gr.Row():
|
| 168 |
with gr.Column(scale=1):
|
| 169 |
ml_ds_sel = gr.Dropdown(list(REPO_CONFIG.keys()), value="Core (Clean)", label="Environment")
|
| 170 |
+
ml_feat_sel = gr.CheckboxGroup(label="Structural Metrics", choices=[])
|
| 171 |
+
train_btn = gr.Button("Train Classifier", variant="primary")
|
| 172 |
with gr.Column(scale=2):
|
| 173 |
p_out = gr.Plot()
|
| 174 |
t_out = gr.Markdown()
|
|
|
|
| 178 |
|
| 179 |
gr.Markdown("--- \n ### 🔗 [Website](https://qsbench.github.io) | [Hugging Face](https://huggingface.co/QSBench) | [GitHub](https://github.com/QSBench)")
|
| 180 |
|
| 181 |
+
# Event Mapping
|
|
|
|
| 182 |
ds_sel.change(update_explorer, [ds_sel, sp_sel], [sp_sel, data_view, c_raw, c_tr, meta_txt])
|
| 183 |
sp_sel.change(update_explorer, [ds_sel, sp_sel], [sp_sel, data_view, c_raw, c_tr, meta_txt])
|
|
|
|
| 184 |
ml_ds_sel.change(sync_ml_metrics, [ml_ds_sel], [ml_feat_sel])
|
| 185 |
train_btn.click(train_classifier, [ml_ds_sel, ml_feat_sel], [p_out, t_out])
|
| 186 |
|
| 187 |
+
# Startup Load
|
| 188 |
demo.load(update_explorer, [ds_sel, sp_sel], [sp_sel, data_view, c_raw, c_tr, meta_txt])
|
| 189 |
demo.load(sync_ml_metrics, [ml_ds_sel], [ml_feat_sel])
|
| 190 |
|