mboukabous commited on
Commit
441e594
·
1 Parent(s): a77c071

CM enhance

Browse files
scripts/train_classification_model.py CHANGED
@@ -161,14 +161,40 @@ def main(args):
161
 
162
  # Display and save the confusion matrix
163
  from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
164
- conf_matrix = confusion_matrix(y_test, y_pred)
165
- disp = ConfusionMatrixDisplay(confusion_matrix=conf_matrix)
166
- disp.plot(cmap=plt.cm.Blues, values_format='d')
167
- plt.title(f'{model_name} Confusion Matrix')
168
- conf_matrix_path = os.path.join(args.results_path, 'confusion_matrix.png')
169
- plt.savefig(conf_matrix_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  plt.show()
171
- print(f"Confusion matrix saved to {conf_matrix_path}")
 
172
 
173
  if __name__ == "__main__":
174
  parser = argparse.ArgumentParser(description="Train a classification model.")
 
161
 
162
  # Display and save the confusion matrix
163
  from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
164
+
165
+ # Load the label encoder (if it exists)
166
+ label_encoder_path = os.path.join(args.model_path, "label_encoder.pkl")
167
+ if os.path.exists(label_encoder_path):
168
+ label_encoder = joblib.load(label_encoder_path)
169
+ # Decode the predicted and true labels
170
+ y_test_decoded = label_encoder.inverse_transform(y_test)
171
+ y_pred_decoded = label_encoder.inverse_transform(y_pred)
172
+ display_labels = label_encoder.classes_
173
+ else:
174
+ # If no encoder, use the original numeric labels
175
+ y_test_decoded = y_test
176
+ y_pred_decoded = y_pred
177
+ display_labels = None # Numeric labels will be used by default
178
+
179
+ # Save confusion matrix
180
+ conf_mat = confusion_matrix(y_test_decoded, y_pred_decoded)
181
+ plt.figure(figsize=(10, 8)) # Increased figure size for better spacing
182
+ disp = ConfusionMatrixDisplay(conf_mat, display_labels=display_labels)
183
+
184
+ # Customize the plot
185
+ disp.plot(cmap="Blues", values_format="d", ax=plt.gca())
186
+ plt.title("Confusion Matrix", fontsize=16, pad=20) # Increased font size and added padding
187
+ plt.xticks(rotation=45, ha="right", fontsize=12) # Rotated x-axis labels and increased font size
188
+ plt.yticks(fontsize=12) # Increased font size for y-axis labels
189
+ plt.xlabel("Predicted Label", fontsize=14) # Added font size for x-axis label
190
+ plt.ylabel("True Label", fontsize=14) # Added font size for y-axis label
191
+
192
+ # Save the improved plot
193
+ cm_path = os.path.join(args.results_path, "confusion_matrix.png")
194
+ plt.savefig(cm_path, bbox_inches="tight") # Ensures no clipping of labels
195
  plt.show()
196
+
197
+ print(f"Confusion matrix saved to {cm_path}")
198
 
199
  if __name__ == "__main__":
200
  parser = argparse.ArgumentParser(description="Train a classification model.")