Spaces:
Running
Running
| # ============================================================ | |
| # PhishGuard AI - cnn/cnn_inference.py | |
| # CNN inference wrapper for Tier 4 visual analysis. | |
| # Supports: predict, hot-reload, incremental_update. | |
| # ============================================================ | |
| from __future__ import annotations | |
| import io | |
| import random | |
| import logging | |
| from pathlib import Path | |
| from typing import List, Optional, Tuple | |
| import torch | |
| from PIL import Image | |
| logger = logging.getLogger("phishguard.cnn.inference") | |
| CNN_DIR = Path(__file__).parent | |
| BACKEND_DIR = CNN_DIR.parent | |
| WEIGHTS_PATH = CNN_DIR / "cnn_weights.pt" | |
| REPLAY_BUFFER_PATH = BACKEND_DIR / "data" / "cnn_replay_buffer.pt" | |
| class CNNInference: | |
| """CNN inference wrapper with hot-reload and incremental update.""" | |
| def __init__(self, weights_path: Optional[Path] = None) -> None: | |
| self._weights_path = weights_path or WEIGHTS_PATH | |
| self._model = None | |
| self._loaded = False | |
| def load(self, weights_path: Optional[Path] = None) -> bool: | |
| """Load CNN model.""" | |
| from cnn_model import load_cnn | |
| path = weights_path or self._weights_path | |
| self._model = load_cnn(str(path) if path.exists() else None) | |
| self._loaded = self._model is not None | |
| return self._loaded | |
| def predict(self, screenshot_bytes: bytes) -> float: | |
| """ | |
| Predict phishing probability from screenshot bytes. | |
| Returns P_cnn ∈ [0,1]. | |
| """ | |
| if not self._loaded: | |
| self.load() | |
| if self._model is None: | |
| return 0.5 | |
| from cnn_model import preprocess_screenshot | |
| try: | |
| tensor = preprocess_screenshot(screenshot_bytes) | |
| return self._model.predict_proba(tensor) | |
| except Exception as e: | |
| logger.error(f"CNN predict failed: {e}") | |
| return 0.5 | |
| def reload(self, weights_path: Optional[Path] = None) -> bool: | |
| """Hot-reload model with new weights.""" | |
| from cnn_model import load_cnn | |
| path = weights_path or self._weights_path | |
| new_model = load_cnn(str(path)) | |
| if new_model is not None: | |
| self._model = new_model | |
| self._loaded = True | |
| logger.info(f"CNN hot-reloaded from {path}") | |
| return True | |
| return False | |
| async def incremental_update( | |
| self, | |
| tier4_samples: List[Tuple[str, int]], | |
| replay_buffer_path: Optional[Path] = None, | |
| lr: float = 1e-4, | |
| epochs: int = 3, | |
| ) -> Optional[float]: | |
| """ | |
| Incremental update on Tier 4 feedback samples. | |
| Re-captures screenshots via Playwright, trains on them + replay buffer. | |
| Returns accuracy_delta or None if no Tier 4 samples. | |
| """ | |
| if not tier4_samples: | |
| logger.info("No Tier 4 samples — skipping CNN update") | |
| return None | |
| if self._model is None: | |
| logger.warning("CNN not loaded, cannot update") | |
| return None | |
| try: | |
| import torch.nn as nn | |
| from torch.optim import AdamW | |
| from torch.utils.data import DataLoader, TensorDataset | |
| import torchvision.transforms as T | |
| device = torch.device("cpu") | |
| model = self._model.to(device) | |
| transform = T.Compose([ | |
| T.Resize((224, 224)), | |
| T.ToTensor(), | |
| T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| # Try to capture screenshots for the new samples | |
| tensors = [] | |
| labels = [] | |
| for url, label in tier4_samples: | |
| try: | |
| # Try to capture screenshot | |
| screenshot_bytes = await self._capture_screenshot(url) | |
| if screenshot_bytes: | |
| img = Image.open(io.BytesIO(screenshot_bytes)).convert("RGB") | |
| tensor = transform(img) | |
| tensors.append(tensor) | |
| labels.append(float(label)) | |
| except Exception as e: | |
| logger.warning(f"Screenshot capture failed for {url}: {e}") | |
| continue | |
| # Load replay buffer (20% mix) | |
| buf_path = replay_buffer_path or REPLAY_BUFFER_PATH | |
| if buf_path.exists(): | |
| try: | |
| buf_data = torch.load(buf_path, map_location="cpu", weights_only=False) | |
| buf_paths = buf_data.get("paths", []) | |
| buf_labels = buf_data.get("labels", []) | |
| replay_count = max(1, len(buf_paths) // 5) | |
| indices = random.sample(range(len(buf_paths)), min(replay_count, len(buf_paths))) | |
| for idx in indices: | |
| try: | |
| img = Image.open(buf_paths[idx]).convert("RGB") | |
| tensor = transform(img) | |
| tensors.append(tensor) | |
| labels.append(float(buf_labels[idx])) | |
| except Exception: | |
| continue | |
| except Exception as e: | |
| logger.warning(f"CNN replay buffer load failed: {e}") | |
| if len(tensors) < 5: | |
| logger.warning(f"Too few CNN samples ({len(tensors)}), skipping update") | |
| return None | |
| # Stack and create dataset | |
| x_data = torch.stack(tensors) | |
| y_data = torch.tensor(labels, dtype=torch.float) | |
| dataset = TensorDataset(x_data, y_data) | |
| loader = DataLoader(dataset, batch_size=8, shuffle=True) | |
| # Pre-update accuracy | |
| model.eval() | |
| pre_correct = 0 | |
| with torch.no_grad(): | |
| for bx, by in loader: | |
| bx, by = bx.to(device), by.to(device) | |
| out = model(bx).squeeze() | |
| preds = (out >= 0.5).float() | |
| pre_correct += (preds == by).sum().item() | |
| pre_acc = pre_correct / len(dataset) | |
| # Train (head only — backbone stays frozen) | |
| head_params = [p for p in model.backbone.fc.parameters() if p.requires_grad] | |
| optimizer = AdamW(head_params, lr=lr) | |
| loss_fn = nn.BCELoss() | |
| model.train() | |
| for epoch in range(epochs): | |
| total_loss = 0.0 | |
| for bx, by in loader: | |
| bx, by = bx.to(device), by.to(device) | |
| optimizer.zero_grad() | |
| out = model(bx).squeeze() | |
| loss = loss_fn(out, by) | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() | |
| logger.info(f"CNN incremental epoch {epoch+1}/{epochs}, loss={total_loss/len(loader):.4f}") | |
| # Post-update accuracy | |
| model.eval() | |
| post_correct = 0 | |
| with torch.no_grad(): | |
| for bx, by in loader: | |
| bx, by = bx.to(device), by.to(device) | |
| out = model(bx).squeeze() | |
| preds = (out >= 0.5).float() | |
| post_correct += (preds == by).sum().item() | |
| post_acc = post_correct / len(dataset) | |
| delta = post_acc - pre_acc | |
| self._model = model | |
| # Save weights | |
| torch.save(model.state_dict(), self._weights_path) | |
| logger.info(f"CNN incremental: {pre_acc:.4f} → {post_acc:.4f} (Δ={delta:+.4f})") | |
| return round(delta, 4) | |
| except Exception as e: | |
| logger.error(f"CNN incremental update failed: {e}") | |
| return None | |
| async def _capture_screenshot(self, url: str) -> Optional[bytes]: | |
| """Capture a screenshot of a URL using Playwright.""" | |
| try: | |
| from playwright.async_api import async_playwright | |
| async with async_playwright() as p: | |
| browser = await p.chromium.launch(headless=True) | |
| page = await browser.new_page(viewport={"width": 1280, "height": 800}) | |
| # Block heavy resources | |
| await page.route("**/*.{png,jpg,jpeg,gif,svg,mp4,webm,ogg,woff,woff2,ttf,eot}", | |
| lambda route: route.abort()) | |
| await page.goto(url, wait_until="domcontentloaded", timeout=10000) | |
| screenshot = await page.screenshot(type="png") | |
| await browser.close() | |
| return screenshot | |
| except Exception as e: | |
| logger.warning(f"Screenshot capture failed: {e}") | |
| return None | |
| def is_loaded(self) -> bool: | |
| return self._loaded | |