| """ |
| Retrieval primitives: vector (cosine via pgvector) and BM25 (Postgres ts_rank). |
| |
| These are the two retrievers that get fused in api/hybrid.py. Each returns |
| a ranked list of `Hit` records pulled from the `chunks_with_source` view, |
| oldest-rank-first (rank 0 = best). Both share the same row shape so the |
| fusion layer doesn't need to special-case either one. |
| |
| The embedding model is cached as a module-level singleton so repeated calls |
| in the same process reuse the loaded weights. Cold-load is ~3-5 s on CPU. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import os |
| import re |
| from dataclasses import dataclass |
| from typing import Sequence |
|
|
| import psycopg |
| from pgvector.psycopg import register_vector |
| from sentence_transformers import SentenceTransformer |
|
|
| DEFAULT_EMBEDDING_MODEL = "pritamdeka/S-PubMedBert-MS-MARCO" |
| _TOKEN_RE = re.compile(r"[A-Za-z0-9]+") |
|
|
| |
| |
| |
| _GENERIC_LONG_TOKENS = frozenset({ |
| "patient", "patients", "clinical", "disorder", "disorders", "depression", |
| "depressive", "anxiety", "criteria", "diagnosis", "treatment", "symptoms", |
| "research", "adolescents", "adolescent", "generalized", "augmentation", |
| "disease", "therapy", "results", "study", "studies", "moderate", "severe", |
| "history", "currently", "recommend", "recommended", "negative", "positive", |
| "psychiatric", "psychological", "medication", "medications", |
| }) |
|
|
| _embedding_model: SentenceTransformer | None = None |
|
|
|
|
| @dataclass(frozen=True) |
| class Hit: |
| """One retrieval hit — fields cover both retriever paths and rerank later.""" |
| chunk_id: int |
| document_id: int |
| source_type: str |
| source_uri: str | None |
| section: str | None |
| title: str | None |
| chunk_text: str |
| score: float |
|
|
|
|
| def get_embedding_model() -> SentenceTransformer: |
| global _embedding_model |
| if _embedding_model is None: |
| name = os.environ.get("EMBEDDING_MODEL", DEFAULT_EMBEDDING_MODEL) |
| _embedding_model = SentenceTransformer(name) |
| return _embedding_model |
|
|
|
|
| def retrieve_vector( |
| conn: psycopg.Connection, |
| query: str, |
| k: int = 50, |
| source_types: Sequence[str] | None = None, |
| ) -> list[Hit]: |
| """Top-k by cosine similarity against the chunk embeddings. |
| |
| Uses the `<=>` cosine-distance operator backed by the HNSW index. |
| Score returned is `1 - distance` so higher = better, matching the |
| intuitive direction expected by the fusion layer. |
| """ |
| register_vector(conn) |
| embedding = get_embedding_model().encode(query, normalize_embeddings=True) |
| sql, params = _build_vector_sql(embedding, k, source_types) |
| with conn.cursor() as cur: |
| cur.execute(sql, params) |
| return [_row_to_hit(row) for row in cur.fetchall()] |
|
|
|
|
| def retrieve_bm25( |
| conn: psycopg.Connection, |
| query: str, |
| k: int = 50, |
| source_types: Sequence[str] | None = None, |
| ) -> list[Hit]: |
| """Top-k by Postgres `ts_rank` over the auto-populated `tsv` column. |
| |
| Tokens are extracted with a strict alphanumeric regex and joined with |
| OR semantics — `plainto_tsquery`'s implicit AND is too brittle for |
| natural-language clinical queries (e.g. "sertraline 50mg for MDD" |
| requires every literal token in one chunk, which usually fails). |
| OR keeps any token-overlap candidates flowing into RRF, which then |
| ranks them. The regex also keeps user input safely outside the |
| `to_tsquery` parser, which is strict about punctuation. |
| """ |
| ts_query = _to_or_tsquery(query) |
| if not ts_query: |
| return [] |
| sql, params = _build_bm25_sql(ts_query, k, source_types) |
| with conn.cursor() as cur: |
| cur.execute(sql, params) |
| return [_row_to_hit(row) for row in cur.fetchall()] |
|
|
|
|
| def _to_or_tsquery(query: str) -> str: |
| tokens = {t.lower() for t in _TOKEN_RE.findall(query) if len(t) > 1} |
| return " | ".join(sorted(tokens)) |
|
|
|
|
| def retrieve_lexical( |
| conn: psycopg.Connection, |
| query: str, |
| k: int = 50, |
| source_types: Sequence[str] | None = None, |
| ) -> list[Hit]: |
| """Top-k by literal-substring matching on rare query tokens. |
| |
| Third RRF input alongside vector + BM25. Targets the failure mode where |
| a chunk literally contains a rare clinical entity (drug name, ICD code, |
| acronym) but the surrounding context buries it for both dense and |
| `ts_rank` retrievers. |
| |
| Score = sum of matched-token lengths — gives longer/more-specific |
| tokens proportionally more weight than short noisy ones like "50mg". |
| Returns [] when the query has no tokens passing the rarity heuristic |
| (the other two retrievers handle that case fine). |
| """ |
| rare = rare_query_tokens(query) |
| if not rare: |
| return [] |
| patterns = [f"%{t}%" for t in rare] |
| score_expr = " + ".join( |
| f"(CASE WHEN chunk_text ILIKE %s THEN {len(t)} ELSE 0 END)" for t in rare |
| ) |
| where_any = " OR ".join("chunk_text ILIKE %s" for _ in rare) |
| src_clause, src_params = "", () |
| if source_types: |
| src_clause = " AND source_type = ANY(%s)" |
| src_params = (list(source_types),) |
| sql = ( |
| "SELECT chunk_id, document_id, source_type, source_uri, section, " |
| " title, chunk_text, " |
| f" ({score_expr})::float AS score " |
| "FROM chunks_with_source " |
| f"WHERE ({where_any}){src_clause} " |
| "ORDER BY score DESC, chunk_id ASC " |
| "LIMIT %s" |
| ) |
| with conn.cursor() as cur: |
| cur.execute(sql, (*patterns, *patterns, *src_params, k)) |
| return [_row_to_hit(row) for row in cur.fetchall()] |
|
|
|
|
| def rare_query_tokens(query: str) -> list[str]: |
| """Extract tokens worth literal-matching: long alphabetic, acronyms, codes. |
| |
| Three rules combined: |
| - alphabetic and len > 7, not in the generic-medical stoplist |
| (catches drug names like sertraline, paroxetine, fluoxetine) |
| - all-uppercase and len >= 3 (catches acronyms: OCD, SSRI, MDD, TRD) |
| - mixed letter+digit and len >= 3 (catches ICD codes like F41 / 6A20) |
| """ |
| rare: list[str] = [] |
| seen: set[str] = set() |
| for raw in _TOKEN_RE.findall(query): |
| low = raw.lower() |
| if low in seen: |
| continue |
| has_digit = any(c.isdigit() for c in raw) |
| has_alpha = any(c.isalpha() for c in raw) |
| is_upper = raw.isupper() and len(raw) >= 3 and not has_digit |
| is_long = len(raw) > 7 and not has_digit and low not in _GENERIC_LONG_TOKENS |
| is_codeish = has_digit and has_alpha and len(raw) >= 3 |
| if is_long or is_upper or is_codeish: |
| rare.append(low) |
| seen.add(low) |
| return rare |
|
|
|
|
| def _build_vector_sql( |
| embedding, k: int, source_types: Sequence[str] | None |
| ) -> tuple[str, tuple]: |
| where, params_pre = _source_filter(source_types) |
| sql = ( |
| "SELECT chunk_id, document_id, source_type, source_uri, section, " |
| " title, chunk_text, 1 - (embedding <=> %s) AS score " |
| "FROM chunks_with_source" |
| f"{where} " |
| "ORDER BY embedding <=> %s " |
| "LIMIT %s" |
| ) |
| |
| |
| return sql, (embedding, *params_pre, embedding, k) |
|
|
|
|
| def _build_bm25_sql( |
| ts_query: str, k: int, source_types: Sequence[str] | None |
| ) -> tuple[str, tuple]: |
| where, params_pre = _source_filter(source_types, leading_where=False) |
| base_where = "tsv @@ to_tsquery('english', %s)" |
| full_where = f"WHERE {base_where}" + (f" AND {where}" if where else "") |
| sql = ( |
| "SELECT chunk_id, document_id, source_type, source_uri, section, " |
| " title, chunk_text, ts_rank(tsv, to_tsquery('english', %s)) AS score " |
| "FROM chunks_with_source " |
| f"{full_where} " |
| "ORDER BY ts_rank(tsv, to_tsquery('english', %s)) DESC " |
| "LIMIT %s" |
| ) |
| return sql, (ts_query, ts_query, *params_pre, ts_query, k) |
|
|
|
|
| def _source_filter( |
| source_types: Sequence[str] | None, *, leading_where: bool = True |
| ) -> tuple[str, tuple]: |
| if not source_types: |
| return ("", ()) |
| clause = "source_type = ANY(%s)" |
| return (f" WHERE {clause}" if leading_where else clause, (list(source_types),)) |
|
|
|
|
| def _row_to_hit(row) -> Hit: |
| return Hit( |
| chunk_id=row[0], |
| document_id=row[1], |
| source_type=row[2], |
| source_uri=row[3], |
| section=row[4], |
| title=row[5], |
| chunk_text=row[6], |
| score=float(row[7]), |
| ) |
|
|