Spaces:
Running
Running
| # ============================================================ | |
| # PhishGuard AI - cnn/train_cnn.py | |
| # CNN fine-tuning script for phishing screenshot detection. | |
| # | |
| # Loads data/screenshots/ with ImageFolder structure | |
| # Augmentation: RandomHorizontalFlip, ColorJitter, RandomRotation | |
| # 15 epochs, AdamW on head only (backbone stays frozen) | |
| # Saves cnn_weights.pt + cnn_replay_buffer.pt | |
| # Works with as few as 100 images per class | |
| # ============================================================ | |
| from __future__ import annotations | |
| import logging | |
| import sys | |
| from pathlib import Path | |
| from typing import List | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s | %(levelname)-7s | %(message)s", | |
| ) | |
| logger = logging.getLogger("phishguard.cnn.train") | |
| CNN_DIR = Path(__file__).parent | |
| BACKEND_DIR = CNN_DIR.parent | |
| WEIGHTS_PATH = CNN_DIR / "cnn_weights.pt" | |
| REPLAY_BUFFER_PATH = BACKEND_DIR / "data" / "cnn_replay_buffer.pt" | |
| SCREENSHOTS_DIR = BACKEND_DIR / "data" / "screenshots" | |
| sys.path.insert(0, str(CNN_DIR)) | |
| sys.path.insert(0, str(BACKEND_DIR)) | |
| def main() -> None: | |
| print("=" * 60) | |
| print("PhishGuard AI β CNN Training") | |
| print("=" * 60) | |
| import torch | |
| import torch.nn as nn | |
| from torch.optim import AdamW | |
| from torch.utils.data import DataLoader, Dataset, random_split | |
| import torchvision.transforms as T | |
| from PIL import Image | |
| from sklearn.metrics import accuracy_score, precision_recall_fscore_support | |
| from cnn_model import PhishCNN | |
| # ββ Check data βββββββββββββββββββββββββββββββββββββββββββββββ | |
| phishing_dir = SCREENSHOTS_DIR / "phishing" | |
| legitimate_dir = SCREENSHOTS_DIR / "legitimate" | |
| if not phishing_dir.exists() or not legitimate_dir.exists(): | |
| print(f"\nβ οΈ Screenshot directories not found:") | |
| print(f" Expected: {phishing_dir}") | |
| print(f" Expected: {legitimate_dir}") | |
| print(f"\n Run: python screenshot_collector.py") | |
| # Create dirs and generate placeholder images for testing | |
| phishing_dir.mkdir(parents=True, exist_ok=True) | |
| legitimate_dir.mkdir(parents=True, exist_ok=True) | |
| print(" Generating synthetic training images...") | |
| _generate_synthetic_screenshots(phishing_dir, legitimate_dir) | |
| phishing_files = list(phishing_dir.glob("*.png")) + list(phishing_dir.glob("*.jpg")) | |
| legit_files = list(legitimate_dir.glob("*.png")) + list(legitimate_dir.glob("*.jpg")) | |
| print(f"\nπ Dataset:") | |
| print(f" Phishing screenshots: {len(phishing_files)}") | |
| print(f" Legitimate screenshots: {len(legit_files)}") | |
| if len(phishing_files) < 10 or len(legit_files) < 10: | |
| print("β οΈ Too few screenshots. Generating synthetic images...") | |
| _generate_synthetic_screenshots(phishing_dir, legitimate_dir, count=100) | |
| phishing_files = list(phishing_dir.glob("*.png")) | |
| legit_files = list(legitimate_dir.glob("*.png")) | |
| print(f" Phishing: {len(phishing_files)}, Legitimate: {len(legit_files)}") | |
| # ββ Dataset ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| train_transform = T.Compose([ | |
| T.Resize((224, 224)), | |
| T.RandomHorizontalFlip(), | |
| T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), | |
| T.RandomRotation(5), | |
| T.ToTensor(), | |
| T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| val_transform = T.Compose([ | |
| T.Resize((224, 224)), | |
| T.ToTensor(), | |
| T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| class ScreenshotDataset(Dataset): | |
| def __init__(self, files: List[Path], label: int, transform): | |
| self.files = files | |
| self.label = label | |
| self.transform = transform | |
| def __len__(self) -> int: | |
| return len(self.files) | |
| def __getitem__(self, idx: int): | |
| try: | |
| img = Image.open(self.files[idx]).convert("RGB") | |
| tensor = self.transform(img) | |
| return tensor, self.label | |
| except Exception: | |
| # Return black image on error | |
| tensor = torch.zeros(3, 224, 224) | |
| return tensor, self.label | |
| # Split: 80% train, 20% val | |
| import random | |
| random.shuffle(phishing_files) | |
| random.shuffle(legit_files) | |
| phish_split = int(len(phishing_files) * 0.8) | |
| legit_split = int(len(legit_files) * 0.8) | |
| train_phish = phishing_files[:phish_split] | |
| val_phish = phishing_files[phish_split:] | |
| train_legit = legit_files[:legit_split] | |
| val_legit = legit_files[legit_split:] | |
| train_dataset = ( | |
| ScreenshotDataset(train_phish, 1, train_transform) | |
| + ScreenshotDataset(train_legit, 0, train_transform) | |
| ) | |
| val_dataset = ( | |
| ScreenshotDataset(val_phish, 1, val_transform) | |
| + ScreenshotDataset(val_legit, 0, val_transform) | |
| ) | |
| train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0) | |
| val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0) | |
| # ββ Model ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"\nπ€ Device: {device}") | |
| model = PhishCNN(pretrained=True).to(device) | |
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| total = sum(p.numel() for p in model.parameters()) | |
| print(f" Parameters: {total:,} total, {trainable:,} trainable") | |
| # Only optimize head parameters | |
| head_params = [p for p in model.backbone.fc.parameters() if p.requires_grad] | |
| optimizer = AdamW(head_params, lr=1e-3, weight_decay=1e-4) | |
| loss_fn = nn.BCELoss() | |
| # ββ Training βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| EPOCHS = 2 | |
| best_val_acc = 0.0 | |
| print(f"\nποΈ Training for {EPOCHS} epochs...") | |
| print(f" {'Epoch':>5} | {'Loss':>8} | {'Train Acc':>9} | {'Val Acc':>7}") | |
| print(f" {'β'*5} | {'β'*8} | {'β'*9} | {'β'*7}") | |
| for epoch in range(1, EPOCHS + 1): | |
| # Train | |
| model.train() | |
| total_loss = 0.0 | |
| train_preds, train_labels = [], [] | |
| for batch_x, batch_y in train_loader: | |
| batch_x = batch_x.to(device) | |
| batch_y = batch_y.float().to(device) | |
| optimizer.zero_grad() | |
| output = model(batch_x).squeeze() | |
| loss = loss_fn(output, batch_y) | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() | |
| preds = (output >= 0.5).int() | |
| train_preds.extend(preds.cpu().tolist()) | |
| train_labels.extend(batch_y.int().cpu().tolist()) | |
| avg_loss = total_loss / max(len(train_loader), 1) | |
| train_acc = accuracy_score(train_labels, train_preds) if train_labels else 0.0 | |
| # Validate | |
| model.eval() | |
| val_preds, val_labels = [], [] | |
| with torch.no_grad(): | |
| for batch_x, batch_y in val_loader: | |
| batch_x = batch_x.to(device) | |
| batch_y = batch_y.float().to(device) | |
| output = model(batch_x).squeeze() | |
| preds = (output >= 0.5).int() | |
| val_preds.extend(preds.cpu().tolist()) | |
| val_labels.extend(batch_y.int().cpu().tolist()) | |
| val_acc = accuracy_score(val_labels, val_preds) if val_labels else 0.0 | |
| if epoch % 3 == 0 or epoch == 1: | |
| print(f" {epoch:>5} | {avg_loss:>8.4f} | {train_acc:>9.4f} | {val_acc:>7.4f}") | |
| if val_acc > best_val_acc: | |
| best_val_acc = val_acc | |
| torch.save(model.state_dict(), WEIGHTS_PATH) | |
| # ββ Final metrics ββββββββββββββββββββββββββββββββββββββββββββ | |
| if val_labels: | |
| precision, recall, f1, _ = precision_recall_fscore_support( | |
| val_labels, val_preds, average="binary", zero_division=0, | |
| ) | |
| print(f"\nπ Final Validation:") | |
| print(f" Accuracy: {best_val_acc:.4f}") | |
| print(f" Precision: {precision:.4f}") | |
| print(f" Recall: {recall:.4f}") | |
| print(f" F1 Score: {f1:.4f}") | |
| # ββ Save replay buffer βββββββββββββββββββββββββββββββββββββββ | |
| all_paths = phishing_files + legit_files | |
| replay_paths = [str(p) for p in all_paths[:100]] | |
| replay_labels = [1] * min(len(phishing_files), 50) + [0] * min(len(legit_files), 50) | |
| REPLAY_BUFFER_PATH.parent.mkdir(parents=True, exist_ok=True) | |
| torch.save({"paths": replay_paths, "labels": replay_labels}, REPLAY_BUFFER_PATH) | |
| print(f"\nβ CNN weights saved to: {WEIGHTS_PATH}") | |
| print(f"πΎ Replay buffer saved: {len(replay_paths)} paths β {REPLAY_BUFFER_PATH}") | |
| print("=" * 60) | |
| def _generate_synthetic_screenshots( | |
| phishing_dir: Path, | |
| legitimate_dir: Path, | |
| count: int = 100, | |
| ) -> None: | |
| """Generate synthetic screenshots for training when real data unavailable.""" | |
| import random | |
| from PIL import Image, ImageDraw, ImageFont | |
| for label, save_dir, colors in [ | |
| ("phishing", phishing_dir, [(200, 50, 50), (180, 30, 30), (220, 80, 60)]), | |
| ("legitimate", legitimate_dir, [(50, 120, 200), (30, 100, 180), (60, 140, 220)]), | |
| ]: | |
| save_dir.mkdir(parents=True, exist_ok=True) | |
| existing = len(list(save_dir.glob("*.png"))) | |
| needed = max(0, count - existing) | |
| for i in range(needed): | |
| # Create varied synthetic images | |
| w, h = 1280, 800 | |
| bg = random.choice(colors) | |
| img = Image.new("RGB", (w, h), bg) | |
| draw = ImageDraw.Draw(img) | |
| # Add shapes | |
| for _ in range(random.randint(5, 15)): | |
| x1 = random.randint(0, w - 100) | |
| y1 = random.randint(0, h - 100) | |
| x2 = x1 + random.randint(50, 300) | |
| y2 = y1 + random.randint(30, 200) | |
| color = tuple(random.randint(0, 255) for _ in range(3)) | |
| draw.rectangle([x1, y1, x2, y2], fill=color) | |
| # Add text-like rectangles | |
| for _ in range(random.randint(3, 8)): | |
| x = random.randint(100, w - 400) | |
| y = random.randint(100, h - 100) | |
| draw.rectangle([x, y, x + random.randint(100, 300), y + 20], | |
| fill=(255, 255, 255)) | |
| img.save(save_dir / f"synthetic_{i:04d}.png") | |
| logger.info(f"Generated synthetic screenshots in {phishing_dir.parent}") | |
| if __name__ == "__main__": | |
| main() | |