# ============================================================ # PhishGuard AI - retraining_service.py # Incremental retraining service for all 3 ML models. # # Receives labeled feedback samples from the Chrome extension. # Runs parallel incremental updates for BERT, GNN, and CNN. # Tracks model version and accuracy deltas. # Supports hot-reload of all models without server restart. # ============================================================ from __future__ import annotations import asyncio import json import logging import time from dataclasses import dataclass, field from pathlib import Path from typing import Dict, List, Optional, Tuple logger = logging.getLogger("phishguard.retrain") DATA_DIR = Path(__file__).parent / "data" MODEL_VERSION_PATH = DATA_DIR / "model_version.json" @dataclass class FeedbackRecord: """A single feedback record from the Chrome extension.""" url: str verdict: str # "phishing" or "safe" confidence: float = 0.0 tier_used: int = 0 heuristic_score: int = 0 signals: List[str] = field(default_factory=list) user_feedback: Optional[str] = None # "correct" or "incorrect" timestamp: str = "" feedback_ts: Optional[str] = None url_hash: str = "" session_id: str = "" @dataclass class RetrainResult: """Result from a retraining run.""" status: str # "success", "skipped", "error" models_updated: List[str] = field(default_factory=list) samples_used: int = 0 duration_seconds: float = 0.0 accuracy_delta: Dict[str, Optional[float]] = field(default_factory=dict) next_retrain_hint: Dict = field(default_factory=dict) class RetrainingService: """ Orchestrates incremental retraining for all 3 ML models. Called by POST /retrain endpoint. """ def __init__( self, bert_classifier, gnn_inference, cnn_inference, ) -> None: self._bert = bert_classifier self._gnn = gnn_inference self._cnn = cnn_inference self._model_version = self._load_version() def _load_version(self) -> int: """Load current model version from disk.""" MODEL_VERSION_PATH.parent.mkdir(parents=True, exist_ok=True) if MODEL_VERSION_PATH.exists(): try: data = json.loads(MODEL_VERSION_PATH.read_text()) return data.get("version", 0) except Exception: pass return 0 def _save_version(self, accuracy_delta: Dict[str, Optional[float]]) -> None: """Save updated model version to disk.""" self._model_version += 1 data = { "version": self._model_version, "updated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), "accuracy": accuracy_delta, } MODEL_VERSION_PATH.write_text(json.dumps(data, indent=2)) @property def model_version(self) -> int: return self._model_version def get_version_info(self) -> dict: """Get current model version info for GET /model_version.""" if MODEL_VERSION_PATH.exists(): try: return json.loads(MODEL_VERSION_PATH.read_text()) except Exception: pass return { "version": self._model_version, "updated_at": None, "accuracy": {}, } async def retrain( self, samples: List[FeedbackRecord], ) -> RetrainResult: """ Perform incremental retraining on all models. Steps: 1. Validate samples (min 10, URL format check) 2. Separate by tier_used for targeted updates 3. Run BERT + GNN updates in parallel 4. Run CNN update if Tier 4 samples exist 5. Compute accuracy_delta for each model 6. Increment model version 7. Hot-reload all models Returns RetrainResult with status and deltas. """ start_time = time.time() # 1. Validate valid_samples = self._validate_samples(samples) if len(valid_samples) < 10: return RetrainResult( status="skipped", samples_used=len(valid_samples), next_retrain_hint={ "recommended_trigger": "count", "min_samples_needed": 10 - len(valid_samples), }, ) # 2. Convert to (url, label) pairs url_label_pairs: List[Tuple[str, int]] = [] tier4_pairs: List[Tuple[str, int]] = [] for sample in valid_samples: # Determine the true label based on user feedback if sample.user_feedback == "correct": label = 1 if sample.verdict == "phishing" else 0 elif sample.user_feedback == "incorrect": label = 0 if sample.verdict == "phishing" else 1 else: continue url_label_pairs.append((sample.url, label)) if sample.tier_used == 4: tier4_pairs.append((sample.url, label)) if len(url_label_pairs) < 5: return RetrainResult( status="skipped", samples_used=len(url_label_pairs), next_retrain_hint={ "recommended_trigger": "count", "min_samples_needed": 5, }, ) # 3. Run updates models_updated: List[str] = [] accuracy_delta: Dict[str, Optional[float]] = {} try: # BERT + GNN in parallel loop = asyncio.get_event_loop() bert_task = loop.run_in_executor( None, self._bert.incremental_update, url_label_pairs, ) gnn_task = loop.run_in_executor( None, self._gnn.incremental_update, url_label_pairs, ) bert_delta, gnn_delta = await asyncio.gather( bert_task, gnn_task, return_exceptions=True, ) # Process BERT result if isinstance(bert_delta, Exception): logger.error(f"BERT update error: {bert_delta}") accuracy_delta["bert"] = None elif bert_delta is not None: accuracy_delta["bert"] = bert_delta models_updated.append("bert") else: accuracy_delta["bert"] = None # Process GNN result if isinstance(gnn_delta, Exception): logger.error(f"GNN update error: {gnn_delta}") accuracy_delta["gnn"] = None elif gnn_delta is not None: accuracy_delta["gnn"] = gnn_delta models_updated.append("gnn") else: accuracy_delta["gnn"] = None # 4. CNN update (only if Tier 4 samples exist) if tier4_pairs: try: cnn_delta = await self._cnn.incremental_update(tier4_pairs) if cnn_delta is not None: accuracy_delta["cnn"] = cnn_delta models_updated.append("cnn") else: accuracy_delta["cnn"] = None except Exception as e: logger.error(f"CNN update error: {e}") accuracy_delta["cnn"] = None else: accuracy_delta["cnn"] = None # 5. Update version if models_updated: self._save_version(accuracy_delta) # 6. Hot-reload await self._hot_reload(models_updated) duration = time.time() - start_time return RetrainResult( status="success" if models_updated else "skipped", models_updated=models_updated, samples_used=len(url_label_pairs), duration_seconds=round(duration, 2), accuracy_delta=accuracy_delta, next_retrain_hint={ "recommended_trigger": "count", "min_samples_needed": 10, }, ) except Exception as e: logger.error(f"Retraining failed: {e}") return RetrainResult( status="error", duration_seconds=round(time.time() - start_time, 2), accuracy_delta=accuracy_delta, ) def _validate_samples(self, samples: List[FeedbackRecord]) -> List[FeedbackRecord]: """Validate and filter feedback samples.""" valid = [] for s in samples: # Must have user feedback if not s.user_feedback: continue if s.user_feedback not in ("correct", "incorrect"): continue # Must have a valid URL if not s.url or not s.url.startswith(("http://", "https://")): continue valid.append(s) return valid async def _hot_reload(self, models: List[str]) -> None: """Hot-reload updated models in-memory.""" if "bert" in models: try: bert_weights = Path(__file__).parent / "bert_weights" if bert_weights.exists(): self._bert.load_local(bert_weights) logger.info("BERT hot-reloaded") except Exception as e: logger.error(f"BERT hot-reload failed: {e}") if "gnn" in models: try: self._gnn.reload() logger.info("GNN hot-reloaded") except Exception as e: logger.error(f"GNN hot-reload failed: {e}") if "cnn" in models: try: self._cnn.reload() logger.info("CNN hot-reloaded") except Exception as e: logger.error(f"CNN hot-reload failed: {e}")