File size: 10,050 Bytes
bebe233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
# ============================================================
# PhishGuard AI - retraining_service.py
# Incremental retraining service for all 3 ML models.
#
# Receives labeled feedback samples from the Chrome extension.
# Runs parallel incremental updates for BERT, GNN, and CNN.
# Tracks model version and accuracy deltas.
# Supports hot-reload of all models without server restart.
# ============================================================

from __future__ import annotations

import asyncio
import json
import logging
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Tuple

logger = logging.getLogger("phishguard.retrain")

DATA_DIR = Path(__file__).parent / "data"
MODEL_VERSION_PATH = DATA_DIR / "model_version.json"


@dataclass
class FeedbackRecord:
    """A single feedback record from the Chrome extension."""
    url: str
    verdict: str                        # "phishing" or "safe"
    confidence: float = 0.0
    tier_used: int = 0
    heuristic_score: int = 0
    signals: List[str] = field(default_factory=list)
    user_feedback: Optional[str] = None  # "correct" or "incorrect"
    timestamp: str = ""
    feedback_ts: Optional[str] = None
    url_hash: str = ""
    session_id: str = ""


@dataclass
class RetrainResult:
    """Result from a retraining run."""
    status: str                                      # "success", "skipped", "error"
    models_updated: List[str] = field(default_factory=list)
    samples_used: int = 0
    duration_seconds: float = 0.0
    accuracy_delta: Dict[str, Optional[float]] = field(default_factory=dict)
    next_retrain_hint: Dict = field(default_factory=dict)


class RetrainingService:
    """
    Orchestrates incremental retraining for all 3 ML models.
    Called by POST /retrain endpoint.
    """

    def __init__(
        self,
        bert_classifier,
        gnn_inference,
        cnn_inference,
    ) -> None:
        self._bert = bert_classifier
        self._gnn = gnn_inference
        self._cnn = cnn_inference
        self._model_version = self._load_version()

    def _load_version(self) -> int:
        """Load current model version from disk."""
        MODEL_VERSION_PATH.parent.mkdir(parents=True, exist_ok=True)
        if MODEL_VERSION_PATH.exists():
            try:
                data = json.loads(MODEL_VERSION_PATH.read_text())
                return data.get("version", 0)
            except Exception:
                pass
        return 0

    def _save_version(self, accuracy_delta: Dict[str, Optional[float]]) -> None:
        """Save updated model version to disk."""
        self._model_version += 1
        data = {
            "version": self._model_version,
            "updated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
            "accuracy": accuracy_delta,
        }
        MODEL_VERSION_PATH.write_text(json.dumps(data, indent=2))

    @property
    def model_version(self) -> int:
        return self._model_version

    def get_version_info(self) -> dict:
        """Get current model version info for GET /model_version."""
        if MODEL_VERSION_PATH.exists():
            try:
                return json.loads(MODEL_VERSION_PATH.read_text())
            except Exception:
                pass
        return {
            "version": self._model_version,
            "updated_at": None,
            "accuracy": {},
        }

    async def retrain(
        self,
        samples: List[FeedbackRecord],
    ) -> RetrainResult:
        """
        Perform incremental retraining on all models.

        Steps:
          1. Validate samples (min 10, URL format check)
          2. Separate by tier_used for targeted updates
          3. Run BERT + GNN updates in parallel
          4. Run CNN update if Tier 4 samples exist
          5. Compute accuracy_delta for each model
          6. Increment model version
          7. Hot-reload all models

        Returns RetrainResult with status and deltas.
        """
        start_time = time.time()

        # 1. Validate
        valid_samples = self._validate_samples(samples)
        if len(valid_samples) < 10:
            return RetrainResult(
                status="skipped",
                samples_used=len(valid_samples),
                next_retrain_hint={
                    "recommended_trigger": "count",
                    "min_samples_needed": 10 - len(valid_samples),
                },
            )

        # 2. Convert to (url, label) pairs
        url_label_pairs: List[Tuple[str, int]] = []
        tier4_pairs: List[Tuple[str, int]] = []

        for sample in valid_samples:
            # Determine the true label based on user feedback
            if sample.user_feedback == "correct":
                label = 1 if sample.verdict == "phishing" else 0
            elif sample.user_feedback == "incorrect":
                label = 0 if sample.verdict == "phishing" else 1
            else:
                continue

            url_label_pairs.append((sample.url, label))
            if sample.tier_used == 4:
                tier4_pairs.append((sample.url, label))

        if len(url_label_pairs) < 5:
            return RetrainResult(
                status="skipped",
                samples_used=len(url_label_pairs),
                next_retrain_hint={
                    "recommended_trigger": "count",
                    "min_samples_needed": 5,
                },
            )

        # 3. Run updates
        models_updated: List[str] = []
        accuracy_delta: Dict[str, Optional[float]] = {}

        try:
            # BERT + GNN in parallel
            loop = asyncio.get_event_loop()

            bert_task = loop.run_in_executor(
                None,
                self._bert.incremental_update,
                url_label_pairs,
            )

            gnn_task = loop.run_in_executor(
                None,
                self._gnn.incremental_update,
                url_label_pairs,
            )

            bert_delta, gnn_delta = await asyncio.gather(
                bert_task, gnn_task,
                return_exceptions=True,
            )

            # Process BERT result
            if isinstance(bert_delta, Exception):
                logger.error(f"BERT update error: {bert_delta}")
                accuracy_delta["bert"] = None
            elif bert_delta is not None:
                accuracy_delta["bert"] = bert_delta
                models_updated.append("bert")
            else:
                accuracy_delta["bert"] = None

            # Process GNN result
            if isinstance(gnn_delta, Exception):
                logger.error(f"GNN update error: {gnn_delta}")
                accuracy_delta["gnn"] = None
            elif gnn_delta is not None:
                accuracy_delta["gnn"] = gnn_delta
                models_updated.append("gnn")
            else:
                accuracy_delta["gnn"] = None

            # 4. CNN update (only if Tier 4 samples exist)
            if tier4_pairs:
                try:
                    cnn_delta = await self._cnn.incremental_update(tier4_pairs)
                    if cnn_delta is not None:
                        accuracy_delta["cnn"] = cnn_delta
                        models_updated.append("cnn")
                    else:
                        accuracy_delta["cnn"] = None
                except Exception as e:
                    logger.error(f"CNN update error: {e}")
                    accuracy_delta["cnn"] = None
            else:
                accuracy_delta["cnn"] = None

            # 5. Update version
            if models_updated:
                self._save_version(accuracy_delta)

            # 6. Hot-reload
            await self._hot_reload(models_updated)

            duration = time.time() - start_time

            return RetrainResult(
                status="success" if models_updated else "skipped",
                models_updated=models_updated,
                samples_used=len(url_label_pairs),
                duration_seconds=round(duration, 2),
                accuracy_delta=accuracy_delta,
                next_retrain_hint={
                    "recommended_trigger": "count",
                    "min_samples_needed": 10,
                },
            )

        except Exception as e:
            logger.error(f"Retraining failed: {e}")
            return RetrainResult(
                status="error",
                duration_seconds=round(time.time() - start_time, 2),
                accuracy_delta=accuracy_delta,
            )

    def _validate_samples(self, samples: List[FeedbackRecord]) -> List[FeedbackRecord]:
        """Validate and filter feedback samples."""
        valid = []
        for s in samples:
            # Must have user feedback
            if not s.user_feedback:
                continue
            if s.user_feedback not in ("correct", "incorrect"):
                continue
            # Must have a valid URL
            if not s.url or not s.url.startswith(("http://", "https://")):
                continue
            valid.append(s)
        return valid

    async def _hot_reload(self, models: List[str]) -> None:
        """Hot-reload updated models in-memory."""
        if "bert" in models:
            try:
                bert_weights = Path(__file__).parent / "bert_weights"
                if bert_weights.exists():
                    self._bert.load_local(bert_weights)
                    logger.info("BERT hot-reloaded")
            except Exception as e:
                logger.error(f"BERT hot-reload failed: {e}")

        if "gnn" in models:
            try:
                self._gnn.reload()
                logger.info("GNN hot-reloaded")
            except Exception as e:
                logger.error(f"GNN hot-reload failed: {e}")

        if "cnn" in models:
            try:
                self._cnn.reload()
                logger.info("CNN hot-reloaded")
            except Exception as e:
                logger.error(f"CNN hot-reload failed: {e}")