# ============================================================ # PhishGuard AI - cnn/train_cnn.py # CNN fine-tuning script for phishing screenshot detection. # # Loads data/screenshots/ with ImageFolder structure # Augmentation: RandomHorizontalFlip, ColorJitter, RandomRotation # 15 epochs, AdamW on head only (backbone stays frozen) # Saves cnn_weights.pt + cnn_replay_buffer.pt # Works with as few as 100 images per class # ============================================================ from __future__ import annotations import logging import sys from pathlib import Path from typing import List logging.basicConfig( level=logging.INFO, format="%(asctime)s | %(levelname)-7s | %(message)s", ) logger = logging.getLogger("phishguard.cnn.train") CNN_DIR = Path(__file__).parent BACKEND_DIR = CNN_DIR.parent WEIGHTS_PATH = CNN_DIR / "cnn_weights.pt" REPLAY_BUFFER_PATH = BACKEND_DIR / "data" / "cnn_replay_buffer.pt" SCREENSHOTS_DIR = BACKEND_DIR / "data" / "screenshots" sys.path.insert(0, str(CNN_DIR)) sys.path.insert(0, str(BACKEND_DIR)) def main() -> None: print("=" * 60) print("PhishGuard AI — CNN Training") print("=" * 60) import torch import torch.nn as nn from torch.optim import AdamW from torch.utils.data import DataLoader, Dataset, random_split import torchvision.transforms as T from PIL import Image from sklearn.metrics import accuracy_score, precision_recall_fscore_support from cnn_model import PhishCNN # ── Check data ─────────────────────────────────────────────── phishing_dir = SCREENSHOTS_DIR / "phishing" legitimate_dir = SCREENSHOTS_DIR / "legitimate" if not phishing_dir.exists() or not legitimate_dir.exists(): print(f"\n⚠️ Screenshot directories not found:") print(f" Expected: {phishing_dir}") print(f" Expected: {legitimate_dir}") print(f"\n Run: python screenshot_collector.py") # Create dirs and generate placeholder images for testing phishing_dir.mkdir(parents=True, exist_ok=True) legitimate_dir.mkdir(parents=True, exist_ok=True) print(" Generating synthetic training images...") _generate_synthetic_screenshots(phishing_dir, legitimate_dir) phishing_files = list(phishing_dir.glob("*.png")) + list(phishing_dir.glob("*.jpg")) legit_files = list(legitimate_dir.glob("*.png")) + list(legitimate_dir.glob("*.jpg")) print(f"\n📊 Dataset:") print(f" Phishing screenshots: {len(phishing_files)}") print(f" Legitimate screenshots: {len(legit_files)}") if len(phishing_files) < 10 or len(legit_files) < 10: print("⚠️ Too few screenshots. Generating synthetic images...") _generate_synthetic_screenshots(phishing_dir, legitimate_dir, count=100) phishing_files = list(phishing_dir.glob("*.png")) legit_files = list(legitimate_dir.glob("*.png")) print(f" Phishing: {len(phishing_files)}, Legitimate: {len(legit_files)}") # ── Dataset ────────────────────────────────────────────────── train_transform = T.Compose([ T.Resize((224, 224)), T.RandomHorizontalFlip(), T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), T.RandomRotation(5), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) val_transform = T.Compose([ T.Resize((224, 224)), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) class ScreenshotDataset(Dataset): def __init__(self, files: List[Path], label: int, transform): self.files = files self.label = label self.transform = transform def __len__(self) -> int: return len(self.files) def __getitem__(self, idx: int): try: img = Image.open(self.files[idx]).convert("RGB") tensor = self.transform(img) return tensor, self.label except Exception: # Return black image on error tensor = torch.zeros(3, 224, 224) return tensor, self.label # Split: 80% train, 20% val import random random.shuffle(phishing_files) random.shuffle(legit_files) phish_split = int(len(phishing_files) * 0.8) legit_split = int(len(legit_files) * 0.8) train_phish = phishing_files[:phish_split] val_phish = phishing_files[phish_split:] train_legit = legit_files[:legit_split] val_legit = legit_files[legit_split:] train_dataset = ( ScreenshotDataset(train_phish, 1, train_transform) + ScreenshotDataset(train_legit, 0, train_transform) ) val_dataset = ( ScreenshotDataset(val_phish, 1, val_transform) + ScreenshotDataset(val_legit, 0, val_transform) ) train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0) val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0) # ── Model ──────────────────────────────────────────────────── device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"\n🤖 Device: {device}") model = PhishCNN(pretrained=True).to(device) trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) total = sum(p.numel() for p in model.parameters()) print(f" Parameters: {total:,} total, {trainable:,} trainable") # Only optimize head parameters head_params = [p for p in model.backbone.fc.parameters() if p.requires_grad] optimizer = AdamW(head_params, lr=1e-3, weight_decay=1e-4) loss_fn = nn.BCELoss() # ── Training ───────────────────────────────────────────────── EPOCHS = 2 best_val_acc = 0.0 print(f"\n🏋️ Training for {EPOCHS} epochs...") print(f" {'Epoch':>5} | {'Loss':>8} | {'Train Acc':>9} | {'Val Acc':>7}") print(f" {'─'*5} | {'─'*8} | {'─'*9} | {'─'*7}") for epoch in range(1, EPOCHS + 1): # Train model.train() total_loss = 0.0 train_preds, train_labels = [], [] for batch_x, batch_y in train_loader: batch_x = batch_x.to(device) batch_y = batch_y.float().to(device) optimizer.zero_grad() output = model(batch_x).squeeze() loss = loss_fn(output, batch_y) loss.backward() optimizer.step() total_loss += loss.item() preds = (output >= 0.5).int() train_preds.extend(preds.cpu().tolist()) train_labels.extend(batch_y.int().cpu().tolist()) avg_loss = total_loss / max(len(train_loader), 1) train_acc = accuracy_score(train_labels, train_preds) if train_labels else 0.0 # Validate model.eval() val_preds, val_labels = [], [] with torch.no_grad(): for batch_x, batch_y in val_loader: batch_x = batch_x.to(device) batch_y = batch_y.float().to(device) output = model(batch_x).squeeze() preds = (output >= 0.5).int() val_preds.extend(preds.cpu().tolist()) val_labels.extend(batch_y.int().cpu().tolist()) val_acc = accuracy_score(val_labels, val_preds) if val_labels else 0.0 if epoch % 3 == 0 or epoch == 1: print(f" {epoch:>5} | {avg_loss:>8.4f} | {train_acc:>9.4f} | {val_acc:>7.4f}") if val_acc > best_val_acc: best_val_acc = val_acc torch.save(model.state_dict(), WEIGHTS_PATH) # ── Final metrics ──────────────────────────────────────────── if val_labels: precision, recall, f1, _ = precision_recall_fscore_support( val_labels, val_preds, average="binary", zero_division=0, ) print(f"\n📊 Final Validation:") print(f" Accuracy: {best_val_acc:.4f}") print(f" Precision: {precision:.4f}") print(f" Recall: {recall:.4f}") print(f" F1 Score: {f1:.4f}") # ── Save replay buffer ─────────────────────────────────────── all_paths = phishing_files + legit_files replay_paths = [str(p) for p in all_paths[:100]] replay_labels = [1] * min(len(phishing_files), 50) + [0] * min(len(legit_files), 50) REPLAY_BUFFER_PATH.parent.mkdir(parents=True, exist_ok=True) torch.save({"paths": replay_paths, "labels": replay_labels}, REPLAY_BUFFER_PATH) print(f"\n✅ CNN weights saved to: {WEIGHTS_PATH}") print(f"💾 Replay buffer saved: {len(replay_paths)} paths → {REPLAY_BUFFER_PATH}") print("=" * 60) def _generate_synthetic_screenshots( phishing_dir: Path, legitimate_dir: Path, count: int = 100, ) -> None: """Generate synthetic screenshots for training when real data unavailable.""" import random from PIL import Image, ImageDraw, ImageFont for label, save_dir, colors in [ ("phishing", phishing_dir, [(200, 50, 50), (180, 30, 30), (220, 80, 60)]), ("legitimate", legitimate_dir, [(50, 120, 200), (30, 100, 180), (60, 140, 220)]), ]: save_dir.mkdir(parents=True, exist_ok=True) existing = len(list(save_dir.glob("*.png"))) needed = max(0, count - existing) for i in range(needed): # Create varied synthetic images w, h = 1280, 800 bg = random.choice(colors) img = Image.new("RGB", (w, h), bg) draw = ImageDraw.Draw(img) # Add shapes for _ in range(random.randint(5, 15)): x1 = random.randint(0, w - 100) y1 = random.randint(0, h - 100) x2 = x1 + random.randint(50, 300) y2 = y1 + random.randint(30, 200) color = tuple(random.randint(0, 255) for _ in range(3)) draw.rectangle([x1, y1, x2, y2], fill=color) # Add text-like rectangles for _ in range(random.randint(3, 8)): x = random.randint(100, w - 400) y = random.randint(100, h - 100) draw.rectangle([x, y, x + random.randint(100, 300), y + 20], fill=(255, 255, 255)) img.save(save_dir / f"synthetic_{i:04d}.png") logger.info(f"Generated synthetic screenshots in {phishing_dir.parent}") if __name__ == "__main__": main()