QSBench commited on
Commit
170aab6
Β·
verified Β·
1 Parent(s): 30add96

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -5
app.py CHANGED
@@ -80,13 +80,29 @@ def train_classifier(ds_name: str, features: List[str]):
80
  assets = load_all_assets(ds_name)
81
  df = assets["df"]
82
 
83
- # Filter for the 5 target families only
84
- train_df = df[df['circuit_type_requested'].isin(TARGET_FAMILIES)].dropna(subset=features)
 
 
 
 
 
 
 
 
85
  X, y = train_df[features], train_df['circuit_type_requested']
86
 
 
 
 
 
87
  le = LabelEncoder()
88
  y_encoded = le.fit_transform(y)
89
- X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=0.2, random_state=42)
 
 
 
 
90
 
91
  clf = RandomForestClassifier(n_estimators=100, max_depth=12, n_jobs=-1).fit(X_train, y_train)
92
  preds = clf.predict(X_test)
@@ -95,7 +111,9 @@ def train_classifier(ds_name: str, features: List[str]):
95
  fig, axes = plt.subplots(1, 2, figsize=(20, 8))
96
 
97
  cm = confusion_matrix(y_test, preds)
98
- sns.heatmap(cm, annot=True, fmt='d', cmap='magma', xticklabels=le.classes_, yticklabels=le.classes_, ax=axes[0], cbar=False)
 
 
99
  axes[0].set_title(f"Confusion Matrix (Acc: {accuracy_score(y_test, preds):.2%})")
100
 
101
  importances = clf.feature_importances_
@@ -105,7 +123,7 @@ def train_classifier(ds_name: str, features: List[str]):
105
 
106
  plt.tight_layout()
107
  report = classification_report(y_test, preds, target_names=le.classes_)
108
- return fig, f"### πŸ† Results\n```\n{report}\n```"
109
 
110
  def update_explorer(ds_name: str, split_name: str):
111
  assets = load_all_assets(ds_name)
 
80
  assets = load_all_assets(ds_name)
81
  df = assets["df"]
82
 
83
+ # Automatically determine available classes in the dataset, excluding empty values
84
+ available_in_df = df['circuit_type_requested'].dropna().unique()
85
+
86
+ # Filter: keep only those that are in our list of interests (case-insensitive)
87
+ # Or simply take all available types if we want universality
88
+ train_df = df[df['circuit_type_requested'].isin(available_in_df)].dropna(subset=features)
89
+
90
+ if train_df.empty:
91
+ return None, f"### ❌ Error: No data found for features {features}. Check if these columns are empty in the dataset."
92
+
93
  X, y = train_df[features], train_df['circuit_type_requested']
94
 
95
+ # Check number of classes
96
+ if len(y.unique()) < 2:
97
+ return None, f"### ❌ Error: Need at least 2 classes to train. Found only: {y.unique()}"
98
+
99
  le = LabelEncoder()
100
  y_encoded = le.fit_transform(y)
101
+
102
+ try:
103
+ X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=0.2, random_state=42)
104
+ except ValueError as e:
105
+ return None, f"### ❌ Split Error: {str(e)}"
106
 
107
  clf = RandomForestClassifier(n_estimators=100, max_depth=12, n_jobs=-1).fit(X_train, y_train)
108
  preds = clf.predict(X_test)
 
111
  fig, axes = plt.subplots(1, 2, figsize=(20, 8))
112
 
113
  cm = confusion_matrix(y_test, preds)
114
+ sns.heatmap(cm, annot=True, fmt='d', cmap='magma',
115
+ xticklabels=le.classes_, yticklabels=le.classes_,
116
+ ax=axes[0], cbar=False)
117
  axes[0].set_title(f"Confusion Matrix (Acc: {accuracy_score(y_test, preds):.2%})")
118
 
119
  importances = clf.feature_importances_
 
123
 
124
  plt.tight_layout()
125
  report = classification_report(y_test, preds, target_names=le.classes_)
126
+ return fig, f"### πŸ† Results for {ds_name}\n```\n{report}\n```"
127
 
128
  def update_explorer(ds_name: str, split_name: str):
129
  assets = load_all_assets(ds_name)