Spaces:
Sleeping
Sleeping
File size: 8,902 Bytes
4fd1054 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 | """
database/feedback.py
====================
Persistence layer for RL-style feedback and experience replay.
Tables
------
- ``feedback`` — one row per user-submitted feedback event (text, predicted
score, user rating, optional LLM reward, computed reward scalar).
- ``experience`` — the same data shaped as an experience-replay buffer that
``training/retrain.py`` can query to build a fine-tuning dataset.
Both tables are stored in the same SQLite file as the main user/session DB
(``stress_detection.db``) so that a single file contains the whole
application's state.
"""
from __future__ import annotations
import json
import logging
import os
import sqlite3
import time
from typing import Any, Optional
logger = logging.getLogger(__name__)
_DEFAULT_DB_PATH = os.environ.get("STRESS_DB_PATH", "stress_detection.db")
class FeedbackStore:
"""Thin SQLite wrapper for feedback storage and experience replay.
Parameters
----------
db_path : str
File path for the SQLite database. Pass ``":memory:"`` for
ephemeral in-memory storage (useful in tests).
"""
def __init__(self, db_path: str = _DEFAULT_DB_PATH) -> None:
self._db_path = db_path
self._conn = sqlite3.connect(db_path, check_same_thread=False)
self._conn.row_factory = sqlite3.Row
self._conn.execute("PRAGMA journal_mode=WAL")
self._create_tables()
# ------------------------------------------------------------------
# Schema
# ------------------------------------------------------------------
def _create_tables(self) -> None:
"""Create feedback tables if they do not already exist."""
self._conn.executescript(
"""
CREATE TABLE IF NOT EXISTS feedback (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL,
text TEXT NOT NULL,
prediction REAL NOT NULL,
user_feedback INTEGER NOT NULL, -- 1 = correct, 0 = wrong
llm_reward INTEGER, -- +1 / -1 / NULL
reward REAL NOT NULL, -- final combined reward
created_at REAL NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_feedback_username
ON feedback(username);
CREATE INDEX IF NOT EXISTS idx_feedback_created_at
ON feedback(created_at);
CREATE TABLE IF NOT EXISTS experience (
id INTEGER PRIMARY KEY AUTOINCREMENT,
text TEXT NOT NULL,
label INTEGER NOT NULL, -- corrected label
reward REAL NOT NULL, -- sample weight for training
source TEXT NOT NULL DEFAULT 'feedback',
created_at REAL NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_experience_created_at
ON experience(created_at);
"""
)
# ------------------------------------------------------------------
# Feedback CRUD
# ------------------------------------------------------------------
def save_feedback(
self,
username: str,
text: str,
prediction: float,
user_feedback: int,
reward: float,
llm_reward: Optional[int] = None,
) -> int:
"""Persist one feedback event and derive a corrected training sample.
The corrected label is:
- ``round(prediction)`` when ``user_feedback == 1`` (prediction was right).
- ``1 - round(prediction)`` when ``user_feedback == 0`` (prediction was wrong).
The corrected sample is also inserted into ``experience`` so that
``training/retrain.py`` can build a dataset without joining tables.
Parameters
----------
username : str
User who submitted the feedback.
text : str
Original input text that was analysed.
prediction : float
Raw stress probability returned by the model (0–1).
user_feedback : int
1 if the prediction was correct, 0 if it was wrong.
reward : float
Computed reward scalar (e.g. from ``utils.reward``).
llm_reward : int | None
Optional reward from an LLM judge (+1 / -1 / None).
Returns
-------
int
Row id of the newly inserted feedback row.
"""
now = time.time()
cur = self._conn.execute(
"INSERT INTO feedback "
"(username, text, prediction, user_feedback, llm_reward, reward, created_at) "
"VALUES (?, ?, ?, ?, ?, ?, ?)",
(username, text, prediction, user_feedback, llm_reward, reward, now),
)
feedback_id = cur.lastrowid
# Derive corrected label for experience replay
predicted_class = int(round(prediction))
corrected_label = predicted_class if user_feedback == 1 else 1 - predicted_class
self._conn.execute(
"INSERT INTO experience (text, label, reward, source, created_at) "
"VALUES (?, ?, ?, 'feedback', ?)",
(text, corrected_label, abs(reward), now),
)
self._conn.commit()
return feedback_id # type: ignore[return-value]
# ------------------------------------------------------------------
# Queries
# ------------------------------------------------------------------
def get_all_feedback(
self,
limit: int = 100,
offset: int = 0,
) -> list[dict[str, Any]]:
"""Return feedback rows ordered newest-first."""
rows = self._conn.execute(
"SELECT id, username, text, prediction, user_feedback, "
"llm_reward, reward, created_at "
"FROM feedback ORDER BY created_at DESC LIMIT ? OFFSET ?",
(limit, offset),
).fetchall()
return [dict(r) for r in rows]
def get_user_stats(self, username: str) -> dict[str, Any]:
"""Return aggregated feedback statistics for one user."""
row = self._conn.execute(
"SELECT COUNT(*) as total, "
"AVG(reward) as mean_reward, "
"SUM(CASE WHEN user_feedback=1 THEN 1 ELSE 0 END) as n_correct, "
"SUM(CASE WHEN user_feedback=0 THEN 1 ELSE 0 END) as n_wrong "
"FROM feedback WHERE username = ?",
(username,),
).fetchone()
if row is None or row["total"] == 0:
return {
"total": 0,
"mean_reward": 0.0,
"n_correct": 0,
"n_wrong": 0,
"accuracy_rate": 0.0,
}
total = row["total"]
n_correct = row["n_correct"] or 0
return {
"total": total,
"mean_reward": float(row["mean_reward"] or 0.0),
"n_correct": n_correct,
"n_wrong": row["n_wrong"] or 0,
"accuracy_rate": n_correct / total if total > 0 else 0.0,
}
def get_experience_for_training(
self,
min_samples: int = 1,
limit: int = 10_000,
) -> list[dict[str, Any]]:
"""Return experience rows suitable for building a training dataset.
Parameters
----------
min_samples : int
Return an empty list when fewer than this many rows exist
(avoids retraining on negligible data).
limit : int
Maximum rows to return.
Returns
-------
list of dict with keys: ``text``, ``label``, ``reward``.
"""
count_row = self._conn.execute(
"SELECT COUNT(*) as cnt FROM experience"
).fetchone()
if (count_row["cnt"] or 0) < min_samples:
return []
rows = self._conn.execute(
"SELECT text, label, reward FROM experience "
"ORDER BY created_at DESC LIMIT ?",
(limit,),
).fetchall()
return [dict(r) for r in rows]
def get_feedback_count(self, username: Optional[str] = None) -> int:
"""Return the total number of feedback rows (optionally per user)."""
if username is not None:
row = self._conn.execute(
"SELECT COUNT(*) as cnt FROM feedback WHERE username = ?",
(username,),
).fetchone()
else:
row = self._conn.execute(
"SELECT COUNT(*) as cnt FROM feedback"
).fetchone()
return row["cnt"] if row else 0
# ------------------------------------------------------------------
# Lifecycle
# ------------------------------------------------------------------
def close(self) -> None:
"""Close the database connection."""
self._conn.close()
|