Spaces:
Running
Running
| # ============================================================ | |
| # PhishGuard AI - gnn/gnn_inference.py | |
| # GNN inference wrapper for main.py. | |
| # Loads model once at startup, reuses for every request. | |
| # Supports: predict, hot-reload, incremental_update. | |
| # ============================================================ | |
| from __future__ import annotations | |
| import os | |
| import sys | |
| import random | |
| import logging | |
| from pathlib import Path | |
| from typing import List, Optional, Tuple | |
| import torch | |
| logger = logging.getLogger("phishguard.gnn.inference") | |
| # Add parent paths | |
| _GNN_DIR = Path(__file__).parent | |
| _BACKEND_DIR = _GNN_DIR.parent | |
| sys.path.insert(0, str(_GNN_DIR)) | |
| sys.path.insert(0, str(_BACKEND_DIR)) | |
| from domain_graph_builder import DomainGraphBuilder | |
| from gnn_model import load_gnn_model, PhishMLP, PYGEOM_AVAILABLE, INPUT_DIM | |
| if PYGEOM_AVAILABLE: | |
| from gnn_model import PhishGNN | |
| MODEL_PATH = _GNN_DIR / "gnn_weights.pt" | |
| REPLAY_BUFFER_PATH = _BACKEND_DIR / "data" / "gnn_replay_buffer.pt" | |
| class GNNInference: | |
| """ | |
| GNN inference wrapper with hot-reload and incremental update support. | |
| """ | |
| def __init__(self, weights_path: Optional[Path] = None) -> None: | |
| self._weights_path = weights_path or MODEL_PATH | |
| self._model: Optional[torch.nn.Module] = None | |
| self._builder = DomainGraphBuilder() | |
| self._loaded = False | |
| def load(self, weights_path: Optional[Path] = None) -> bool: | |
| """Load GNN model from weights file.""" | |
| path = weights_path or self._weights_path | |
| self._model = load_gnn_model(str(path) if path.exists() else None) | |
| self._loaded = self._model is not None | |
| if self._loaded: | |
| logger.info(f"GNN model loaded from {path}") | |
| return self._loaded | |
| def predict(self, url: str, related_urls: Optional[List[str]] = None) -> float: | |
| """ | |
| Predict phishing probability for a URL. | |
| Returns P_gnn β [0,1]. | |
| Falls back to MLP if model unavailable or graph too small. | |
| """ | |
| if not self._loaded: | |
| self.load() | |
| if self._model is None: | |
| return 0.5 # Neutral when model unavailable | |
| urls = [url] + (related_urls or []) | |
| # Single URL β MLP fallback path | |
| if len(urls) == 1: | |
| graph = self._builder.build_single_node_graph(url) | |
| else: | |
| graph = self._builder.build_graph(urls) | |
| x = torch.tensor(graph["features"], dtype=torch.float) | |
| edges = graph["edges"] | |
| if edges and len(edges) > 0: | |
| edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous() | |
| else: | |
| n = x.size(0) | |
| edge_index = torch.arange(n).unsqueeze(0).repeat(2, 1) | |
| prob = self._model.predict_proba(x, edge_index) | |
| return round(float(prob), 4) | |
| def reload(self, weights_path: Optional[Path] = None) -> bool: | |
| """Hot-reload model with new weights (no server restart needed).""" | |
| path = weights_path or self._weights_path | |
| new_model = load_gnn_model(str(path)) | |
| if new_model is not None: | |
| self._model = new_model | |
| self._loaded = True | |
| logger.info(f"GNN model hot-reloaded from {path}") | |
| return True | |
| logger.warning(f"GNN hot-reload failed from {path}") | |
| return False | |
| def incremental_update( | |
| self, | |
| samples: List[Tuple[str, int]], | |
| replay_buffer_path: Optional[Path] = None, | |
| lr: float = 5e-4, | |
| epochs: int = 5, | |
| ) -> Optional[float]: | |
| """ | |
| Incremental update on feedback samples + replay buffer. | |
| Returns accuracy_delta or None if failed. | |
| samples: list of (url, label) where label is 0 or 1 | |
| """ | |
| if self._model is None: | |
| logger.warning("GNN not loaded, cannot incrementally update") | |
| return None | |
| if len(samples) < 5: | |
| logger.warning(f"Too few samples ({len(samples)}) for GNN update") | |
| return None | |
| try: | |
| import torch.nn.functional as F | |
| device = torch.device("cpu") | |
| model = self._model.to(device) | |
| builder = DomainGraphBuilder() | |
| # Build graphs from new feedback | |
| new_graphs = [] | |
| CHUNK = 4 | |
| phish = [url for url, label in samples if label == 1] | |
| legit = [url for url, label in samples if label == 0] | |
| for urls, label in [(phish, 1), (legit, 0)]: | |
| for i in range(0, len(urls), CHUNK): | |
| chunk = urls[i:i + CHUNK] | |
| if not chunk: | |
| continue | |
| graph = builder.build_graph(chunk) | |
| x = torch.tensor(graph["features"], dtype=torch.float) | |
| edges = graph["edges"] | |
| if edges: | |
| ei = torch.tensor(edges, dtype=torch.long).t().contiguous() | |
| else: | |
| n = x.size(0) | |
| ei = torch.arange(n).unsqueeze(0).repeat(2, 1) | |
| new_graphs.append({ | |
| "x": x, "edge_index": ei, | |
| "y": torch.tensor([float(label)]), | |
| }) | |
| # Load replay buffer (20% mix) | |
| buf_path = replay_buffer_path or REPLAY_BUFFER_PATH | |
| replay_graphs = [] | |
| if buf_path.exists(): | |
| try: | |
| all_replay = torch.load(buf_path, map_location="cpu", weights_only=False) | |
| replay_count = max(1, len(all_replay) // 5) # 20% | |
| replay_graphs = random.sample(all_replay, min(replay_count, len(all_replay))) | |
| except Exception as e: | |
| logger.warning(f"Replay buffer load failed: {e}") | |
| # Merge: 80% new + 20% replay | |
| dataset = new_graphs + replay_graphs | |
| random.shuffle(dataset) | |
| if not dataset: | |
| return None | |
| # Pre-update accuracy | |
| model.eval() | |
| pre_correct = 0 | |
| with torch.no_grad(): | |
| for item in dataset: | |
| out = model(item["x"].to(device), item["edge_index"].to(device)) | |
| pred = 1 if out.squeeze().item() >= 0.5 else 0 | |
| pre_correct += int(pred == int(item["y"].item())) | |
| pre_acc = pre_correct / len(dataset) | |
| # Train | |
| optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4) | |
| model.train() | |
| for epoch in range(epochs): | |
| random.shuffle(dataset) | |
| total_loss = 0.0 | |
| for item in dataset: | |
| x = item["x"].to(device) | |
| ei = item["edge_index"].to(device) | |
| y = item["y"].to(device) | |
| optimizer.zero_grad() | |
| out = model(x, ei) | |
| loss = F.binary_cross_entropy(out.squeeze(), y.squeeze()) | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() | |
| logger.info(f"GNN incremental epoch {epoch+1}/{epochs}, loss={total_loss/len(dataset):.4f}") | |
| # Post-update accuracy | |
| model.eval() | |
| post_correct = 0 | |
| with torch.no_grad(): | |
| for item in dataset: | |
| out = model(item["x"].to(device), item["edge_index"].to(device)) | |
| pred = 1 if out.squeeze().item() >= 0.5 else 0 | |
| post_correct += int(pred == int(item["y"].item())) | |
| post_acc = post_correct / len(dataset) | |
| delta = post_acc - pre_acc | |
| self._model = model | |
| # Save weights | |
| torch.save(model.state_dict(), self._weights_path) | |
| logger.info(f"GNN incremental update: {pre_acc:.4f} β {post_acc:.4f} (Ξ={delta:+.4f})") | |
| # Update replay buffer (rolling 500) | |
| try: | |
| existing = [] | |
| if buf_path.exists(): | |
| existing = torch.load(buf_path, map_location="cpu", weights_only=False) | |
| combined = existing + new_graphs | |
| if len(combined) > 500: | |
| combined = combined[-500:] | |
| buf_path.parent.mkdir(parents=True, exist_ok=True) | |
| torch.save(combined, buf_path) | |
| except Exception as e: | |
| logger.warning(f"Replay buffer update failed: {e}") | |
| return round(delta, 4) | |
| except Exception as e: | |
| logger.error(f"GNN incremental update failed: {e}") | |
| return None | |
| def is_loaded(self) -> bool: | |
| return self._loaded | |
| # ββ Legacy compatibility functions βββββββββββββββββββββββββββββββββββ | |
| _inference = GNNInference() | |
| def analyze_url_with_gnn(url: str, related_urls: list = None) -> dict: | |
| """Legacy wrapper for backward compatibility.""" | |
| if not _inference.is_loaded: | |
| _inference.load() | |
| if not _inference.is_loaded: | |
| return { | |
| "gnn_phish_prob": None, | |
| "tier3_status": "model_not_loaded", | |
| "node_count": 0, | |
| "edge_count": 0, | |
| "graph_suspicious": False, | |
| } | |
| prob = _inference.predict(url, related_urls) | |
| return { | |
| "gnn_phish_prob": prob, | |
| "node_count": 1 + len(related_urls or []), | |
| "edge_count": 0, | |
| "graph_suspicious": prob > 0.6, | |
| } | |
| def reload_model(new_weights_path: str = None) -> bool: | |
| path = Path(new_weights_path) if new_weights_path else None | |
| return _inference.reload(path) | |
| def is_model_loaded() -> bool: | |
| return _inference.is_loaded | |