Mekam commited on
Commit
724b321
·
1 Parent(s): 2b11736

feat(prediction): add prediction model for l2 and improve report

Browse files
src/agents/l2_supervisor.py CHANGED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import joblib
2
+ from fastapi import HTTPException
3
+
4
+ class Supervisor:
5
+ def __init__(self, model_path: str = "src/models/L2_XGBOOST_Supervisor_V0.joblib"):
6
+ try:
7
+ saved = joblib.load(model_path)
8
+ self.model = saved["model"]
9
+ self.scaler = saved.get("scaler", None) # certains modèles peuvent ne pas avoir de scaler
10
+ self.features = [
11
+ "Header_Length",
12
+ "Time_To_Live",
13
+ "Rate",
14
+ "Tot sum",
15
+ "Tot size",
16
+ "Min",
17
+ "Max",
18
+ "AVG",
19
+ "Std",
20
+ "Variance",
21
+ "IAT",
22
+ "Number",
23
+ "syn_ratio",
24
+ "ack_ratio",
25
+ "fin_ratio",
26
+ "rst_ratio",
27
+ "mean_pkt_size",
28
+ "pkt_size_range",
29
+ "pkt_size_ratio",
30
+ "mean_iat",
31
+ "pkt_rate",
32
+ "throughput",
33
+ "bytes_per_sec",
34
+ "coef_var",
35
+ "tcp_udp_ratio",
36
+ "flag_entropy",
37
+ ]
38
+ except FileNotFoundError:
39
+ raise HTTPException(status_code=500, detail=f"Modèle '{model_path}' introuvable")
40
+ except Exception as e:
41
+ raise HTTPException(status_code=500, detail=f"Erreur lors du chargement du modèle: {e}")
42
+
43
+ def predict(self, data):
44
+ try:
45
+ # Préparer les features
46
+ X = data[self.features]
47
+
48
+ # Appliquer le scaler si existant
49
+ if self.scaler is not None:
50
+ X = self.scaler.transform(X)
51
+
52
+ # Prédictions
53
+ preds = self.model.predict(X)
54
+ return preds.tolist()
55
+ except Exception as e:
56
+ raise HTTPException(status_code=500, detail=f"Erreur lors de la prédiction: {e}")
src/controllers/prediction_controller.py CHANGED
@@ -1,7 +1,10 @@
1
  from fastapi import HTTPException, UploadFile
2
  from src.services.preprocess_features import preprocess_features
3
  from src.services.file_verification import file_verification
 
 
4
  from src.agents.l1_screener import Screener
 
5
 
6
  def global_prediction_on_csv(file: UploadFile):
7
  try:
@@ -15,18 +18,16 @@ def global_prediction_on_csv(file: UploadFile):
15
  print("start loading model and predicting")
16
 
17
  screener = Screener()
18
- predictions = screener.predict(data)
 
 
 
 
19
 
20
- total_samples = len(predictions)
21
- total_attacks = sum(predictions) # 1 = attaque
22
- total_normal = total_samples - total_attacks # 0 = requête normale
23
  return {
24
- "message": "Prédictions effectuées avec succès",
25
- "total_samples": len(predictions),
26
- "total_attacks": total_attacks,
27
- "total_normal_requests": total_normal,
28
- "predictions": predictions
29
- }
30
 
31
  except HTTPException:
32
  raise
 
1
  from fastapi import HTTPException, UploadFile
2
  from src.services.preprocess_features import preprocess_features
3
  from src.services.file_verification import file_verification
4
+ from src.services.report import summarize_predictions
5
+
6
  from src.agents.l1_screener import Screener
7
+ from src.agents.l2_supervisor import Supervisor
8
 
9
  def global_prediction_on_csv(file: UploadFile):
10
  try:
 
18
  print("start loading model and predicting")
19
 
20
  screener = Screener()
21
+ l1_summary = summarize_predictions(screener.predict, data)
22
+
23
+ supervisor = Supervisor()
24
+ l2_summary = summarize_predictions(supervisor.predict, data)
25
+
26
 
 
 
 
27
  return {
28
+ "l1": l1_summary,
29
+ "l2": l2_summary
30
+ }
 
 
 
31
 
32
  except HTTPException:
33
  raise
src/models/L1_Logistic_v0.joblib CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:79d5f16e998f61fdf1faff94e137865437cf932fc6a1fa5c6b01fa3cf4302c47
3
- size 2617
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a84679401d8c2c9fb18f98cac36aba8fe6638f5baee9c10584199adfe504c0b
3
+ size 2605
src/services/report.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from typing import List, Dict, Callable
3
+
4
+ def summarize_predictions(pred_func: Callable, data) -> Dict:
5
+ """
6
+ Exécute une fonction de prédiction sur des données et retourne
7
+ les statistiques générales des prédictions.
8
+
9
+ Args:
10
+ pred_func: fonction de prédiction, doit accepter `data` et retourner une liste ou array de prédictions binaires (0/1)
11
+ data: données sur lesquelles effectuer la prédiction
12
+
13
+ Returns:
14
+ dict contenant : total_samples, total_attacks, total_normal,
15
+ processing_time, avg_pred_time, attack_ratio, normal_ratio
16
+ """
17
+ start_time = time.time()
18
+ predictions = pred_func(data)
19
+ end_time = time.time()
20
+
21
+ total_samples = len(predictions)
22
+ total_attacks = sum(predictions) # 1 = attaque
23
+ total_normal = total_samples - total_attacks
24
+
25
+ processing_time = end_time - start_time
26
+ avg_pred_time = processing_time / total_samples if total_samples > 0 else 0
27
+ attack_ratio = total_attacks / total_samples if total_samples > 0 else 0
28
+ normal_ratio = total_normal / total_samples if total_samples > 0 else 0
29
+
30
+ return {
31
+ "total_samples": total_samples,
32
+ "total_attacks": total_attacks,
33
+ "total_normal": total_normal,
34
+ "processing_time_sec": processing_time,
35
+ "avg_pred_time_sec": avg_pred_time,
36
+ "attack_ratio": attack_ratio,
37
+ "normal_ratio": normal_ratio,
38
+ "predictions": predictions
39
+ }