Spaces:
Runtime error
Runtime error
feat(prediction): add prediction model for l2 and improve report
Browse files- src/agents/l2_supervisor.py +56 -0
- src/controllers/prediction_controller.py +11 -10
- src/models/L1_Logistic_v0.joblib +2 -2
- src/services/report.py +39 -0
src/agents/l2_supervisor.py
CHANGED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import joblib
|
| 2 |
+
from fastapi import HTTPException
|
| 3 |
+
|
| 4 |
+
class Supervisor:
|
| 5 |
+
def __init__(self, model_path: str = "src/models/L2_XGBOOST_Supervisor_V0.joblib"):
|
| 6 |
+
try:
|
| 7 |
+
saved = joblib.load(model_path)
|
| 8 |
+
self.model = saved["model"]
|
| 9 |
+
self.scaler = saved.get("scaler", None) # certains modèles peuvent ne pas avoir de scaler
|
| 10 |
+
self.features = [
|
| 11 |
+
"Header_Length",
|
| 12 |
+
"Time_To_Live",
|
| 13 |
+
"Rate",
|
| 14 |
+
"Tot sum",
|
| 15 |
+
"Tot size",
|
| 16 |
+
"Min",
|
| 17 |
+
"Max",
|
| 18 |
+
"AVG",
|
| 19 |
+
"Std",
|
| 20 |
+
"Variance",
|
| 21 |
+
"IAT",
|
| 22 |
+
"Number",
|
| 23 |
+
"syn_ratio",
|
| 24 |
+
"ack_ratio",
|
| 25 |
+
"fin_ratio",
|
| 26 |
+
"rst_ratio",
|
| 27 |
+
"mean_pkt_size",
|
| 28 |
+
"pkt_size_range",
|
| 29 |
+
"pkt_size_ratio",
|
| 30 |
+
"mean_iat",
|
| 31 |
+
"pkt_rate",
|
| 32 |
+
"throughput",
|
| 33 |
+
"bytes_per_sec",
|
| 34 |
+
"coef_var",
|
| 35 |
+
"tcp_udp_ratio",
|
| 36 |
+
"flag_entropy",
|
| 37 |
+
]
|
| 38 |
+
except FileNotFoundError:
|
| 39 |
+
raise HTTPException(status_code=500, detail=f"Modèle '{model_path}' introuvable")
|
| 40 |
+
except Exception as e:
|
| 41 |
+
raise HTTPException(status_code=500, detail=f"Erreur lors du chargement du modèle: {e}")
|
| 42 |
+
|
| 43 |
+
def predict(self, data):
|
| 44 |
+
try:
|
| 45 |
+
# Préparer les features
|
| 46 |
+
X = data[self.features]
|
| 47 |
+
|
| 48 |
+
# Appliquer le scaler si existant
|
| 49 |
+
if self.scaler is not None:
|
| 50 |
+
X = self.scaler.transform(X)
|
| 51 |
+
|
| 52 |
+
# Prédictions
|
| 53 |
+
preds = self.model.predict(X)
|
| 54 |
+
return preds.tolist()
|
| 55 |
+
except Exception as e:
|
| 56 |
+
raise HTTPException(status_code=500, detail=f"Erreur lors de la prédiction: {e}")
|
src/controllers/prediction_controller.py
CHANGED
|
@@ -1,7 +1,10 @@
|
|
| 1 |
from fastapi import HTTPException, UploadFile
|
| 2 |
from src.services.preprocess_features import preprocess_features
|
| 3 |
from src.services.file_verification import file_verification
|
|
|
|
|
|
|
| 4 |
from src.agents.l1_screener import Screener
|
|
|
|
| 5 |
|
| 6 |
def global_prediction_on_csv(file: UploadFile):
|
| 7 |
try:
|
|
@@ -15,18 +18,16 @@ def global_prediction_on_csv(file: UploadFile):
|
|
| 15 |
print("start loading model and predicting")
|
| 16 |
|
| 17 |
screener = Screener()
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
-
total_samples = len(predictions)
|
| 21 |
-
total_attacks = sum(predictions) # 1 = attaque
|
| 22 |
-
total_normal = total_samples - total_attacks # 0 = requête normale
|
| 23 |
return {
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
"total_normal_requests": total_normal,
|
| 28 |
-
"predictions": predictions
|
| 29 |
-
}
|
| 30 |
|
| 31 |
except HTTPException:
|
| 32 |
raise
|
|
|
|
| 1 |
from fastapi import HTTPException, UploadFile
|
| 2 |
from src.services.preprocess_features import preprocess_features
|
| 3 |
from src.services.file_verification import file_verification
|
| 4 |
+
from src.services.report import summarize_predictions
|
| 5 |
+
|
| 6 |
from src.agents.l1_screener import Screener
|
| 7 |
+
from src.agents.l2_supervisor import Supervisor
|
| 8 |
|
| 9 |
def global_prediction_on_csv(file: UploadFile):
|
| 10 |
try:
|
|
|
|
| 18 |
print("start loading model and predicting")
|
| 19 |
|
| 20 |
screener = Screener()
|
| 21 |
+
l1_summary = summarize_predictions(screener.predict, data)
|
| 22 |
+
|
| 23 |
+
supervisor = Supervisor()
|
| 24 |
+
l2_summary = summarize_predictions(supervisor.predict, data)
|
| 25 |
+
|
| 26 |
|
|
|
|
|
|
|
|
|
|
| 27 |
return {
|
| 28 |
+
"l1": l1_summary,
|
| 29 |
+
"l2": l2_summary
|
| 30 |
+
}
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
except HTTPException:
|
| 33 |
raise
|
src/models/L1_Logistic_v0.joblib
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2a84679401d8c2c9fb18f98cac36aba8fe6638f5baee9c10584199adfe504c0b
|
| 3 |
+
size 2605
|
src/services/report.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
from typing import List, Dict, Callable
|
| 3 |
+
|
| 4 |
+
def summarize_predictions(pred_func: Callable, data) -> Dict:
|
| 5 |
+
"""
|
| 6 |
+
Exécute une fonction de prédiction sur des données et retourne
|
| 7 |
+
les statistiques générales des prédictions.
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
pred_func: fonction de prédiction, doit accepter `data` et retourner une liste ou array de prédictions binaires (0/1)
|
| 11 |
+
data: données sur lesquelles effectuer la prédiction
|
| 12 |
+
|
| 13 |
+
Returns:
|
| 14 |
+
dict contenant : total_samples, total_attacks, total_normal,
|
| 15 |
+
processing_time, avg_pred_time, attack_ratio, normal_ratio
|
| 16 |
+
"""
|
| 17 |
+
start_time = time.time()
|
| 18 |
+
predictions = pred_func(data)
|
| 19 |
+
end_time = time.time()
|
| 20 |
+
|
| 21 |
+
total_samples = len(predictions)
|
| 22 |
+
total_attacks = sum(predictions) # 1 = attaque
|
| 23 |
+
total_normal = total_samples - total_attacks
|
| 24 |
+
|
| 25 |
+
processing_time = end_time - start_time
|
| 26 |
+
avg_pred_time = processing_time / total_samples if total_samples > 0 else 0
|
| 27 |
+
attack_ratio = total_attacks / total_samples if total_samples > 0 else 0
|
| 28 |
+
normal_ratio = total_normal / total_samples if total_samples > 0 else 0
|
| 29 |
+
|
| 30 |
+
return {
|
| 31 |
+
"total_samples": total_samples,
|
| 32 |
+
"total_attacks": total_attacks,
|
| 33 |
+
"total_normal": total_normal,
|
| 34 |
+
"processing_time_sec": processing_time,
|
| 35 |
+
"avg_pred_time_sec": avg_pred_time,
|
| 36 |
+
"attack_ratio": attack_ratio,
|
| 37 |
+
"normal_ratio": normal_ratio,
|
| 38 |
+
"predictions": predictions
|
| 39 |
+
}
|