Spaces:
Sleeping
Sleeping
| """ | |
| database/feedback.py | |
| ==================== | |
| Persistence layer for RL-style feedback and experience replay. | |
| Tables | |
| ------ | |
| - ``feedback`` — one row per user-submitted feedback event (text, predicted | |
| score, user rating, optional LLM reward, computed reward scalar). | |
| - ``experience`` — the same data shaped as an experience-replay buffer that | |
| ``training/retrain.py`` can query to build a fine-tuning dataset. | |
| Both tables are stored in the same SQLite file as the main user/session DB | |
| (``stress_detection.db``) so that a single file contains the whole | |
| application's state. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| import os | |
| import sqlite3 | |
| import time | |
| from typing import Any, Optional | |
| logger = logging.getLogger(__name__) | |
| _DEFAULT_DB_PATH = os.environ.get("STRESS_DB_PATH", "stress_detection.db") | |
| class FeedbackStore: | |
| """Thin SQLite wrapper for feedback storage and experience replay. | |
| Parameters | |
| ---------- | |
| db_path : str | |
| File path for the SQLite database. Pass ``":memory:"`` for | |
| ephemeral in-memory storage (useful in tests). | |
| """ | |
| def __init__(self, db_path: str = _DEFAULT_DB_PATH) -> None: | |
| self._db_path = db_path | |
| self._conn = sqlite3.connect(db_path, check_same_thread=False) | |
| self._conn.row_factory = sqlite3.Row | |
| self._conn.execute("PRAGMA journal_mode=WAL") | |
| self._create_tables() | |
| # ------------------------------------------------------------------ | |
| # Schema | |
| # ------------------------------------------------------------------ | |
| def _create_tables(self) -> None: | |
| """Create feedback tables if they do not already exist.""" | |
| self._conn.executescript( | |
| """ | |
| CREATE TABLE IF NOT EXISTS feedback ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| username TEXT NOT NULL, | |
| text TEXT NOT NULL, | |
| prediction REAL NOT NULL, | |
| user_feedback INTEGER NOT NULL, -- 1 = correct, 0 = wrong | |
| llm_reward INTEGER, -- +1 / -1 / NULL | |
| reward REAL NOT NULL, -- final combined reward | |
| created_at REAL NOT NULL | |
| ); | |
| CREATE INDEX IF NOT EXISTS idx_feedback_username | |
| ON feedback(username); | |
| CREATE INDEX IF NOT EXISTS idx_feedback_created_at | |
| ON feedback(created_at); | |
| CREATE TABLE IF NOT EXISTS experience ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| text TEXT NOT NULL, | |
| label INTEGER NOT NULL, -- corrected label | |
| reward REAL NOT NULL, -- sample weight for training | |
| source TEXT NOT NULL DEFAULT 'feedback', | |
| created_at REAL NOT NULL | |
| ); | |
| CREATE INDEX IF NOT EXISTS idx_experience_created_at | |
| ON experience(created_at); | |
| """ | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Feedback CRUD | |
| # ------------------------------------------------------------------ | |
| def save_feedback( | |
| self, | |
| username: str, | |
| text: str, | |
| prediction: float, | |
| user_feedback: int, | |
| reward: float, | |
| llm_reward: Optional[int] = None, | |
| ) -> int: | |
| """Persist one feedback event and derive a corrected training sample. | |
| The corrected label is: | |
| - ``round(prediction)`` when ``user_feedback == 1`` (prediction was right). | |
| - ``1 - round(prediction)`` when ``user_feedback == 0`` (prediction was wrong). | |
| The corrected sample is also inserted into ``experience`` so that | |
| ``training/retrain.py`` can build a dataset without joining tables. | |
| Parameters | |
| ---------- | |
| username : str | |
| User who submitted the feedback. | |
| text : str | |
| Original input text that was analysed. | |
| prediction : float | |
| Raw stress probability returned by the model (0–1). | |
| user_feedback : int | |
| 1 if the prediction was correct, 0 if it was wrong. | |
| reward : float | |
| Computed reward scalar (e.g. from ``utils.reward``). | |
| llm_reward : int | None | |
| Optional reward from an LLM judge (+1 / -1 / None). | |
| Returns | |
| ------- | |
| int | |
| Row id of the newly inserted feedback row. | |
| """ | |
| now = time.time() | |
| cur = self._conn.execute( | |
| "INSERT INTO feedback " | |
| "(username, text, prediction, user_feedback, llm_reward, reward, created_at) " | |
| "VALUES (?, ?, ?, ?, ?, ?, ?)", | |
| (username, text, prediction, user_feedback, llm_reward, reward, now), | |
| ) | |
| feedback_id = cur.lastrowid | |
| # Derive corrected label for experience replay | |
| predicted_class = int(round(prediction)) | |
| corrected_label = predicted_class if user_feedback == 1 else 1 - predicted_class | |
| self._conn.execute( | |
| "INSERT INTO experience (text, label, reward, source, created_at) " | |
| "VALUES (?, ?, ?, 'feedback', ?)", | |
| (text, corrected_label, abs(reward), now), | |
| ) | |
| self._conn.commit() | |
| return feedback_id # type: ignore[return-value] | |
| # ------------------------------------------------------------------ | |
| # Queries | |
| # ------------------------------------------------------------------ | |
| def get_all_feedback( | |
| self, | |
| limit: int = 100, | |
| offset: int = 0, | |
| ) -> list[dict[str, Any]]: | |
| """Return feedback rows ordered newest-first.""" | |
| rows = self._conn.execute( | |
| "SELECT id, username, text, prediction, user_feedback, " | |
| "llm_reward, reward, created_at " | |
| "FROM feedback ORDER BY created_at DESC LIMIT ? OFFSET ?", | |
| (limit, offset), | |
| ).fetchall() | |
| return [dict(r) for r in rows] | |
| def get_user_stats(self, username: str) -> dict[str, Any]: | |
| """Return aggregated feedback statistics for one user.""" | |
| row = self._conn.execute( | |
| "SELECT COUNT(*) as total, " | |
| "AVG(reward) as mean_reward, " | |
| "SUM(CASE WHEN user_feedback=1 THEN 1 ELSE 0 END) as n_correct, " | |
| "SUM(CASE WHEN user_feedback=0 THEN 1 ELSE 0 END) as n_wrong " | |
| "FROM feedback WHERE username = ?", | |
| (username,), | |
| ).fetchone() | |
| if row is None or row["total"] == 0: | |
| return { | |
| "total": 0, | |
| "mean_reward": 0.0, | |
| "n_correct": 0, | |
| "n_wrong": 0, | |
| "accuracy_rate": 0.0, | |
| } | |
| total = row["total"] | |
| n_correct = row["n_correct"] or 0 | |
| return { | |
| "total": total, | |
| "mean_reward": float(row["mean_reward"] or 0.0), | |
| "n_correct": n_correct, | |
| "n_wrong": row["n_wrong"] or 0, | |
| "accuracy_rate": n_correct / total if total > 0 else 0.0, | |
| } | |
| def get_experience_for_training( | |
| self, | |
| min_samples: int = 1, | |
| limit: int = 10_000, | |
| ) -> list[dict[str, Any]]: | |
| """Return experience rows suitable for building a training dataset. | |
| Parameters | |
| ---------- | |
| min_samples : int | |
| Return an empty list when fewer than this many rows exist | |
| (avoids retraining on negligible data). | |
| limit : int | |
| Maximum rows to return. | |
| Returns | |
| ------- | |
| list of dict with keys: ``text``, ``label``, ``reward``. | |
| """ | |
| count_row = self._conn.execute( | |
| "SELECT COUNT(*) as cnt FROM experience" | |
| ).fetchone() | |
| if (count_row["cnt"] or 0) < min_samples: | |
| return [] | |
| rows = self._conn.execute( | |
| "SELECT text, label, reward FROM experience " | |
| "ORDER BY created_at DESC LIMIT ?", | |
| (limit,), | |
| ).fetchall() | |
| return [dict(r) for r in rows] | |
| def get_feedback_count(self, username: Optional[str] = None) -> int: | |
| """Return the total number of feedback rows (optionally per user).""" | |
| if username is not None: | |
| row = self._conn.execute( | |
| "SELECT COUNT(*) as cnt FROM feedback WHERE username = ?", | |
| (username,), | |
| ).fetchone() | |
| else: | |
| row = self._conn.execute( | |
| "SELECT COUNT(*) as cnt FROM feedback" | |
| ).fetchone() | |
| return row["cnt"] if row else 0 | |
| # ------------------------------------------------------------------ | |
| # Lifecycle | |
| # ------------------------------------------------------------------ | |
| def close(self) -> None: | |
| """Close the database connection.""" | |
| self._conn.close() | |