Update app.py
Browse files
app.py
CHANGED
|
@@ -80,13 +80,29 @@ def train_classifier(ds_name: str, features: List[str]):
|
|
| 80 |
assets = load_all_assets(ds_name)
|
| 81 |
df = assets["df"]
|
| 82 |
|
| 83 |
-
#
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
X, y = train_df[features], train_df['circuit_type_requested']
|
| 86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
le = LabelEncoder()
|
| 88 |
y_encoded = le.fit_transform(y)
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
clf = RandomForestClassifier(n_estimators=100, max_depth=12, n_jobs=-1).fit(X_train, y_train)
|
| 92 |
preds = clf.predict(X_test)
|
|
@@ -95,7 +111,9 @@ def train_classifier(ds_name: str, features: List[str]):
|
|
| 95 |
fig, axes = plt.subplots(1, 2, figsize=(20, 8))
|
| 96 |
|
| 97 |
cm = confusion_matrix(y_test, preds)
|
| 98 |
-
sns.heatmap(cm, annot=True, fmt='d', cmap='magma',
|
|
|
|
|
|
|
| 99 |
axes[0].set_title(f"Confusion Matrix (Acc: {accuracy_score(y_test, preds):.2%})")
|
| 100 |
|
| 101 |
importances = clf.feature_importances_
|
|
@@ -105,7 +123,7 @@ def train_classifier(ds_name: str, features: List[str]):
|
|
| 105 |
|
| 106 |
plt.tight_layout()
|
| 107 |
report = classification_report(y_test, preds, target_names=le.classes_)
|
| 108 |
-
return fig, f"### π Results\n```\n{report}\n```"
|
| 109 |
|
| 110 |
def update_explorer(ds_name: str, split_name: str):
|
| 111 |
assets = load_all_assets(ds_name)
|
|
|
|
| 80 |
assets = load_all_assets(ds_name)
|
| 81 |
df = assets["df"]
|
| 82 |
|
| 83 |
+
# Automatically determine available classes in the dataset, excluding empty values
|
| 84 |
+
available_in_df = df['circuit_type_requested'].dropna().unique()
|
| 85 |
+
|
| 86 |
+
# Filter: keep only those that are in our list of interests (case-insensitive)
|
| 87 |
+
# Or simply take all available types if we want universality
|
| 88 |
+
train_df = df[df['circuit_type_requested'].isin(available_in_df)].dropna(subset=features)
|
| 89 |
+
|
| 90 |
+
if train_df.empty:
|
| 91 |
+
return None, f"### β Error: No data found for features {features}. Check if these columns are empty in the dataset."
|
| 92 |
+
|
| 93 |
X, y = train_df[features], train_df['circuit_type_requested']
|
| 94 |
|
| 95 |
+
# Check number of classes
|
| 96 |
+
if len(y.unique()) < 2:
|
| 97 |
+
return None, f"### β Error: Need at least 2 classes to train. Found only: {y.unique()}"
|
| 98 |
+
|
| 99 |
le = LabelEncoder()
|
| 100 |
y_encoded = le.fit_transform(y)
|
| 101 |
+
|
| 102 |
+
try:
|
| 103 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=0.2, random_state=42)
|
| 104 |
+
except ValueError as e:
|
| 105 |
+
return None, f"### β Split Error: {str(e)}"
|
| 106 |
|
| 107 |
clf = RandomForestClassifier(n_estimators=100, max_depth=12, n_jobs=-1).fit(X_train, y_train)
|
| 108 |
preds = clf.predict(X_test)
|
|
|
|
| 111 |
fig, axes = plt.subplots(1, 2, figsize=(20, 8))
|
| 112 |
|
| 113 |
cm = confusion_matrix(y_test, preds)
|
| 114 |
+
sns.heatmap(cm, annot=True, fmt='d', cmap='magma',
|
| 115 |
+
xticklabels=le.classes_, yticklabels=le.classes_,
|
| 116 |
+
ax=axes[0], cbar=False)
|
| 117 |
axes[0].set_title(f"Confusion Matrix (Acc: {accuracy_score(y_test, preds):.2%})")
|
| 118 |
|
| 119 |
importances = clf.feature_importances_
|
|
|
|
| 123 |
|
| 124 |
plt.tight_layout()
|
| 125 |
report = classification_report(y_test, preds, target_names=le.classes_)
|
| 126 |
+
return fig, f"### π Results for {ds_name}\n```\n{report}\n```"
|
| 127 |
|
| 128 |
def update_explorer(ds_name: str, split_name: str):
|
| 129 |
assets = load_all_assets(ds_name)
|