File size: 2,002 Bytes
f381be8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""
src.models.classical.classifiers
================================
Classification models for battery degradation state prediction.

4-class classification:
    0 – Healthy   (SOH ≥ 90%)
    1 – Aging     (80% ≤ SOH < 90%)
    2 – Near-EOL  (70% ≤ SOH < 80%)
    3 – EOL       (SOH < 70%)
"""

from __future__ import annotations

from typing import Any

import joblib
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report

from src.evaluation.metrics import classification_metrics
from src.utils.config import MODELS_DIR, RANDOM_STATE

DEGRADATION_LABELS = ["Healthy", "Aging", "Near-EOL", "EOL"]


def _save_model(model: Any, name: str) -> None:
    path = MODELS_DIR / "classical" / f"{name}.joblib"
    joblib.dump(model, path)


def train_rf_classifier(
    X: np.ndarray, y: np.ndarray,
    n_estimators: int = 500,
) -> RandomForestClassifier:
    model = RandomForestClassifier(
        n_estimators=n_estimators, random_state=RANDOM_STATE,
        class_weight="balanced", n_jobs=-1,
    )
    model.fit(X, y)
    _save_model(model, "rf_classifier")
    return model


def train_xgb_classifier(
    X: np.ndarray, y: np.ndarray,
    n_estimators: int = 500,
) -> Any:
    from xgboost import XGBClassifier
    model = XGBClassifier(
        n_estimators=n_estimators, max_depth=6,
        learning_rate=0.1, tree_method="hist",
        random_state=RANDOM_STATE, verbosity=0, n_jobs=-1,
        eval_metric="mlogloss",
    )
    model.fit(X, y)
    _save_model(model, "xgb_classifier")
    return model


def evaluate_classifier(
    model: Any,
    X_test: np.ndarray,
    y_test: np.ndarray,
    model_name: str = "",
) -> dict[str, Any]:
    y_pred = model.predict(X_test)
    metrics = classification_metrics(y_test, y_pred, labels=[0, 1, 2, 3])
    metrics["classification_report"] = classification_report(
        y_test, y_pred, target_names=DEGRADATION_LABELS, zero_division=0,
    )
    return metrics