File size: 8,344 Bytes
bebe233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
# ============================================================
# PhishGuard AI - feedback_store.py
# Thread-safe feedback storage, retraining trigger, analytics.
#
# Storage: feedback_data.jsonl (append-only, one JSON per line)
# Lock: asyncio.Lock prevents concurrent writes & double-retrain
# ============================================================

from __future__ import annotations

import os
import json
import time
import asyncio
import shutil
import logging
from datetime import datetime, timezone
from typing import Optional

logger = logging.getLogger("phishguard.feedback")

_BASE_DIR     = os.path.dirname(os.path.abspath(__file__))
FEEDBACK_FILE = os.path.join(_BASE_DIR, "feedback_data.jsonl")
STATE_FILE    = os.path.join(_BASE_DIR, "retrain_state.json")

# ── Async lock for thread-safe writes ────────────────────────────────────────
_write_lock   = asyncio.Lock()

# ── Retrain state (persisted to retrain_state.json) ──────────────────────────
_retrain_state = {
    "model_version":       1,
    "total_feedback":      0,
    "unprocessed_count":   0,
    "phishing_corrections": 0,
    "safe_corrections":    0,
    "last_retrain":        None,      # ISO 8601 timestamp
    "retrain_history":     [],        # [{ts, samples, accuracy, version}]
}


def _load_state():
    """Load persisted retrain state from disk."""
    global _retrain_state
    if os.path.exists(STATE_FILE):
        try:
            with open(STATE_FILE, "r") as f:
                saved = json.load(f)
            _retrain_state.update(saved)
            logger.info(f"[FeedbackStore] State loaded | version={_retrain_state['model_version']} | total={_retrain_state['total_feedback']}")
        except Exception as e:
            logger.warning(f"[FeedbackStore] Could not load state: {e}")


def _save_state():
    """Persist retrain state to disk (atomic write)."""
    try:
        tmp = STATE_FILE + ".tmp"
        with open(tmp, "w") as f:
            json.dump(_retrain_state, f, indent=2, default=str)
        os.replace(tmp, STATE_FILE)
    except Exception as e:
        logger.warning(f"[FeedbackStore] Could not save state: {e}")


# Load state on module import
_load_state()


# ══════════════════════════════════════════════════════════════════════════════
#  FEEDBACK STORAGE
# ══════════════════════════════════════════════════════════════════════════════

async def append_feedback(
    url: str,
    label: str,
    source: str = "user_feedback",
    original_prediction: Optional[float] = None,
) -> dict:
    """
    Thread-safe append of a feedback entry to feedback_data.jsonl.

    Returns: {"success": True, "feedback_count": N, "unprocessed": M}
    """
    entry = {
        "url":                 url,
        "label":               label,              # "phishing" or "safe"
        "timestamp":           datetime.now(timezone.utc).isoformat(),
        "source":              source,
        "original_prediction": round(original_prediction, 4) if original_prediction is not None else None,
    }

    async with _write_lock:
        try:
            with open(FEEDBACK_FILE, "a") as f:
                f.write(json.dumps(entry) + "\n")
        except Exception as e:
            logger.error(f"[FeedbackStore] Write failed: {e}")
            return {"success": False, "error": str(e)}

        # Update in-memory state
        _retrain_state["total_feedback"]    += 1
        _retrain_state["unprocessed_count"] += 1
        if label == "phishing":
            _retrain_state["phishing_corrections"] += 1
        elif label == "safe":
            _retrain_state["safe_corrections"] += 1

        _save_state()

    logger.info(f"[FeedbackStore] Saved | url={url} | label={label} | total={_retrain_state['total_feedback']}")

    return {
        "success":        True,
        "feedback_count": _retrain_state["total_feedback"],
        "unprocessed":    _retrain_state["unprocessed_count"],
    }


def get_unprocessed_count() -> int:
    """Number of feedback entries since last retraining."""
    return _retrain_state["unprocessed_count"]


def get_model_version() -> int:
    """Current model version number."""
    return _retrain_state["model_version"]


def get_stats() -> dict:
    """Return feedback analytics for the /feedback/stats endpoint."""
    return {
        "total_feedback":       _retrain_state["total_feedback"],
        "phishing_corrections": _retrain_state["phishing_corrections"],
        "safe_corrections":     _retrain_state["safe_corrections"],
        "unprocessed_count":    _retrain_state["unprocessed_count"],
        "last_retrain":         _retrain_state["last_retrain"],
        "model_version":        _retrain_state["model_version"],
        "retrain_history":      _retrain_state["retrain_history"][-10:],  # last 10
    }


def get_recent_entries(n: int = 50) -> list:
    """Read the last N feedback entries from the JSONL file."""
    if not os.path.exists(FEEDBACK_FILE):
        return []
    try:
        with open(FEEDBACK_FILE, "r") as f:
            lines = f.readlines()
        entries = []
        for line in lines[-(n):]:
            line = line.strip()
            if line:
                entries.append(json.loads(line))
        return entries
    except Exception:
        return []


# ══════════════════════════════════════════════════════════════════════════════
#  RETRAINING PIPELINE
# ══════════════════════════════════════════════════════════════════════════════

RETRAIN_THRESHOLD = 50
_retrain_running  = False


def should_retrain() -> bool:
    """Check if retraining should be triggered."""
    return (
        _retrain_state["unprocessed_count"] >= RETRAIN_THRESHOLD
        and not _retrain_running
    )


def mark_retrain_complete(samples: int, accuracy: float):
    """
    Called after successful retraining.
    Increments model_version, resets unprocessed counter, logs history.
    """
    _retrain_state["model_version"]     += 1
    _retrain_state["unprocessed_count"]  = 0
    _retrain_state["last_retrain"]       = datetime.now(timezone.utc).isoformat()
    _retrain_state["retrain_history"].append({
        "timestamp": _retrain_state["last_retrain"],
        "samples":   samples,
        "accuracy":  round(accuracy, 4),
        "version":   _retrain_state["model_version"],
    })
    # Keep only last 50 history entries
    if len(_retrain_state["retrain_history"]) > 50:
        _retrain_state["retrain_history"] = _retrain_state["retrain_history"][-50:]
    _save_state()
    logger.info(
        f"[FeedbackStore] Retrained on {samples} feedback samples. "
        f"New accuracy: {accuracy:.2%}. Model version: {_retrain_state['model_version']}"
    )


def archive_feedback_file():
    """Move the processed feedback file to a timestamped backup."""
    if os.path.exists(FEEDBACK_FILE):
        archive = FEEDBACK_FILE + f".{int(time.time())}.bak"
        try:
            shutil.move(FEEDBACK_FILE, archive)
            logger.info(f"[FeedbackStore] Archived feedback β†’ {archive}")
        except Exception as e:
            logger.warning(f"[FeedbackStore] Archive failed: {e}")


def load_feedback_entries() -> list:
    """Load ALL entries from the feedback JSONL file."""
    if not os.path.exists(FEEDBACK_FILE):
        return []
    entries = []
    try:
        with open(FEEDBACK_FILE, "r") as f:
            for line in f:
                line = line.strip()
                if line:
                    entries.append(json.loads(line))
    except Exception as e:
        logger.error(f"[FeedbackStore] Read failed: {e}")
    return entries