Mekam commited on
Commit
d1119e8
·
1 Parent(s): 7bc2636

feat(prediction): add single prediction

Browse files
app.py CHANGED
@@ -1,12 +1,11 @@
1
  from fastapi import FastAPI
2
  from src.routes import user_routes,organization_routes, prediction_routes
3
- from src.routes import user_routes
4
 
5
  app = FastAPI(title="NetSentinel Backend")
6
 
7
  # Register routers
8
- app.include_router(user_routes.router)
9
- app.include_router(organization_routes.router)
10
  app.include_router(prediction_routes.router)
11
 
12
 
 
1
  from fastapi import FastAPI
2
  from src.routes import user_routes,organization_routes, prediction_routes
 
3
 
4
  app = FastAPI(title="NetSentinel Backend")
5
 
6
  # Register routers
7
+ # app.include_router(user_routes.router)
8
+ # app.include_router(organization_routes.router)
9
  app.include_router(prediction_routes.router)
10
 
11
 
src/controllers/prediction_controller.py CHANGED
@@ -7,6 +7,15 @@ from src.agents.l1_screener import Screener
7
  from src.agents.l2_supervisor import Supervisor
8
  from src.agents.l3_classifier import Classifier
9
 
 
 
 
 
 
 
 
 
 
10
  def global_prediction_on_csv(file: UploadFile):
11
  try:
12
  # Vérifier l'extension et les colonnes du fichier
@@ -24,17 +33,40 @@ def global_prediction_on_csv(file: UploadFile):
24
  supervisor = Supervisor()
25
  l2_summary = summarize_predictions(supervisor.predict, data)
26
 
27
- classifier = Classifier()
28
- l3_summary = summarize_predictions(classifier.predict, data)
29
 
30
 
31
  return {
32
  "l1": l1_summary,
33
  "l2": l2_summary,
34
- "l3": l3_summary
35
  }
36
 
37
  except HTTPException:
38
  raise
39
  except Exception as e:
40
  raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  from src.agents.l2_supervisor import Supervisor
8
  from src.agents.l3_classifier import Classifier
9
 
10
+ required_columns = [
11
+ "Header_Length", "Protocol Type", "Time_To_Live", "Rate",
12
+ "fin_flag_number", "syn_flag_number", "rst_flag_number",
13
+ "psh_flag_number", "ack_flag_number", "ece_flag_number",
14
+ "cwr_flag_number", "ack_count", "syn_count", "fin_count",
15
+ "rst_count", "TCP", "UDP", "Tot sum", "Min", "Max", "AVG",
16
+ "Std", "Tot size", "IAT", "Number", "Variance"
17
+ ]
18
+
19
  def global_prediction_on_csv(file: UploadFile):
20
  try:
21
  # Vérifier l'extension et les colonnes du fichier
 
33
  supervisor = Supervisor()
34
  l2_summary = summarize_predictions(supervisor.predict, data)
35
 
36
+ # classifier = Classifier()
37
+ # l3_summary = summarize_predictions(classifier.predict, data)
38
 
39
 
40
  return {
41
  "l1": l1_summary,
42
  "l2": l2_summary,
43
+ # "l3": l3_summary
44
  }
45
 
46
  except HTTPException:
47
  raise
48
  except Exception as e:
49
  raise HTTPException(status_code=500, detail=str(e))
50
+
51
+ def single_prediction(data: dict):
52
+ try:
53
+ # log dans la console
54
+ if not isinstance(data, dict):
55
+ raise HTTPException(status_code=400, detail="Invalid JSON payload")
56
+
57
+ missing_columns = [col for col in required_columns if col not in data]
58
+ if missing_columns:
59
+ raise HTTPException(
60
+ status_code=422,
61
+ detail=f"Missing required columns: {missing_columns}"
62
+ )
63
+
64
+ # Ici tu peux faire la prédiction avec data
65
+ print("Received data for single prediction:", data)
66
+
67
+ return {"message": "All required columns present", "to_do": "Not yet implemented"}
68
+ except HTTPException:
69
+ raise
70
+ except Exception as e:
71
+ raise HTTPException(status_code=500, detail=str(e))
72
+
src/routes/prediction_routes.py CHANGED
@@ -9,3 +9,8 @@ def predict_from_csv(file: UploadFile = File(...)):
9
  Route pour uploader un CSV, le prétraiter et obtenir les prédictions.
10
  """
11
  return prediction_controller.global_prediction_on_csv(file)
 
 
 
 
 
 
9
  Route pour uploader un CSV, le prétraiter et obtenir les prédictions.
10
  """
11
  return prediction_controller.global_prediction_on_csv(file)
12
+
13
+ @router.post("/single-prediction")
14
+ def list_users():
15
+ return prediction_controller.single_prediction()
16
+