Image_Classification / metrics_utils.py
CircleStar's picture
Update metrics_utils.py
958eb86 verified
import os
from typing import Dict, List
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.metrics import (
accuracy_score,
f1_score,
classification_report,
confusion_matrix,
)
from config import FIGURE_DIR
def compute_classification_metrics(y_true, y_pred, class_names: List[str]) -> Dict:
labels = list(range(len(class_names)))
acc = accuracy_score(y_true, y_pred)
f1_macro = f1_score(
y_true,
y_pred,
labels=labels,
average="macro",
zero_division=0,
)
f1_weighted = f1_score(
y_true,
y_pred,
labels=labels,
average="weighted",
zero_division=0,
)
report_dict = classification_report(
y_true,
y_pred,
labels=labels,
target_names=class_names,
zero_division=0,
output_dict=True,
)
report_df = pd.DataFrame(report_dict).transpose().reset_index()
report_df = report_df.rename(columns={"index": "classe"})
cm = confusion_matrix(y_true, y_pred, labels=labels)
cm_df = pd.DataFrame(cm, index=class_names, columns=class_names)
return {
"accuracy": round(float(acc), 4),
"f1_macro": round(float(f1_macro), 4),
"f1_weighted": round(float(f1_weighted), 4),
"classification_report": report_df,
"confusion_matrix": cm_df,
}
def save_confusion_matrix_figure(cm_df: pd.DataFrame, model_name: str) -> str:
fig_path = os.path.join(FIGURE_DIR, f"{model_name}_confusion_matrix.png")
fig_width = max(8, min(24, 0.45 * len(cm_df.columns)))
fig_height = max(6, min(24, 0.45 * len(cm_df.index)))
plt.figure(figsize=(fig_width, fig_height))
plt.imshow(cm_df.values, interpolation="nearest")
plt.title("Matrice de confusion")
plt.colorbar()
tick_marks = range(len(cm_df.columns))
plt.xticks(tick_marks, cm_df.columns, rotation=90, fontsize=7)
plt.yticks(tick_marks, cm_df.index, fontsize=7)
plt.xlabel("Classe prédite")
plt.ylabel("Classe réelle")
plt.tight_layout()
plt.savefig(fig_path, dpi=200)
plt.close()
return fig_path