Pant0x's picture
Update app.py
853e348 verified
import gradio as gr
import torch
import torch.nn as nn
import pickle
import numpy as np
import os
# ---------------------------------------------------------
# 1. Define the Missing Model Architecture
# ---------------------------------------------------------
# Your .pt file only has weights. We need this class to hold them.
class PhishingNet(nn.Module):
def __init__(self, input_size=5, hidden_size=10, output_size=2):
super(PhishingNet, self).__init__()
# Standard architecture for these types of models
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, output_size)
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out
# ---------------------------------------------------------
# 2. Define a "Dummy" Scaler (Backup Plan)
# ---------------------------------------------------------
# If your scaler.pkl is broken, we use this so the code doesn't crash.
class DummyScaler:
def transform(self, x):
return x # Do nothing, just pass the data through
# ---------------------------------------------------------
# 3. Load Resources (Safely)
# ---------------------------------------------------------
MODEL_PATH = "models/phishing_rf_model.pt"
SCALER_PATH = "models/scaler.pkl"
model = None
scaler = None
status_log = "System Startup Log:\n"
# --- Try Loading Scaler ---
if os.path.exists(SCALER_PATH):
try:
with open(SCALER_PATH, "rb") as f:
scaler = pickle.load(f)
status_log += "βœ… Scaler loaded successfully.\n"
except Exception as e:
status_log += f"⚠️ Scaler is corrupt ({str(e)}). Using raw data instead.\n"
scaler = DummyScaler()
else:
status_log += "⚠️ Scaler file not found. Using raw data.\n"
scaler = DummyScaler()
# --- Try Loading Model ---
if os.path.exists(MODEL_PATH):
try:
# Load the weights (OrderedDict)
state_dict = torch.load(MODEL_PATH, map_location=torch.device('cpu'))
# Initialize the architecture
# We try to guess the shape from the weights if possible
input_shape = 5 # Default
if 'fc1.weight' in state_dict:
input_shape = state_dict['fc1.weight'].shape[1]
model = PhishingNet(input_size=input_shape)
model.load_state_dict(state_dict)
model.eval()
status_log += f"βœ… Model weights loaded (Input size: {input_shape}).\n"
except Exception as e:
status_log += f"❌ MODEL CRASH: {str(e)}\n"
model = None
else:
status_log += "❌ Model file not found.\n"
# ---------------------------------------------------------
# 4. Feature Extraction & Prediction
# ---------------------------------------------------------
def extract_features(url: str) -> np.ndarray:
length = len(url)
dots = url.count('.')
hyphens = url.count('-')
digits = sum(c.isdigit() for c in url)
at_sign = url.count('@')
return np.array([[length, dots, hyphens, digits, at_sign]], dtype=float)
def predict_phishing(url):
if model is None:
return {"Error": 0}, "Model failed to load. Check status."
if not url:
return None, "Please enter a URL."
try:
# 1. Extract Features
features = extract_features(url)
# 2. Scale (Real or Dummy)
features_scaled = scaler.transform(features)
# 3. Convert to Tensor
features_tensor = torch.tensor(features_scaled, dtype=torch.float32)
# 4. Predict
with torch.no_grad():
logits = model(features_tensor)
probs = torch.nn.functional.softmax(logits, dim=1)
# Map probabilities (Index 0 = Safe, Index 1 = Phishing)
safe_conf = float(probs[0][0])
phish_conf = float(probs[0][1])
return {"βœ… Safe": safe_conf, "🚨 Phishing": phish_conf}, "Success"
except Exception as e:
return {"Error": 0}, f"Prediction Error: {str(e)}"
# ---------------------------------------------------------
# 5. UI Setup
# ---------------------------------------------------------
with gr.Blocks(theme=gr.themes.Soft()) as iface:
gr.Markdown("# πŸ›‘οΈ PhishScope Final")
with gr.Row():
url_input = gr.Textbox(label="URL", placeholder="https://example.com")
btn = gr.Button("Scan", variant="primary")
with gr.Row():
label_out = gr.Label(label="Result")
debug_out = gr.Textbox(label="System Status (Debug)", value=status_log, lines=5)
btn.click(fn=predict_phishing, inputs=url_input, outputs=[label_out, debug_out])
iface.launch()