prashanth135 commited on
Commit
6663b5f
Β·
verified Β·
1 Parent(s): d9246d4

Upload 2 files

Browse files
Files changed (2) hide show
  1. main.py +700 -0
  2. requirements.txt +19 -0
main.py ADDED
@@ -0,0 +1,700 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================
2
+ # PhishGuard AI - main.py
3
+ # FastAPI orchestrator β€” Full 4-tier phishing detection pipeline
4
+ # with feedback-driven incremental retraining.
5
+ #
6
+ # Endpoints:
7
+ # POST /analyze β†’ 4-tier URL phishing analysis
8
+ # POST /analyze/email β†’ BERT-only email body analysis
9
+ # POST /retrain β†’ Incremental model retraining
10
+ # GET /model_version β†’ Current model version info
11
+ # GET /health β†’ All model load statuses
12
+ #
13
+ # Architecture:
14
+ # Tier 1: Whitelist O(1) β†’ SAFE exit (~55% traffic)
15
+ # Tier 2: Heuristic 15 signals β†’ BLOCK if >= 80 (~15% blocked)
16
+ # Tier 3: BERT+GNN parallel β†’ BLOCK/SAFE/escalate (~15% exits)
17
+ # Tier 4: CNN visual + brand hash β†’ BLOCK/SAFE (~15% borderline)
18
+ # ============================================================
19
+
20
+ from __future__ import annotations
21
+
22
+ import os
23
+ import sys
24
+ import asyncio
25
+ import time
26
+ import hashlib
27
+ import logging
28
+ import logging.handlers
29
+ from collections import OrderedDict
30
+ from contextlib import asynccontextmanager
31
+ from pathlib import Path
32
+ from typing import List, Optional
33
+
34
+ from fastapi import FastAPI
35
+ from fastapi.middleware.cors import CORSMiddleware
36
+ from pydantic import BaseModel
37
+
38
+ # ── Path setup ────────────────────────────────────────────────────────
39
+ BASE_DIR = Path(__file__).parent
40
+ for sub_dir in ["gnn", "cnn"]:
41
+ sub_path = BASE_DIR / sub_dir
42
+ if sub_path.is_dir():
43
+ sys.path.insert(0, str(sub_path))
44
+
45
+ # ── Logging ───────────────────────────────────────────────────────────
46
+ log_dir = BASE_DIR / "logs"
47
+ log_dir.mkdir(exist_ok=True)
48
+
49
+ _handler = logging.handlers.RotatingFileHandler(
50
+ log_dir / "phishguard.log",
51
+ maxBytes=5 * 1024 * 1024,
52
+ backupCount=3,
53
+ encoding="utf-8",
54
+ )
55
+ _handler.setFormatter(logging.Formatter(
56
+ "%(asctime)s | %(levelname)-7s | %(name)s | %(message)s",
57
+ datefmt="%Y-%m-%d %H:%M:%S",
58
+ ))
59
+
60
+ logger = logging.getLogger("phishguard")
61
+ logger.setLevel(logging.INFO)
62
+ logger.addHandler(_handler)
63
+ logger.addHandler(logging.StreamHandler())
64
+
65
+
66
+ # ── Import project modules ───────────────────────────────────────────
67
+ from url_heuristics import HeuristicScorer, HeuristicResult
68
+ from bert_analyzer import BERTPhishingClassifier
69
+
70
+ # GNN imports
71
+ GNN_AVAILABLE = False
72
+ gnn_inference = None
73
+ try:
74
+ from gnn.gnn_inference import GNNInference
75
+ GNN_AVAILABLE = True
76
+ except ImportError:
77
+ try:
78
+ from gnn_inference import GNNInference
79
+ GNN_AVAILABLE = True
80
+ except ImportError:
81
+ logger.warning("GNN module not available")
82
+
83
+ # CNN imports
84
+ CNN_AVAILABLE = False
85
+ cnn_inference = None
86
+ brand_detector = None
87
+ try:
88
+ from cnn.cnn_inference import CNNInference
89
+ from cnn.screenshot_hasher import BrandHashDetector
90
+ from cnn.cnn_model import preprocess_screenshot
91
+ CNN_AVAILABLE = True
92
+ except ImportError:
93
+ try:
94
+ from cnn_inference import CNNInference
95
+ from screenshot_hasher import BrandHashDetector
96
+ from cnn_model import preprocess_screenshot
97
+ CNN_AVAILABLE = True
98
+ except ImportError:
99
+ logger.warning("CNN module not available")
100
+
101
+ from tier3_bert_gnn import Tier3Ensemble
102
+ from retraining_service import RetrainingService, FeedbackRecord, RetrainResult
103
+
104
+
105
+ # ── Whitelist (Tier 1) ────────────────────────────────────────────────
106
+ WHITELIST: set[str] = {
107
+ "google.com", "youtube.com", "facebook.com", "amazon.com", "wikipedia.org",
108
+ "twitter.com", "instagram.com", "linkedin.com", "microsoft.com", "apple.com",
109
+ "github.com", "stackoverflow.com", "reddit.com", "netflix.com", "paypal.com",
110
+ "bankofamerica.com", "chase.com", "wellsfargo.com", "yahoo.com", "bing.com",
111
+ "outlook.com", "office.com", "live.com", "adobe.com", "dropbox.com",
112
+ "zoom.us", "slack.com", "spotify.com", "twitch.tv", "ebay.com",
113
+ "walmart.com", "target.com", "bestbuy.com", "airbnb.com",
114
+ "x.com", "tiktok.com", "pinterest.com", "quora.com", "medium.com",
115
+ }
116
+
117
+
118
+ def get_root_domain(url: str) -> str:
119
+ """Extract root domain from a URL."""
120
+ from urllib.parse import urlparse
121
+ try:
122
+ host = urlparse(url).hostname or ""
123
+ host = host.replace("www.", "")
124
+ parts = host.split(".")
125
+ return ".".join(parts[-2:]) if len(parts) >= 2 else host
126
+ except Exception:
127
+ return ""
128
+
129
+
130
+ # ── URL Cache (LRU, 30-min TTL) ──────────────────────────────────────
131
+ CACHE_TTL = 30 * 60
132
+ CACHE_MAX = 500
133
+
134
+
135
+ class URLCache:
136
+ def __init__(self, maxsize: int = CACHE_MAX, ttl: int = CACHE_TTL) -> None:
137
+ self._cache: OrderedDict = OrderedDict()
138
+ self._maxsize = maxsize
139
+ self._ttl = ttl
140
+
141
+ def get(self, url: str) -> Optional[dict]:
142
+ if url in self._cache:
143
+ entry = self._cache[url]
144
+ if time.time() - entry["ts"] < self._ttl:
145
+ self._cache.move_to_end(url)
146
+ return entry["result"]
147
+ else:
148
+ del self._cache[url]
149
+ return None
150
+
151
+ def set(self, url: str, result: dict) -> None:
152
+ self._cache[url] = {"result": result, "ts": time.time()}
153
+ self._cache.move_to_end(url)
154
+ if len(self._cache) > self._maxsize:
155
+ self._cache.popitem(last=False)
156
+
157
+ def clear(self) -> None:
158
+ self._cache.clear()
159
+
160
+
161
+ _url_cache = URLCache()
162
+
163
+
164
+ # ── Request/Response Models ───────────────────────────────────────────
165
+ class AnalyzeRequest(BaseModel):
166
+ url: str
167
+ heuristic_score: float = 0.0
168
+ page_title: str = ""
169
+ page_snippet: str = ""
170
+ related_urls: list = []
171
+
172
+
173
+ class EmailRequest(BaseModel):
174
+ sender: str
175
+ subject: str = ""
176
+ body: str = ""
177
+ urls: list = []
178
+ timestamp: str = ""
179
+
180
+
181
+ class FeedbackSample(BaseModel):
182
+ url: str
183
+ verdict: str = ""
184
+ confidence: float = 0.0
185
+ tier_used: int = 0
186
+ heuristic_score: int = 0
187
+ signals: list = []
188
+ user_feedback: Optional[str] = None
189
+ timestamp: str = ""
190
+ feedback_ts: Optional[str] = None
191
+ url_hash: str = ""
192
+ session_id: str = ""
193
+
194
+
195
+ class RetrainRequest(BaseModel):
196
+ samples: List[FeedbackSample]
197
+ trigger: str = "count"
198
+ session_id: str = ""
199
+ extension_version: str = ""
200
+
201
+
202
+ # ── Global state ──────────────────────────────────────────────────────
203
+ _scorer: Optional[HeuristicScorer] = None
204
+ _bert: Optional[BERTPhishingClassifier] = None
205
+ _gnn: Optional[GNNInference] = None
206
+ _cnn: Optional[CNNInference] = None
207
+ _brand: Optional[BrandHashDetector] = None
208
+ _tier3: Optional[Tier3Ensemble] = None
209
+ _retrain_service: Optional[RetrainingService] = None
210
+ _retrain_lock = asyncio.Lock()
211
+
212
+
213
+ # ── Lifespan (startup/shutdown) ───────────────────────────────────────
214
+ @asynccontextmanager
215
+ async def lifespan(app: FastAPI):
216
+ """Load all models at startup, clean up at shutdown."""
217
+ global _scorer, _bert, _gnn, _cnn, _brand, _tier3, _retrain_service
218
+
219
+ logger.info("=== PhishGuard AI starting up ===")
220
+
221
+ # Tier 2: Heuristic Scorer
222
+ _scorer = HeuristicScorer()
223
+ logger.info(" Tier 2: HeuristicScorer initialized")
224
+
225
+ # Tier 3a: BERT
226
+ _bert = BERTPhishingClassifier()
227
+ _bert.load_model()
228
+ logger.info(" Tier 3a: BERT classifier initialized and loaded")
229
+
230
+ # Tier 3b: GNN
231
+ if GNN_AVAILABLE:
232
+ _gnn = GNNInference()
233
+ _gnn.load()
234
+ logger.info(f" Tier 3b: GNN loaded={_gnn.is_loaded}")
235
+ else:
236
+ _gnn = None
237
+ logger.warning(" Tier 3b: GNN not available")
238
+
239
+ # Tier 3 Ensemble
240
+ if _gnn:
241
+ _tier3 = Tier3Ensemble(_bert, _gnn)
242
+ logger.info(" Tier 3: Ensemble initialized")
243
+ else:
244
+ _tier3 = None
245
+ logger.warning(" Tier 3: Ensemble not available (GNN missing)")
246
+
247
+ # Tier 4: CNN + Brand Detection
248
+ if CNN_AVAILABLE:
249
+ _cnn = CNNInference()
250
+ _cnn.load()
251
+ _brand = BrandHashDetector()
252
+ logger.info(f" Tier 4: CNN loaded={_cnn.is_loaded}, Brand hash DB loaded")
253
+ else:
254
+ _cnn = None
255
+ _brand = None
256
+ logger.warning(" Tier 4: CNN not available")
257
+
258
+ # Retraining Service
259
+ _retrain_service = RetrainingService(
260
+ bert_classifier=_bert,
261
+ gnn_inference=_gnn or GNNInference(),
262
+ cnn_inference=_cnn or (CNNInference() if CNN_AVAILABLE else None),
263
+ )
264
+ logger.info(" Retraining service initialized")
265
+ logger.info("=== PhishGuard AI ready ===")
266
+
267
+ yield
268
+
269
+ logger.info("=== PhishGuard AI shutting down ===")
270
+
271
+
272
+ # ── FastAPI App ───────────────────────────────────────────────────────
273
+ app = FastAPI(
274
+ title="PhishGuard AI Backend",
275
+ version="3.0",
276
+ description="4-tier ML phishing detection with feedback-driven retraining",
277
+ lifespan=lifespan,
278
+ )
279
+
280
+ app.add_middleware(
281
+ CORSMiddleware,
282
+ allow_origins=["*"],
283
+ allow_methods=["*"],
284
+ allow_headers=["*"],
285
+ )
286
+
287
+
288
+ # ── POST /analyze β€” Full 4-tier pipeline ──────────────────────────────
289
+ @app.post("/analyze")
290
+ async def analyze_endpoint(req: AnalyzeRequest) -> dict:
291
+ """
292
+ Analyze a URL through the 4-tier phishing detection pipeline.
293
+
294
+ Tier 1: Whitelist β†’ SAFE
295
+ Tier 2: Heuristic β†’ BLOCK if >= 80
296
+ Tier 3: BERT+GNN ensemble β†’ BLOCK/SAFE/escalate
297
+ Tier 4: CNN visual + brand hash β†’ BLOCK/SAFE
298
+ """
299
+ url = req.url
300
+ details: dict = {}
301
+
302
+ # ── TIER 1: Whitelist ────────────────────────────────────────
303
+ root = get_root_domain(url)
304
+ if root in WHITELIST:
305
+ return {
306
+ "url": url,
307
+ "is_phishing": False,
308
+ "confidence": 0.0,
309
+ "method": "whitelist",
310
+ "status": "safe",
311
+ "tier": 1,
312
+ "heuristic_score": 0,
313
+ "signals": [],
314
+ "details": {"whitelisted_domain": root},
315
+ }
316
+
317
+ # ── Cache check ──────────────────────────────────────────────
318
+ cached = _url_cache.get(url)
319
+ if cached is not None:
320
+ return cached
321
+
322
+ # ── TIER 2: Heuristic scoring ────────────────────────────────
323
+ h_result: HeuristicResult = _scorer.score(url)
324
+
325
+ # Use the higher of server-side and browser-side heuristic scores
326
+ h_score = max(h_result.score, int(req.heuristic_score))
327
+ details["heuristic"] = {
328
+ "score": h_result.score,
329
+ "raw_score": h_result.raw_score,
330
+ "signals": h_result.signals,
331
+ "browser_score": int(req.heuristic_score),
332
+ "combined_score": h_score,
333
+ }
334
+
335
+ if h_score >= 80:
336
+ result = {
337
+ "url": url,
338
+ "is_phishing": True,
339
+ "confidence": h_score / 100.0,
340
+ "method": "heuristic",
341
+ "status": "blocked",
342
+ "tier": 2,
343
+ "heuristic_score": h_score,
344
+ "signals": h_result.signals,
345
+ "details": details,
346
+ }
347
+ _url_cache.set(url, result)
348
+ logger.info(f"Tier 2 BLOCK | url={url[:60]} | score={h_score}")
349
+ return result
350
+
351
+ # ── TIER 3: BERT + GNN Ensemble ──────────────────────────────
352
+ if _tier3 is not None:
353
+ try:
354
+ p3 = await _tier3.predict(
355
+ url=url,
356
+ title=req.page_title,
357
+ snippet=req.page_snippet,
358
+ h_score=h_score,
359
+ )
360
+ details["tier3_score"] = p3
361
+ except Exception as e:
362
+ logger.error(f"Tier 3 error: {e}")
363
+ p3 = h_score / 100.0 # fallback to heuristic
364
+ details["tier3_error"] = str(e)
365
+ else:
366
+ # Tier 3 unavailable β€” use BERT alone + heuristic
367
+ if _bert is not None:
368
+ loop = asyncio.get_event_loop()
369
+ try:
370
+ p_bert = await loop.run_in_executor(
371
+ None, _bert.predict, url, req.page_title, req.page_snippet,
372
+ )
373
+ except Exception:
374
+ p_bert = 0.5
375
+ h_norm = h_score / 100.0
376
+ p3 = 0.60 * p_bert + 0.40 * h_norm
377
+ else:
378
+ p3 = h_score / 100.0
379
+ details["tier3_score"] = p3
380
+ details["tier3_note"] = "ensemble_unavailable"
381
+
382
+ # Tier 3 decision
383
+ decision = Tier3Ensemble.decide(p3)
384
+
385
+ if decision == "block":
386
+ result = {
387
+ "url": url,
388
+ "is_phishing": True,
389
+ "confidence": round(p3, 4),
390
+ "method": "bert_gnn_ensemble",
391
+ "status": "blocked",
392
+ "tier": 3,
393
+ "heuristic_score": h_score,
394
+ "signals": h_result.signals,
395
+ "details": details,
396
+ }
397
+ _url_cache.set(url, result)
398
+ logger.info(f"Tier 3 BLOCK | url={url[:60]} | P3={p3:.4f}")
399
+ return result
400
+
401
+ if decision == "safe":
402
+ result = {
403
+ "url": url,
404
+ "is_phishing": False,
405
+ "confidence": round(p3, 4),
406
+ "method": "bert_gnn_ensemble",
407
+ "status": "safe",
408
+ "tier": 3,
409
+ "heuristic_score": h_score,
410
+ "signals": h_result.signals,
411
+ "details": details,
412
+ }
413
+ _url_cache.set(url, result)
414
+ logger.info(f"Tier 3 SAFE | url={url[:60]} | P3={p3:.4f}")
415
+ return result
416
+
417
+ # ── TIER 4: CNN Visual + Brand Hash (borderline 0.40 ≀ P3 < 0.85)
418
+ if _cnn is not None and _cnn.is_loaded:
419
+ try:
420
+ # Capture screenshot
421
+ screenshot_bytes = await _capture_screenshot_for_tier4(url)
422
+
423
+ if screenshot_bytes:
424
+ # CNN prediction
425
+ p_cnn = _cnn.predict(screenshot_bytes)
426
+ details["cnn_prob"] = round(p_cnn, 4)
427
+
428
+ # Brand hash check
429
+ brand_boost = 0.0
430
+ if _brand is not None:
431
+ is_impersonation, brand_name, brand_conf = _brand.detect(
432
+ screenshot_bytes, url
433
+ )
434
+ details["brand"] = {
435
+ "impersonation_detected": is_impersonation,
436
+ "brand": brand_name,
437
+ "confidence": round(brand_conf, 3),
438
+ }
439
+ if is_impersonation:
440
+ brand_boost = 0.25
441
+
442
+ # P_final = 0.55Β·P3 + 0.30Β·P_cnn + brand_boost
443
+ p_final = min((0.55 * p3) + (0.30 * p_cnn) + brand_boost, 1.0)
444
+ details["tier4_score"] = round(p_final, 4)
445
+
446
+ is_phishing = p_final >= 0.65
447
+ result = {
448
+ "url": url,
449
+ "is_phishing": is_phishing,
450
+ "confidence": round(p_final, 4),
451
+ "method": "full_ensemble_bert_gnn_cnn",
452
+ "status": "blocked" if is_phishing else "safe",
453
+ "tier": 4,
454
+ "heuristic_score": h_score,
455
+ "signals": h_result.signals,
456
+ "details": details,
457
+ }
458
+ _url_cache.set(url, result)
459
+ logger.info(f"Tier 4 {'BLOCK' if is_phishing else 'SAFE'} | url={url[:60]} | P_final={p_final:.4f}")
460
+ return result
461
+
462
+ except Exception as e:
463
+ logger.error(f"Tier 4 error: {e}")
464
+ details["tier4_error"] = str(e)
465
+
466
+ # Tier 4 unavailable/failed β€” use Tier 3 score with conservative threshold
467
+ is_phishing = p3 >= 0.65
468
+ result = {
469
+ "url": url,
470
+ "is_phishing": is_phishing,
471
+ "confidence": round(p3, 4),
472
+ "method": "bert_gnn_ensemble",
473
+ "status": "blocked" if is_phishing else "safe",
474
+ "tier": 3,
475
+ "heuristic_score": h_score,
476
+ "signals": h_result.signals,
477
+ "details": details,
478
+ }
479
+ _url_cache.set(url, result)
480
+ logger.info(f"Tier 4 fallback β†’ Tier 3 | url={url[:60]} | P3={p3:.4f}")
481
+ return result
482
+
483
+
484
+ async def _capture_screenshot_for_tier4(url: str) -> Optional[bytes]:
485
+ """Capture screenshot for Tier 4 CNN analysis."""
486
+ try:
487
+ from playwright.async_api import async_playwright
488
+
489
+ async with async_playwright() as p:
490
+ browser = await p.chromium.launch(headless=True)
491
+ page = await browser.new_page(
492
+ viewport={"width": 1280, "height": 800},
493
+ user_agent=(
494
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
495
+ "AppleWebKit/537.36 Chrome/120.0.0.0 Safari/537.36"
496
+ ),
497
+ )
498
+
499
+ # Block heavy resources
500
+ await page.route(
501
+ "**/*.{woff,woff2,ttf,eot,mp4,webm,ogg,wav,mp3}",
502
+ lambda route: route.abort(),
503
+ )
504
+
505
+ await page.goto(url, wait_until="domcontentloaded", timeout=10000)
506
+ screenshot = await page.screenshot(type="png")
507
+ await browser.close()
508
+ return screenshot
509
+
510
+ except Exception as e:
511
+ logger.warning(f"Tier 4 screenshot failed: {e}")
512
+ return None
513
+
514
+
515
+ # ── POST /analyze/email ───────────────────────────────────────────────
516
+ @app.post("/analyze/email")
517
+ async def analyze_email_endpoint(req: EmailRequest) -> dict:
518
+ """BERT-only path for email body text analysis."""
519
+ # Sender whitelist check
520
+ sender_domain = req.sender.split("@")[-1].lower() if "@" in req.sender else ""
521
+ if sender_domain in WHITELIST:
522
+ return {
523
+ "status": "safe",
524
+ "analysis": {
525
+ "isPhishing": False,
526
+ "probability": 0.0,
527
+ "reason": "Trusted sender domain",
528
+ },
529
+ }
530
+
531
+ # Analyze embedded URLs
532
+ MAX_URLS = 3
533
+ urls_to_check = req.urls[:MAX_URLS]
534
+
535
+ if not urls_to_check:
536
+ # Text-only analysis
537
+ if _bert:
538
+ combined = f"{req.subject} {req.body}"
539
+ prob = _bert.predict(combined, req.subject, req.body)
540
+ is_phishing = prob > 0.6
541
+ return {
542
+ "status": "blocked" if is_phishing else "safe",
543
+ "analysis": {
544
+ "isPhishing": is_phishing,
545
+ "probability": prob,
546
+ "reason": "BERT text analysis (no URLs)",
547
+ },
548
+ }
549
+ return {
550
+ "status": "safe",
551
+ "analysis": {
552
+ "isPhishing": False,
553
+ "probability": 0.1,
554
+ "reason": "No URLs and no ML model available",
555
+ },
556
+ }
557
+
558
+ # Analyze URLs through the main pipeline
559
+ tasks = [
560
+ analyze_endpoint(AnalyzeRequest(url=u, page_title=req.subject))
561
+ for u in urls_to_check
562
+ ]
563
+ results = await asyncio.gather(*tasks, return_exceptions=True)
564
+
565
+ max_prob = 0.0
566
+ phishing_detected = False
567
+ flagged_urls = []
568
+
569
+ for idx, r in enumerate(results):
570
+ if isinstance(r, Exception):
571
+ continue
572
+ prob = r.get("confidence", 0.0)
573
+ max_prob = max(max_prob, prob)
574
+ if r.get("is_phishing"):
575
+ phishing_detected = True
576
+ flagged_urls.append(r.get("url", urls_to_check[idx]))
577
+
578
+ return {
579
+ "status": "blocked" if phishing_detected else "safe",
580
+ "analysis": {
581
+ "isPhishing": phishing_detected,
582
+ "probability": max_prob,
583
+ "flagged_urls": flagged_urls,
584
+ "reason": "URL analysis via ML ensemble",
585
+ },
586
+ }
587
+
588
+
589
+ # ── POST /retrain β€” Incremental retraining ────────────────────────────
590
+ @app.post("/retrain")
591
+ async def retrain_endpoint(req: RetrainRequest) -> dict:
592
+ """
593
+ Receive labeled feedback and incrementally update all models.
594
+ Uses asyncio.Lock() to prevent concurrent retraining jobs.
595
+ Timeout: 600s max.
596
+ """
597
+ if _retrain_service is None:
598
+ return {"status": "error", "message": "Retraining service not initialized"}
599
+
600
+ # Prevent concurrent retraining
601
+ if _retrain_lock.locked():
602
+ return {
603
+ "status": "skipped",
604
+ "message": "Retraining already in progress",
605
+ "models_updated": [],
606
+ }
607
+
608
+ async with _retrain_lock:
609
+ # Convert Pydantic models to FeedbackRecord dataclasses
610
+ records = [
611
+ FeedbackRecord(
612
+ url=s.url,
613
+ verdict=s.verdict,
614
+ confidence=s.confidence,
615
+ tier_used=s.tier_used,
616
+ heuristic_score=s.heuristic_score,
617
+ signals=s.signals,
618
+ user_feedback=s.user_feedback,
619
+ timestamp=s.timestamp,
620
+ feedback_ts=s.feedback_ts,
621
+ url_hash=s.url_hash,
622
+ session_id=s.session_id,
623
+ )
624
+ for s in req.samples
625
+ ]
626
+
627
+ try:
628
+ result = await asyncio.wait_for(
629
+ _retrain_service.retrain(records),
630
+ timeout=600,
631
+ )
632
+
633
+ # Clear URL cache after retraining (stale results)
634
+ if result.status == "success":
635
+ _url_cache.clear()
636
+
637
+ return {
638
+ "status": result.status,
639
+ "models_updated": result.models_updated,
640
+ "samples_used": result.samples_used,
641
+ "duration_seconds": result.duration_seconds,
642
+ "accuracy_delta": result.accuracy_delta,
643
+ "next_retrain_hint": result.next_retrain_hint,
644
+ }
645
+
646
+ except asyncio.TimeoutError:
647
+ return {
648
+ "status": "error",
649
+ "message": "Retraining timed out (600s limit)",
650
+ }
651
+ except Exception as e:
652
+ logger.error(f"Retrain endpoint error: {e}")
653
+ return {
654
+ "status": "error",
655
+ "message": str(e),
656
+ }
657
+
658
+
659
+ # ── GET /model_version ────────────────────────────────────────────────
660
+ @app.get("/model_version")
661
+ async def model_version_endpoint() -> dict:
662
+ """Return current model version info for extension polling."""
663
+ if _retrain_service:
664
+ return _retrain_service.get_version_info()
665
+ return {"version": 0, "updated_at": None, "accuracy": {}}
666
+
667
+
668
+ # ── GET /health ───────────────────────────────────────────────────────
669
+ @app.get("/health")
670
+ async def health_endpoint() -> dict:
671
+ """Liveness probe with per-tier readiness and model statuses."""
672
+ return {
673
+ "status": "ok",
674
+ "version": "3.0",
675
+ "tier1": True,
676
+ "tier2": _scorer is not None,
677
+ "tier3": _tier3 is not None,
678
+ "tier4": _cnn is not None and _cnn.is_loaded if _cnn else False,
679
+ "retraining_in_progress": _retrain_lock.locked(),
680
+ "model_version": _retrain_service.model_version if _retrain_service else 0,
681
+ "modules": {
682
+ "heuristic": _scorer is not None,
683
+ "bert": _bert is not None and _bert.is_loaded,
684
+ "bert_lazy": _bert is not None and not _bert.is_loaded,
685
+ "gnn": _gnn is not None and _gnn.is_loaded if _gnn else False,
686
+ "cnn": _cnn is not None and _cnn.is_loaded if _cnn else False,
687
+ "brand_hash": _brand is not None,
688
+ },
689
+ }
690
+
691
+
692
+ # ── Legacy feedback endpoint (backward compat) ───────────────────────
693
+ @app.post("/feedback")
694
+ async def legacy_feedback_endpoint(req: dict) -> dict:
695
+ """Legacy feedback endpoint for backward compatibility."""
696
+ return {"status": "success", "message": "Use POST /retrain for feedback-driven retraining"}
697
+
698
+
699
+ # ── Run directly ──────────────────────────────────────────────────────
700
+ # uvicorn main:app --reload --port 8000
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.111.0
2
+ uvicorn[standard]==0.29.0
3
+ transformers==4.40.0
4
+ torch==2.2.2
5
+ torch-geometric==2.5.2
6
+ torchvision==0.17.2
7
+ playwright==1.44.0
8
+ pillow==10.3.0
9
+ scikit-learn==1.4.2
10
+ pandas==2.2.2
11
+ numpy==1.26.4
12
+ httpx==0.27.0
13
+ imagehash==4.3.1
14
+ requests==2.31.0
15
+ aiohttp==3.9.5
16
+ aiofiles==23.2.1
17
+ python-multipart==0.0.9
18
+ apscheduler==3.10.4
19
+ huggingface-hub==0.23.2