# ============================================================ # PhishGuard AI - cnn/cnn_inference.py # CNN inference wrapper for Tier 4 visual analysis. # Supports: predict, hot-reload, incremental_update. # ============================================================ from __future__ import annotations import io import random import logging from pathlib import Path from typing import List, Optional, Tuple import torch from PIL import Image logger = logging.getLogger("phishguard.cnn.inference") CNN_DIR = Path(__file__).parent BACKEND_DIR = CNN_DIR.parent WEIGHTS_PATH = CNN_DIR / "cnn_weights.pt" REPLAY_BUFFER_PATH = BACKEND_DIR / "data" / "cnn_replay_buffer.pt" class CNNInference: """CNN inference wrapper with hot-reload and incremental update.""" def __init__(self, weights_path: Optional[Path] = None) -> None: self._weights_path = weights_path or WEIGHTS_PATH self._model = None self._loaded = False def load(self, weights_path: Optional[Path] = None) -> bool: """Load CNN model.""" from cnn_model import load_cnn path = weights_path or self._weights_path self._model = load_cnn(str(path) if path.exists() else None) self._loaded = self._model is not None return self._loaded def predict(self, screenshot_bytes: bytes) -> float: """ Predict phishing probability from screenshot bytes. Returns P_cnn ∈ [0,1]. """ if not self._loaded: self.load() if self._model is None: return 0.5 from cnn_model import preprocess_screenshot try: tensor = preprocess_screenshot(screenshot_bytes) return self._model.predict_proba(tensor) except Exception as e: logger.error(f"CNN predict failed: {e}") return 0.5 def reload(self, weights_path: Optional[Path] = None) -> bool: """Hot-reload model with new weights.""" from cnn_model import load_cnn path = weights_path or self._weights_path new_model = load_cnn(str(path)) if new_model is not None: self._model = new_model self._loaded = True logger.info(f"CNN hot-reloaded from {path}") return True return False async def incremental_update( self, tier4_samples: List[Tuple[str, int]], replay_buffer_path: Optional[Path] = None, lr: float = 1e-4, epochs: int = 3, ) -> Optional[float]: """ Incremental update on Tier 4 feedback samples. Re-captures screenshots via Playwright, trains on them + replay buffer. Returns accuracy_delta or None if no Tier 4 samples. """ if not tier4_samples: logger.info("No Tier 4 samples — skipping CNN update") return None if self._model is None: logger.warning("CNN not loaded, cannot update") return None try: import torch.nn as nn from torch.optim import AdamW from torch.utils.data import DataLoader, TensorDataset import torchvision.transforms as T device = torch.device("cpu") model = self._model.to(device) transform = T.Compose([ T.Resize((224, 224)), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Try to capture screenshots for the new samples tensors = [] labels = [] for url, label in tier4_samples: try: # Try to capture screenshot screenshot_bytes = await self._capture_screenshot(url) if screenshot_bytes: img = Image.open(io.BytesIO(screenshot_bytes)).convert("RGB") tensor = transform(img) tensors.append(tensor) labels.append(float(label)) except Exception as e: logger.warning(f"Screenshot capture failed for {url}: {e}") continue # Load replay buffer (20% mix) buf_path = replay_buffer_path or REPLAY_BUFFER_PATH if buf_path.exists(): try: buf_data = torch.load(buf_path, map_location="cpu", weights_only=False) buf_paths = buf_data.get("paths", []) buf_labels = buf_data.get("labels", []) replay_count = max(1, len(buf_paths) // 5) indices = random.sample(range(len(buf_paths)), min(replay_count, len(buf_paths))) for idx in indices: try: img = Image.open(buf_paths[idx]).convert("RGB") tensor = transform(img) tensors.append(tensor) labels.append(float(buf_labels[idx])) except Exception: continue except Exception as e: logger.warning(f"CNN replay buffer load failed: {e}") if len(tensors) < 5: logger.warning(f"Too few CNN samples ({len(tensors)}), skipping update") return None # Stack and create dataset x_data = torch.stack(tensors) y_data = torch.tensor(labels, dtype=torch.float) dataset = TensorDataset(x_data, y_data) loader = DataLoader(dataset, batch_size=8, shuffle=True) # Pre-update accuracy model.eval() pre_correct = 0 with torch.no_grad(): for bx, by in loader: bx, by = bx.to(device), by.to(device) out = model(bx).squeeze() preds = (out >= 0.5).float() pre_correct += (preds == by).sum().item() pre_acc = pre_correct / len(dataset) # Train (head only — backbone stays frozen) head_params = [p for p in model.backbone.fc.parameters() if p.requires_grad] optimizer = AdamW(head_params, lr=lr) loss_fn = nn.BCELoss() model.train() for epoch in range(epochs): total_loss = 0.0 for bx, by in loader: bx, by = bx.to(device), by.to(device) optimizer.zero_grad() out = model(bx).squeeze() loss = loss_fn(out, by) loss.backward() optimizer.step() total_loss += loss.item() logger.info(f"CNN incremental epoch {epoch+1}/{epochs}, loss={total_loss/len(loader):.4f}") # Post-update accuracy model.eval() post_correct = 0 with torch.no_grad(): for bx, by in loader: bx, by = bx.to(device), by.to(device) out = model(bx).squeeze() preds = (out >= 0.5).float() post_correct += (preds == by).sum().item() post_acc = post_correct / len(dataset) delta = post_acc - pre_acc self._model = model # Save weights torch.save(model.state_dict(), self._weights_path) logger.info(f"CNN incremental: {pre_acc:.4f} → {post_acc:.4f} (Δ={delta:+.4f})") return round(delta, 4) except Exception as e: logger.error(f"CNN incremental update failed: {e}") return None async def _capture_screenshot(self, url: str) -> Optional[bytes]: """Capture a screenshot of a URL using Playwright.""" try: from playwright.async_api import async_playwright async with async_playwright() as p: browser = await p.chromium.launch(headless=True) page = await browser.new_page(viewport={"width": 1280, "height": 800}) # Block heavy resources await page.route("**/*.{png,jpg,jpeg,gif,svg,mp4,webm,ogg,woff,woff2,ttf,eot}", lambda route: route.abort()) await page.goto(url, wait_until="domcontentloaded", timeout=10000) screenshot = await page.screenshot(type="png") await browser.close() return screenshot except Exception as e: logger.warning(f"Screenshot capture failed: {e}") return None @property def is_loaded(self) -> bool: return self._loaded