File size: 2,137 Bytes
e74b30d
958eb86
 
e74b30d
 
 
 
 
 
 
 
 
 
958eb86
 
63e305e
22ca06d
 
63e305e
22ca06d
 
 
 
 
 
 
 
 
 
 
 
 
 
63e305e
 
 
 
22ca06d
63e305e
 
 
 
 
 
 
 
958eb86
63e305e
 
 
 
 
 
 
 
958eb86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
from typing import Dict, List

import matplotlib.pyplot as plt
import pandas as pd
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    classification_report,
    confusion_matrix,
)

from config import FIGURE_DIR


def compute_classification_metrics(y_true, y_pred, class_names: List[str]) -> Dict:
    labels = list(range(len(class_names)))

    acc = accuracy_score(y_true, y_pred)
    f1_macro = f1_score(
        y_true,
        y_pred,
        labels=labels,
        average="macro",
        zero_division=0,
    )
    f1_weighted = f1_score(
        y_true,
        y_pred,
        labels=labels,
        average="weighted",
        zero_division=0,
    )

    report_dict = classification_report(
        y_true,
        y_pred,
        labels=labels,
        target_names=class_names,
        zero_division=0,
        output_dict=True,
    )

    report_df = pd.DataFrame(report_dict).transpose().reset_index()
    report_df = report_df.rename(columns={"index": "classe"})

    cm = confusion_matrix(y_true, y_pred, labels=labels)
    cm_df = pd.DataFrame(cm, index=class_names, columns=class_names)

    return {
        "accuracy": round(float(acc), 4),
        "f1_macro": round(float(f1_macro), 4),
        "f1_weighted": round(float(f1_weighted), 4),
        "classification_report": report_df,
        "confusion_matrix": cm_df,
    }


def save_confusion_matrix_figure(cm_df: pd.DataFrame, model_name: str) -> str:
    fig_path = os.path.join(FIGURE_DIR, f"{model_name}_confusion_matrix.png")

    fig_width = max(8, min(24, 0.45 * len(cm_df.columns)))
    fig_height = max(6, min(24, 0.45 * len(cm_df.index)))

    plt.figure(figsize=(fig_width, fig_height))
    plt.imshow(cm_df.values, interpolation="nearest")
    plt.title("Matrice de confusion")
    plt.colorbar()

    tick_marks = range(len(cm_df.columns))
    plt.xticks(tick_marks, cm_df.columns, rotation=90, fontsize=7)
    plt.yticks(tick_marks, cm_df.index, fontsize=7)

    plt.xlabel("Classe prédite")
    plt.ylabel("Classe réelle")
    plt.tight_layout()
    plt.savefig(fig_path, dpi=200)
    plt.close()

    return fig_path