File size: 1,882 Bytes
724b321
 
 
 
64ec546
724b321
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import joblib
from fastapi import HTTPException

class Supervisor:
    def __init__(self, model_path: str = "src/models/l2_XGBOOST_Supervisor_V0.joblib"):
        try:
            saved = joblib.load(model_path)
            self.model = saved["model"]
            self.scaler = saved.get("scaler", None)  # certains modèles peuvent ne pas avoir de scaler
            self.features = [
                "Header_Length",
                "Time_To_Live",
                "Rate",
                "Tot sum",
                "Tot size",
                "Min",
                "Max",
                "AVG",
                "Std",
                "Variance",
                "IAT",
                "Number",
                "syn_ratio",
                "ack_ratio",
                "fin_ratio",
                "rst_ratio",
                "mean_pkt_size",
                "pkt_size_range",
                "pkt_size_ratio",
                "mean_iat",
                "pkt_rate",
                "throughput",
                "bytes_per_sec",
                "coef_var",
                "tcp_udp_ratio",
                "flag_entropy",
            ]
        except FileNotFoundError:
            raise HTTPException(status_code=500, detail=f"Modèle '{model_path}' introuvable")
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"Erreur lors du chargement du modèle: {e}")

    def predict(self, data):
        try:
            # Préparer les features
            X = data[self.features]

            # Appliquer le scaler si existant
            if self.scaler is not None:
                X = self.scaler.transform(X)

            # Prédictions
            preds = self.model.predict(X)
            return preds.tolist()
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"Erreur lors de la prédiction: {e}")