Spaces:
Running
Running
| # ============================================================ | |
| # PhishGuard AI - feedback_store.py | |
| # Thread-safe feedback storage, retraining trigger, analytics. | |
| # | |
| # Storage: feedback_data.jsonl (append-only, one JSON per line) | |
| # Lock: asyncio.Lock prevents concurrent writes & double-retrain | |
| # ============================================================ | |
| from __future__ import annotations | |
| import os | |
| import json | |
| import time | |
| import asyncio | |
| import shutil | |
| import logging | |
| from datetime import datetime, timezone | |
| from typing import Optional | |
| logger = logging.getLogger("phishguard.feedback") | |
| _BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| FEEDBACK_FILE = os.path.join(_BASE_DIR, "feedback_data.jsonl") | |
| STATE_FILE = os.path.join(_BASE_DIR, "retrain_state.json") | |
| # ββ Async lock for thread-safe writes ββββββββββββββββββββββββββββββββββββββββ | |
| _write_lock = asyncio.Lock() | |
| # ββ Retrain state (persisted to retrain_state.json) ββββββββββββββββββββββββββ | |
| _retrain_state = { | |
| "model_version": 1, | |
| "total_feedback": 0, | |
| "unprocessed_count": 0, | |
| "phishing_corrections": 0, | |
| "safe_corrections": 0, | |
| "last_retrain": None, # ISO 8601 timestamp | |
| "retrain_history": [], # [{ts, samples, accuracy, version}] | |
| } | |
| def _load_state(): | |
| """Load persisted retrain state from disk.""" | |
| global _retrain_state | |
| if os.path.exists(STATE_FILE): | |
| try: | |
| with open(STATE_FILE, "r") as f: | |
| saved = json.load(f) | |
| _retrain_state.update(saved) | |
| logger.info(f"[FeedbackStore] State loaded | version={_retrain_state['model_version']} | total={_retrain_state['total_feedback']}") | |
| except Exception as e: | |
| logger.warning(f"[FeedbackStore] Could not load state: {e}") | |
| def _save_state(): | |
| """Persist retrain state to disk (atomic write).""" | |
| try: | |
| tmp = STATE_FILE + ".tmp" | |
| with open(tmp, "w") as f: | |
| json.dump(_retrain_state, f, indent=2, default=str) | |
| os.replace(tmp, STATE_FILE) | |
| except Exception as e: | |
| logger.warning(f"[FeedbackStore] Could not save state: {e}") | |
| # Load state on module import | |
| _load_state() | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # FEEDBACK STORAGE | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def append_feedback( | |
| url: str, | |
| label: str, | |
| source: str = "user_feedback", | |
| original_prediction: Optional[float] = None, | |
| ) -> dict: | |
| """ | |
| Thread-safe append of a feedback entry to feedback_data.jsonl. | |
| Returns: {"success": True, "feedback_count": N, "unprocessed": M} | |
| """ | |
| entry = { | |
| "url": url, | |
| "label": label, # "phishing" or "safe" | |
| "timestamp": datetime.now(timezone.utc).isoformat(), | |
| "source": source, | |
| "original_prediction": round(original_prediction, 4) if original_prediction is not None else None, | |
| } | |
| async with _write_lock: | |
| try: | |
| with open(FEEDBACK_FILE, "a") as f: | |
| f.write(json.dumps(entry) + "\n") | |
| except Exception as e: | |
| logger.error(f"[FeedbackStore] Write failed: {e}") | |
| return {"success": False, "error": str(e)} | |
| # Update in-memory state | |
| _retrain_state["total_feedback"] += 1 | |
| _retrain_state["unprocessed_count"] += 1 | |
| if label == "phishing": | |
| _retrain_state["phishing_corrections"] += 1 | |
| elif label == "safe": | |
| _retrain_state["safe_corrections"] += 1 | |
| _save_state() | |
| logger.info(f"[FeedbackStore] Saved | url={url} | label={label} | total={_retrain_state['total_feedback']}") | |
| return { | |
| "success": True, | |
| "feedback_count": _retrain_state["total_feedback"], | |
| "unprocessed": _retrain_state["unprocessed_count"], | |
| } | |
| def get_unprocessed_count() -> int: | |
| """Number of feedback entries since last retraining.""" | |
| return _retrain_state["unprocessed_count"] | |
| def get_model_version() -> int: | |
| """Current model version number.""" | |
| return _retrain_state["model_version"] | |
| def get_stats() -> dict: | |
| """Return feedback analytics for the /feedback/stats endpoint.""" | |
| return { | |
| "total_feedback": _retrain_state["total_feedback"], | |
| "phishing_corrections": _retrain_state["phishing_corrections"], | |
| "safe_corrections": _retrain_state["safe_corrections"], | |
| "unprocessed_count": _retrain_state["unprocessed_count"], | |
| "last_retrain": _retrain_state["last_retrain"], | |
| "model_version": _retrain_state["model_version"], | |
| "retrain_history": _retrain_state["retrain_history"][-10:], # last 10 | |
| } | |
| def get_recent_entries(n: int = 50) -> list: | |
| """Read the last N feedback entries from the JSONL file.""" | |
| if not os.path.exists(FEEDBACK_FILE): | |
| return [] | |
| try: | |
| with open(FEEDBACK_FILE, "r") as f: | |
| lines = f.readlines() | |
| entries = [] | |
| for line in lines[-(n):]: | |
| line = line.strip() | |
| if line: | |
| entries.append(json.loads(line)) | |
| return entries | |
| except Exception: | |
| return [] | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # RETRAINING PIPELINE | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| RETRAIN_THRESHOLD = 50 | |
| _retrain_running = False | |
| def should_retrain() -> bool: | |
| """Check if retraining should be triggered.""" | |
| return ( | |
| _retrain_state["unprocessed_count"] >= RETRAIN_THRESHOLD | |
| and not _retrain_running | |
| ) | |
| def mark_retrain_complete(samples: int, accuracy: float): | |
| """ | |
| Called after successful retraining. | |
| Increments model_version, resets unprocessed counter, logs history. | |
| """ | |
| _retrain_state["model_version"] += 1 | |
| _retrain_state["unprocessed_count"] = 0 | |
| _retrain_state["last_retrain"] = datetime.now(timezone.utc).isoformat() | |
| _retrain_state["retrain_history"].append({ | |
| "timestamp": _retrain_state["last_retrain"], | |
| "samples": samples, | |
| "accuracy": round(accuracy, 4), | |
| "version": _retrain_state["model_version"], | |
| }) | |
| # Keep only last 50 history entries | |
| if len(_retrain_state["retrain_history"]) > 50: | |
| _retrain_state["retrain_history"] = _retrain_state["retrain_history"][-50:] | |
| _save_state() | |
| logger.info( | |
| f"[FeedbackStore] Retrained on {samples} feedback samples. " | |
| f"New accuracy: {accuracy:.2%}. Model version: {_retrain_state['model_version']}" | |
| ) | |
| def archive_feedback_file(): | |
| """Move the processed feedback file to a timestamped backup.""" | |
| if os.path.exists(FEEDBACK_FILE): | |
| archive = FEEDBACK_FILE + f".{int(time.time())}.bak" | |
| try: | |
| shutil.move(FEEDBACK_FILE, archive) | |
| logger.info(f"[FeedbackStore] Archived feedback β {archive}") | |
| except Exception as e: | |
| logger.warning(f"[FeedbackStore] Archive failed: {e}") | |
| def load_feedback_entries() -> list: | |
| """Load ALL entries from the feedback JSONL file.""" | |
| if not os.path.exists(FEEDBACK_FILE): | |
| return [] | |
| entries = [] | |
| try: | |
| with open(FEEDBACK_FILE, "r") as f: | |
| for line in f: | |
| line = line.strip() | |
| if line: | |
| entries.append(json.loads(line)) | |
| except Exception as e: | |
| logger.error(f"[FeedbackStore] Read failed: {e}") | |
| return entries | |