phishguard-api / retraining_service.py
prashanth135's picture
Upload 38 files
bebe233 verified
# ============================================================
# PhishGuard AI - retraining_service.py
# Incremental retraining service for all 3 ML models.
#
# Receives labeled feedback samples from the Chrome extension.
# Runs parallel incremental updates for BERT, GNN, and CNN.
# Tracks model version and accuracy deltas.
# Supports hot-reload of all models without server restart.
# ============================================================
from __future__ import annotations
import asyncio
import json
import logging
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Tuple
logger = logging.getLogger("phishguard.retrain")
DATA_DIR = Path(__file__).parent / "data"
MODEL_VERSION_PATH = DATA_DIR / "model_version.json"
@dataclass
class FeedbackRecord:
"""A single feedback record from the Chrome extension."""
url: str
verdict: str # "phishing" or "safe"
confidence: float = 0.0
tier_used: int = 0
heuristic_score: int = 0
signals: List[str] = field(default_factory=list)
user_feedback: Optional[str] = None # "correct" or "incorrect"
timestamp: str = ""
feedback_ts: Optional[str] = None
url_hash: str = ""
session_id: str = ""
@dataclass
class RetrainResult:
"""Result from a retraining run."""
status: str # "success", "skipped", "error"
models_updated: List[str] = field(default_factory=list)
samples_used: int = 0
duration_seconds: float = 0.0
accuracy_delta: Dict[str, Optional[float]] = field(default_factory=dict)
next_retrain_hint: Dict = field(default_factory=dict)
class RetrainingService:
"""
Orchestrates incremental retraining for all 3 ML models.
Called by POST /retrain endpoint.
"""
def __init__(
self,
bert_classifier,
gnn_inference,
cnn_inference,
) -> None:
self._bert = bert_classifier
self._gnn = gnn_inference
self._cnn = cnn_inference
self._model_version = self._load_version()
def _load_version(self) -> int:
"""Load current model version from disk."""
MODEL_VERSION_PATH.parent.mkdir(parents=True, exist_ok=True)
if MODEL_VERSION_PATH.exists():
try:
data = json.loads(MODEL_VERSION_PATH.read_text())
return data.get("version", 0)
except Exception:
pass
return 0
def _save_version(self, accuracy_delta: Dict[str, Optional[float]]) -> None:
"""Save updated model version to disk."""
self._model_version += 1
data = {
"version": self._model_version,
"updated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"accuracy": accuracy_delta,
}
MODEL_VERSION_PATH.write_text(json.dumps(data, indent=2))
@property
def model_version(self) -> int:
return self._model_version
def get_version_info(self) -> dict:
"""Get current model version info for GET /model_version."""
if MODEL_VERSION_PATH.exists():
try:
return json.loads(MODEL_VERSION_PATH.read_text())
except Exception:
pass
return {
"version": self._model_version,
"updated_at": None,
"accuracy": {},
}
async def retrain(
self,
samples: List[FeedbackRecord],
) -> RetrainResult:
"""
Perform incremental retraining on all models.
Steps:
1. Validate samples (min 10, URL format check)
2. Separate by tier_used for targeted updates
3. Run BERT + GNN updates in parallel
4. Run CNN update if Tier 4 samples exist
5. Compute accuracy_delta for each model
6. Increment model version
7. Hot-reload all models
Returns RetrainResult with status and deltas.
"""
start_time = time.time()
# 1. Validate
valid_samples = self._validate_samples(samples)
if len(valid_samples) < 10:
return RetrainResult(
status="skipped",
samples_used=len(valid_samples),
next_retrain_hint={
"recommended_trigger": "count",
"min_samples_needed": 10 - len(valid_samples),
},
)
# 2. Convert to (url, label) pairs
url_label_pairs: List[Tuple[str, int]] = []
tier4_pairs: List[Tuple[str, int]] = []
for sample in valid_samples:
# Determine the true label based on user feedback
if sample.user_feedback == "correct":
label = 1 if sample.verdict == "phishing" else 0
elif sample.user_feedback == "incorrect":
label = 0 if sample.verdict == "phishing" else 1
else:
continue
url_label_pairs.append((sample.url, label))
if sample.tier_used == 4:
tier4_pairs.append((sample.url, label))
if len(url_label_pairs) < 5:
return RetrainResult(
status="skipped",
samples_used=len(url_label_pairs),
next_retrain_hint={
"recommended_trigger": "count",
"min_samples_needed": 5,
},
)
# 3. Run updates
models_updated: List[str] = []
accuracy_delta: Dict[str, Optional[float]] = {}
try:
# BERT + GNN in parallel
loop = asyncio.get_event_loop()
bert_task = loop.run_in_executor(
None,
self._bert.incremental_update,
url_label_pairs,
)
gnn_task = loop.run_in_executor(
None,
self._gnn.incremental_update,
url_label_pairs,
)
bert_delta, gnn_delta = await asyncio.gather(
bert_task, gnn_task,
return_exceptions=True,
)
# Process BERT result
if isinstance(bert_delta, Exception):
logger.error(f"BERT update error: {bert_delta}")
accuracy_delta["bert"] = None
elif bert_delta is not None:
accuracy_delta["bert"] = bert_delta
models_updated.append("bert")
else:
accuracy_delta["bert"] = None
# Process GNN result
if isinstance(gnn_delta, Exception):
logger.error(f"GNN update error: {gnn_delta}")
accuracy_delta["gnn"] = None
elif gnn_delta is not None:
accuracy_delta["gnn"] = gnn_delta
models_updated.append("gnn")
else:
accuracy_delta["gnn"] = None
# 4. CNN update (only if Tier 4 samples exist)
if tier4_pairs:
try:
cnn_delta = await self._cnn.incremental_update(tier4_pairs)
if cnn_delta is not None:
accuracy_delta["cnn"] = cnn_delta
models_updated.append("cnn")
else:
accuracy_delta["cnn"] = None
except Exception as e:
logger.error(f"CNN update error: {e}")
accuracy_delta["cnn"] = None
else:
accuracy_delta["cnn"] = None
# 5. Update version
if models_updated:
self._save_version(accuracy_delta)
# 6. Hot-reload
await self._hot_reload(models_updated)
duration = time.time() - start_time
return RetrainResult(
status="success" if models_updated else "skipped",
models_updated=models_updated,
samples_used=len(url_label_pairs),
duration_seconds=round(duration, 2),
accuracy_delta=accuracy_delta,
next_retrain_hint={
"recommended_trigger": "count",
"min_samples_needed": 10,
},
)
except Exception as e:
logger.error(f"Retraining failed: {e}")
return RetrainResult(
status="error",
duration_seconds=round(time.time() - start_time, 2),
accuracy_delta=accuracy_delta,
)
def _validate_samples(self, samples: List[FeedbackRecord]) -> List[FeedbackRecord]:
"""Validate and filter feedback samples."""
valid = []
for s in samples:
# Must have user feedback
if not s.user_feedback:
continue
if s.user_feedback not in ("correct", "incorrect"):
continue
# Must have a valid URL
if not s.url or not s.url.startswith(("http://", "https://")):
continue
valid.append(s)
return valid
async def _hot_reload(self, models: List[str]) -> None:
"""Hot-reload updated models in-memory."""
if "bert" in models:
try:
bert_weights = Path(__file__).parent / "bert_weights"
if bert_weights.exists():
self._bert.load_local(bert_weights)
logger.info("BERT hot-reloaded")
except Exception as e:
logger.error(f"BERT hot-reload failed: {e}")
if "gnn" in models:
try:
self._gnn.reload()
logger.info("GNN hot-reloaded")
except Exception as e:
logger.error(f"GNN hot-reload failed: {e}")
if "cnn" in models:
try:
self._cnn.reload()
logger.info("CNN hot-reloaded")
except Exception as e:
logger.error(f"CNN hot-reload failed: {e}")