File size: 3,763 Bytes
bebe233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# ============================================================
# PhishGuard AI - cnn/cnn_model.py
# ResNet50 visual classifier for phishing screenshot detection.
#
# Architecture (from spec):
#   Backbone: ResNet50 fully frozen
#   Custom head: Linear(2048β†’512) β†’ ReLU β†’ Dropout(0.5) β†’
#                Linear(512β†’1) β†’ Sigmoid
#   Input: 224Γ—224 screenshot tensor
#   Output: P_cnn ∈ [0,1]
# ============================================================

from __future__ import annotations

import io
import logging
from typing import Optional

import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as T
from PIL import Image

logger = logging.getLogger("phishguard.cnn.model")


class PhishCNN(nn.Module):
    """
    ResNet50 with frozen backbone and custom 2-layer binary classification head.
    Output: P_cnn ∈ [0,1] via sigmoid.
    """

    def __init__(self, pretrained: bool = True) -> None:
        super().__init__()

        # Load pretrained ResNet50 backbone
        if pretrained:
            self.backbone = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        else:
            self.backbone = models.resnet50(weights=None)

        # Freeze entire backbone
        for param in self.backbone.parameters():
            param.requires_grad = False

        # Replace fc with custom head: 2048 β†’ 512 β†’ 1 β†’ sigmoid
        in_features = self.backbone.fc.in_features  # 2048
        self.backbone.fc = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 1),
        )

        # Ensure custom head is trainable
        for param in self.backbone.fc.parameters():
            param.requires_grad = True

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass.
        Input: (batch, 3, 224, 224)
        Output: (batch, 1) probabilities in [0, 1]
        """
        logits = self.backbone(x)
        return torch.sigmoid(logits)

    def predict_proba(self, x: torch.Tensor) -> float:
        """Return P_cnn ∈ [0,1] β€” probability of phishing."""
        self.eval()
        with torch.no_grad():
            output = self.forward(x)
            return output.squeeze().item()


# ── Preprocessing pipeline (matches ImageNet normalization) ──────────
TRANSFORM = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(
        mean=[0.485, 0.456, 0.406],  # ImageNet mean
        std=[0.229, 0.224, 0.225],   # ImageNet std
    ),
])

# Training augmentation transforms
TRAIN_TRANSFORM = T.Compose([
    T.Resize((224, 224)),
    T.RandomHorizontalFlip(),
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    T.RandomRotation(5),
    T.ToTensor(),
    T.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])


def preprocess_screenshot(screenshot_bytes: bytes) -> torch.Tensor:
    """Convert raw screenshot bytes β†’ model-ready tensor [1, 3, 224, 224]."""
    img = Image.open(io.BytesIO(screenshot_bytes)).convert("RGB")
    return TRANSFORM(img).unsqueeze(0)


def load_cnn(weights_path: Optional[str] = None) -> PhishCNN:
    """Load CNN model with optional trained weights."""
    model = PhishCNN(pretrained=True)

    if weights_path:
        try:
            state = torch.load(weights_path, map_location="cpu", weights_only=True)
            model.load_state_dict(state)
            logger.info(f"CNN weights loaded from {weights_path}")
        except Exception as e:
            logger.warning(f"Could not load CNN weights: {e}")
            logger.info("Using ImageNet features only (baseline)")

    model.eval()
    return model