File size: 2,417 Bytes
a65325a
70201a5
 
724b321
 
70201a5
724b321
5e89f69
a65325a
d1119e8
 
 
 
 
 
 
 
 
a65325a
 
70201a5
 
a65325a
70201a5
 
a65325a
 
 
 
70201a5
724b321
 
 
 
 
d1119e8
 
5e89f69
a65325a
 
724b321
5e89f69
d1119e8
724b321
a65325a
 
 
 
 
d1119e8
0fa688c
d1119e8
 
 
 
 
 
 
 
 
 
 
0fa688c
d1119e8
 
0fa688c
d1119e8
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from fastapi import HTTPException, UploadFile
from src.services.preprocess_features import preprocess_features
from src.services.file_verification import file_verification
from src.services.report import summarize_predictions

from src.agents.l1_screener import Screener
from src.agents.l2_supervisor import Supervisor
from src.agents.l3_classifier import Classifier

required_columns = [
    "Header_Length", "Protocol Type", "Time_To_Live", "Rate",
    "fin_flag_number", "syn_flag_number", "rst_flag_number",
    "psh_flag_number", "ack_flag_number", "ece_flag_number",
    "cwr_flag_number", "ack_count", "syn_count", "fin_count",
    "rst_count", "TCP", "UDP", "Tot sum", "Min", "Max", "AVG",
    "Std", "Tot size", "IAT", "Number", "Variance"
]

def global_prediction_on_csv(file: UploadFile):
    try:
        # Vérifier l'extension et les colonnes du fichier
        correct_csv = file_verification(file)

        # augmenter le nombre de features
        data = preprocess_features(correct_csv)


        print("start loading model and predicting")

        screener = Screener()
        l1_summary = summarize_predictions(screener.predict, data)
        
        supervisor = Supervisor()
        l2_summary = summarize_predictions(supervisor.predict, data)
        
        # classifier = Classifier()
        # l3_summary = summarize_predictions(classifier.predict, data)
        

        return {
                "l1": l1_summary,
                "l2": l2_summary,
                # "l3": l3_summary
                }

    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

def single_prediction_controller(data: dict):
    try:
        if not isinstance(data, dict):
            raise HTTPException(status_code=400, detail="Invalid JSON payload")

        missing_columns = [col for col in required_columns if col not in data]
        if missing_columns:
            raise HTTPException(
                status_code=422,
                detail=f"Missing required columns: {missing_columns}"
            )

        # Ici tu peux faire la prédiction
        print("Received data for single prediction:", data)
        return {"message": "All required columns present", "to_do": "Not yet implemented"}

    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))