Spaces:
Runtime error
Runtime error
feat(supervisor): all
Browse files
src/agents/l3_classifier.py
CHANGED
|
@@ -2,48 +2,55 @@ import joblib
|
|
| 2 |
from fastapi import HTTPException
|
| 3 |
|
| 4 |
class Classifier:
|
| 5 |
-
def __init__(self, model_path: str = "src/models/
|
| 6 |
try:
|
| 7 |
saved = joblib.load(model_path)
|
| 8 |
self.model = saved["model"]
|
| 9 |
-
self.scaler = saved.get("scaler", None)
|
| 10 |
-
self.features =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
except FileNotFoundError:
|
| 12 |
raise HTTPException(status_code=500, detail=f"Modèle '{model_path}' introuvable")
|
| 13 |
except Exception as e:
|
| 14 |
raise HTTPException(status_code=500, detail=f"Erreur lors du chargement du modèle: {e}")
|
| 15 |
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
# Sélectionner uniquement les features utilisées lors de l'entraînement
|
| 22 |
-
X = data[self.features]
|
| 23 |
-
|
| 24 |
-
# Standardisation si scaler présent
|
| 25 |
-
if self.scaler is not None:
|
| 26 |
-
X = self.scaler.transform(X)
|
| 27 |
-
|
| 28 |
-
# Prédictions
|
| 29 |
-
preds = self.model.predict(X)
|
| 30 |
-
return preds.tolist()
|
| 31 |
-
except KeyError as e:
|
| 32 |
-
raise HTTPException(status_code=500, detail=f"Colonne manquante dans les données: {e}")
|
| 33 |
-
except Exception as e:
|
| 34 |
-
raise HTTPException(status_code=500, detail=f"Erreur lors de la prédiction: {e}")
|
| 35 |
-
|
| 36 |
-
# def predict(self, data):
|
| 37 |
-
# try:
|
| 38 |
-
# # Préparer les features
|
| 39 |
-
# X = data[self.features]
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
|
|
|
| 2 |
from fastapi import HTTPException
|
| 3 |
|
| 4 |
class Classifier:
|
| 5 |
+
def __init__(self, model_path: str = "src/models/l2_XGBOOST_Supervisor_V5.joblib"):
|
| 6 |
try:
|
| 7 |
saved = joblib.load(model_path)
|
| 8 |
self.model = saved["model"]
|
| 9 |
+
self.scaler = saved.get("scaler", None) # certains modèles peuvent ne pas avoir de scaler
|
| 10 |
+
self.features = [
|
| 11 |
+
"Header_Length",
|
| 12 |
+
"Time_To_Live",
|
| 13 |
+
"Rate",
|
| 14 |
+
"Tot sum",
|
| 15 |
+
"Tot size",
|
| 16 |
+
"Min",
|
| 17 |
+
"Max",
|
| 18 |
+
"AVG",
|
| 19 |
+
"Std",
|
| 20 |
+
"Variance",
|
| 21 |
+
"IAT",
|
| 22 |
+
"Number",
|
| 23 |
+
"syn_ratio",
|
| 24 |
+
"ack_ratio",
|
| 25 |
+
"fin_ratio",
|
| 26 |
+
"rst_ratio",
|
| 27 |
+
"mean_pkt_size",
|
| 28 |
+
"pkt_size_range",
|
| 29 |
+
"pkt_size_ratio",
|
| 30 |
+
"mean_iat",
|
| 31 |
+
"pkt_rate",
|
| 32 |
+
"throughput",
|
| 33 |
+
"bytes_per_sec",
|
| 34 |
+
"coef_var",
|
| 35 |
+
"tcp_udp_ratio",
|
| 36 |
+
"flag_entropy",
|
| 37 |
+
]
|
| 38 |
except FileNotFoundError:
|
| 39 |
raise HTTPException(status_code=500, detail=f"Modèle '{model_path}' introuvable")
|
| 40 |
except Exception as e:
|
| 41 |
raise HTTPException(status_code=500, detail=f"Erreur lors du chargement du modèle: {e}")
|
| 42 |
|
| 43 |
+
def predict(self, data):
|
| 44 |
+
try:
|
| 45 |
+
# Préparer les features
|
| 46 |
+
X = data[self.features]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
+
# Appliquer le scaler si existant
|
| 49 |
+
if self.scaler is not None:
|
| 50 |
+
X = self.scaler.transform(X)
|
| 51 |
|
| 52 |
+
# Prédictions
|
| 53 |
+
preds = self.model.predict(X)
|
| 54 |
+
return preds.tolist()
|
| 55 |
+
except Exception as e:
|
| 56 |
+
raise HTTPException(status_code=500, detail=f"Erreur lors de la prédiction: {e}")
|
src/models/{new_l3.joblib → l2_XGBOOST_Supervisor_V5.joblib}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c8caea236ddbefbb6c23910847204d2020a5be9209ce52be0da4f00ad2e49239
|
| 3 |
+
size 8368
|
src/services/report.py
CHANGED
|
@@ -37,3 +37,4 @@ def summarize_predictions(pred_func: Callable, data) -> Dict:
|
|
| 37 |
"normal_ratio": normal_ratio,
|
| 38 |
"predictions": predictions
|
| 39 |
}
|
|
|
|
|
|
| 37 |
"normal_ratio": normal_ratio,
|
| 38 |
"predictions": predictions
|
| 39 |
}
|
| 40 |
+
|