Pant0x commited on
Commit
ef93a22
·
verified ·
1 Parent(s): 06c0e38

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -92
app.py CHANGED
@@ -3,119 +3,104 @@ import torch
3
  import numpy as np
4
 
5
  # -----------------------------
6
- # 1. Load Your Specific Model
7
  # -----------------------------
8
  MODEL_PATH = "models/phishing_rf_model.pt"
9
 
10
- print(f"Loading model from {MODEL_PATH}...")
11
-
12
- # We use torch.load because the file extension is .pt
13
- # map_location='cpu' ensures it works on servers without massive GPUs
14
  try:
 
15
  model = torch.load(MODEL_PATH, map_location=torch.device('cpu'))
16
- print("✅ Model loaded successfully!")
17
  except Exception as e:
18
  print(f"❌ Failed to load model: {e}")
19
- raise e
20
-
21
- # -----------------------------
22
- # 2. Feature Extraction (No Scaler)
23
- # -----------------------------
24
- def extract_features(url: str) -> np.ndarray:
25
- """
26
- Extracts the features the model expects.
27
- Since we are skipping the scaler, we feed these raw numbers directly.
28
- """
29
- length = len(url)
30
- dots = url.count('.')
31
- hyphens = url.count('-')
32
- digits = sum(c.isdigit() for c in url)
33
- at_sign = url.count('@')
34
-
35
- # Create the array shape [1, 5] (1 sample, 5 features)
36
- return np.array([[length, dots, hyphens, digits, at_sign]], dtype=float)
37
 
38
  # -----------------------------
39
- # 3. Prediction Logic
40
  # -----------------------------
41
- def predict_phishing(url: str):
 
 
 
42
  if not url:
43
- return None
44
 
45
- # 1. Extract features
46
- features = extract_features(url)
47
-
48
- # 2. Predict
49
- # We assume the model inside the .pt file is a standard sklearn model
50
- # (RandomForest) that supports .predict_proba()
51
  try:
52
- pred_prob = model.predict_proba(features)[0]
53
- except AttributeError:
54
- # Fallback if the model doesn't support probabilities
55
- pred = model.predict(features)[0]
56
- # Mock probabilities if exact confidence isn't available
57
- pred_prob = [1.0, 0.0] if pred == 0 else [0.0, 1.0]
 
 
 
58
 
59
- # 3. Format Output
60
- # Assuming Index 0 = Safe, Index 1 = Phishing
61
- label_index = pred_prob.argmax()
62
- confidence = float(pred_prob[label_index])
63
-
64
- if label_index == 1:
65
- label = "🚨 Phishing"
66
- else:
67
- label = "✅ Safe"
68
-
69
- return {label: confidence}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  # -----------------------------
72
- # 4. Unique Professional UI
73
  # -----------------------------
74
- theme = gr.themes.Soft(
75
- primary_hue="blue",
76
- secondary_hue="slate",
77
- )
78
-
79
- with gr.Blocks(theme=theme, title="PhishGuard Local") as iface:
80
 
81
- # Header
82
  with gr.Row():
83
- gr.Markdown(
84
- """
85
- # 🛡️ PhishGuard (Local Model)
86
- ### Custom Random Forest Detector
87
- Running locally using your `phishing_rf_model.pt` file.
88
- """
89
- )
90
 
91
- # Main Interface
92
  with gr.Row():
93
- with gr.Column(scale=1):
94
- url_input = gr.Textbox(
95
- lines=3,
96
- placeholder="https://example.com",
97
- label="Check URL",
98
- info="Paste the link you want to test."
99
- )
100
- submit_btn = gr.Button("Scan URL 🔍", variant="primary")
101
-
102
- gr.Examples(
103
- examples=[
104
- ["https://google.com"],
105
- ["http://fake-login-secure.com/update"]
106
- ],
107
- inputs=url_input
108
- )
109
-
110
- with gr.Column(scale=1):
111
- output_label = gr.Label(label="Result")
112
- gr.Markdown("> **Note:** Running without feature scaler. Results depend on raw feature interpretation.")
113
 
114
- # Actions
115
- submit_btn.click(
116
- fn=predict_phishing,
117
- inputs=url_input,
118
- outputs=output_label
119
  )
120
 
121
- iface.launch(share=True)
 
3
  import numpy as np
4
 
5
  # -----------------------------
6
+ # 1. Load Model (Robust)
7
  # -----------------------------
8
  MODEL_PATH = "models/phishing_rf_model.pt"
9
 
10
+ print(f"Attempting to load model from {MODEL_PATH}...")
 
 
 
11
  try:
12
+ # Load the model file
13
  model = torch.load(MODEL_PATH, map_location=torch.device('cpu'))
14
+ print(f"✅ Model loaded successfully! Type: {type(model)}")
15
  except Exception as e:
16
  print(f"❌ Failed to load model: {e}")
17
+ model = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  # -----------------------------
20
+ # 2. Prediction Logic (Universal)
21
  # -----------------------------
22
+ def predict_phishing(url):
23
+ # Safety checks
24
+ if model is None:
25
+ return {"Error": 0.0}, "Model failed to load. Check logs."
26
  if not url:
27
+ return None, "Please enter a URL."
28
 
 
 
 
 
 
 
29
  try:
30
+ # --- A. Extract Features ---
31
+ length = len(url)
32
+ dots = url.count('.')
33
+ hyphens = url.count('-')
34
+ digits = sum(c.isdigit() for c in url)
35
+ at_sign = url.count('@')
36
+
37
+ # Raw features list
38
+ features_list = [length, dots, hyphens, digits, at_sign]
39
 
40
+ # --- B. Smart Detection & Prediction ---
41
+
42
+ # CASE 1: It is a Scikit-Learn Model (Random Forest, etc.)
43
+ if hasattr(model, "predict_proba"):
44
+ # Sklearn expects a Numpy Array
45
+ input_data = np.array([features_list], dtype=float)
46
+
47
+ pred_prob = model.predict_proba(input_data)[0]
48
+ # Usually: Index 0 = Safe, Index 1 = Phishing
49
+ safe_score = float(pred_prob[0])
50
+ phish_score = float(pred_prob[1])
51
+
52
+ # CASE 2: It is a PyTorch Neural Network
53
+ elif isinstance(model, torch.nn.Module):
54
+ model.eval() # Set to evaluation mode
55
+
56
+ # PyTorch expects a Tensor
57
+ input_tensor = torch.tensor([features_list], dtype=torch.float32)
58
+
59
+ with torch.no_grad():
60
+ logits = model(input_tensor)
61
+
62
+ # Check output shape to decide between Softmax or Sigmoid
63
+ if logits.shape[1] == 1:
64
+ # Binary output (Sigmoid)
65
+ phish_score = torch.sigmoid(logits).item()
66
+ safe_score = 1.0 - phish_score
67
+ else:
68
+ # Multi-class output (Softmax)
69
+ probs = torch.nn.functional.softmax(logits, dim=1)
70
+ safe_score = float(probs[0][0])
71
+ phish_score = float(probs[0][1])
72
+
73
+ else:
74
+ return {"Error": 0}, f"Unknown model type: {type(model)}"
75
+
76
+ # Return results
77
+ return {"✅ Safe": safe_score, "🚨 Phishing": phish_score}, "Success"
78
+
79
+ except Exception as e:
80
+ # This catches the specific error and shows it in the UI
81
+ error_msg = f"Crash Error: {str(e)}"
82
+ print(error_msg)
83
+ return {"Error": 0}, error_msg
84
 
85
  # -----------------------------
86
+ # 3. UI Setup
87
  # -----------------------------
88
+ with gr.Blocks(theme=gr.themes.Soft()) as iface:
89
+ gr.Markdown("# 🛡️ PhishGuard Debugger")
 
 
 
 
90
 
 
91
  with gr.Row():
92
+ input_box = gr.Textbox(label="URL", placeholder="https://google.com")
93
+ predict_btn = gr.Button("Scan", variant="primary")
 
 
 
 
 
94
 
 
95
  with gr.Row():
96
+ # We use two outputs: one for the label, one for the error message
97
+ output_label = gr.Label(label="Prediction")
98
+ status_box = gr.Textbox(label="Debug Status (Read this if error)", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
+ predict_btn.click(
101
+ fn=predict_phishing,
102
+ inputs=input_box,
103
+ outputs=[output_label, status_box]
 
104
  )
105
 
106
+ iface.launch()