Spaces:
Running
Running
| # ============================================================ | |
| # PhishGuard AI - main.py | |
| # FastAPI orchestrator β Full 4-tier phishing detection pipeline | |
| # with feedback-driven incremental retraining. | |
| # | |
| # Endpoints: | |
| # POST /analyze β 4-tier URL phishing analysis | |
| # POST /analyze/email β BERT-only email body analysis | |
| # POST /retrain β Incremental model retraining | |
| # GET /model_version β Current model version info | |
| # GET /health β All model load statuses | |
| # | |
| # Architecture: | |
| # Tier 1: Whitelist O(1) β SAFE exit (~55% traffic) | |
| # Tier 2: Heuristic 15 signals β BLOCK if >= 80 (~15% blocked) | |
| # Tier 3: BERT+GNN parallel β BLOCK/SAFE/escalate (~15% exits) | |
| # Tier 4: CNN visual + brand hash β BLOCK/SAFE (~15% borderline) | |
| # ============================================================ | |
| from __future__ import annotations | |
| import os | |
| import sys | |
| import asyncio | |
| import time | |
| import hashlib | |
| import logging | |
| import logging.handlers | |
| from collections import OrderedDict | |
| from contextlib import asynccontextmanager | |
| from pathlib import Path | |
| from typing import List, Optional | |
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| # ββ Path setup ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| BASE_DIR = Path(__file__).parent | |
| for sub_dir in ["gnn", "cnn"]: | |
| sub_path = BASE_DIR / sub_dir | |
| if sub_path.is_dir(): | |
| sys.path.insert(0, str(sub_path)) | |
| # ββ Logging βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| log_dir = BASE_DIR / "logs" | |
| log_dir.mkdir(exist_ok=True) | |
| _handler = logging.handlers.RotatingFileHandler( | |
| log_dir / "phishguard.log", | |
| maxBytes=5 * 1024 * 1024, | |
| backupCount=3, | |
| encoding="utf-8", | |
| ) | |
| _handler.setFormatter(logging.Formatter( | |
| "%(asctime)s | %(levelname)-7s | %(name)s | %(message)s", | |
| datefmt="%Y-%m-%d %H:%M:%S", | |
| )) | |
| logger = logging.getLogger("phishguard") | |
| logger.setLevel(logging.INFO) | |
| logger.addHandler(_handler) | |
| logger.addHandler(logging.StreamHandler()) | |
| # ββ Import project modules βββββββββββββββββββββββββββββββββββββββββββ | |
| from url_heuristics import HeuristicScorer, HeuristicResult | |
| from bert_analyzer import BERTPhishingClassifier | |
| # GNN imports | |
| GNN_AVAILABLE = False | |
| gnn_inference = None | |
| try: | |
| from gnn.gnn_inference import GNNInference | |
| GNN_AVAILABLE = True | |
| except ImportError: | |
| try: | |
| from gnn_inference import GNNInference | |
| GNN_AVAILABLE = True | |
| except ImportError: | |
| logger.warning("GNN module not available") | |
| # CNN imports | |
| CNN_AVAILABLE = False | |
| cnn_inference = None | |
| brand_detector = None | |
| try: | |
| from cnn.cnn_inference import CNNInference | |
| from cnn.screenshot_hasher import BrandHashDetector | |
| from cnn.cnn_model import preprocess_screenshot | |
| CNN_AVAILABLE = True | |
| except ImportError: | |
| try: | |
| from cnn_inference import CNNInference | |
| from screenshot_hasher import BrandHashDetector | |
| from cnn_model import preprocess_screenshot | |
| CNN_AVAILABLE = True | |
| except ImportError: | |
| logger.warning("CNN module not available") | |
| from tier3_bert_gnn import Tier3Ensemble | |
| from retraining_service import RetrainingService, FeedbackRecord, RetrainResult | |
| # ββ Whitelist (Tier 1) ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| WHITELIST: set[str] = { | |
| "google.com", "youtube.com", "facebook.com", "amazon.com", "wikipedia.org", | |
| "twitter.com", "instagram.com", "linkedin.com", "microsoft.com", "apple.com", | |
| "github.com", "stackoverflow.com", "reddit.com", "netflix.com", "paypal.com", | |
| "bankofamerica.com", "chase.com", "wellsfargo.com", "yahoo.com", "bing.com", | |
| "outlook.com", "office.com", "live.com", "adobe.com", "dropbox.com", | |
| "zoom.us", "slack.com", "spotify.com", "twitch.tv", "ebay.com", | |
| "walmart.com", "target.com", "bestbuy.com", "airbnb.com", | |
| "x.com", "tiktok.com", "pinterest.com", "quora.com", "medium.com", | |
| } | |
| def get_root_domain(url: str) -> str: | |
| """Extract root domain from a URL.""" | |
| from urllib.parse import urlparse | |
| try: | |
| host = urlparse(url).hostname or "" | |
| host = host.replace("www.", "") | |
| parts = host.split(".") | |
| return ".".join(parts[-2:]) if len(parts) >= 2 else host | |
| except Exception: | |
| return "" | |
| # ββ URL Cache (LRU, 30-min TTL) ββββββββββββββββββββββββββββββββββββββ | |
| CACHE_TTL = 30 * 60 | |
| CACHE_MAX = 500 | |
| class URLCache: | |
| def __init__(self, maxsize: int = CACHE_MAX, ttl: int = CACHE_TTL) -> None: | |
| self._cache: OrderedDict = OrderedDict() | |
| self._maxsize = maxsize | |
| self._ttl = ttl | |
| def get(self, url: str) -> Optional[dict]: | |
| if url in self._cache: | |
| entry = self._cache[url] | |
| if time.time() - entry["ts"] < self._ttl: | |
| self._cache.move_to_end(url) | |
| return entry["result"] | |
| else: | |
| del self._cache[url] | |
| return None | |
| def set(self, url: str, result: dict) -> None: | |
| self._cache[url] = {"result": result, "ts": time.time()} | |
| self._cache.move_to_end(url) | |
| if len(self._cache) > self._maxsize: | |
| self._cache.popitem(last=False) | |
| def clear(self) -> None: | |
| self._cache.clear() | |
| _url_cache = URLCache() | |
| # ββ Request/Response Models βββββββββββββββββββββββββββββββββββββββββββ | |
| class AnalyzeRequest(BaseModel): | |
| url: str | |
| heuristic_score: float = 0.0 | |
| page_title: str = "" | |
| page_snippet: str = "" | |
| related_urls: list = [] | |
| class EmailRequest(BaseModel): | |
| sender: str | |
| subject: str = "" | |
| body: str = "" | |
| urls: list = [] | |
| timestamp: str = "" | |
| class FeedbackSample(BaseModel): | |
| url: str | |
| verdict: str = "" | |
| confidence: float = 0.0 | |
| tier_used: int = 0 | |
| heuristic_score: int = 0 | |
| signals: list = [] | |
| user_feedback: Optional[str] = None | |
| timestamp: str = "" | |
| feedback_ts: Optional[str] = None | |
| url_hash: str = "" | |
| session_id: str = "" | |
| class RetrainRequest(BaseModel): | |
| samples: List[FeedbackSample] | |
| trigger: str = "count" | |
| session_id: str = "" | |
| extension_version: str = "" | |
| # ββ Global state ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _scorer: Optional[HeuristicScorer] = None | |
| _bert: Optional[BERTPhishingClassifier] = None | |
| _gnn: Optional[GNNInference] = None | |
| _cnn: Optional[CNNInference] = None | |
| _brand: Optional[BrandHashDetector] = None | |
| _tier3: Optional[Tier3Ensemble] = None | |
| _retrain_service: Optional[RetrainingService] = None | |
| _retrain_lock = asyncio.Lock() | |
| # ββ Lifespan (startup/shutdown) βββββββββββββββββββββββββββββββββββββββ | |
| async def lifespan(app: FastAPI): | |
| """Load all models at startup, clean up at shutdown.""" | |
| global _scorer, _bert, _gnn, _cnn, _brand, _tier3, _retrain_service | |
| logger.info("=== PhishGuard AI starting up ===") | |
| # Tier 2: Heuristic Scorer | |
| _scorer = HeuristicScorer() | |
| logger.info(" Tier 2: HeuristicScorer initialized") | |
| # Tier 3a: BERT | |
| _bert = BERTPhishingClassifier() | |
| _bert.load_model() | |
| logger.info(" Tier 3a: BERT classifier initialized and loaded") | |
| # Tier 3b: GNN | |
| if GNN_AVAILABLE: | |
| _gnn = GNNInference() | |
| _gnn.load() | |
| logger.info(f" Tier 3b: GNN loaded={_gnn.is_loaded}") | |
| else: | |
| _gnn = None | |
| logger.warning(" Tier 3b: GNN not available") | |
| # Tier 3 Ensemble | |
| if _gnn: | |
| _tier3 = Tier3Ensemble(_bert, _gnn) | |
| logger.info(" Tier 3: Ensemble initialized") | |
| else: | |
| _tier3 = None | |
| logger.warning(" Tier 3: Ensemble not available (GNN missing)") | |
| # Tier 4: CNN + Brand Detection | |
| if CNN_AVAILABLE: | |
| _cnn = CNNInference() | |
| _cnn.load() | |
| _brand = BrandHashDetector() | |
| logger.info(f" Tier 4: CNN loaded={_cnn.is_loaded}, Brand hash DB loaded") | |
| else: | |
| _cnn = None | |
| _brand = None | |
| logger.warning(" Tier 4: CNN not available") | |
| # Retraining Service | |
| _retrain_service = RetrainingService( | |
| bert_classifier=_bert, | |
| gnn_inference=_gnn or GNNInference(), | |
| cnn_inference=_cnn or (CNNInference() if CNN_AVAILABLE else None), | |
| ) | |
| logger.info(" Retraining service initialized") | |
| logger.info("=== PhishGuard AI ready ===") | |
| yield | |
| logger.info("=== PhishGuard AI shutting down ===") | |
| # ββ FastAPI App βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| app = FastAPI( | |
| title="PhishGuard AI Backend", | |
| version="3.0", | |
| description="4-tier ML phishing detection with feedback-driven retraining", | |
| lifespan=lifespan, | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ββ POST /analyze β Full 4-tier pipeline ββββββββββββββββββββββββββββββ | |
| async def analyze_endpoint(req: AnalyzeRequest) -> dict: | |
| """ | |
| Analyze a URL through the 4-tier phishing detection pipeline. | |
| Tier 1: Whitelist β SAFE | |
| Tier 2: Heuristic β BLOCK if >= 80 | |
| Tier 3: BERT+GNN ensemble β BLOCK/SAFE/escalate | |
| Tier 4: CNN visual + brand hash β BLOCK/SAFE | |
| """ | |
| url = req.url | |
| details: dict = {} | |
| # ββ TIER 1: Whitelist ββββββββββββββββββββββββββββββββββββββββ | |
| root = get_root_domain(url) | |
| if root in WHITELIST: | |
| return { | |
| "url": url, | |
| "is_phishing": False, | |
| "confidence": 0.0, | |
| "method": "whitelist", | |
| "status": "safe", | |
| "tier": 1, | |
| "heuristic_score": 0, | |
| "signals": [], | |
| "details": {"whitelisted_domain": root}, | |
| } | |
| # ββ Cache check ββββββββββββββββββββββββββββββββββββββββββββββ | |
| cached = _url_cache.get(url) | |
| if cached is not None: | |
| return cached | |
| # ββ TIER 2: Heuristic scoring ββββββββββββββββββββββββββββββββ | |
| h_result: HeuristicResult = _scorer.score(url) | |
| # Use the higher of server-side and browser-side heuristic scores | |
| h_score = max(h_result.score, int(req.heuristic_score)) | |
| details["heuristic"] = { | |
| "score": h_result.score, | |
| "raw_score": h_result.raw_score, | |
| "signals": h_result.signals, | |
| "browser_score": int(req.heuristic_score), | |
| "combined_score": h_score, | |
| } | |
| if h_score >= 80: | |
| result = { | |
| "url": url, | |
| "is_phishing": True, | |
| "confidence": h_score / 100.0, | |
| "method": "heuristic", | |
| "status": "blocked", | |
| "tier": 2, | |
| "heuristic_score": h_score, | |
| "signals": h_result.signals, | |
| "details": details, | |
| } | |
| _url_cache.set(url, result) | |
| logger.info(f"Tier 2 BLOCK | url={url[:60]} | score={h_score}") | |
| return result | |
| # ββ TIER 3: BERT + GNN Ensemble ββββββββββββββββββββββββββββββ | |
| if _tier3 is not None: | |
| try: | |
| p3 = await _tier3.predict( | |
| url=url, | |
| title=req.page_title, | |
| snippet=req.page_snippet, | |
| h_score=h_score, | |
| ) | |
| details["tier3_score"] = p3 | |
| except Exception as e: | |
| logger.error(f"Tier 3 error: {e}") | |
| p3 = h_score / 100.0 # fallback to heuristic | |
| details["tier3_error"] = str(e) | |
| else: | |
| # Tier 3 unavailable β use BERT alone + heuristic | |
| if _bert is not None: | |
| loop = asyncio.get_event_loop() | |
| try: | |
| p_bert = await loop.run_in_executor( | |
| None, _bert.predict, url, req.page_title, req.page_snippet, | |
| ) | |
| except Exception: | |
| p_bert = 0.5 | |
| h_norm = h_score / 100.0 | |
| p3 = 0.60 * p_bert + 0.40 * h_norm | |
| else: | |
| p3 = h_score / 100.0 | |
| details["tier3_score"] = p3 | |
| details["tier3_note"] = "ensemble_unavailable" | |
| # Tier 3 decision | |
| decision = Tier3Ensemble.decide(p3) | |
| if decision == "block": | |
| result = { | |
| "url": url, | |
| "is_phishing": True, | |
| "confidence": round(p3, 4), | |
| "method": "bert_gnn_ensemble", | |
| "status": "blocked", | |
| "tier": 3, | |
| "heuristic_score": h_score, | |
| "signals": h_result.signals, | |
| "details": details, | |
| } | |
| _url_cache.set(url, result) | |
| logger.info(f"Tier 3 BLOCK | url={url[:60]} | P3={p3:.4f}") | |
| return result | |
| if decision == "safe": | |
| result = { | |
| "url": url, | |
| "is_phishing": False, | |
| "confidence": round(p3, 4), | |
| "method": "bert_gnn_ensemble", | |
| "status": "safe", | |
| "tier": 3, | |
| "heuristic_score": h_score, | |
| "signals": h_result.signals, | |
| "details": details, | |
| } | |
| _url_cache.set(url, result) | |
| logger.info(f"Tier 3 SAFE | url={url[:60]} | P3={p3:.4f}") | |
| return result | |
| # ββ TIER 4: CNN Visual + Brand Hash (borderline 0.40 β€ P3 < 0.85) | |
| if _cnn is not None and _cnn.is_loaded: | |
| try: | |
| # Capture screenshot | |
| screenshot_bytes = await _capture_screenshot_for_tier4(url) | |
| if screenshot_bytes: | |
| # CNN prediction | |
| p_cnn = _cnn.predict(screenshot_bytes) | |
| details["cnn_prob"] = round(p_cnn, 4) | |
| # Brand hash check | |
| brand_boost = 0.0 | |
| if _brand is not None: | |
| is_impersonation, brand_name, brand_conf = _brand.detect( | |
| screenshot_bytes, url | |
| ) | |
| details["brand"] = { | |
| "impersonation_detected": is_impersonation, | |
| "brand": brand_name, | |
| "confidence": round(brand_conf, 3), | |
| } | |
| if is_impersonation: | |
| brand_boost = 0.25 | |
| # P_final = 0.55Β·P3 + 0.30Β·P_cnn + brand_boost | |
| p_final = min((0.55 * p3) + (0.30 * p_cnn) + brand_boost, 1.0) | |
| details["tier4_score"] = round(p_final, 4) | |
| is_phishing = p_final >= 0.65 | |
| result = { | |
| "url": url, | |
| "is_phishing": is_phishing, | |
| "confidence": round(p_final, 4), | |
| "method": "full_ensemble_bert_gnn_cnn", | |
| "status": "blocked" if is_phishing else "safe", | |
| "tier": 4, | |
| "heuristic_score": h_score, | |
| "signals": h_result.signals, | |
| "details": details, | |
| } | |
| _url_cache.set(url, result) | |
| logger.info(f"Tier 4 {'BLOCK' if is_phishing else 'SAFE'} | url={url[:60]} | P_final={p_final:.4f}") | |
| return result | |
| except Exception as e: | |
| logger.error(f"Tier 4 error: {e}") | |
| details["tier4_error"] = str(e) | |
| # Tier 4 unavailable/failed β use Tier 3 score with conservative threshold | |
| is_phishing = p3 >= 0.65 | |
| result = { | |
| "url": url, | |
| "is_phishing": is_phishing, | |
| "confidence": round(p3, 4), | |
| "method": "bert_gnn_ensemble", | |
| "status": "blocked" if is_phishing else "safe", | |
| "tier": 3, | |
| "heuristic_score": h_score, | |
| "signals": h_result.signals, | |
| "details": details, | |
| } | |
| _url_cache.set(url, result) | |
| logger.info(f"Tier 4 fallback β Tier 3 | url={url[:60]} | P3={p3:.4f}") | |
| return result | |
| async def _capture_screenshot_for_tier4(url: str) -> Optional[bytes]: | |
| """Capture screenshot for Tier 4 CNN analysis.""" | |
| try: | |
| from playwright.async_api import async_playwright | |
| async with async_playwright() as p: | |
| browser = await p.chromium.launch(headless=True) | |
| page = await browser.new_page( | |
| viewport={"width": 1280, "height": 800}, | |
| user_agent=( | |
| "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " | |
| "AppleWebKit/537.36 Chrome/120.0.0.0 Safari/537.36" | |
| ), | |
| ) | |
| # Block heavy resources | |
| await page.route( | |
| "**/*.{woff,woff2,ttf,eot,mp4,webm,ogg,wav,mp3}", | |
| lambda route: route.abort(), | |
| ) | |
| await page.goto(url, wait_until="domcontentloaded", timeout=10000) | |
| screenshot = await page.screenshot(type="png") | |
| await browser.close() | |
| return screenshot | |
| except Exception as e: | |
| logger.warning(f"Tier 4 screenshot failed: {e}") | |
| return None | |
| # ββ POST /analyze/email βββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def analyze_email_endpoint(req: EmailRequest) -> dict: | |
| """BERT-only path for email body text analysis.""" | |
| # Sender whitelist check | |
| sender_domain = req.sender.split("@")[-1].lower() if "@" in req.sender else "" | |
| if sender_domain in WHITELIST: | |
| return { | |
| "status": "safe", | |
| "analysis": { | |
| "isPhishing": False, | |
| "probability": 0.0, | |
| "reason": "Trusted sender domain", | |
| }, | |
| } | |
| # Analyze embedded URLs | |
| MAX_URLS = 3 | |
| urls_to_check = req.urls[:MAX_URLS] | |
| if not urls_to_check: | |
| # Text-only analysis | |
| if _bert: | |
| combined = f"{req.subject} {req.body}" | |
| prob = _bert.predict(combined, req.subject, req.body) | |
| is_phishing = prob > 0.6 | |
| return { | |
| "status": "blocked" if is_phishing else "safe", | |
| "analysis": { | |
| "isPhishing": is_phishing, | |
| "probability": prob, | |
| "reason": "BERT text analysis (no URLs)", | |
| }, | |
| } | |
| return { | |
| "status": "safe", | |
| "analysis": { | |
| "isPhishing": False, | |
| "probability": 0.1, | |
| "reason": "No URLs and no ML model available", | |
| }, | |
| } | |
| # Analyze URLs through the main pipeline | |
| tasks = [ | |
| analyze_endpoint(AnalyzeRequest(url=u, page_title=req.subject)) | |
| for u in urls_to_check | |
| ] | |
| results = await asyncio.gather(*tasks, return_exceptions=True) | |
| max_prob = 0.0 | |
| phishing_detected = False | |
| flagged_urls = [] | |
| for idx, r in enumerate(results): | |
| if isinstance(r, Exception): | |
| continue | |
| prob = r.get("confidence", 0.0) | |
| max_prob = max(max_prob, prob) | |
| if r.get("is_phishing"): | |
| phishing_detected = True | |
| flagged_urls.append(r.get("url", urls_to_check[idx])) | |
| return { | |
| "status": "blocked" if phishing_detected else "safe", | |
| "analysis": { | |
| "isPhishing": phishing_detected, | |
| "probability": max_prob, | |
| "flagged_urls": flagged_urls, | |
| "reason": "URL analysis via ML ensemble", | |
| }, | |
| } | |
| # ββ POST /retrain β Incremental retraining ββββββββββββββββββββββββββββ | |
| async def retrain_endpoint(req: RetrainRequest) -> dict: | |
| """ | |
| Receive labeled feedback and incrementally update all models. | |
| Uses asyncio.Lock() to prevent concurrent retraining jobs. | |
| Timeout: 600s max. | |
| """ | |
| if _retrain_service is None: | |
| return {"status": "error", "message": "Retraining service not initialized"} | |
| # Prevent concurrent retraining | |
| if _retrain_lock.locked(): | |
| return { | |
| "status": "skipped", | |
| "message": "Retraining already in progress", | |
| "models_updated": [], | |
| } | |
| async with _retrain_lock: | |
| # Convert Pydantic models to FeedbackRecord dataclasses | |
| records = [ | |
| FeedbackRecord( | |
| url=s.url, | |
| verdict=s.verdict, | |
| confidence=s.confidence, | |
| tier_used=s.tier_used, | |
| heuristic_score=s.heuristic_score, | |
| signals=s.signals, | |
| user_feedback=s.user_feedback, | |
| timestamp=s.timestamp, | |
| feedback_ts=s.feedback_ts, | |
| url_hash=s.url_hash, | |
| session_id=s.session_id, | |
| ) | |
| for s in req.samples | |
| ] | |
| try: | |
| result = await asyncio.wait_for( | |
| _retrain_service.retrain(records), | |
| timeout=600, | |
| ) | |
| # Clear URL cache after retraining (stale results) | |
| if result.status == "success": | |
| _url_cache.clear() | |
| return { | |
| "status": result.status, | |
| "models_updated": result.models_updated, | |
| "samples_used": result.samples_used, | |
| "duration_seconds": result.duration_seconds, | |
| "accuracy_delta": result.accuracy_delta, | |
| "next_retrain_hint": result.next_retrain_hint, | |
| } | |
| except asyncio.TimeoutError: | |
| return { | |
| "status": "error", | |
| "message": "Retraining timed out (600s limit)", | |
| } | |
| except Exception as e: | |
| logger.error(f"Retrain endpoint error: {e}") | |
| return { | |
| "status": "error", | |
| "message": str(e), | |
| } | |
| # ββ GET /model_version ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def model_version_endpoint() -> dict: | |
| """Return current model version info for extension polling.""" | |
| if _retrain_service: | |
| return _retrain_service.get_version_info() | |
| return {"version": 0, "updated_at": None, "accuracy": {}} | |
| # ββ GET /health βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def health_endpoint() -> dict: | |
| """Liveness probe with per-tier readiness and model statuses.""" | |
| return { | |
| "status": "ok", | |
| "version": "3.0", | |
| "tier1": True, | |
| "tier2": _scorer is not None, | |
| "tier3": _tier3 is not None, | |
| "tier4": _cnn is not None and _cnn.is_loaded if _cnn else False, | |
| "retraining_in_progress": _retrain_lock.locked(), | |
| "model_version": _retrain_service.model_version if _retrain_service else 0, | |
| "modules": { | |
| "heuristic": _scorer is not None, | |
| "bert": _bert is not None and _bert.is_loaded, | |
| "bert_lazy": _bert is not None and not _bert.is_loaded, | |
| "gnn": _gnn is not None and _gnn.is_loaded if _gnn else False, | |
| "cnn": _cnn is not None and _cnn.is_loaded if _cnn else False, | |
| "brand_hash": _brand is not None, | |
| }, | |
| } | |
| # ββ Legacy feedback endpoint (backward compat) βββββββββββββββββββββββ | |
| async def legacy_feedback_endpoint(req: dict) -> dict: | |
| """Legacy feedback endpoint for backward compatibility.""" | |
| return {"status": "success", "message": "Use POST /retrain for feedback-driven retraining"} | |
| # ββ Run directly ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # uvicorn main:app --reload --port 8000 | |