import gradio as gr import torch import torch.nn as nn import pickle import numpy as np import os # --------------------------------------------------------- # 1. Define the Missing Model Architecture # --------------------------------------------------------- # Your .pt file only has weights. We need this class to hold them. class PhishingNet(nn.Module): def __init__(self, input_size=5, hidden_size=10, output_size=2): super(PhishingNet, self).__init__() # Standard architecture for these types of models self.fc1 = nn.Linear(input_size, hidden_size) self.relu = nn.ReLU() self.fc2 = nn.Linear(hidden_size, output_size) def forward(self, x): out = self.fc1(x) out = self.relu(out) out = self.fc2(out) return out # --------------------------------------------------------- # 2. Define a "Dummy" Scaler (Backup Plan) # --------------------------------------------------------- # If your scaler.pkl is broken, we use this so the code doesn't crash. class DummyScaler: def transform(self, x): return x # Do nothing, just pass the data through # --------------------------------------------------------- # 3. Load Resources (Safely) # --------------------------------------------------------- MODEL_PATH = "models/phishing_rf_model.pt" SCALER_PATH = "models/scaler.pkl" model = None scaler = None status_log = "System Startup Log:\n" # --- Try Loading Scaler --- if os.path.exists(SCALER_PATH): try: with open(SCALER_PATH, "rb") as f: scaler = pickle.load(f) status_log += "✅ Scaler loaded successfully.\n" except Exception as e: status_log += f"⚠️ Scaler is corrupt ({str(e)}). Using raw data instead.\n" scaler = DummyScaler() else: status_log += "⚠️ Scaler file not found. Using raw data.\n" scaler = DummyScaler() # --- Try Loading Model --- if os.path.exists(MODEL_PATH): try: # Load the weights (OrderedDict) state_dict = torch.load(MODEL_PATH, map_location=torch.device('cpu')) # Initialize the architecture # We try to guess the shape from the weights if possible input_shape = 5 # Default if 'fc1.weight' in state_dict: input_shape = state_dict['fc1.weight'].shape[1] model = PhishingNet(input_size=input_shape) model.load_state_dict(state_dict) model.eval() status_log += f"✅ Model weights loaded (Input size: {input_shape}).\n" except Exception as e: status_log += f"❌ MODEL CRASH: {str(e)}\n" model = None else: status_log += "❌ Model file not found.\n" # --------------------------------------------------------- # 4. Feature Extraction & Prediction # --------------------------------------------------------- def extract_features(url: str) -> np.ndarray: length = len(url) dots = url.count('.') hyphens = url.count('-') digits = sum(c.isdigit() for c in url) at_sign = url.count('@') return np.array([[length, dots, hyphens, digits, at_sign]], dtype=float) def predict_phishing(url): if model is None: return {"Error": 0}, "Model failed to load. Check status." if not url: return None, "Please enter a URL." try: # 1. Extract Features features = extract_features(url) # 2. Scale (Real or Dummy) features_scaled = scaler.transform(features) # 3. Convert to Tensor features_tensor = torch.tensor(features_scaled, dtype=torch.float32) # 4. Predict with torch.no_grad(): logits = model(features_tensor) probs = torch.nn.functional.softmax(logits, dim=1) # Map probabilities (Index 0 = Safe, Index 1 = Phishing) safe_conf = float(probs[0][0]) phish_conf = float(probs[0][1]) return {"✅ Safe": safe_conf, "🚨 Phishing": phish_conf}, "Success" except Exception as e: return {"Error": 0}, f"Prediction Error: {str(e)}" # --------------------------------------------------------- # 5. UI Setup # --------------------------------------------------------- with gr.Blocks(theme=gr.themes.Soft()) as iface: gr.Markdown("# 🛡️ PhishScope Final") with gr.Row(): url_input = gr.Textbox(label="URL", placeholder="https://example.com") btn = gr.Button("Scan", variant="primary") with gr.Row(): label_out = gr.Label(label="Result") debug_out = gr.Textbox(label="System Status (Debug)", value=status_log, lines=5) btn.click(fn=predict_phishing, inputs=url_input, outputs=[label_out, debug_out]) iface.launch()