Update app.py
Browse files
app.py
CHANGED
|
@@ -75,55 +75,90 @@ def sync_ml_metrics(ds_name: str):
|
|
| 75 |
defaults = [f for f in ["gate_entropy", "meyer_wallach", "adjacency", "depth", "cx_count"] if f in valid_features]
|
| 76 |
return gr.update(choices=valid_features, value=defaults)
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
def train_classifier(ds_name: str, features: List[str]):
|
| 79 |
-
if not features:
|
|
|
|
|
|
|
| 80 |
assets = load_all_assets(ds_name)
|
| 81 |
df = assets["df"]
|
| 82 |
|
| 83 |
-
#
|
| 84 |
-
|
| 85 |
|
| 86 |
-
#
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
|
|
|
|
|
|
| 94 |
|
| 95 |
-
#
|
| 96 |
-
|
| 97 |
-
|
|
|
|
|
|
|
| 98 |
|
|
|
|
| 99 |
le = LabelEncoder()
|
| 100 |
y_encoded = le.fit_transform(y)
|
| 101 |
-
|
| 102 |
-
try:
|
| 103 |
-
X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=0.2, random_state=42)
|
| 104 |
-
except ValueError as e:
|
| 105 |
-
return None, f"### ❌ Split Error: {str(e)}"
|
| 106 |
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
preds = clf.predict(X_test)
|
| 109 |
|
|
|
|
| 110 |
sns.set_theme(style="whitegrid")
|
| 111 |
fig, axes = plt.subplots(1, 2, figsize=(20, 8))
|
| 112 |
|
|
|
|
| 113 |
cm = confusion_matrix(y_test, preds)
|
| 114 |
-
sns.heatmap(cm, annot=True, fmt='d', cmap='
|
| 115 |
-
xticklabels=
|
| 116 |
-
ax=axes[0], cbar=False)
|
| 117 |
axes[0].set_title(f"Confusion Matrix (Acc: {accuracy_score(y_test, preds):.2%})")
|
|
|
|
|
|
|
| 118 |
|
|
|
|
| 119 |
importances = clf.feature_importances_
|
| 120 |
-
|
| 121 |
-
axes[1].barh([features[i] for i in
|
| 122 |
-
axes[1].set_title("
|
| 123 |
|
| 124 |
plt.tight_layout()
|
| 125 |
-
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
def update_explorer(ds_name: str, split_name: str):
|
| 129 |
assets = load_all_assets(ds_name)
|
|
|
|
| 75 |
defaults = [f for f in ["gate_entropy", "meyer_wallach", "adjacency", "depth", "cx_count"] if f in valid_features]
|
| 76 |
return gr.update(choices=valid_features, value=defaults)
|
| 77 |
|
| 78 |
+
Судя по ошибке Found only: ['mixed'], в вашем столбце circuit_type_requested вместо конкретных названий семейств (QFT, HEA и т.д.) записано значение 'mixed'. Это часто случается в демонстрационных подмножествах, где данные уже перемешаны и помечены общим тегом.
|
| 79 |
+
|
| 80 |
+
Для классификации нам нужны исходные метки. В датасетах QSBench они обычно находятся в столбце circuit_type_resolved.
|
| 81 |
+
|
| 82 |
+
Вот обновленный код функции train_classifier с исправленной логикой выбора столбца и более надежной обработкой ошибок.
|
| 83 |
+
Исправленный код (App Code)
|
| 84 |
+
Python
|
| 85 |
+
|
| 86 |
def train_classifier(ds_name: str, features: List[str]):
|
| 87 |
+
if not features:
|
| 88 |
+
return None, "### ❌ Error: No features selected. Please pick structural metrics."
|
| 89 |
+
|
| 90 |
assets = load_all_assets(ds_name)
|
| 91 |
df = assets["df"]
|
| 92 |
|
| 93 |
+
# Try 'resolved' column first as 'requested' might contain 'mixed' in demo shards
|
| 94 |
+
target_col = 'circuit_type_resolved' if 'circuit_type_resolved' in df.columns else 'circuit_type_requested'
|
| 95 |
|
| 96 |
+
# Clean data: remove NaNs and ensure we have valid target strings
|
| 97 |
+
train_df = df.dropna(subset=features + [target_col])
|
| 98 |
+
|
| 99 |
+
# Filter out rows where the target might be 'mixed' or generic if others are available
|
| 100 |
+
unique_types = train_df[target_col].unique()
|
| 101 |
+
if 'mixed' in unique_types and len(unique_types) > 1:
|
| 102 |
+
train_df = train_df[train_df[target_col] != 'mixed']
|
| 103 |
+
|
| 104 |
+
X = train_df[features]
|
| 105 |
+
y = train_df[target_col]
|
| 106 |
|
| 107 |
+
# Verification: Do we have at least 2 distinct classes to perform classification?
|
| 108 |
+
current_classes = y.unique()
|
| 109 |
+
if len(current_classes) < 2:
|
| 110 |
+
return None, f"### ❌ Classification Error\nFound only one class: `{current_classes}` in column `{target_col}`. " \
|
| 111 |
+
"Try a different dataset or check if the source file has labels."
|
| 112 |
|
| 113 |
+
# Encode labels to integers
|
| 114 |
le = LabelEncoder()
|
| 115 |
y_encoded = le.fit_transform(y)
|
| 116 |
+
class_names = le.classes_
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
+
# Split dataset
|
| 119 |
+
try:
|
| 120 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
| 121 |
+
X, y_encoded, test_size=0.2, random_state=42, stratify=y_encoded
|
| 122 |
+
)
|
| 123 |
+
except ValueError:
|
| 124 |
+
# Fallback if stratify fails due to very small class sizes
|
| 125 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
| 126 |
+
X, y_encoded, test_size=0.2, random_state=42
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# Train Random Forest Classifier
|
| 130 |
+
clf = RandomForestClassifier(n_estimators=100, max_depth=12, n_jobs=-1, random_state=42)
|
| 131 |
+
clf.fit(X_train, y_train)
|
| 132 |
preds = clf.predict(X_test)
|
| 133 |
|
| 134 |
+
# Visuals
|
| 135 |
sns.set_theme(style="whitegrid")
|
| 136 |
fig, axes = plt.subplots(1, 2, figsize=(20, 8))
|
| 137 |
|
| 138 |
+
# Plot 1: Confusion Matrix
|
| 139 |
cm = confusion_matrix(y_test, preds)
|
| 140 |
+
sns.heatmap(cm, annot=True, fmt='d', cmap='viridis',
|
| 141 |
+
xticklabels=class_names, yticklabels=class_names, ax=axes[0], cbar=False)
|
|
|
|
| 142 |
axes[0].set_title(f"Confusion Matrix (Acc: {accuracy_score(y_test, preds):.2%})")
|
| 143 |
+
axes[0].set_xlabel("Predicted Label")
|
| 144 |
+
axes[0].set_ylabel("True Label")
|
| 145 |
|
| 146 |
+
# Plot 2: Feature Importance
|
| 147 |
importances = clf.feature_importances_
|
| 148 |
+
indices = np.argsort(importances)[-10:]
|
| 149 |
+
axes[1].barh([features[i] for i in indices], importances[indices], color='#2ecc71')
|
| 150 |
+
axes[1].set_title("Top-10 Discriminative Features")
|
| 151 |
|
| 152 |
plt.tight_layout()
|
| 153 |
+
|
| 154 |
+
# Generate text report
|
| 155 |
+
report_dict = classification_report(y_test, preds, target_names=class_names)
|
| 156 |
+
summary = f"### 🏆 Classifier Results: {ds_name}\n" \
|
| 157 |
+
f"**Target Column used:** `{target_col}`\n" \
|
| 158 |
+
f"**Accuracy:** {accuracy_score(y_test, preds):.2%}\n\n" \
|
| 159 |
+
f"**Report:**\n```\n{report_dict}\n```"
|
| 160 |
+
|
| 161 |
+
return fig, summary
|
| 162 |
|
| 163 |
def update_explorer(ds_name: str, split_name: str):
|
| 164 |
assets = load_all_assets(ds_name)
|