Spaces:
Sleeping
Sleeping
| # api/db.py | |
| """ | |
| Minimal PostgreSQL persistence for ClareVoice (pilot). | |
| Uses psycopg2 SimpleConnectionPool + raw SQL. | |
| Graceful degradation: if DATABASE_URL is unset or DB is unreachable, | |
| all public functions silently no-op and the app continues normally. | |
| Environment variable: | |
| DATABASE_URL=postgres://user:password@host:5432/dbname | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| import os | |
| import secrets | |
| import uuid | |
| from contextlib import contextmanager | |
| from typing import Any, Dict, List, Optional | |
| logger = logging.getLogger("clare.db") | |
| DATABASE_URL: str = os.getenv("DATABASE_URL", "").strip() | |
| # Module-level connection pool (None when DB is disabled) | |
| _pool = None | |
| def init_db() -> None: | |
| """ | |
| Called once on FastAPI startup. | |
| Creates the connection pool and runs CREATE TABLE IF NOT EXISTS. | |
| Safe to call when DATABASE_URL is absent — logs a warning and returns. | |
| """ | |
| global _pool | |
| if not DATABASE_URL: | |
| logger.warning("[db] DATABASE_URL not set — DB persistence disabled.") | |
| return | |
| try: | |
| from psycopg2 import pool as pg_pool | |
| _pool = pg_pool.SimpleConnectionPool(minconn=1, maxconn=10, dsn=DATABASE_URL) | |
| logger.info("[db] Connection pool created.") | |
| _create_tables() | |
| logger.info("[db] Tables ready.") | |
| except Exception as exc: | |
| logger.error("[db] init failed — DB persistence disabled: %s", exc) | |
| _pool = None | |
| def _get_conn(): | |
| """Yield a pooled connection; commit on success, rollback on error.""" | |
| if _pool is None: | |
| raise RuntimeError("DB pool not initialised") | |
| conn = _pool.getconn() | |
| try: | |
| yield conn | |
| conn.commit() | |
| except Exception: | |
| conn.rollback() | |
| raise | |
| finally: | |
| _pool.putconn(conn) | |
| # --------------------------------------------------------------------------- | |
| # DDL | |
| # --------------------------------------------------------------------------- | |
| _CREATE_LOGIN_SESSIONS_SQL = """ | |
| CREATE TABLE IF NOT EXISTS login_sessions ( | |
| session_id TEXT PRIMARY KEY, | |
| login_id TEXT NOT NULL, | |
| created_at TIMESTAMPTZ NOT NULL DEFAULT now(), | |
| updated_at TIMESTAMPTZ NOT NULL DEFAULT now() | |
| ); | |
| CREATE INDEX IF NOT EXISTS login_sessions_login_id_idx ON login_sessions (login_id); | |
| """ | |
| _CREATE_CHATS_SQL = """ | |
| CREATE TABLE IF NOT EXISTS chats ( | |
| chat_id TEXT PRIMARY KEY, | |
| login_id TEXT NOT NULL, | |
| name TEXT NOT NULL DEFAULT '', | |
| chat_mode TEXT NOT NULL DEFAULT 'ask', | |
| created_session_id TEXT REFERENCES login_sessions(session_id), | |
| last_updated_session_id TEXT REFERENCES login_sessions(session_id), | |
| created_at TIMESTAMPTZ NOT NULL DEFAULT now(), | |
| updated_at TIMESTAMPTZ NOT NULL DEFAULT now() | |
| ); | |
| CREATE INDEX IF NOT EXISTS chats_login_id_idx ON chats (login_id); | |
| """ | |
| _CREATE_INTERACTIONS_SQL = """ | |
| CREATE TABLE IF NOT EXISTS interactions ( | |
| id UUID PRIMARY KEY, | |
| session_id TEXT NOT NULL REFERENCES login_sessions(session_id), | |
| chat_id TEXT REFERENCES chats(chat_id) ON DELETE CASCADE, | |
| login_id TEXT, | |
| turn_index INTEGER NOT NULL, | |
| user_message TEXT NOT NULL, | |
| assistant_reply TEXT NOT NULL DEFAULT '', | |
| learning_mode TEXT NOT NULL DEFAULT '', | |
| total_tokens INTEGER NOT NULL DEFAULT 0, | |
| estimated_cost DOUBLE PRECISION NOT NULL DEFAULT 0, | |
| user_ts TIMESTAMPTZ, | |
| first_token_ts TIMESTAMPTZ, | |
| last_token_ts TIMESTAMPTZ, | |
| suggestions_ts TIMESTAMPTZ, | |
| doc_references TEXT[] NOT NULL DEFAULT '{}', | |
| suggested_questions TEXT[] NOT NULL DEFAULT '{}', | |
| error_flag BOOLEAN NOT NULL DEFAULT FALSE, | |
| timeout_flag BOOLEAN NOT NULL DEFAULT FALSE, | |
| run_id TEXT, | |
| thumbs_rating TEXT, | |
| free_text_feedback TEXT, | |
| created_at TIMESTAMPTZ NOT NULL DEFAULT now() | |
| ); | |
| CREATE INDEX IF NOT EXISTS interactions_session_id_idx ON interactions (session_id); | |
| CREATE INDEX IF NOT EXISTS interactions_chat_id_idx ON interactions (chat_id); | |
| CREATE INDEX IF NOT EXISTS interactions_login_id_idx ON interactions (login_id); | |
| """ | |
| _CREATE_SURVEY_RESPONSES_SQL = """ | |
| CREATE TABLE IF NOT EXISTS survey_responses ( | |
| id UUID PRIMARY KEY DEFAULT gen_random_uuid(), | |
| login_id TEXT, | |
| submitted_at TIMESTAMPTZ NOT NULL DEFAULT now(), | |
| q1 SMALLINT, q1_feedback TEXT, | |
| q2 SMALLINT, q2_feedback TEXT, | |
| q3 SMALLINT, q3_feedback TEXT, | |
| q4 SMALLINT, q4_feedback TEXT, | |
| q5 SMALLINT, q5_feedback TEXT, | |
| q6 SMALLINT, q6_feedback TEXT, | |
| q7 SMALLINT, q7_feedback TEXT, | |
| q8 SMALLINT, q8_feedback TEXT, | |
| q9 SMALLINT, q9_feedback TEXT, | |
| q10 SMALLINT, q10_feedback TEXT, | |
| q11 TEXT | |
| ); | |
| CREATE INDEX IF NOT EXISTS survey_responses_login_id_idx ON survey_responses (login_id); | |
| """ | |
| def _create_tables() -> None: | |
| with _get_conn() as conn: | |
| with conn.cursor() as cur: | |
| cur.execute(_CREATE_LOGIN_SESSIONS_SQL) | |
| cur.execute(_CREATE_CHATS_SQL) | |
| cur.execute(_CREATE_INTERACTIONS_SQL) | |
| cur.execute(_CREATE_SURVEY_RESPONSES_SQL) | |
| _migrate_schema() | |
| def _migrate_schema() -> None: | |
| """Idempotent ALTER TABLE migrations for existing installations.""" | |
| with _get_conn() as conn: | |
| with conn.cursor() as cur: | |
| # Migrate chat_id UUID → TEXT (only if still UUID type) | |
| cur.execute(""" | |
| DO $$ | |
| BEGIN | |
| IF EXISTS ( | |
| SELECT 1 FROM information_schema.columns | |
| WHERE table_name = 'chats' | |
| AND column_name = 'chat_id' | |
| AND data_type = 'uuid' | |
| ) THEN | |
| ALTER TABLE interactions DROP CONSTRAINT IF EXISTS interactions_chat_id_fkey; | |
| ALTER TABLE interactions ALTER COLUMN chat_id TYPE TEXT USING chat_id::TEXT; | |
| ALTER TABLE chats ALTER COLUMN chat_id TYPE TEXT USING chat_id::TEXT; | |
| END IF; | |
| END $$; | |
| """) | |
| # Re-add FK if missing | |
| cur.execute(""" | |
| DO $$ | |
| BEGIN | |
| IF NOT EXISTS ( | |
| SELECT 1 FROM pg_constraint WHERE conname = 'interactions_chat_id_fkey' | |
| ) THEN | |
| ALTER TABLE interactions ADD CONSTRAINT interactions_chat_id_fkey | |
| FOREIGN KEY (chat_id) REFERENCES chats(chat_id) ON DELETE CASCADE; | |
| END IF; | |
| END $$; | |
| """) | |
| # Drop removed columns if still present | |
| cur.execute("ALTER TABLE interactions DROP COLUMN IF EXISTS latency_ms;") | |
| cur.execute("ALTER TABLE interactions DROP COLUMN IF EXISTS retrieved_ids;") | |
| cur.execute("ALTER TABLE interactions DROP COLUMN IF EXISTS cited_ids;") | |
| # Add new columns if missing | |
| cur.execute("ALTER TABLE interactions ADD COLUMN IF NOT EXISTS user_ts TIMESTAMPTZ;") | |
| cur.execute("ALTER TABLE interactions ADD COLUMN IF NOT EXISTS first_token_ts TIMESTAMPTZ;") | |
| cur.execute("ALTER TABLE interactions ADD COLUMN IF NOT EXISTS last_token_ts TIMESTAMPTZ;") | |
| cur.execute("ALTER TABLE interactions ADD COLUMN IF NOT EXISTS suggestions_ts TIMESTAMPTZ;") | |
| cur.execute("ALTER TABLE interactions ADD COLUMN IF NOT EXISTS doc_references TEXT[] NOT NULL DEFAULT '{}';") | |
| cur.execute("ALTER TABLE interactions ADD COLUMN IF NOT EXISTS suggested_questions TEXT[] NOT NULL DEFAULT '{}';") | |
| # Remove DEFAULT gen_random_uuid() from chats.chat_id if set | |
| cur.execute(""" | |
| ALTER TABLE chats ALTER COLUMN chat_id DROP DEFAULT; | |
| """) | |
| # --------------------------------------------------------------------------- | |
| # Public helpers | |
| # --------------------------------------------------------------------------- | |
| def upsert_session( | |
| *, | |
| session_id: str, | |
| login_id: str, | |
| learning_mode: str = "", | |
| ) -> None: | |
| """Create or update a login_session row. Safe no-op if DB is unavailable.""" | |
| if _pool is None: | |
| return | |
| try: | |
| with _get_conn() as conn: | |
| with conn.cursor() as cur: | |
| cur.execute( | |
| """ | |
| INSERT INTO login_sessions | |
| (session_id, login_id) | |
| VALUES (%s, %s) | |
| ON CONFLICT (session_id) DO UPDATE SET | |
| updated_at = now(); | |
| """, | |
| (session_id, login_id), | |
| ) | |
| except Exception as exc: | |
| logger.error("[db] upsert_session failed: %s", exc) | |
| # --------------------------------------------------------------------------- | |
| # Chat CRUD stubs (schema ready; DB wiring deferred — localStorage is active store) | |
| # --------------------------------------------------------------------------- | |
| def create_chat( | |
| *, | |
| login_id: str, | |
| name: str, | |
| chat_mode: str, | |
| created_session_id: Optional[str] = None, | |
| ) -> Optional[str]: | |
| """Insert a new chat row. Returns chat_id (8-char hex), or None if DB unavailable.""" | |
| if _pool is None: | |
| return None | |
| chat_id = secrets.token_hex(4) | |
| try: | |
| with _get_conn() as conn: | |
| with conn.cursor() as cur: | |
| cur.execute( | |
| """ | |
| INSERT INTO chats (chat_id, login_id, name, chat_mode, created_session_id, last_updated_session_id) | |
| VALUES (%s, %s, %s, %s, %s, %s); | |
| """, | |
| (chat_id, login_id, name, chat_mode, created_session_id, created_session_id), | |
| ) | |
| return chat_id | |
| except Exception as exc: | |
| logger.error("[db] create_chat failed: %s", exc) | |
| return None | |
| def get_chats_for_user(login_id: str) -> List[Dict[str, Any]]: | |
| """Return all chat rows for a login_id, newest first.""" | |
| if _pool is None: | |
| return [] | |
| try: | |
| with _get_conn() as conn: | |
| with conn.cursor() as cur: | |
| cur.execute( | |
| """ | |
| SELECT chat_id, login_id, name, chat_mode, | |
| created_session_id, last_updated_session_id, | |
| created_at, updated_at | |
| FROM chats | |
| WHERE login_id = %s | |
| ORDER BY updated_at DESC; | |
| """, | |
| (login_id,), | |
| ) | |
| cols = [d[0] for d in cur.description] | |
| return [dict(zip(cols, row)) for row in cur.fetchall()] | |
| except Exception as exc: | |
| logger.error("[db] get_chats_for_user failed: %s", exc) | |
| return [] | |
| def rename_chat( | |
| *, | |
| chat_id: str, | |
| name: str, | |
| session_id: Optional[str] = None, | |
| ) -> None: | |
| """Rename a chat and update last_updated_session_id. Safe no-op if DB unavailable.""" | |
| if _pool is None: | |
| return | |
| try: | |
| with _get_conn() as conn: | |
| with conn.cursor() as cur: | |
| cur.execute( | |
| """ | |
| UPDATE chats | |
| SET name = %s, | |
| last_updated_session_id = COALESCE(%s, last_updated_session_id), | |
| updated_at = now() | |
| WHERE chat_id = %s; | |
| """, | |
| (name, session_id, chat_id), | |
| ) | |
| except Exception as exc: | |
| logger.error("[db] rename_chat failed: %s", exc) | |
| def delete_chat(*, chat_id: str) -> None: | |
| """Delete a chat row (interactions with this chat_id are set to NULL via FK). Safe no-op if DB unavailable.""" | |
| if _pool is None: | |
| return | |
| try: | |
| with _get_conn() as conn: | |
| with conn.cursor() as cur: | |
| cur.execute("DELETE FROM chats WHERE chat_id = %s;", (chat_id,)) | |
| except Exception as exc: | |
| logger.error("[db] delete_chat failed: %s", exc) | |
| def insert_interaction( | |
| *, | |
| session_id: str, | |
| chat_id: Optional[str] = None, | |
| login_id: str, | |
| user_message: str, | |
| assistant_reply: str = "", | |
| learning_mode: str = "", | |
| total_tokens: int = 0, | |
| estimated_cost: float = 0.0, | |
| user_ts=None, | |
| first_token_ts=None, | |
| last_token_ts=None, | |
| suggestions_ts=None, | |
| doc_references: Optional[List[str]] = None, | |
| suggested_questions: Optional[List[str]] = None, | |
| error_flag: bool = False, | |
| timeout_flag: bool = False, | |
| run_id: Optional[str] = None, | |
| ) -> Optional[str]: | |
| """ | |
| Insert one interaction row. Returns the new UUID string, or None if DB unavailable. | |
| """ | |
| if _pool is None: | |
| return None | |
| interaction_id = str(uuid.uuid4()) | |
| try: | |
| with _get_conn() as conn: | |
| with conn.cursor() as cur: | |
| cur.execute( | |
| """ | |
| INSERT INTO interactions ( | |
| id, session_id, chat_id, login_id, turn_index, | |
| user_message, assistant_reply, learning_mode, | |
| total_tokens, estimated_cost, | |
| user_ts, first_token_ts, last_token_ts, suggestions_ts, | |
| doc_references, suggested_questions, | |
| error_flag, timeout_flag, run_id | |
| ) VALUES ( | |
| %s, %s, %s, %s, | |
| COALESCE((SELECT MAX(turn_index) + 1 FROM interactions WHERE chat_id = %s), 1), | |
| %s, %s, %s, | |
| %s, %s, | |
| %s, %s, %s, %s, | |
| %s, %s, | |
| %s, %s, %s | |
| ); | |
| """, | |
| ( | |
| interaction_id, session_id, chat_id, login_id, | |
| chat_id, # for the subquery | |
| user_message, assistant_reply, learning_mode, | |
| total_tokens, estimated_cost, | |
| user_ts, first_token_ts, last_token_ts, suggestions_ts, | |
| doc_references or [], suggested_questions or [], | |
| error_flag, timeout_flag, run_id, | |
| ), | |
| ) | |
| return interaction_id | |
| except Exception as exc: | |
| logger.error("[db] insert_interaction failed: %s", exc) | |
| return None | |
| def get_messages_for_chat(chat_id: str) -> List[Dict[str, Any]]: | |
| """ | |
| Return all interactions for a chat ordered by turn_index ASC. | |
| Each row is expanded into [user_msg, assistant_msg] by the caller. | |
| """ | |
| if _pool is None: | |
| return [] | |
| try: | |
| with _get_conn() as conn: | |
| with conn.cursor() as cur: | |
| cur.execute( | |
| """ | |
| SELECT id, turn_index, user_message, assistant_reply, | |
| user_ts, first_token_ts, last_token_ts, suggestions_ts, | |
| doc_references, suggested_questions | |
| FROM interactions | |
| WHERE chat_id = %s | |
| ORDER BY turn_index ASC; | |
| """, | |
| (chat_id,), | |
| ) | |
| cols = [d[0] for d in cur.description] | |
| return [dict(zip(cols, row)) for row in cur.fetchall()] | |
| except Exception as exc: | |
| logger.error("[db] get_messages_for_chat failed: %s", exc) | |
| return [] | |
| def get_user_overview() -> List[Dict[str, Any]]: | |
| """Return one row per login_id with turn count, avg latency (from timestamps), avg survey rating.""" | |
| if _pool is None: | |
| return [] | |
| try: | |
| with _get_conn() as conn: | |
| with conn.cursor() as cur: | |
| cur.execute(""" | |
| SELECT | |
| i.login_id, | |
| COUNT(DISTINCT i.session_id) AS sessions, | |
| COUNT(i.id) AS turns, | |
| ROUND(AVG( | |
| EXTRACT(EPOCH FROM (i.last_token_ts - i.user_ts)) * 1000 | |
| )::numeric, 0) AS avg_latency_ms, | |
| (MAX(sr.submitted_at) IS NOT NULL) AS survey_completed, | |
| ROUND(AVG( | |
| CASE WHEN sr.q1 IS NOT NULL THEN | |
| (sr.q1 + sr.q2 + sr.q3 + sr.q4 + sr.q5 + | |
| sr.q6 + sr.q7 + sr.q8 + sr.q9 + sr.q10)::numeric / 10 | |
| END | |
| ), 2) AS avg_survey_rating, | |
| MIN(i.created_at) AS first_turn_at, | |
| MAX(i.created_at) AS last_turn_at, | |
| ROUND(EXTRACT(EPOCH FROM (MAX(i.created_at) - MIN(i.created_at))) / 60, 1) AS duration_minutes | |
| FROM interactions i | |
| LEFT JOIN survey_responses sr ON sr.login_id = i.login_id | |
| WHERE i.login_id IS NOT NULL | |
| GROUP BY i.login_id | |
| ORDER BY i.login_id; | |
| """) | |
| cols = [d[0] for d in cur.description] | |
| return [dict(zip(cols, row)) for row in cur.fetchall()] | |
| except Exception as exc: | |
| logger.error("[db] get_user_overview failed: %s", exc) | |
| return [] | |
| def get_interactions_for_user(login_id: str) -> List[Dict[str, Any]]: | |
| """Return all interaction rows for a login_id, ordered by time.""" | |
| if _pool is None: | |
| return [] | |
| try: | |
| with _get_conn() as conn: | |
| with conn.cursor() as cur: | |
| cur.execute(""" | |
| SELECT | |
| id, session_id, turn_index, | |
| user_message, assistant_reply, | |
| learning_mode, total_tokens, | |
| user_ts, first_token_ts, last_token_ts, | |
| thumbs_rating, free_text_feedback, created_at | |
| FROM interactions | |
| WHERE login_id = %s | |
| ORDER BY created_at ASC; | |
| """, (login_id,)) | |
| cols = [d[0] for d in cur.description] | |
| return [dict(zip(cols, row)) for row in cur.fetchall()] | |
| except Exception as exc: | |
| logger.error("[db] get_interactions_for_user failed: %s", exc) | |
| return [] | |
| def insert_survey_response(*, login_id: Optional[str] = None, responses: Dict[str, Any]) -> Optional[str]: | |
| """Insert one survey response row. Returns the new UUID string, or None if DB unavailable.""" | |
| if _pool is None: | |
| return None | |
| row_id = str(uuid.uuid4()) | |
| try: | |
| with _get_conn() as conn: | |
| with conn.cursor() as cur: | |
| cur.execute( | |
| """ | |
| INSERT INTO survey_responses ( | |
| id, login_id, | |
| q1, q1_feedback, q2, q2_feedback, q3, q3_feedback, | |
| q4, q4_feedback, q5, q5_feedback, q6, q6_feedback, | |
| q7, q7_feedback, q8, q8_feedback, q9, q9_feedback, | |
| q10, q10_feedback, q11 | |
| ) VALUES ( | |
| %s, %s, | |
| %s, %s, %s, %s, %s, %s, | |
| %s, %s, %s, %s, %s, %s, | |
| %s, %s, %s, %s, %s, %s, | |
| %s, %s, %s | |
| ) | |
| """, | |
| ( | |
| row_id, login_id, | |
| responses.get("q1"), responses.get("q1_feedback"), | |
| responses.get("q2"), responses.get("q2_feedback"), | |
| responses.get("q3"), responses.get("q3_feedback"), | |
| responses.get("q4"), responses.get("q4_feedback"), | |
| responses.get("q5"), responses.get("q5_feedback"), | |
| responses.get("q6"), responses.get("q6_feedback"), | |
| responses.get("q7"), responses.get("q7_feedback"), | |
| responses.get("q8"), responses.get("q8_feedback"), | |
| responses.get("q9"), responses.get("q9_feedback"), | |
| responses.get("q10"), responses.get("q10_feedback"), | |
| responses.get("q11"), | |
| ), | |
| ) | |
| return row_id | |
| except Exception as exc: | |
| logger.error("[db] insert_survey_response failed: %s", exc) | |
| return None | |
| def get_history_for_session(session_id: str, max_turns: int = 10) -> List[tuple]: | |
| """ | |
| Return the last `max_turns` (user_message, assistant_reply) pairs for a session, | |
| ordered by turn_index ASC. Returns [] if DB unavailable or no rows found. | |
| """ | |
| if _pool is None: | |
| return [] | |
| try: | |
| with _get_conn() as conn: | |
| with conn.cursor() as cur: | |
| cur.execute( | |
| """ | |
| SELECT user_message, assistant_reply | |
| FROM interactions | |
| WHERE session_id = %s | |
| ORDER BY turn_index ASC | |
| LIMIT %s | |
| """, | |
| (session_id, max_turns), | |
| ) | |
| return list(cur.fetchall()) # List of (user_message, assistant_reply) | |
| except Exception as exc: | |
| logger.error("[db] get_history_for_session failed: %s", exc) | |
| return [] | |
| def update_interaction_feedback( | |
| *, | |
| run_id: str, | |
| thumbs_rating: str, | |
| free_text_feedback: str = "", | |
| ) -> None: | |
| """ | |
| Update thumbs rating and free-text feedback on the matching interaction row. | |
| Matched by run_id. Silent no-op if DB unavailable or run_id not found. | |
| """ | |
| if _pool is None or not run_id: | |
| return | |
| try: | |
| with _get_conn() as conn: | |
| with conn.cursor() as cur: | |
| cur.execute( | |
| """ | |
| UPDATE interactions | |
| SET thumbs_rating = %s, | |
| free_text_feedback = %s | |
| WHERE run_id = %s; | |
| """, | |
| (thumbs_rating, free_text_feedback, run_id), | |
| ) | |
| except Exception as exc: | |
| logger.error("[db] update_interaction_feedback failed: %s", exc) | |