Spaces:
Sleeping
Sleeping
Commit ·
441e594
1
Parent(s): a77c071
CM enhance
Browse files
scripts/train_classification_model.py
CHANGED
|
@@ -161,14 +161,40 @@ def main(args):
|
|
| 161 |
|
| 162 |
# Display and save the confusion matrix
|
| 163 |
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
plt.show()
|
| 171 |
-
|
|
|
|
| 172 |
|
| 173 |
if __name__ == "__main__":
|
| 174 |
parser = argparse.ArgumentParser(description="Train a classification model.")
|
|
|
|
| 161 |
|
| 162 |
# Display and save the confusion matrix
|
| 163 |
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
|
| 164 |
+
|
| 165 |
+
# Load the label encoder (if it exists)
|
| 166 |
+
label_encoder_path = os.path.join(args.model_path, "label_encoder.pkl")
|
| 167 |
+
if os.path.exists(label_encoder_path):
|
| 168 |
+
label_encoder = joblib.load(label_encoder_path)
|
| 169 |
+
# Decode the predicted and true labels
|
| 170 |
+
y_test_decoded = label_encoder.inverse_transform(y_test)
|
| 171 |
+
y_pred_decoded = label_encoder.inverse_transform(y_pred)
|
| 172 |
+
display_labels = label_encoder.classes_
|
| 173 |
+
else:
|
| 174 |
+
# If no encoder, use the original numeric labels
|
| 175 |
+
y_test_decoded = y_test
|
| 176 |
+
y_pred_decoded = y_pred
|
| 177 |
+
display_labels = None # Numeric labels will be used by default
|
| 178 |
+
|
| 179 |
+
# Save confusion matrix
|
| 180 |
+
conf_mat = confusion_matrix(y_test_decoded, y_pred_decoded)
|
| 181 |
+
plt.figure(figsize=(10, 8)) # Increased figure size for better spacing
|
| 182 |
+
disp = ConfusionMatrixDisplay(conf_mat, display_labels=display_labels)
|
| 183 |
+
|
| 184 |
+
# Customize the plot
|
| 185 |
+
disp.plot(cmap="Blues", values_format="d", ax=plt.gca())
|
| 186 |
+
plt.title("Confusion Matrix", fontsize=16, pad=20) # Increased font size and added padding
|
| 187 |
+
plt.xticks(rotation=45, ha="right", fontsize=12) # Rotated x-axis labels and increased font size
|
| 188 |
+
plt.yticks(fontsize=12) # Increased font size for y-axis labels
|
| 189 |
+
plt.xlabel("Predicted Label", fontsize=14) # Added font size for x-axis label
|
| 190 |
+
plt.ylabel("True Label", fontsize=14) # Added font size for y-axis label
|
| 191 |
+
|
| 192 |
+
# Save the improved plot
|
| 193 |
+
cm_path = os.path.join(args.results_path, "confusion_matrix.png")
|
| 194 |
+
plt.savefig(cm_path, bbox_inches="tight") # Ensures no clipping of labels
|
| 195 |
plt.show()
|
| 196 |
+
|
| 197 |
+
print(f"Confusion matrix saved to {cm_path}")
|
| 198 |
|
| 199 |
if __name__ == "__main__":
|
| 200 |
parser = argparse.ArgumentParser(description="Train a classification model.")
|