hieu3636 commited on
Commit
c39c97b
·
verified ·
1 Parent(s): 73d4cd3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -18
app.py CHANGED
@@ -1,28 +1,43 @@
1
  import gradio as gr
 
2
  import numpy as np
3
  import joblib
4
  import tensorflow as tf
5
 
6
- # Load local files
7
  model = tf.keras.models.load_model("mlp_model.keras")
8
  scaler = joblib.load("scaler.pkl")
9
 
10
  N_FEATURES = model.input_shape[1]
11
 
12
- def predict(*features):
13
- x = np.array(features).reshape(1, -1)
14
- x_scaled = scaler.transform(x)
15
- prob = model.predict(x_scaled)[0][0]
16
- label = int(prob > 0.5)
17
- return {
18
- "probability": float(prob),
19
- "prediction": label
20
- }
21
-
22
- inputs = [gr.Number(label=f"Feature {i+1}") for i in range(N_FEATURES)]
23
-
24
- gr.Interface(
25
- fn=predict,
26
- inputs=inputs,
27
- outputs="json"
28
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import pandas as pd
3
  import numpy as np
4
  import joblib
5
  import tensorflow as tf
6
 
7
+ # Load model & scaler
8
  model = tf.keras.models.load_model("mlp_model.keras")
9
  scaler = joblib.load("scaler.pkl")
10
 
11
  N_FEATURES = model.input_shape[1]
12
 
13
+ def predict_csv(file):
14
+ df = pd.read_csv(file)
15
+
16
+ # Check number of features
17
+ if df.shape[1] != N_FEATURES:
18
+ return f"Expected {N_FEATURES} features, but got {df.shape[1]} columns."
19
+
20
+ X = df.values.astype(float)
21
+ X_scaled = scaler.transform(X)
22
+ probs = model.predict(X_scaled).reshape(-1)
23
+ preds = (probs > 0.5).astype(int)
24
+ # Build result dataframe
25
+ result = df.copy()
26
+ result["probability_malware"] = probs
27
+ result["prediction"] = preds
28
+ result["prediction_label"] = result["prediction"].map(
29
+ {1: "malware", 0: "benign"}
30
+ )
31
+
32
+ return result
33
+
34
+
35
+ demo = gr.Interface(
36
+ fn=predict_csv,
37
+ inputs=gr.File(label="Upload CSV file"),
38
+ outputs=gr.Dataframe(label="Prediction Result"),
39
+ title="Malware Detection MLP Model",
40
+ description="Upload a CSV file with features to predict malware or benign."
41
+ )
42
+
43
+ demo.launch()