File size: 4,424 Bytes
cdc317a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import json
import os
from datetime import datetime
from typing import List

import joblib
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC

from config import MODEL_DIR, META_DIR
from metrics_utils import compute_classification_metrics, save_confusion_matrix_figure

CLF_TYPE_MAP = {
    "SVM": "svm",
    "Régression logistique": "logreg",
    "k-NN": "knn",
    "Forêt aléatoire": "rf",
    "LDA": "lda",
}


def classifier_path(model_name: str) -> str:
    return os.path.join(MODEL_DIR, f"{model_name}.joblib")


def meta_path(model_name: str) -> str:
    return os.path.join(META_DIR, f"{model_name}.json")


def build_pipeline(clf_type: str, **params) -> Pipeline:
    key = CLF_TYPE_MAP.get(clf_type, clf_type)

    if key == "svm":
        clf = SVC(
            C=params.get("C", 1.0),
            kernel=params.get("kernel", "rbf"),
            gamma=params.get("gamma", "scale"),
            probability=True,
            random_state=42,
        )
    elif key == "logreg":
        clf = LogisticRegression(
            C=params.get("C", 1.0),
            max_iter=params.get("max_iter", 1000),
            random_state=42,
        )
    elif key == "knn":
        clf = KNeighborsClassifier(
            n_neighbors=params.get("n_neighbors", 5),
            metric=params.get("metric", "euclidean"),
        )
    elif key == "rf":
        max_depth = params.get("max_depth") or None
        clf = RandomForestClassifier(
            n_estimators=params.get("n_estimators", 100),
            max_depth=max_depth,
            random_state=42,
            n_jobs=-1,
        )
    elif key == "lda":
        clf = LinearDiscriminantAnalysis(solver=params.get("solver", "svd"))
    else:
        raise ValueError(f"Classifieur inconnu : {clf_type}")

    return Pipeline([("scaler", StandardScaler()), ("clf", clf)])


def train_classical_model(
    clf_type: str,
    features_cache: dict,
    class_names: List[str],
    model_tag: str = "",
    **params,
):
    X_train = features_cache["train"]["X"]
    y_train = features_cache["train"]["y"]
    X_test = features_cache["test"]["X"]
    y_test = features_cache["test"]["y"]

    pipeline = build_pipeline(clf_type, **params)
    pipeline.fit(X_train, y_train)

    y_pred = pipeline.predict(X_test)
    metrics = compute_classification_metrics(y_test.tolist(), y_pred.tolist(), class_names)

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    safe_tag = model_tag.strip().replace(" ", "_") if model_tag.strip() else CLF_TYPE_MAP.get(clf_type, "clf")
    model_name = f"{safe_tag}_{timestamp}"

    joblib.dump(pipeline, classifier_path(model_name))
    cm_path = save_confusion_matrix_figure(metrics["confusion_matrix"], model_name)

    config_dict = {
        "model_type": CLF_TYPE_MAP.get(clf_type, clf_type),
        "clf_type_label": clf_type,
        "class_names": class_names,
        "num_classes": len(class_names),
        **{k: v for k, v in params.items() if v is not None},
    }

    training_summary = {
        "test_accuracy": metrics["accuracy"],
        "test_f1_macro": metrics["f1_macro"],
        "test_f1_weighted": metrics["f1_weighted"],
        "train_samples": int(len(X_train)),
        "test_samples": int(len(X_test)),
    }

    with open(meta_path(model_name), "w", encoding="utf-8") as f:
        json.dump(
            {
                "model_name": model_name,
                "config": config_dict,
                "training_summary": training_summary,
                "created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
            },
            f,
            indent=2,
            ensure_ascii=False,
        )

    return {
        "model_name": model_name,
        "summary": training_summary,
        "classification_report": metrics["classification_report"],
        "confusion_matrix": metrics["confusion_matrix"],
        "confusion_matrix_path": cm_path,
    }


def load_classical_pipeline(model_name: str) -> Pipeline:
    path = classifier_path(model_name)
    if not os.path.exists(path):
        raise FileNotFoundError(f"Classifieur introuvable : {model_name}")
    return joblib.load(path)