Spaces:
Running
Running
| # ============================================================ | |
| # PhishGuard AI - retraining_service.py | |
| # Incremental retraining service for all 3 ML models. | |
| # | |
| # Receives labeled feedback samples from the Chrome extension. | |
| # Runs parallel incremental updates for BERT, GNN, and CNN. | |
| # Tracks model version and accuracy deltas. | |
| # Supports hot-reload of all models without server restart. | |
| # ============================================================ | |
| from __future__ import annotations | |
| import asyncio | |
| import json | |
| import logging | |
| import time | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple | |
| logger = logging.getLogger("phishguard.retrain") | |
| DATA_DIR = Path(__file__).parent / "data" | |
| MODEL_VERSION_PATH = DATA_DIR / "model_version.json" | |
| class FeedbackRecord: | |
| """A single feedback record from the Chrome extension.""" | |
| url: str | |
| verdict: str # "phishing" or "safe" | |
| confidence: float = 0.0 | |
| tier_used: int = 0 | |
| heuristic_score: int = 0 | |
| signals: List[str] = field(default_factory=list) | |
| user_feedback: Optional[str] = None # "correct" or "incorrect" | |
| timestamp: str = "" | |
| feedback_ts: Optional[str] = None | |
| url_hash: str = "" | |
| session_id: str = "" | |
| class RetrainResult: | |
| """Result from a retraining run.""" | |
| status: str # "success", "skipped", "error" | |
| models_updated: List[str] = field(default_factory=list) | |
| samples_used: int = 0 | |
| duration_seconds: float = 0.0 | |
| accuracy_delta: Dict[str, Optional[float]] = field(default_factory=dict) | |
| next_retrain_hint: Dict = field(default_factory=dict) | |
| class RetrainingService: | |
| """ | |
| Orchestrates incremental retraining for all 3 ML models. | |
| Called by POST /retrain endpoint. | |
| """ | |
| def __init__( | |
| self, | |
| bert_classifier, | |
| gnn_inference, | |
| cnn_inference, | |
| ) -> None: | |
| self._bert = bert_classifier | |
| self._gnn = gnn_inference | |
| self._cnn = cnn_inference | |
| self._model_version = self._load_version() | |
| def _load_version(self) -> int: | |
| """Load current model version from disk.""" | |
| MODEL_VERSION_PATH.parent.mkdir(parents=True, exist_ok=True) | |
| if MODEL_VERSION_PATH.exists(): | |
| try: | |
| data = json.loads(MODEL_VERSION_PATH.read_text()) | |
| return data.get("version", 0) | |
| except Exception: | |
| pass | |
| return 0 | |
| def _save_version(self, accuracy_delta: Dict[str, Optional[float]]) -> None: | |
| """Save updated model version to disk.""" | |
| self._model_version += 1 | |
| data = { | |
| "version": self._model_version, | |
| "updated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), | |
| "accuracy": accuracy_delta, | |
| } | |
| MODEL_VERSION_PATH.write_text(json.dumps(data, indent=2)) | |
| def model_version(self) -> int: | |
| return self._model_version | |
| def get_version_info(self) -> dict: | |
| """Get current model version info for GET /model_version.""" | |
| if MODEL_VERSION_PATH.exists(): | |
| try: | |
| return json.loads(MODEL_VERSION_PATH.read_text()) | |
| except Exception: | |
| pass | |
| return { | |
| "version": self._model_version, | |
| "updated_at": None, | |
| "accuracy": {}, | |
| } | |
| async def retrain( | |
| self, | |
| samples: List[FeedbackRecord], | |
| ) -> RetrainResult: | |
| """ | |
| Perform incremental retraining on all models. | |
| Steps: | |
| 1. Validate samples (min 10, URL format check) | |
| 2. Separate by tier_used for targeted updates | |
| 3. Run BERT + GNN updates in parallel | |
| 4. Run CNN update if Tier 4 samples exist | |
| 5. Compute accuracy_delta for each model | |
| 6. Increment model version | |
| 7. Hot-reload all models | |
| Returns RetrainResult with status and deltas. | |
| """ | |
| start_time = time.time() | |
| # 1. Validate | |
| valid_samples = self._validate_samples(samples) | |
| if len(valid_samples) < 10: | |
| return RetrainResult( | |
| status="skipped", | |
| samples_used=len(valid_samples), | |
| next_retrain_hint={ | |
| "recommended_trigger": "count", | |
| "min_samples_needed": 10 - len(valid_samples), | |
| }, | |
| ) | |
| # 2. Convert to (url, label) pairs | |
| url_label_pairs: List[Tuple[str, int]] = [] | |
| tier4_pairs: List[Tuple[str, int]] = [] | |
| for sample in valid_samples: | |
| # Determine the true label based on user feedback | |
| if sample.user_feedback == "correct": | |
| label = 1 if sample.verdict == "phishing" else 0 | |
| elif sample.user_feedback == "incorrect": | |
| label = 0 if sample.verdict == "phishing" else 1 | |
| else: | |
| continue | |
| url_label_pairs.append((sample.url, label)) | |
| if sample.tier_used == 4: | |
| tier4_pairs.append((sample.url, label)) | |
| if len(url_label_pairs) < 5: | |
| return RetrainResult( | |
| status="skipped", | |
| samples_used=len(url_label_pairs), | |
| next_retrain_hint={ | |
| "recommended_trigger": "count", | |
| "min_samples_needed": 5, | |
| }, | |
| ) | |
| # 3. Run updates | |
| models_updated: List[str] = [] | |
| accuracy_delta: Dict[str, Optional[float]] = {} | |
| try: | |
| # BERT + GNN in parallel | |
| loop = asyncio.get_event_loop() | |
| bert_task = loop.run_in_executor( | |
| None, | |
| self._bert.incremental_update, | |
| url_label_pairs, | |
| ) | |
| gnn_task = loop.run_in_executor( | |
| None, | |
| self._gnn.incremental_update, | |
| url_label_pairs, | |
| ) | |
| bert_delta, gnn_delta = await asyncio.gather( | |
| bert_task, gnn_task, | |
| return_exceptions=True, | |
| ) | |
| # Process BERT result | |
| if isinstance(bert_delta, Exception): | |
| logger.error(f"BERT update error: {bert_delta}") | |
| accuracy_delta["bert"] = None | |
| elif bert_delta is not None: | |
| accuracy_delta["bert"] = bert_delta | |
| models_updated.append("bert") | |
| else: | |
| accuracy_delta["bert"] = None | |
| # Process GNN result | |
| if isinstance(gnn_delta, Exception): | |
| logger.error(f"GNN update error: {gnn_delta}") | |
| accuracy_delta["gnn"] = None | |
| elif gnn_delta is not None: | |
| accuracy_delta["gnn"] = gnn_delta | |
| models_updated.append("gnn") | |
| else: | |
| accuracy_delta["gnn"] = None | |
| # 4. CNN update (only if Tier 4 samples exist) | |
| if tier4_pairs: | |
| try: | |
| cnn_delta = await self._cnn.incremental_update(tier4_pairs) | |
| if cnn_delta is not None: | |
| accuracy_delta["cnn"] = cnn_delta | |
| models_updated.append("cnn") | |
| else: | |
| accuracy_delta["cnn"] = None | |
| except Exception as e: | |
| logger.error(f"CNN update error: {e}") | |
| accuracy_delta["cnn"] = None | |
| else: | |
| accuracy_delta["cnn"] = None | |
| # 5. Update version | |
| if models_updated: | |
| self._save_version(accuracy_delta) | |
| # 6. Hot-reload | |
| await self._hot_reload(models_updated) | |
| duration = time.time() - start_time | |
| return RetrainResult( | |
| status="success" if models_updated else "skipped", | |
| models_updated=models_updated, | |
| samples_used=len(url_label_pairs), | |
| duration_seconds=round(duration, 2), | |
| accuracy_delta=accuracy_delta, | |
| next_retrain_hint={ | |
| "recommended_trigger": "count", | |
| "min_samples_needed": 10, | |
| }, | |
| ) | |
| except Exception as e: | |
| logger.error(f"Retraining failed: {e}") | |
| return RetrainResult( | |
| status="error", | |
| duration_seconds=round(time.time() - start_time, 2), | |
| accuracy_delta=accuracy_delta, | |
| ) | |
| def _validate_samples(self, samples: List[FeedbackRecord]) -> List[FeedbackRecord]: | |
| """Validate and filter feedback samples.""" | |
| valid = [] | |
| for s in samples: | |
| # Must have user feedback | |
| if not s.user_feedback: | |
| continue | |
| if s.user_feedback not in ("correct", "incorrect"): | |
| continue | |
| # Must have a valid URL | |
| if not s.url or not s.url.startswith(("http://", "https://")): | |
| continue | |
| valid.append(s) | |
| return valid | |
| async def _hot_reload(self, models: List[str]) -> None: | |
| """Hot-reload updated models in-memory.""" | |
| if "bert" in models: | |
| try: | |
| bert_weights = Path(__file__).parent / "bert_weights" | |
| if bert_weights.exists(): | |
| self._bert.load_local(bert_weights) | |
| logger.info("BERT hot-reloaded") | |
| except Exception as e: | |
| logger.error(f"BERT hot-reload failed: {e}") | |
| if "gnn" in models: | |
| try: | |
| self._gnn.reload() | |
| logger.info("GNN hot-reloaded") | |
| except Exception as e: | |
| logger.error(f"GNN hot-reload failed: {e}") | |
| if "cnn" in models: | |
| try: | |
| self._cnn.reload() | |
| logger.info("CNN hot-reloaded") | |
| except Exception as e: | |
| logger.error(f"CNN hot-reload failed: {e}") | |