phishguard-api / cnn_inference.py
prashanth135's picture
Upload 38 files
bebe233 verified
# ============================================================
# PhishGuard AI - cnn/cnn_inference.py
# CNN inference wrapper for Tier 4 visual analysis.
# Supports: predict, hot-reload, incremental_update.
# ============================================================
from __future__ import annotations
import io
import random
import logging
from pathlib import Path
from typing import List, Optional, Tuple
import torch
from PIL import Image
logger = logging.getLogger("phishguard.cnn.inference")
CNN_DIR = Path(__file__).parent
BACKEND_DIR = CNN_DIR.parent
WEIGHTS_PATH = CNN_DIR / "cnn_weights.pt"
REPLAY_BUFFER_PATH = BACKEND_DIR / "data" / "cnn_replay_buffer.pt"
class CNNInference:
"""CNN inference wrapper with hot-reload and incremental update."""
def __init__(self, weights_path: Optional[Path] = None) -> None:
self._weights_path = weights_path or WEIGHTS_PATH
self._model = None
self._loaded = False
def load(self, weights_path: Optional[Path] = None) -> bool:
"""Load CNN model."""
from cnn_model import load_cnn
path = weights_path or self._weights_path
self._model = load_cnn(str(path) if path.exists() else None)
self._loaded = self._model is not None
return self._loaded
def predict(self, screenshot_bytes: bytes) -> float:
"""
Predict phishing probability from screenshot bytes.
Returns P_cnn ∈ [0,1].
"""
if not self._loaded:
self.load()
if self._model is None:
return 0.5
from cnn_model import preprocess_screenshot
try:
tensor = preprocess_screenshot(screenshot_bytes)
return self._model.predict_proba(tensor)
except Exception as e:
logger.error(f"CNN predict failed: {e}")
return 0.5
def reload(self, weights_path: Optional[Path] = None) -> bool:
"""Hot-reload model with new weights."""
from cnn_model import load_cnn
path = weights_path or self._weights_path
new_model = load_cnn(str(path))
if new_model is not None:
self._model = new_model
self._loaded = True
logger.info(f"CNN hot-reloaded from {path}")
return True
return False
async def incremental_update(
self,
tier4_samples: List[Tuple[str, int]],
replay_buffer_path: Optional[Path] = None,
lr: float = 1e-4,
epochs: int = 3,
) -> Optional[float]:
"""
Incremental update on Tier 4 feedback samples.
Re-captures screenshots via Playwright, trains on them + replay buffer.
Returns accuracy_delta or None if no Tier 4 samples.
"""
if not tier4_samples:
logger.info("No Tier 4 samples — skipping CNN update")
return None
if self._model is None:
logger.warning("CNN not loaded, cannot update")
return None
try:
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader, TensorDataset
import torchvision.transforms as T
device = torch.device("cpu")
model = self._model.to(device)
transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Try to capture screenshots for the new samples
tensors = []
labels = []
for url, label in tier4_samples:
try:
# Try to capture screenshot
screenshot_bytes = await self._capture_screenshot(url)
if screenshot_bytes:
img = Image.open(io.BytesIO(screenshot_bytes)).convert("RGB")
tensor = transform(img)
tensors.append(tensor)
labels.append(float(label))
except Exception as e:
logger.warning(f"Screenshot capture failed for {url}: {e}")
continue
# Load replay buffer (20% mix)
buf_path = replay_buffer_path or REPLAY_BUFFER_PATH
if buf_path.exists():
try:
buf_data = torch.load(buf_path, map_location="cpu", weights_only=False)
buf_paths = buf_data.get("paths", [])
buf_labels = buf_data.get("labels", [])
replay_count = max(1, len(buf_paths) // 5)
indices = random.sample(range(len(buf_paths)), min(replay_count, len(buf_paths)))
for idx in indices:
try:
img = Image.open(buf_paths[idx]).convert("RGB")
tensor = transform(img)
tensors.append(tensor)
labels.append(float(buf_labels[idx]))
except Exception:
continue
except Exception as e:
logger.warning(f"CNN replay buffer load failed: {e}")
if len(tensors) < 5:
logger.warning(f"Too few CNN samples ({len(tensors)}), skipping update")
return None
# Stack and create dataset
x_data = torch.stack(tensors)
y_data = torch.tensor(labels, dtype=torch.float)
dataset = TensorDataset(x_data, y_data)
loader = DataLoader(dataset, batch_size=8, shuffle=True)
# Pre-update accuracy
model.eval()
pre_correct = 0
with torch.no_grad():
for bx, by in loader:
bx, by = bx.to(device), by.to(device)
out = model(bx).squeeze()
preds = (out >= 0.5).float()
pre_correct += (preds == by).sum().item()
pre_acc = pre_correct / len(dataset)
# Train (head only — backbone stays frozen)
head_params = [p for p in model.backbone.fc.parameters() if p.requires_grad]
optimizer = AdamW(head_params, lr=lr)
loss_fn = nn.BCELoss()
model.train()
for epoch in range(epochs):
total_loss = 0.0
for bx, by in loader:
bx, by = bx.to(device), by.to(device)
optimizer.zero_grad()
out = model(bx).squeeze()
loss = loss_fn(out, by)
loss.backward()
optimizer.step()
total_loss += loss.item()
logger.info(f"CNN incremental epoch {epoch+1}/{epochs}, loss={total_loss/len(loader):.4f}")
# Post-update accuracy
model.eval()
post_correct = 0
with torch.no_grad():
for bx, by in loader:
bx, by = bx.to(device), by.to(device)
out = model(bx).squeeze()
preds = (out >= 0.5).float()
post_correct += (preds == by).sum().item()
post_acc = post_correct / len(dataset)
delta = post_acc - pre_acc
self._model = model
# Save weights
torch.save(model.state_dict(), self._weights_path)
logger.info(f"CNN incremental: {pre_acc:.4f}{post_acc:.4f} (Δ={delta:+.4f})")
return round(delta, 4)
except Exception as e:
logger.error(f"CNN incremental update failed: {e}")
return None
async def _capture_screenshot(self, url: str) -> Optional[bytes]:
"""Capture a screenshot of a URL using Playwright."""
try:
from playwright.async_api import async_playwright
async with async_playwright() as p:
browser = await p.chromium.launch(headless=True)
page = await browser.new_page(viewport={"width": 1280, "height": 800})
# Block heavy resources
await page.route("**/*.{png,jpg,jpeg,gif,svg,mp4,webm,ogg,woff,woff2,ttf,eot}",
lambda route: route.abort())
await page.goto(url, wait_until="domcontentloaded", timeout=10000)
screenshot = await page.screenshot(type="png")
await browser.close()
return screenshot
except Exception as e:
logger.warning(f"Screenshot capture failed: {e}")
return None
@property
def is_loaded(self) -> bool:
return self._loaded