Spaces:
Sleeping
Sleeping
| import os | |
| from typing import Dict, List | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| from sklearn.metrics import ( | |
| accuracy_score, | |
| f1_score, | |
| classification_report, | |
| confusion_matrix, | |
| ) | |
| from config import FIGURE_DIR | |
| def compute_classification_metrics(y_true, y_pred, class_names: List[str]) -> Dict: | |
| labels = list(range(len(class_names))) | |
| acc = accuracy_score(y_true, y_pred) | |
| f1_macro = f1_score( | |
| y_true, | |
| y_pred, | |
| labels=labels, | |
| average="macro", | |
| zero_division=0, | |
| ) | |
| f1_weighted = f1_score( | |
| y_true, | |
| y_pred, | |
| labels=labels, | |
| average="weighted", | |
| zero_division=0, | |
| ) | |
| report_dict = classification_report( | |
| y_true, | |
| y_pred, | |
| labels=labels, | |
| target_names=class_names, | |
| zero_division=0, | |
| output_dict=True, | |
| ) | |
| report_df = pd.DataFrame(report_dict).transpose().reset_index() | |
| report_df = report_df.rename(columns={"index": "classe"}) | |
| cm = confusion_matrix(y_true, y_pred, labels=labels) | |
| cm_df = pd.DataFrame(cm, index=class_names, columns=class_names) | |
| return { | |
| "accuracy": round(float(acc), 4), | |
| "f1_macro": round(float(f1_macro), 4), | |
| "f1_weighted": round(float(f1_weighted), 4), | |
| "classification_report": report_df, | |
| "confusion_matrix": cm_df, | |
| } | |
| def save_confusion_matrix_figure(cm_df: pd.DataFrame, model_name: str) -> str: | |
| fig_path = os.path.join(FIGURE_DIR, f"{model_name}_confusion_matrix.png") | |
| fig_width = max(8, min(24, 0.45 * len(cm_df.columns))) | |
| fig_height = max(6, min(24, 0.45 * len(cm_df.index))) | |
| plt.figure(figsize=(fig_width, fig_height)) | |
| plt.imshow(cm_df.values, interpolation="nearest") | |
| plt.title("Matrice de confusion") | |
| plt.colorbar() | |
| tick_marks = range(len(cm_df.columns)) | |
| plt.xticks(tick_marks, cm_df.columns, rotation=90, fontsize=7) | |
| plt.yticks(tick_marks, cm_df.index, fontsize=7) | |
| plt.xlabel("Classe prédite") | |
| plt.ylabel("Classe réelle") | |
| plt.tight_layout() | |
| plt.savefig(fig_path, dpi=200) | |
| plt.close() | |
| return fig_path |