StressDetect / database /feedback.py
Ace-119's picture
Define SQLite schema for users, sessions, and feedback
4fd1054
"""
database/feedback.py
====================
Persistence layer for RL-style feedback and experience replay.
Tables
------
- ``feedback`` — one row per user-submitted feedback event (text, predicted
score, user rating, optional LLM reward, computed reward scalar).
- ``experience`` — the same data shaped as an experience-replay buffer that
``training/retrain.py`` can query to build a fine-tuning dataset.
Both tables are stored in the same SQLite file as the main user/session DB
(``stress_detection.db``) so that a single file contains the whole
application's state.
"""
from __future__ import annotations
import json
import logging
import os
import sqlite3
import time
from typing import Any, Optional
logger = logging.getLogger(__name__)
_DEFAULT_DB_PATH = os.environ.get("STRESS_DB_PATH", "stress_detection.db")
class FeedbackStore:
"""Thin SQLite wrapper for feedback storage and experience replay.
Parameters
----------
db_path : str
File path for the SQLite database. Pass ``":memory:"`` for
ephemeral in-memory storage (useful in tests).
"""
def __init__(self, db_path: str = _DEFAULT_DB_PATH) -> None:
self._db_path = db_path
self._conn = sqlite3.connect(db_path, check_same_thread=False)
self._conn.row_factory = sqlite3.Row
self._conn.execute("PRAGMA journal_mode=WAL")
self._create_tables()
# ------------------------------------------------------------------
# Schema
# ------------------------------------------------------------------
def _create_tables(self) -> None:
"""Create feedback tables if they do not already exist."""
self._conn.executescript(
"""
CREATE TABLE IF NOT EXISTS feedback (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL,
text TEXT NOT NULL,
prediction REAL NOT NULL,
user_feedback INTEGER NOT NULL, -- 1 = correct, 0 = wrong
llm_reward INTEGER, -- +1 / -1 / NULL
reward REAL NOT NULL, -- final combined reward
created_at REAL NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_feedback_username
ON feedback(username);
CREATE INDEX IF NOT EXISTS idx_feedback_created_at
ON feedback(created_at);
CREATE TABLE IF NOT EXISTS experience (
id INTEGER PRIMARY KEY AUTOINCREMENT,
text TEXT NOT NULL,
label INTEGER NOT NULL, -- corrected label
reward REAL NOT NULL, -- sample weight for training
source TEXT NOT NULL DEFAULT 'feedback',
created_at REAL NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_experience_created_at
ON experience(created_at);
"""
)
# ------------------------------------------------------------------
# Feedback CRUD
# ------------------------------------------------------------------
def save_feedback(
self,
username: str,
text: str,
prediction: float,
user_feedback: int,
reward: float,
llm_reward: Optional[int] = None,
) -> int:
"""Persist one feedback event and derive a corrected training sample.
The corrected label is:
- ``round(prediction)`` when ``user_feedback == 1`` (prediction was right).
- ``1 - round(prediction)`` when ``user_feedback == 0`` (prediction was wrong).
The corrected sample is also inserted into ``experience`` so that
``training/retrain.py`` can build a dataset without joining tables.
Parameters
----------
username : str
User who submitted the feedback.
text : str
Original input text that was analysed.
prediction : float
Raw stress probability returned by the model (0–1).
user_feedback : int
1 if the prediction was correct, 0 if it was wrong.
reward : float
Computed reward scalar (e.g. from ``utils.reward``).
llm_reward : int | None
Optional reward from an LLM judge (+1 / -1 / None).
Returns
-------
int
Row id of the newly inserted feedback row.
"""
now = time.time()
cur = self._conn.execute(
"INSERT INTO feedback "
"(username, text, prediction, user_feedback, llm_reward, reward, created_at) "
"VALUES (?, ?, ?, ?, ?, ?, ?)",
(username, text, prediction, user_feedback, llm_reward, reward, now),
)
feedback_id = cur.lastrowid
# Derive corrected label for experience replay
predicted_class = int(round(prediction))
corrected_label = predicted_class if user_feedback == 1 else 1 - predicted_class
self._conn.execute(
"INSERT INTO experience (text, label, reward, source, created_at) "
"VALUES (?, ?, ?, 'feedback', ?)",
(text, corrected_label, abs(reward), now),
)
self._conn.commit()
return feedback_id # type: ignore[return-value]
# ------------------------------------------------------------------
# Queries
# ------------------------------------------------------------------
def get_all_feedback(
self,
limit: int = 100,
offset: int = 0,
) -> list[dict[str, Any]]:
"""Return feedback rows ordered newest-first."""
rows = self._conn.execute(
"SELECT id, username, text, prediction, user_feedback, "
"llm_reward, reward, created_at "
"FROM feedback ORDER BY created_at DESC LIMIT ? OFFSET ?",
(limit, offset),
).fetchall()
return [dict(r) for r in rows]
def get_user_stats(self, username: str) -> dict[str, Any]:
"""Return aggregated feedback statistics for one user."""
row = self._conn.execute(
"SELECT COUNT(*) as total, "
"AVG(reward) as mean_reward, "
"SUM(CASE WHEN user_feedback=1 THEN 1 ELSE 0 END) as n_correct, "
"SUM(CASE WHEN user_feedback=0 THEN 1 ELSE 0 END) as n_wrong "
"FROM feedback WHERE username = ?",
(username,),
).fetchone()
if row is None or row["total"] == 0:
return {
"total": 0,
"mean_reward": 0.0,
"n_correct": 0,
"n_wrong": 0,
"accuracy_rate": 0.0,
}
total = row["total"]
n_correct = row["n_correct"] or 0
return {
"total": total,
"mean_reward": float(row["mean_reward"] or 0.0),
"n_correct": n_correct,
"n_wrong": row["n_wrong"] or 0,
"accuracy_rate": n_correct / total if total > 0 else 0.0,
}
def get_experience_for_training(
self,
min_samples: int = 1,
limit: int = 10_000,
) -> list[dict[str, Any]]:
"""Return experience rows suitable for building a training dataset.
Parameters
----------
min_samples : int
Return an empty list when fewer than this many rows exist
(avoids retraining on negligible data).
limit : int
Maximum rows to return.
Returns
-------
list of dict with keys: ``text``, ``label``, ``reward``.
"""
count_row = self._conn.execute(
"SELECT COUNT(*) as cnt FROM experience"
).fetchone()
if (count_row["cnt"] or 0) < min_samples:
return []
rows = self._conn.execute(
"SELECT text, label, reward FROM experience "
"ORDER BY created_at DESC LIMIT ?",
(limit,),
).fetchall()
return [dict(r) for r in rows]
def get_feedback_count(self, username: Optional[str] = None) -> int:
"""Return the total number of feedback rows (optionally per user)."""
if username is not None:
row = self._conn.execute(
"SELECT COUNT(*) as cnt FROM feedback WHERE username = ?",
(username,),
).fetchone()
else:
row = self._conn.execute(
"SELECT COUNT(*) as cnt FROM feedback"
).fetchone()
return row["cnt"] if row else 0
# ------------------------------------------------------------------
# Lifecycle
# ------------------------------------------------------------------
def close(self) -> None:
"""Close the database connection."""
self._conn.close()