File size: 8,902 Bytes
4fd1054
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
"""
database/feedback.py
====================
Persistence layer for RL-style feedback and experience replay.

Tables
------
- ``feedback``   — one row per user-submitted feedback event (text, predicted
  score, user rating, optional LLM reward, computed reward scalar).
- ``experience`` — the same data shaped as an experience-replay buffer that
  ``training/retrain.py`` can query to build a fine-tuning dataset.

Both tables are stored in the same SQLite file as the main user/session DB
(``stress_detection.db``) so that a single file contains the whole
application's state.
"""

from __future__ import annotations

import json
import logging
import os
import sqlite3
import time
from typing import Any, Optional

logger = logging.getLogger(__name__)

_DEFAULT_DB_PATH = os.environ.get("STRESS_DB_PATH", "stress_detection.db")


class FeedbackStore:
    """Thin SQLite wrapper for feedback storage and experience replay.

    Parameters
    ----------
    db_path : str
        File path for the SQLite database.  Pass ``":memory:"`` for
        ephemeral in-memory storage (useful in tests).
    """

    def __init__(self, db_path: str = _DEFAULT_DB_PATH) -> None:
        self._db_path = db_path
        self._conn = sqlite3.connect(db_path, check_same_thread=False)
        self._conn.row_factory = sqlite3.Row
        self._conn.execute("PRAGMA journal_mode=WAL")
        self._create_tables()

    # ------------------------------------------------------------------
    # Schema
    # ------------------------------------------------------------------

    def _create_tables(self) -> None:
        """Create feedback tables if they do not already exist."""
        self._conn.executescript(
            """
            CREATE TABLE IF NOT EXISTS feedback (
                id              INTEGER PRIMARY KEY AUTOINCREMENT,
                username        TEXT    NOT NULL,
                text            TEXT    NOT NULL,
                prediction      REAL    NOT NULL,
                user_feedback   INTEGER NOT NULL,  -- 1 = correct, 0 = wrong
                llm_reward      INTEGER,            -- +1 / -1 / NULL
                reward          REAL    NOT NULL,   -- final combined reward
                created_at      REAL    NOT NULL
            );

            CREATE INDEX IF NOT EXISTS idx_feedback_username
                ON feedback(username);
            CREATE INDEX IF NOT EXISTS idx_feedback_created_at
                ON feedback(created_at);

            CREATE TABLE IF NOT EXISTS experience (
                id              INTEGER PRIMARY KEY AUTOINCREMENT,
                text            TEXT    NOT NULL,
                label           INTEGER NOT NULL,  -- corrected label
                reward          REAL    NOT NULL,  -- sample weight for training
                source          TEXT    NOT NULL DEFAULT 'feedback',
                created_at      REAL    NOT NULL
            );

            CREATE INDEX IF NOT EXISTS idx_experience_created_at
                ON experience(created_at);
            """
        )

    # ------------------------------------------------------------------
    # Feedback CRUD
    # ------------------------------------------------------------------

    def save_feedback(
        self,
        username: str,
        text: str,
        prediction: float,
        user_feedback: int,
        reward: float,
        llm_reward: Optional[int] = None,
    ) -> int:
        """Persist one feedback event and derive a corrected training sample.

        The corrected label is:
        - ``round(prediction)`` when ``user_feedback == 1`` (prediction was right).
        - ``1 - round(prediction)`` when ``user_feedback == 0`` (prediction was wrong).

        The corrected sample is also inserted into ``experience`` so that
        ``training/retrain.py`` can build a dataset without joining tables.

        Parameters
        ----------
        username : str
            User who submitted the feedback.
        text : str
            Original input text that was analysed.
        prediction : float
            Raw stress probability returned by the model (0–1).
        user_feedback : int
            1 if the prediction was correct, 0 if it was wrong.
        reward : float
            Computed reward scalar (e.g. from ``utils.reward``).
        llm_reward : int | None
            Optional reward from an LLM judge (+1 / -1 / None).

        Returns
        -------
        int
            Row id of the newly inserted feedback row.
        """
        now = time.time()

        cur = self._conn.execute(
            "INSERT INTO feedback "
            "(username, text, prediction, user_feedback, llm_reward, reward, created_at) "
            "VALUES (?, ?, ?, ?, ?, ?, ?)",
            (username, text, prediction, user_feedback, llm_reward, reward, now),
        )
        feedback_id = cur.lastrowid

        # Derive corrected label for experience replay
        predicted_class = int(round(prediction))
        corrected_label = predicted_class if user_feedback == 1 else 1 - predicted_class

        self._conn.execute(
            "INSERT INTO experience (text, label, reward, source, created_at) "
            "VALUES (?, ?, ?, 'feedback', ?)",
            (text, corrected_label, abs(reward), now),
        )
        self._conn.commit()
        return feedback_id  # type: ignore[return-value]

    # ------------------------------------------------------------------
    # Queries
    # ------------------------------------------------------------------

    def get_all_feedback(
        self,
        limit: int = 100,
        offset: int = 0,
    ) -> list[dict[str, Any]]:
        """Return feedback rows ordered newest-first."""
        rows = self._conn.execute(
            "SELECT id, username, text, prediction, user_feedback, "
            "llm_reward, reward, created_at "
            "FROM feedback ORDER BY created_at DESC LIMIT ? OFFSET ?",
            (limit, offset),
        ).fetchall()
        return [dict(r) for r in rows]

    def get_user_stats(self, username: str) -> dict[str, Any]:
        """Return aggregated feedback statistics for one user."""
        row = self._conn.execute(
            "SELECT COUNT(*) as total, "
            "AVG(reward) as mean_reward, "
            "SUM(CASE WHEN user_feedback=1 THEN 1 ELSE 0 END) as n_correct, "
            "SUM(CASE WHEN user_feedback=0 THEN 1 ELSE 0 END) as n_wrong "
            "FROM feedback WHERE username = ?",
            (username,),
        ).fetchone()
        if row is None or row["total"] == 0:
            return {
                "total": 0,
                "mean_reward": 0.0,
                "n_correct": 0,
                "n_wrong": 0,
                "accuracy_rate": 0.0,
            }

        total = row["total"]
        n_correct = row["n_correct"] or 0
        return {
            "total": total,
            "mean_reward": float(row["mean_reward"] or 0.0),
            "n_correct": n_correct,
            "n_wrong": row["n_wrong"] or 0,
            "accuracy_rate": n_correct / total if total > 0 else 0.0,
        }

    def get_experience_for_training(
        self,
        min_samples: int = 1,
        limit: int = 10_000,
    ) -> list[dict[str, Any]]:
        """Return experience rows suitable for building a training dataset.

        Parameters
        ----------
        min_samples : int
            Return an empty list when fewer than this many rows exist
            (avoids retraining on negligible data).
        limit : int
            Maximum rows to return.

        Returns
        -------
        list of dict with keys: ``text``, ``label``, ``reward``.
        """
        count_row = self._conn.execute(
            "SELECT COUNT(*) as cnt FROM experience"
        ).fetchone()
        if (count_row["cnt"] or 0) < min_samples:
            return []

        rows = self._conn.execute(
            "SELECT text, label, reward FROM experience "
            "ORDER BY created_at DESC LIMIT ?",
            (limit,),
        ).fetchall()
        return [dict(r) for r in rows]

    def get_feedback_count(self, username: Optional[str] = None) -> int:
        """Return the total number of feedback rows (optionally per user)."""
        if username is not None:
            row = self._conn.execute(
                "SELECT COUNT(*) as cnt FROM feedback WHERE username = ?",
                (username,),
            ).fetchone()
        else:
            row = self._conn.execute(
                "SELECT COUNT(*) as cnt FROM feedback"
            ).fetchone()
        return row["cnt"] if row else 0

    # ------------------------------------------------------------------
    # Lifecycle
    # ------------------------------------------------------------------

    def close(self) -> None:
        """Close the database connection."""
        self._conn.close()