hieu3636 commited on
Commit
83cb951
·
verified ·
1 Parent(s): adfcc36

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -8
app.py CHANGED
@@ -12,19 +12,30 @@ N_FEATURES = model.input_shape[1]
12
 
13
  def predict_csv(file):
14
  df = pd.read_csv(file)
15
- # Drop label column
16
- df = df.drop(columns=["Label", "label", "Class", "class"], errors="ignore")
17
- # Add row index column
 
 
18
  df.insert(0, "row_id", range(1, len(df) + 1))
19
- # Check number of features
20
- if df.shape[1] != N_FEATURES:
21
- return f"Expected {N_FEATURES} features, but got {df.shape[1]} columns."
22
 
23
- X = df.drop(columns=["row_id"]).values.astype(float)
 
 
 
 
 
 
 
 
 
 
24
  X_scaled = scaler.transform(X)
 
25
  probs = model.predict(X_scaled).reshape(-1)
26
  preds = (probs > 0.5).astype(int)
27
- # Build result dataframe
 
28
  result = df.copy()
29
  result["probability_malware"] = probs
30
  result["prediction"] = preds
@@ -35,6 +46,7 @@ def predict_csv(file):
35
  return result
36
 
37
 
 
38
  demo = gr.Interface(
39
  fn=predict_csv,
40
  inputs=gr.File(label="Upload CSV file"),
 
12
 
13
  def predict_csv(file):
14
  df = pd.read_csv(file)
15
+
16
+ # Drop label column if it exists
17
+ df = df.drop(columns=["Label", "label"], errors="ignore")
18
+
19
+ # Add row index for display only
20
  df.insert(0, "row_id", range(1, len(df) + 1))
 
 
 
21
 
22
+ # Separate features for model
23
+ feature_df = df.drop(columns=["row_id"])
24
+
25
+ # Check feature count
26
+ if feature_df.shape[1] != N_FEATURES:
27
+ return (
28
+ f"Expected {N_FEATURES} features, "
29
+ f"but got {feature_df.shape[1]} columns."
30
+ )
31
+
32
+ X = feature_df.values.astype(float)
33
  X_scaled = scaler.transform(X)
34
+
35
  probs = model.predict(X_scaled).reshape(-1)
36
  preds = (probs > 0.5).astype(int)
37
+
38
+ # Build result table (row_id kept)
39
  result = df.copy()
40
  result["probability_malware"] = probs
41
  result["prediction"] = preds
 
46
  return result
47
 
48
 
49
+
50
  demo = gr.Interface(
51
  fn=predict_csv,
52
  inputs=gr.File(label="Upload CSV file"),