RAG-PSYCH / api /rag.py
arjun10g's picture
Initial deploy to Hugging Face Spaces
08fc97e
"""
Retrieval primitives: vector (cosine via pgvector) and BM25 (Postgres ts_rank).
These are the two retrievers that get fused in api/hybrid.py. Each returns
a ranked list of `Hit` records pulled from the `chunks_with_source` view,
oldest-rank-first (rank 0 = best). Both share the same row shape so the
fusion layer doesn't need to special-case either one.
The embedding model is cached as a module-level singleton so repeated calls
in the same process reuse the loaded weights. Cold-load is ~3-5 s on CPU.
"""
from __future__ import annotations
import os
import re
from dataclasses import dataclass
from typing import Sequence
import psycopg
from pgvector.psycopg import register_vector
from sentence_transformers import SentenceTransformer
DEFAULT_EMBEDDING_MODEL = "pritamdeka/S-PubMedBert-MS-MARCO"
_TOKEN_RE = re.compile(r"[A-Za-z0-9]+")
# Tokens that are "long" by length but too generic to count as rare clinical
# entities — they appear in nearly every clinical/research chunk and would
# drown the lexical retriever in noise.
_GENERIC_LONG_TOKENS = frozenset({
"patient", "patients", "clinical", "disorder", "disorders", "depression",
"depressive", "anxiety", "criteria", "diagnosis", "treatment", "symptoms",
"research", "adolescents", "adolescent", "generalized", "augmentation",
"disease", "therapy", "results", "study", "studies", "moderate", "severe",
"history", "currently", "recommend", "recommended", "negative", "positive",
"psychiatric", "psychological", "medication", "medications",
})
_embedding_model: SentenceTransformer | None = None
@dataclass(frozen=True)
class Hit:
"""One retrieval hit — fields cover both retriever paths and rerank later."""
chunk_id: int
document_id: int
source_type: str
source_uri: str | None
section: str | None
title: str | None
chunk_text: str
score: float # cosine similarity (vector) or ts_rank (bm25)
def get_embedding_model() -> SentenceTransformer:
global _embedding_model
if _embedding_model is None:
name = os.environ.get("EMBEDDING_MODEL", DEFAULT_EMBEDDING_MODEL)
_embedding_model = SentenceTransformer(name)
return _embedding_model
def retrieve_vector(
conn: psycopg.Connection,
query: str,
k: int = 50,
source_types: Sequence[str] | None = None,
) -> list[Hit]:
"""Top-k by cosine similarity against the chunk embeddings.
Uses the `<=>` cosine-distance operator backed by the HNSW index.
Score returned is `1 - distance` so higher = better, matching the
intuitive direction expected by the fusion layer.
"""
register_vector(conn)
embedding = get_embedding_model().encode(query, normalize_embeddings=True)
sql, params = _build_vector_sql(embedding, k, source_types)
with conn.cursor() as cur:
cur.execute(sql, params)
return [_row_to_hit(row) for row in cur.fetchall()]
def retrieve_bm25(
conn: psycopg.Connection,
query: str,
k: int = 50,
source_types: Sequence[str] | None = None,
) -> list[Hit]:
"""Top-k by Postgres `ts_rank` over the auto-populated `tsv` column.
Tokens are extracted with a strict alphanumeric regex and joined with
OR semantics — `plainto_tsquery`'s implicit AND is too brittle for
natural-language clinical queries (e.g. "sertraline 50mg for MDD"
requires every literal token in one chunk, which usually fails).
OR keeps any token-overlap candidates flowing into RRF, which then
ranks them. The regex also keeps user input safely outside the
`to_tsquery` parser, which is strict about punctuation.
"""
ts_query = _to_or_tsquery(query)
if not ts_query:
return []
sql, params = _build_bm25_sql(ts_query, k, source_types)
with conn.cursor() as cur:
cur.execute(sql, params)
return [_row_to_hit(row) for row in cur.fetchall()]
def _to_or_tsquery(query: str) -> str:
tokens = {t.lower() for t in _TOKEN_RE.findall(query) if len(t) > 1}
return " | ".join(sorted(tokens))
def retrieve_lexical(
conn: psycopg.Connection,
query: str,
k: int = 50,
source_types: Sequence[str] | None = None,
) -> list[Hit]:
"""Top-k by literal-substring matching on rare query tokens.
Third RRF input alongside vector + BM25. Targets the failure mode where
a chunk literally contains a rare clinical entity (drug name, ICD code,
acronym) but the surrounding context buries it for both dense and
`ts_rank` retrievers.
Score = sum of matched-token lengths — gives longer/more-specific
tokens proportionally more weight than short noisy ones like "50mg".
Returns [] when the query has no tokens passing the rarity heuristic
(the other two retrievers handle that case fine).
"""
rare = rare_query_tokens(query)
if not rare:
return []
patterns = [f"%{t}%" for t in rare]
score_expr = " + ".join(
f"(CASE WHEN chunk_text ILIKE %s THEN {len(t)} ELSE 0 END)" for t in rare
)
where_any = " OR ".join("chunk_text ILIKE %s" for _ in rare)
src_clause, src_params = "", ()
if source_types:
src_clause = " AND source_type = ANY(%s)"
src_params = (list(source_types),)
sql = (
"SELECT chunk_id, document_id, source_type, source_uri, section, "
" title, chunk_text, "
f" ({score_expr})::float AS score "
"FROM chunks_with_source "
f"WHERE ({where_any}){src_clause} "
"ORDER BY score DESC, chunk_id ASC "
"LIMIT %s"
)
with conn.cursor() as cur:
cur.execute(sql, (*patterns, *patterns, *src_params, k))
return [_row_to_hit(row) for row in cur.fetchall()]
def rare_query_tokens(query: str) -> list[str]:
"""Extract tokens worth literal-matching: long alphabetic, acronyms, codes.
Three rules combined:
- alphabetic and len > 7, not in the generic-medical stoplist
(catches drug names like sertraline, paroxetine, fluoxetine)
- all-uppercase and len >= 3 (catches acronyms: OCD, SSRI, MDD, TRD)
- mixed letter+digit and len >= 3 (catches ICD codes like F41 / 6A20)
"""
rare: list[str] = []
seen: set[str] = set()
for raw in _TOKEN_RE.findall(query):
low = raw.lower()
if low in seen:
continue
has_digit = any(c.isdigit() for c in raw)
has_alpha = any(c.isalpha() for c in raw)
is_upper = raw.isupper() and len(raw) >= 3 and not has_digit
is_long = len(raw) > 7 and not has_digit and low not in _GENERIC_LONG_TOKENS
is_codeish = has_digit and has_alpha and len(raw) >= 3
if is_long or is_upper or is_codeish:
rare.append(low)
seen.add(low)
return rare
def _build_vector_sql(
embedding, k: int, source_types: Sequence[str] | None
) -> tuple[str, tuple]:
where, params_pre = _source_filter(source_types)
sql = (
"SELECT chunk_id, document_id, source_type, source_uri, section, "
" title, chunk_text, 1 - (embedding <=> %s) AS score "
"FROM chunks_with_source"
f"{where} "
"ORDER BY embedding <=> %s "
"LIMIT %s"
)
# Placeholder order: SELECT embedding, optional WHERE source_type array,
# ORDER BY embedding, LIMIT.
return sql, (embedding, *params_pre, embedding, k)
def _build_bm25_sql(
ts_query: str, k: int, source_types: Sequence[str] | None
) -> tuple[str, tuple]:
where, params_pre = _source_filter(source_types, leading_where=False)
base_where = "tsv @@ to_tsquery('english', %s)"
full_where = f"WHERE {base_where}" + (f" AND {where}" if where else "")
sql = (
"SELECT chunk_id, document_id, source_type, source_uri, section, "
" title, chunk_text, ts_rank(tsv, to_tsquery('english', %s)) AS score "
"FROM chunks_with_source "
f"{full_where} "
"ORDER BY ts_rank(tsv, to_tsquery('english', %s)) DESC "
"LIMIT %s"
)
return sql, (ts_query, ts_query, *params_pre, ts_query, k)
def _source_filter(
source_types: Sequence[str] | None, *, leading_where: bool = True
) -> tuple[str, tuple]:
if not source_types:
return ("", ())
clause = "source_type = ANY(%s)"
return (f" WHERE {clause}" if leading_where else clause, (list(source_types),))
def _row_to_hit(row) -> Hit:
return Hit(
chunk_id=row[0],
document_id=row[1],
source_type=row[2],
source_uri=row[3],
section=row[4],
title=row[5],
chunk_text=row[6],
score=float(row[7]),
)