phishguard-api / feedback_store.py
prashanth135's picture
Upload 38 files
bebe233 verified
# ============================================================
# PhishGuard AI - feedback_store.py
# Thread-safe feedback storage, retraining trigger, analytics.
#
# Storage: feedback_data.jsonl (append-only, one JSON per line)
# Lock: asyncio.Lock prevents concurrent writes & double-retrain
# ============================================================
from __future__ import annotations
import os
import json
import time
import asyncio
import shutil
import logging
from datetime import datetime, timezone
from typing import Optional
logger = logging.getLogger("phishguard.feedback")
_BASE_DIR = os.path.dirname(os.path.abspath(__file__))
FEEDBACK_FILE = os.path.join(_BASE_DIR, "feedback_data.jsonl")
STATE_FILE = os.path.join(_BASE_DIR, "retrain_state.json")
# ── Async lock for thread-safe writes ────────────────────────────────────────
_write_lock = asyncio.Lock()
# ── Retrain state (persisted to retrain_state.json) ──────────────────────────
_retrain_state = {
"model_version": 1,
"total_feedback": 0,
"unprocessed_count": 0,
"phishing_corrections": 0,
"safe_corrections": 0,
"last_retrain": None, # ISO 8601 timestamp
"retrain_history": [], # [{ts, samples, accuracy, version}]
}
def _load_state():
"""Load persisted retrain state from disk."""
global _retrain_state
if os.path.exists(STATE_FILE):
try:
with open(STATE_FILE, "r") as f:
saved = json.load(f)
_retrain_state.update(saved)
logger.info(f"[FeedbackStore] State loaded | version={_retrain_state['model_version']} | total={_retrain_state['total_feedback']}")
except Exception as e:
logger.warning(f"[FeedbackStore] Could not load state: {e}")
def _save_state():
"""Persist retrain state to disk (atomic write)."""
try:
tmp = STATE_FILE + ".tmp"
with open(tmp, "w") as f:
json.dump(_retrain_state, f, indent=2, default=str)
os.replace(tmp, STATE_FILE)
except Exception as e:
logger.warning(f"[FeedbackStore] Could not save state: {e}")
# Load state on module import
_load_state()
# ══════════════════════════════════════════════════════════════════════════════
# FEEDBACK STORAGE
# ══════════════════════════════════════════════════════════════════════════════
async def append_feedback(
url: str,
label: str,
source: str = "user_feedback",
original_prediction: Optional[float] = None,
) -> dict:
"""
Thread-safe append of a feedback entry to feedback_data.jsonl.
Returns: {"success": True, "feedback_count": N, "unprocessed": M}
"""
entry = {
"url": url,
"label": label, # "phishing" or "safe"
"timestamp": datetime.now(timezone.utc).isoformat(),
"source": source,
"original_prediction": round(original_prediction, 4) if original_prediction is not None else None,
}
async with _write_lock:
try:
with open(FEEDBACK_FILE, "a") as f:
f.write(json.dumps(entry) + "\n")
except Exception as e:
logger.error(f"[FeedbackStore] Write failed: {e}")
return {"success": False, "error": str(e)}
# Update in-memory state
_retrain_state["total_feedback"] += 1
_retrain_state["unprocessed_count"] += 1
if label == "phishing":
_retrain_state["phishing_corrections"] += 1
elif label == "safe":
_retrain_state["safe_corrections"] += 1
_save_state()
logger.info(f"[FeedbackStore] Saved | url={url} | label={label} | total={_retrain_state['total_feedback']}")
return {
"success": True,
"feedback_count": _retrain_state["total_feedback"],
"unprocessed": _retrain_state["unprocessed_count"],
}
def get_unprocessed_count() -> int:
"""Number of feedback entries since last retraining."""
return _retrain_state["unprocessed_count"]
def get_model_version() -> int:
"""Current model version number."""
return _retrain_state["model_version"]
def get_stats() -> dict:
"""Return feedback analytics for the /feedback/stats endpoint."""
return {
"total_feedback": _retrain_state["total_feedback"],
"phishing_corrections": _retrain_state["phishing_corrections"],
"safe_corrections": _retrain_state["safe_corrections"],
"unprocessed_count": _retrain_state["unprocessed_count"],
"last_retrain": _retrain_state["last_retrain"],
"model_version": _retrain_state["model_version"],
"retrain_history": _retrain_state["retrain_history"][-10:], # last 10
}
def get_recent_entries(n: int = 50) -> list:
"""Read the last N feedback entries from the JSONL file."""
if not os.path.exists(FEEDBACK_FILE):
return []
try:
with open(FEEDBACK_FILE, "r") as f:
lines = f.readlines()
entries = []
for line in lines[-(n):]:
line = line.strip()
if line:
entries.append(json.loads(line))
return entries
except Exception:
return []
# ══════════════════════════════════════════════════════════════════════════════
# RETRAINING PIPELINE
# ══════════════════════════════════════════════════════════════════════════════
RETRAIN_THRESHOLD = 50
_retrain_running = False
def should_retrain() -> bool:
"""Check if retraining should be triggered."""
return (
_retrain_state["unprocessed_count"] >= RETRAIN_THRESHOLD
and not _retrain_running
)
def mark_retrain_complete(samples: int, accuracy: float):
"""
Called after successful retraining.
Increments model_version, resets unprocessed counter, logs history.
"""
_retrain_state["model_version"] += 1
_retrain_state["unprocessed_count"] = 0
_retrain_state["last_retrain"] = datetime.now(timezone.utc).isoformat()
_retrain_state["retrain_history"].append({
"timestamp": _retrain_state["last_retrain"],
"samples": samples,
"accuracy": round(accuracy, 4),
"version": _retrain_state["model_version"],
})
# Keep only last 50 history entries
if len(_retrain_state["retrain_history"]) > 50:
_retrain_state["retrain_history"] = _retrain_state["retrain_history"][-50:]
_save_state()
logger.info(
f"[FeedbackStore] Retrained on {samples} feedback samples. "
f"New accuracy: {accuracy:.2%}. Model version: {_retrain_state['model_version']}"
)
def archive_feedback_file():
"""Move the processed feedback file to a timestamped backup."""
if os.path.exists(FEEDBACK_FILE):
archive = FEEDBACK_FILE + f".{int(time.time())}.bak"
try:
shutil.move(FEEDBACK_FILE, archive)
logger.info(f"[FeedbackStore] Archived feedback β†’ {archive}")
except Exception as e:
logger.warning(f"[FeedbackStore] Archive failed: {e}")
def load_feedback_entries() -> list:
"""Load ALL entries from the feedback JSONL file."""
if not os.path.exists(FEEDBACK_FILE):
return []
entries = []
try:
with open(FEEDBACK_FILE, "r") as f:
for line in f:
line = line.strip()
if line:
entries.append(json.loads(line))
except Exception as e:
logger.error(f"[FeedbackStore] Read failed: {e}")
return entries