Spjimr / database.py
shahidshaikh's picture
Upload 40 files
a52bae4 verified
# ============================================================================
# database.py -- Supabase PostgreSQL + pgvector persistence layer
# ============================================================================
#
# PURPOSE
# -------
# Single module that owns ALL database interaction for the workbench.
# Every other module (vectorstore, phase2_agent, phase3_themes, etc.)
# imports from here. No other module should import psycopg2 directly.
#
# CONNECTION
# ----------
# Reads SUPABASE_DB_URL from environment (set as HF Space secret).
# Uses Session Pooler URL (IPv4 compatible with HuggingFace Spaces).
#
# TABLES
# ------
# corpus -- uploaded sentences + MiniLM embeddings (vector 384)
# codebook -- Phase 2 codebook (code_name, definition, ...)
# coded_sentences -- Phase 2 per-sentence codes
# themes -- Phase 3 candidate themes
# theme_reviews -- Phase 4 reviewer verdicts
#
# DESIGN
# ------
# + All tables have session_id (TEXT) so multiple researchers can share
# one Supabase project without data collision.
# + create_tables() is idempotent -- safe to call on every startup.
# + All functions return plain Python dicts/lists -- no psycopg2 objects
# leak out of this module.
# + Graceful degradation: if SUPABASE_DB_URL is not set, all functions
# return empty results and log a warning. The app keeps running.
# ============================================================================
import os
import json
import logging
from datetime import datetime
from typing import Optional
logger = logging.getLogger(__name__)
# ----------------------------------------------------------------
# Connection
# ----------------------------------------------------------------
_DB_URL = os.environ.get("SUPABASE_DB_URL", "")
_conn_cache = None
def _get_conn():
"""Return a live psycopg2 connection (cached, auto-reconnect)."""
global _conn_cache
if not _DB_URL:
raise RuntimeError(
"SUPABASE_DB_URL not set. Add it as a Space secret."
)
try:
import psycopg2
import psycopg2.extras
if _conn_cache is None or _conn_cache.closed:
_conn_cache = psycopg2.connect(_DB_URL, connect_timeout=30)
_conn_cache.autocommit = False
# Ping to check liveness
_conn_cache.cursor().execute("SELECT 1")
return _conn_cache
except Exception:
# Force reconnect on next call
_conn_cache = None
import psycopg2
import psycopg2.extras
_conn_cache = psycopg2.connect(_DB_URL, connect_timeout=30)
_conn_cache.autocommit = False
return _conn_cache
def is_available() -> bool:
"""True if database is reachable."""
if not _DB_URL:
return False
try:
conn = _get_conn()
conn.cursor().execute("SELECT 1")
return True
except Exception as e:
logger.warning(f"[database] not available: {e}")
return False
# ----------------------------------------------------------------
# Schema bootstrap -- call once on startup
# ----------------------------------------------------------------
CREATE_TABLES_SQL = """
CREATE EXTENSION IF NOT EXISTS vector;
CREATE TABLE IF NOT EXISTS corpus (
id SERIAL PRIMARY KEY,
session_id TEXT NOT NULL DEFAULT 'default',
L1 TEXT,
L2 TEXT,
L3 TEXT,
L4 TEXT,
sentence_id TEXT,
sentence TEXT NOT NULL,
label TEXT,
embedding vector(384),
created_at TIMESTAMPTZ DEFAULT NOW()
);
CREATE TABLE IF NOT EXISTS codebook (
id SERIAL PRIMARY KEY,
session_id TEXT NOT NULL DEFAULT 'default',
code_name TEXT NOT NULL,
definition TEXT,
provenance TEXT,
sentence_count INT DEFAULT 1,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW()
);
CREATE TABLE IF NOT EXISTS coded_sentences (
id SERIAL PRIMARY KEY,
session_id TEXT NOT NULL DEFAULT 'default',
sentence_idx INT,
sentence TEXT,
ai_code_iter1 TEXT,
ai_code_iter2 TEXT,
ai_code_iter3 TEXT,
human_code_iter1 TEXT,
human_code_iter2 TEXT,
human_code_iter3 TEXT,
final_code TEXT,
orientation TEXT,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW()
);
CREATE TABLE IF NOT EXISTS themes (
id SERIAL PRIMARY KEY,
session_id TEXT NOT NULL DEFAULT 'default',
theme_id INT,
candidate_theme_name TEXT,
description TEXT,
rationale TEXT,
member_codes TEXT,
code_count INT,
researcher_theme_name TEXT,
researcher_notes TEXT,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW()
);
CREATE TABLE IF NOT EXISTS theme_reviews (
id SERIAL PRIMARY KEY,
session_id TEXT NOT NULL DEFAULT 'default',
theme_id INT,
theme_name TEXT,
member_codes TEXT,
code_count INT,
member_sentence_count INT,
within_cohesion FLOAT,
llm_verdict TEXT,
llm_reasoning TEXT,
llm_action_suggestion TEXT,
researcher_verdict TEXT,
researcher_action_notes TEXT,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW()
);
CREATE TABLE IF NOT EXISTS chats (
id SERIAL PRIMARY KEY,
title TEXT,
user_message TEXT,
bot_message TEXT,
topics_json JSONB,
created_at TIMESTAMPTZ DEFAULT NOW()
);
CREATE TABLE IF NOT EXISTS papers (
id SERIAL PRIMARY KEY,
chat_id INT REFERENCES chats(id) ON DELETE CASCADE,
title TEXT,
abstract TEXT,
doi TEXT,
date_of_publication TEXT,
journal TEXT,
no_of_citations INT,
web_link TEXT,
authors TEXT,
keywords TEXT,
confidence_score FLOAT,
paper_type TEXT,
topic_label TEXT,
embedding vector(384),
created_at TIMESTAMPTZ DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_corpus_session ON corpus(session_id);
CREATE INDEX IF NOT EXISTS idx_codebook_session ON codebook(session_id);
CREATE INDEX IF NOT EXISTS idx_coded_session ON coded_sentences(session_id);
CREATE INDEX IF NOT EXISTS idx_themes_session ON themes(session_id);
CREATE INDEX IF NOT EXISTS idx_reviews_session ON theme_reviews(session_id);
CREATE INDEX IF NOT EXISTS idx_papers_chat ON papers(chat_id);
CREATE INDEX IF NOT EXISTS idx_papers_topic ON papers(topic_label);
"""
def create_tables() -> bool:
"""Create all tables if they don't exist. Safe to call on every startup."""
try:
conn = _get_conn()
cur = conn.cursor()
cur.execute(CREATE_TABLES_SQL)
conn.commit()
logger.info("[database] Tables ready.")
return True
except Exception as e:
logger.error(f"[database] create_tables error: {e}")
try:
_get_conn().rollback()
except Exception:
pass
return False
# ----------------------------------------------------------------
# Corpus
# ----------------------------------------------------------------
def save_corpus(rows: list[dict], session_id: str = "default") -> int:
"""
Save corpus sentences to database.
Clears existing corpus for this session first (fresh load).
Returns number of rows saved.
"""
if not rows:
return 0
try:
conn = _get_conn()
cur = conn.cursor()
cur.execute("DELETE FROM corpus WHERE session_id = %s", (session_id,))
import psycopg2.extras
psycopg2.extras.execute_batch(
cur,
"""INSERT INTO corpus (session_id, L1, L2, L3, L4, sentence_id, sentence, label)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)""",
[
(
session_id,
r.get("L1", ""),
r.get("L2", ""),
r.get("L3", ""),
r.get("L4", ""),
r.get("sentence_id", ""),
r.get("sentence", ""),
r.get("label", ""),
)
for r in rows
],
)
conn.commit()
return len(rows)
except Exception as e:
logger.error(f"[database] save_corpus error: {e}")
try:
_get_conn().rollback()
except Exception:
pass
return 0
def load_corpus(session_id: str = "default") -> list[dict]:
"""Load corpus for a session."""
try:
conn = _get_conn()
cur = conn.cursor()
cur.execute(
"SELECT L1, L2, L3, L4, sentence_id, sentence, label "
"FROM corpus WHERE session_id = %s ORDER BY id",
(session_id,),
)
cols = ["L1", "L2", "L3", "L4", "sentence_id", "sentence", "label"]
return [dict(zip(cols, row)) for row in cur.fetchall()]
except Exception as e:
logger.error(f"[database] load_corpus error: {e}")
return []
# ----------------------------------------------------------------
# Corpus embeddings (pgvector)
# ----------------------------------------------------------------
def save_embeddings(sentence_embeddings: list[tuple[str, list[float]]], session_id: str = "default") -> int:
"""
Save sentence embeddings to corpus table.
sentence_embeddings: list of (sentence_text, embedding_list)
"""
if not sentence_embeddings:
return 0
try:
conn = _get_conn()
cur = conn.cursor()
import psycopg2.extras
psycopg2.extras.execute_batch(
cur,
"UPDATE corpus SET embedding = %s::vector WHERE session_id = %s AND sentence = %s",
[(json.dumps(emb), session_id, sent) for sent, emb in sentence_embeddings],
)
conn.commit()
return len(sentence_embeddings)
except Exception as e:
logger.error(f"[database] save_embeddings error: {e}")
try:
_get_conn().rollback()
except Exception:
pass
return 0
def similarity_search(query_embedding: list[float], session_id: str = "default", top_k: int = 5) -> list[dict]:
"""
Find top_k most similar sentences using pgvector cosine similarity.
Returns list of dicts with sentence, label, similarity.
"""
try:
conn = _get_conn()
cur = conn.cursor()
cur.execute(
"""SELECT sentence, label,
1 - (embedding <=> %s::vector) AS similarity
FROM corpus
WHERE session_id = %s AND embedding IS NOT NULL
ORDER BY embedding <=> %s::vector
LIMIT %s""",
(json.dumps(query_embedding), session_id, json.dumps(query_embedding), top_k),
)
return [
{"sentence": row[0], "label": row[1], "similarity": float(row[2])}
for row in cur.fetchall()
]
except Exception as e:
logger.error(f"[database] similarity_search error: {e}")
return []
# ----------------------------------------------------------------
# Phase 2 -- Codebook
# ----------------------------------------------------------------
def save_codebook(codebook_rows: list[dict], session_id: str = "default") -> int:
"""Save full codebook (replaces existing for this session)."""
try:
conn = _get_conn()
cur = conn.cursor()
cur.execute("DELETE FROM codebook WHERE session_id = %s", (session_id,))
import psycopg2.extras
psycopg2.extras.execute_batch(
cur,
"""INSERT INTO codebook (session_id, code_name, definition, provenance, sentence_count)
VALUES (%s, %s, %s, %s, %s)""",
[
(
session_id,
r.get("code_name", ""),
r.get("definition", ""),
r.get("provenance", ""),
int(r.get("sentence_count", 1)),
)
for r in codebook_rows
],
)
conn.commit()
return len(codebook_rows)
except Exception as e:
logger.error(f"[database] save_codebook error: {e}")
try:
_get_conn().rollback()
except Exception:
pass
return 0
def load_codebook(session_id: str = "default") -> list[dict]:
"""Load codebook for a session."""
try:
conn = _get_conn()
cur = conn.cursor()
cur.execute(
"SELECT code_name, definition, provenance, sentence_count "
"FROM codebook WHERE session_id = %s ORDER BY id",
(session_id,),
)
cols = ["code_name", "definition", "provenance", "sentence_count"]
return [dict(zip(cols, row)) for row in cur.fetchall()]
except Exception as e:
logger.error(f"[database] load_codebook error: {e}")
return []
# ----------------------------------------------------------------
# Phase 2 -- Coded sentences
# ----------------------------------------------------------------
def save_coded_sentences(coded_rows: list[dict], session_id: str = "default") -> int:
"""Save Phase 2 coded sentences (replaces existing for this session)."""
try:
conn = _get_conn()
cur = conn.cursor()
cur.execute("DELETE FROM coded_sentences WHERE session_id = %s", (session_id,))
import psycopg2.extras
psycopg2.extras.execute_batch(
cur,
"""INSERT INTO coded_sentences
(session_id, sentence_idx, sentence,
ai_code_iter1, ai_code_iter2, ai_code_iter3,
human_code_iter1, human_code_iter2, human_code_iter3,
final_code, orientation)
VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)""",
[
(
session_id,
i,
r.get("sentence", ""),
r.get("ai_code_iter1", ""),
r.get("ai_code_iter2", ""),
r.get("ai_code_iter3", ""),
r.get("human_code_iter1", ""),
r.get("human_code_iter2", ""),
r.get("human_code_iter3", ""),
r.get("final_code", ""),
r.get("orientation", "semantic"),
)
for i, r in enumerate(coded_rows)
],
)
conn.commit()
return len(coded_rows)
except Exception as e:
logger.error(f"[database] save_coded_sentences error: {e}")
try:
_get_conn().rollback()
except Exception:
pass
return 0
def load_coded_sentences(session_id: str = "default") -> list[dict]:
"""Load Phase 2 coded sentences for a session."""
try:
conn = _get_conn()
cur = conn.cursor()
cur.execute(
"""SELECT sentence_idx, sentence,
ai_code_iter1, ai_code_iter2, ai_code_iter3,
human_code_iter1, human_code_iter2, human_code_iter3,
final_code, orientation
FROM coded_sentences WHERE session_id = %s ORDER BY sentence_idx""",
(session_id,),
)
cols = [
"sentence_idx", "sentence",
"ai_code_iter1", "ai_code_iter2", "ai_code_iter3",
"human_code_iter1", "human_code_iter2", "human_code_iter3",
"final_code", "orientation",
]
return [dict(zip(cols, row)) for row in cur.fetchall()]
except Exception as e:
logger.error(f"[database] load_coded_sentences error: {e}")
return []
# ----------------------------------------------------------------
# Phase 3 -- Themes
# ----------------------------------------------------------------
def save_themes(themes_rows: list[dict], session_id: str = "default") -> int:
"""Save Phase 3 themes (replaces existing for this session)."""
try:
conn = _get_conn()
cur = conn.cursor()
cur.execute("DELETE FROM themes WHERE session_id = %s", (session_id,))
import psycopg2.extras
psycopg2.extras.execute_batch(
cur,
"""INSERT INTO themes
(session_id, theme_id, candidate_theme_name, description,
rationale, member_codes, code_count,
researcher_theme_name, researcher_notes)
VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s)""",
[
(
session_id,
int(r.get("theme_id", 0)),
r.get("candidate_theme_name", ""),
r.get("description", ""),
r.get("rationale", ""),
r.get("member_codes", ""),
int(r.get("code_count", 0)),
r.get("researcher_theme_name", ""),
r.get("researcher_notes", ""),
)
for r in themes_rows
],
)
conn.commit()
return len(themes_rows)
except Exception as e:
logger.error(f"[database] save_themes error: {e}")
try:
_get_conn().rollback()
except Exception:
pass
return 0
def load_themes(session_id: str = "default") -> list[dict]:
"""Load Phase 3 themes for a session."""
try:
conn = _get_conn()
cur = conn.cursor()
cur.execute(
"""SELECT theme_id, candidate_theme_name, description, rationale,
member_codes, code_count, researcher_theme_name, researcher_notes
FROM themes WHERE session_id = %s ORDER BY theme_id""",
(session_id,),
)
cols = [
"theme_id", "candidate_theme_name", "description", "rationale",
"member_codes", "code_count", "researcher_theme_name", "researcher_notes",
]
return [dict(zip(cols, row)) for row in cur.fetchall()]
except Exception as e:
logger.error(f"[database] load_themes error: {e}")
return []
# ----------------------------------------------------------------
# Phase 4 -- Theme reviews
# ----------------------------------------------------------------
def save_theme_reviews(review_rows: list[dict], session_id: str = "default") -> int:
"""Save Phase 4 theme reviews (replaces existing for this session)."""
try:
conn = _get_conn()
cur = conn.cursor()
cur.execute("DELETE FROM theme_reviews WHERE session_id = %s", (session_id,))
import psycopg2.extras
psycopg2.extras.execute_batch(
cur,
"""INSERT INTO theme_reviews
(session_id, theme_id, theme_name, member_codes, code_count,
member_sentence_count, within_cohesion,
llm_verdict, llm_reasoning, llm_action_suggestion,
researcher_verdict, researcher_action_notes)
VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)""",
[
(
session_id,
int(r.get("theme_id", 0)),
r.get("theme_name", ""),
r.get("member_codes", ""),
int(r.get("code_count", 0)),
int(r.get("member_sentence_count", 0)),
float(r.get("within_cohesion", 0.0)),
r.get("llm_verdict", ""),
r.get("llm_reasoning", ""),
r.get("llm_action_suggestion", ""),
r.get("researcher_verdict", ""),
r.get("researcher_action_notes", ""),
)
for r in review_rows
],
)
conn.commit()
return len(review_rows)
except Exception as e:
logger.error(f"[database] save_theme_reviews error: {e}")
try:
_get_conn().rollback()
except Exception:
pass
return 0
def load_theme_reviews(session_id: str = "default") -> list[dict]:
"""Load Phase 4 theme reviews for a session."""
try:
conn = _get_conn()
cur = conn.cursor()
cur.execute(
"""SELECT theme_id, theme_name, member_codes, code_count,
member_sentence_count, within_cohesion,
llm_verdict, llm_reasoning, llm_action_suggestion,
researcher_verdict, researcher_action_notes
FROM theme_reviews WHERE session_id = %s ORDER BY theme_id""",
(session_id,),
)
cols = [
"theme_id", "theme_name", "member_codes", "code_count",
"member_sentence_count", "within_cohesion",
"llm_verdict", "llm_reasoning", "llm_action_suggestion",
"researcher_verdict", "researcher_action_notes",
]
return [dict(zip(cols, row)) for row in cur.fetchall()]
except Exception as e:
logger.error(f"[database] load_theme_reviews error: {e}")
return []
# ----------------------------------------------------------------
# Startup check
# ----------------------------------------------------------------
def startup_check() -> dict:
"""Run on app startup. Returns status dict for display in UI."""
status = {"db_available": False, "tables_created": False, "error": None}
try:
status["db_available"] = is_available()
if status["db_available"]:
status["tables_created"] = create_tables()
except Exception as e:
status["error"] = str(e)
return status