import os from typing import Dict, List import matplotlib.pyplot as plt import pandas as pd from sklearn.metrics import ( accuracy_score, f1_score, classification_report, confusion_matrix, ) from config import FIGURE_DIR def compute_classification_metrics(y_true, y_pred, class_names: List[str]) -> Dict: labels = list(range(len(class_names))) acc = accuracy_score(y_true, y_pred) f1_macro = f1_score( y_true, y_pred, labels=labels, average="macro", zero_division=0, ) f1_weighted = f1_score( y_true, y_pred, labels=labels, average="weighted", zero_division=0, ) report_dict = classification_report( y_true, y_pred, labels=labels, target_names=class_names, zero_division=0, output_dict=True, ) report_df = pd.DataFrame(report_dict).transpose().reset_index() report_df = report_df.rename(columns={"index": "classe"}) cm = confusion_matrix(y_true, y_pred, labels=labels) cm_df = pd.DataFrame(cm, index=class_names, columns=class_names) return { "accuracy": round(float(acc), 4), "f1_macro": round(float(f1_macro), 4), "f1_weighted": round(float(f1_weighted), 4), "classification_report": report_df, "confusion_matrix": cm_df, } def save_confusion_matrix_figure(cm_df: pd.DataFrame, model_name: str) -> str: fig_path = os.path.join(FIGURE_DIR, f"{model_name}_confusion_matrix.png") fig_width = max(8, min(24, 0.45 * len(cm_df.columns))) fig_height = max(6, min(24, 0.45 * len(cm_df.index))) plt.figure(figsize=(fig_width, fig_height)) plt.imshow(cm_df.values, interpolation="nearest") plt.title("Matrice de confusion") plt.colorbar() tick_marks = range(len(cm_df.columns)) plt.xticks(tick_marks, cm_df.columns, rotation=90, fontsize=7) plt.yticks(tick_marks, cm_df.index, fontsize=7) plt.xlabel("Classe prédite") plt.ylabel("Classe réelle") plt.tight_layout() plt.savefig(fig_path, dpi=200) plt.close() return fig_path