phishguard-api / train_cnn.py
prashanth135's picture
Upload 38 files
bebe233 verified
# ============================================================
# PhishGuard AI - cnn/train_cnn.py
# CNN fine-tuning script for phishing screenshot detection.
#
# Loads data/screenshots/ with ImageFolder structure
# Augmentation: RandomHorizontalFlip, ColorJitter, RandomRotation
# 15 epochs, AdamW on head only (backbone stays frozen)
# Saves cnn_weights.pt + cnn_replay_buffer.pt
# Works with as few as 100 images per class
# ============================================================
from __future__ import annotations
import logging
import sys
from pathlib import Path
from typing import List
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)-7s | %(message)s",
)
logger = logging.getLogger("phishguard.cnn.train")
CNN_DIR = Path(__file__).parent
BACKEND_DIR = CNN_DIR.parent
WEIGHTS_PATH = CNN_DIR / "cnn_weights.pt"
REPLAY_BUFFER_PATH = BACKEND_DIR / "data" / "cnn_replay_buffer.pt"
SCREENSHOTS_DIR = BACKEND_DIR / "data" / "screenshots"
sys.path.insert(0, str(CNN_DIR))
sys.path.insert(0, str(BACKEND_DIR))
def main() -> None:
print("=" * 60)
print("PhishGuard AI β€” CNN Training")
print("=" * 60)
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset, random_split
import torchvision.transforms as T
from PIL import Image
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from cnn_model import PhishCNN
# ── Check data ───────────────────────────────────────────────
phishing_dir = SCREENSHOTS_DIR / "phishing"
legitimate_dir = SCREENSHOTS_DIR / "legitimate"
if not phishing_dir.exists() or not legitimate_dir.exists():
print(f"\n⚠️ Screenshot directories not found:")
print(f" Expected: {phishing_dir}")
print(f" Expected: {legitimate_dir}")
print(f"\n Run: python screenshot_collector.py")
# Create dirs and generate placeholder images for testing
phishing_dir.mkdir(parents=True, exist_ok=True)
legitimate_dir.mkdir(parents=True, exist_ok=True)
print(" Generating synthetic training images...")
_generate_synthetic_screenshots(phishing_dir, legitimate_dir)
phishing_files = list(phishing_dir.glob("*.png")) + list(phishing_dir.glob("*.jpg"))
legit_files = list(legitimate_dir.glob("*.png")) + list(legitimate_dir.glob("*.jpg"))
print(f"\nπŸ“Š Dataset:")
print(f" Phishing screenshots: {len(phishing_files)}")
print(f" Legitimate screenshots: {len(legit_files)}")
if len(phishing_files) < 10 or len(legit_files) < 10:
print("⚠️ Too few screenshots. Generating synthetic images...")
_generate_synthetic_screenshots(phishing_dir, legitimate_dir, count=100)
phishing_files = list(phishing_dir.glob("*.png"))
legit_files = list(legitimate_dir.glob("*.png"))
print(f" Phishing: {len(phishing_files)}, Legitimate: {len(legit_files)}")
# ── Dataset ──────────────────────────────────────────────────
train_transform = T.Compose([
T.Resize((224, 224)),
T.RandomHorizontalFlip(),
T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
T.RandomRotation(5),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
val_transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
class ScreenshotDataset(Dataset):
def __init__(self, files: List[Path], label: int, transform):
self.files = files
self.label = label
self.transform = transform
def __len__(self) -> int:
return len(self.files)
def __getitem__(self, idx: int):
try:
img = Image.open(self.files[idx]).convert("RGB")
tensor = self.transform(img)
return tensor, self.label
except Exception:
# Return black image on error
tensor = torch.zeros(3, 224, 224)
return tensor, self.label
# Split: 80% train, 20% val
import random
random.shuffle(phishing_files)
random.shuffle(legit_files)
phish_split = int(len(phishing_files) * 0.8)
legit_split = int(len(legit_files) * 0.8)
train_phish = phishing_files[:phish_split]
val_phish = phishing_files[phish_split:]
train_legit = legit_files[:legit_split]
val_legit = legit_files[legit_split:]
train_dataset = (
ScreenshotDataset(train_phish, 1, train_transform)
+ ScreenshotDataset(train_legit, 0, train_transform)
)
val_dataset = (
ScreenshotDataset(val_phish, 1, val_transform)
+ ScreenshotDataset(val_legit, 0, val_transform)
)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)
# ── Model ────────────────────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\nπŸ€– Device: {device}")
model = PhishCNN(pretrained=True).to(device)
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f" Parameters: {total:,} total, {trainable:,} trainable")
# Only optimize head parameters
head_params = [p for p in model.backbone.fc.parameters() if p.requires_grad]
optimizer = AdamW(head_params, lr=1e-3, weight_decay=1e-4)
loss_fn = nn.BCELoss()
# ── Training ─────────────────────────────────────────────────
EPOCHS = 2
best_val_acc = 0.0
print(f"\nπŸ‹οΈ Training for {EPOCHS} epochs...")
print(f" {'Epoch':>5} | {'Loss':>8} | {'Train Acc':>9} | {'Val Acc':>7}")
print(f" {'─'*5} | {'─'*8} | {'─'*9} | {'─'*7}")
for epoch in range(1, EPOCHS + 1):
# Train
model.train()
total_loss = 0.0
train_preds, train_labels = [], []
for batch_x, batch_y in train_loader:
batch_x = batch_x.to(device)
batch_y = batch_y.float().to(device)
optimizer.zero_grad()
output = model(batch_x).squeeze()
loss = loss_fn(output, batch_y)
loss.backward()
optimizer.step()
total_loss += loss.item()
preds = (output >= 0.5).int()
train_preds.extend(preds.cpu().tolist())
train_labels.extend(batch_y.int().cpu().tolist())
avg_loss = total_loss / max(len(train_loader), 1)
train_acc = accuracy_score(train_labels, train_preds) if train_labels else 0.0
# Validate
model.eval()
val_preds, val_labels = [], []
with torch.no_grad():
for batch_x, batch_y in val_loader:
batch_x = batch_x.to(device)
batch_y = batch_y.float().to(device)
output = model(batch_x).squeeze()
preds = (output >= 0.5).int()
val_preds.extend(preds.cpu().tolist())
val_labels.extend(batch_y.int().cpu().tolist())
val_acc = accuracy_score(val_labels, val_preds) if val_labels else 0.0
if epoch % 3 == 0 or epoch == 1:
print(f" {epoch:>5} | {avg_loss:>8.4f} | {train_acc:>9.4f} | {val_acc:>7.4f}")
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), WEIGHTS_PATH)
# ── Final metrics ────────────────────────────────────────────
if val_labels:
precision, recall, f1, _ = precision_recall_fscore_support(
val_labels, val_preds, average="binary", zero_division=0,
)
print(f"\nπŸ“Š Final Validation:")
print(f" Accuracy: {best_val_acc:.4f}")
print(f" Precision: {precision:.4f}")
print(f" Recall: {recall:.4f}")
print(f" F1 Score: {f1:.4f}")
# ── Save replay buffer ───────────────────────────────────────
all_paths = phishing_files + legit_files
replay_paths = [str(p) for p in all_paths[:100]]
replay_labels = [1] * min(len(phishing_files), 50) + [0] * min(len(legit_files), 50)
REPLAY_BUFFER_PATH.parent.mkdir(parents=True, exist_ok=True)
torch.save({"paths": replay_paths, "labels": replay_labels}, REPLAY_BUFFER_PATH)
print(f"\nβœ… CNN weights saved to: {WEIGHTS_PATH}")
print(f"πŸ’Ύ Replay buffer saved: {len(replay_paths)} paths β†’ {REPLAY_BUFFER_PATH}")
print("=" * 60)
def _generate_synthetic_screenshots(
phishing_dir: Path,
legitimate_dir: Path,
count: int = 100,
) -> None:
"""Generate synthetic screenshots for training when real data unavailable."""
import random
from PIL import Image, ImageDraw, ImageFont
for label, save_dir, colors in [
("phishing", phishing_dir, [(200, 50, 50), (180, 30, 30), (220, 80, 60)]),
("legitimate", legitimate_dir, [(50, 120, 200), (30, 100, 180), (60, 140, 220)]),
]:
save_dir.mkdir(parents=True, exist_ok=True)
existing = len(list(save_dir.glob("*.png")))
needed = max(0, count - existing)
for i in range(needed):
# Create varied synthetic images
w, h = 1280, 800
bg = random.choice(colors)
img = Image.new("RGB", (w, h), bg)
draw = ImageDraw.Draw(img)
# Add shapes
for _ in range(random.randint(5, 15)):
x1 = random.randint(0, w - 100)
y1 = random.randint(0, h - 100)
x2 = x1 + random.randint(50, 300)
y2 = y1 + random.randint(30, 200)
color = tuple(random.randint(0, 255) for _ in range(3))
draw.rectangle([x1, y1, x2, y2], fill=color)
# Add text-like rectangles
for _ in range(random.randint(3, 8)):
x = random.randint(100, w - 400)
y = random.randint(100, h - 100)
draw.rectangle([x, y, x + random.randint(100, 300), y + 20],
fill=(255, 255, 255))
img.save(save_dir / f"synthetic_{i:04d}.png")
logger.info(f"Generated synthetic screenshots in {phishing_dir.parent}")
if __name__ == "__main__":
main()