hieu3636 commited on
Commit
b55c69c
·
verified ·
1 Parent(s): 74592ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -82
app.py CHANGED
@@ -4,101 +4,49 @@ import numpy as np
4
  import joblib
5
  import tensorflow as tf
6
 
7
- # =========================
8
  # LOAD MODEL & SCALER
9
- # =========================
10
  model = tf.keras.models.load_model("mlp_malware.keras")
11
  scaler = joblib.load("scaler.pkl")
12
 
13
- # =========================
14
- # 30 SELECTED FEATURES
15
 
16
- SELECTED_FEATURES = [
17
- "filesize",
18
- "E_file",
19
- "E_text",
20
- "E_data",
21
- "AddressOfEntryPoint",
22
- "NumberOfSections",
23
- "SizeOfInitializedData",
24
- "SizeOfImage",
25
- "SizeOfOptionalHeader",
26
- "SizeOfCode",
27
- "DirectoryEntryImportSize",
28
- "ImageBase",
29
- "CheckSum",
30
- "Magic",
31
- "MinorLinkerVersion",
32
- "MajorSubsystemVersion",
33
- "e_lfanew",
34
- "sus_sections",
35
- "PointerToSymbolTable",
36
- "SectionsLength",
37
- "SizeOfStackReserve",
38
- "MajorOperatingSystemVersion",
39
- "non_sus_sections",
40
- "Characteristics",
41
- "NumberOfSymbols",
42
- "BaseOfData",
43
- "MajorImageVersion",
44
- "FH_char5",
45
- "FH_char8",
46
- "OH_DLLchar5"
47
- ]
48
 
49
- N_FEATURES = len(SELECTED_FEATURES)
50
 
51
- # =========================
52
  # PREDICTION FUNCTION
53
- # =========================
54
- def predict_csv(file):
55
- df = pd.read_csv(file)
56
-
57
- # Drop label columns if exist
58
- df = df.drop(columns=["Label", "label", "class", "Class"], errors="ignore")
59
-
60
- # Check missing features
61
- missing_features = [f for f in SELECTED_FEATURES if f not in df.columns]
62
- if missing_features:
63
- return (
64
- f"Missing required features: {missing_features}"
65
- )
66
-
67
- # Keep only selected features & correct order
68
- feature_df = df[SELECTED_FEATURES].copy()
69
 
70
- # Convert to float
71
- X = feature_df.values.astype(float)
72
-
73
- # Scale
74
  X_scaled = scaler.transform(X)
75
 
76
- # Predict
77
- probs = model.predict(X_scaled).reshape(-1)
78
- preds = (probs > 0.5).astype(int)
 
 
 
79
 
80
- # Build output dataframe
81
- result = df.copy()
82
- result.insert(0, "row_id", range(1, len(df) + 1))
83
- result["probability_malware"] = probs
84
- result["prediction"] = preds
85
- result["prediction_label"] = result["prediction"].map(
86
- {1: "malware", 0: "benign"}
87
- )
88
 
89
- return result
 
 
 
90
 
91
- # =========================
92
- # GRADIO INTERFACE
93
- # =========================
94
- demo = gr.Interface(
95
- fn=predict_csv,
96
- inputs=gr.File(label="Upload CSV file"),
97
- outputs=gr.Dataframe(label="Prediction Result"),
98
- title="Malware Detection",
99
- description=(
100
- "Upload a CSV file containing PE features. "
101
- )
102
  )
103
 
104
- demo.launch()
 
 
4
  import joblib
5
  import tensorflow as tf
6
 
7
+
8
  # LOAD MODEL & SCALER
 
9
  model = tf.keras.models.load_model("mlp_malware.keras")
10
  scaler = joblib.load("scaler.pkl")
11
 
12
+ N_FEATURES = model.input_shape[1]
 
13
 
14
+ feature_names = [f"feature_{i+1}" for i in range(N_FEATURES)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
 
16
 
 
17
  # PREDICTION FUNCTION
18
+ def predict_malware(*inputs):
19
+ # inputs → DataFrame
20
+ X = pd.DataFrame([inputs], columns=feature_names)
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ # scale
 
 
 
23
  X_scaled = scaler.transform(X)
24
 
25
+ # predict
26
+ prob = model.predict(X_scaled, verbose=0)[0][0]
27
+ pred = int(prob >= 0.5)
28
+
29
+ label = "Malware" if pred == 1 else "Benign"
30
+ return label, float(prob)
31
 
32
+ # UI
33
+ inputs = [
34
+ gr.Number(label=feat, value=0.0)
35
+ for feat in feature_names
36
+ ]
 
 
 
37
 
38
+ outputs = [
39
+ gr.Textbox(label="Prediction"),
40
+ gr.Number(label="Malware Probability")
41
+ ]
42
 
43
+ app = gr.Interface(
44
+ fn=predict_malware,
45
+ inputs=inputs,
46
+ outputs=outputs,
47
+ title="MLP-based Malware Detection",
48
+ description="Malware detection using MLP neural network + StandardScaler"
 
 
 
 
 
49
  )
50
 
51
+ if __name__ == "__main__":
52
+ app.launch()