Mekam commited on
Commit
70201a5
·
1 Parent(s): 66907e4

refactor(prediction): refactor the controller code

Browse files
src/agents/l1_screener.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import joblib
2
+ from fastapi import HTTPException
3
+
4
+ class Screener:
5
+ def __init__(self, model_path: str = "src/models/L1_Logistic_v0.joblib"):
6
+ try:
7
+ saved = joblib.load(model_path)
8
+ self.model = saved["model"]
9
+ self.scaler = saved["scaler"]
10
+ self.features = [
11
+ "Header_Length",
12
+ "Time_To_Live",
13
+ "Rate",
14
+ "Tot sum",
15
+ "Tot size",
16
+ "Min",
17
+ "Max",
18
+ "AVG",
19
+ "Std",
20
+ "Variance",
21
+ "IAT",
22
+ "Number"
23
+ ]
24
+ self.pca = saved.get("pca", None)
25
+ except FileNotFoundError:
26
+ raise HTTPException(status_code=500, detail=f"Modèle '{model_path}' introuvable")
27
+ except Exception as e:
28
+ raise HTTPException(status_code=500, detail=f"Erreur lors du chargement du modèle: {e}")
29
+
30
+ def predict(self, data):
31
+ try:
32
+ # Préparation des features
33
+ if "label" in data.columns:
34
+ X = data[self.features]
35
+ else:
36
+ X = data[self.features]
37
+
38
+ # Standardisation
39
+ X = self.scaler.transform(X)
40
+
41
+ # PCA si utilisé
42
+ if self.pca is not None:
43
+ X = self.pca.transform(X)
44
+
45
+ # Prédictions
46
+ preds = self.model.predict(X)
47
+ return preds.tolist()
48
+ except Exception as e:
49
+ raise HTTPException(status_code=500, detail=f"Erreur lors de la prédiction: {e}")
src/agents/l2_supervisor.py ADDED
File without changes
src/agents/l3_classifier.py ADDED
File without changes
src/controllers/prediction_controller.py CHANGED
@@ -1,133 +1,31 @@
1
- import pandas as pd
2
- import numpy as np
3
- import joblib
4
- import io
5
  from fastapi import HTTPException, UploadFile
6
-
7
- # Colonnes requises
8
- REQUIRED_COLUMNS = [
9
- "Header_Length", "Protocol Type", "Time_To_Live", "Rate",
10
- "fin_flag_number", "syn_flag_number", "rst_flag_number",
11
- "psh_flag_number", "ack_flag_number", "ece_flag_number",
12
- "cwr_flag_number", "ack_count", "syn_count", "fin_count",
13
- "rst_count", "TCP", "UDP", "Tot sum", "Min", "Max", "AVG",
14
- "Std", "Tot size", "IAT", "Number", "Variance"
15
- ]
16
 
17
  def global_prediction_on_csv(file: UploadFile):
18
  try:
19
- # Vérifier l'extension
20
- print("start verifying file extension")
21
- if not file.filename.endswith(".csv"):
22
- raise HTTPException(status_code=400, detail="Le fichier doit être un CSV")
23
-
24
- print("Correct file extension")
25
-
26
- # Lecture du CSV depuis le contenu binaire
27
- contents = file.file.read()
28
- data = pd.read_csv(io.BytesIO(contents))
29
-
30
- print("start verifying required columns")
31
- # Vérification des colonnes
32
- missing_cols = [col for col in REQUIRED_COLUMNS if col not in data.columns]
33
- if missing_cols:
34
- raise HTTPException(
35
- status_code=400,
36
- detail=f"Colonnes manquantes dans le CSV: {', '.join(missing_cols)}"
37
- )
38
-
39
- print("start data preprocessing")
40
- # Nettoyage
41
- data = data.replace([np.inf, -np.inf], np.nan).dropna()
42
-
43
- eps = 1e-6
44
- number = data["Number"] + eps
45
- iat = data["IAT"] + eps
46
- min_val = data["Min"] + 1
47
-
48
- # Ratios de flags
49
- data["syn_ratio"] = data["syn_count"] / number
50
- data["ack_ratio"] = data["ack_count"] / number
51
- data["fin_ratio"] = data["fin_count"] / number
52
- data["rst_ratio"] = data["rst_count"] / number
53
-
54
- # Tailles de paquets
55
- data["mean_pkt_size"] = data["Tot size"] / number
56
- data["pkt_size_range"] = data["Max"] - data["Min"]
57
- data["pkt_size_ratio"] = data["Max"] / min_val
58
 
59
- # Timing
60
- data["mean_iat"] = data["IAT"] / number
61
- data["pkt_rate"] = data["Number"] / iat
62
 
63
- # Débit
64
- data["throughput"] = data["Tot size"] / iat
65
- data["bytes_per_sec"] = data["Rate"] * data["Tot size"]
66
-
67
- # Variation
68
- data["coef_var"] = data["Std"] / (data["AVG"] + eps)
69
-
70
- # Indicateurs combinés
71
- data["tcp_udp_ratio"] = data["TCP"] / (data["UDP"] + 1)
72
- data["flag_entropy"] = (
73
- data["syn_ratio"] + data["ack_ratio"] + data["fin_ratio"] + data["rst_ratio"]
74
- )
75
- print("finish data preprocessing")
76
 
77
  print("start loading model and predicting")
78
 
79
- # Chargement du modèle
80
- try:
81
- saved = joblib.load("src/models/L1_Logistic_v0.joblib")
82
- model = saved["model"]
83
- scaler = saved["scaler"]
84
- pca = saved.get("pca", None)
85
- # features = saved["features"]
86
- except FileNotFoundError:
87
- raise HTTPException(status_code=500, detail="Modèle 'L1_Logistic_v0.joblib' introuvable")
88
- # try:
89
- # model = joblib.load("src/models/L1_Logistic_v0.joblib")
90
- # except FileNotFoundError:
91
- # raise HTTPException(status_code=500, detail="Modèle 'L1_Logistic_v0.joblib' introuvable")
92
-
93
- print("drop label column if exists")
94
-
95
- # Préparation des features
96
- features = [
97
- "Header_Length",
98
- "Time_To_Live",
99
- "Rate",
100
- "Tot sum",
101
- "Tot size",
102
- "Min",
103
- "Max",
104
- "AVG",
105
- "Std",
106
- "Variance",
107
- "IAT",
108
- "Number"
109
- ]
110
-
111
- if "label" in data.columns:
112
- X = data[features]
113
- else:
114
- X = data[features]
115
-
116
- print("features prepared, starting prediction")
117
-
118
- # Standardisation
119
- X = scaler.transform(X)
120
-
121
- # PCA si utilisé
122
- if pca is not None:
123
- X = pca.transform(X)
124
- # Prédictions
125
- preds = model.predict(X)
126
 
 
 
 
127
  return {
128
  "message": "Prédictions effectuées avec succès",
129
- "total_samples": len(preds),
130
- "predictions": preds.tolist()
 
 
131
  }
132
 
133
  except HTTPException:
 
 
 
 
 
1
  from fastapi import HTTPException, UploadFile
2
+ from src.services.preprocess_features import preprocess_features
3
+ from src.services.file_verification import file_verification
4
+ from src.agents.l1_screener import Screener
 
 
 
 
 
 
 
5
 
6
  def global_prediction_on_csv(file: UploadFile):
7
  try:
8
+ # Vérifier l'extension et les colonnes du fichier
9
+ correct_csv = file_verification(file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ # augmenter le nombre de features
12
+ data = preprocess_features(correct_csv)
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  print("start loading model and predicting")
16
 
17
+ screener = Screener()
18
+ predictions = screener.predict(data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ total_samples = len(predictions)
21
+ total_attacks = sum(predictions) # 1 = attaque
22
+ total_normal = total_samples - total_attacks # 0 = requête normale
23
  return {
24
  "message": "Prédictions effectuées avec succès",
25
+ "total_samples": len(predictions),
26
+ "total_attacks": total_attacks,
27
+ "total_normal_requests": total_normal,
28
+ "predictions": predictions.tolist()
29
  }
30
 
31
  except HTTPException:
src/services/file_verification.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import HTTPException, UploadFile
2
+ import pandas as pd
3
+ import io
4
+
5
+
6
+ # Colonnes requises
7
+ required_columns = [
8
+ "Header_Length", "Protocol Type", "Time_To_Live", "Rate",
9
+ "fin_flag_number", "syn_flag_number", "rst_flag_number",
10
+ "psh_flag_number", "ack_flag_number", "ece_flag_number",
11
+ "cwr_flag_number", "ack_count", "syn_count", "fin_count",
12
+ "rst_count", "TCP", "UDP", "Tot sum", "Min", "Max", "AVG",
13
+ "Std", "Tot size", "IAT", "Number", "Variance"
14
+ ]
15
+
16
+ def file_verification(file: UploadFile) -> pd.DataFrame:
17
+ """
18
+ Vérifie l'extension et les colonnes du fichier CSV, puis retourne un DataFrame valide.
19
+ Lève une HTTPException descriptive en cas d'erreur.
20
+ """
21
+ try:
22
+ print("start verifying file extension")
23
+ if not file.filename.endswith(".csv"):
24
+ raise HTTPException(status_code=400, detail="Le fichier doit être un CSV")
25
+
26
+ print("Correct file extension")
27
+
28
+ # Lecture du CSV depuis le contenu binaire
29
+ contents = file.file.read()
30
+ data = pd.read_csv(io.BytesIO(contents))
31
+
32
+ print("start verifying required columns")
33
+
34
+ # Vérification des colonnes
35
+ missing_cols = [col for col in required_columns if col not in data.columns]
36
+ if missing_cols:
37
+ raise HTTPException(
38
+ status_code=400,
39
+ detail=f"Colonnes manquantes dans le CSV: {', '.join(missing_cols)}"
40
+ )
41
+
42
+ print("file verification successful")
43
+ return data
44
+
45
+ except HTTPException:
46
+ raise
47
+ except Exception as e:
48
+ raise HTTPException(status_code=500, detail=f"Erreur de lecture du fichier: {e}")
src/services/preprocess_features.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+
4
+
5
+ def preprocess_features(data: pd.DataFrame) -> pd.DataFrame:
6
+ """
7
+ Prépare et enrichit les features du CSV avant prédiction.
8
+ """
9
+
10
+ print("start data preprocessing")
11
+ data = data.replace([np.inf, -np.inf], np.nan).dropna()
12
+
13
+ eps = 1e-6
14
+ number = data["Number"] + eps
15
+ iat = data["IAT"] + eps
16
+ min_val = data["Min"] + 1
17
+
18
+ # Ratios de flags
19
+ data["syn_ratio"] = data["syn_count"] / number
20
+ data["ack_ratio"] = data["ack_count"] / number
21
+ data["fin_ratio"] = data["fin_count"] / number
22
+ data["rst_ratio"] = data["rst_count"] / number
23
+
24
+ # Tailles de paquets
25
+ data["mean_pkt_size"] = data["Tot size"] / number
26
+ data["pkt_size_range"] = data["Max"] - data["Min"]
27
+ data["pkt_size_ratio"] = data["Max"] / min_val
28
+
29
+ # Timing
30
+ data["mean_iat"] = data["IAT"] / number
31
+ data["pkt_rate"] = data["Number"] / iat
32
+
33
+ # Débit
34
+ data["throughput"] = data["Tot size"] / iat
35
+ data["bytes_per_sec"] = data["Rate"] * data["Tot size"]
36
+
37
+ # Variation
38
+ data["coef_var"] = data["Std"] / (data["AVG"] + eps)
39
+
40
+ # Indicateurs combinés
41
+ data["tcp_udp_ratio"] = data["TCP"] / (data["UDP"] + 1)
42
+ data["flag_entropy"] = (
43
+ data["syn_ratio"] + data["ack_ratio"] + data["fin_ratio"] + data["rst_ratio"]
44
+ )
45
+
46
+ print("finish data preprocessing")
47
+
48
+ return data