File size: 8,605 Bytes
08fc97e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
"""
Retrieval primitives: vector (cosine via pgvector) and BM25 (Postgres ts_rank).

These are the two retrievers that get fused in api/hybrid.py. Each returns
a ranked list of `Hit` records pulled from the `chunks_with_source` view,
oldest-rank-first (rank 0 = best). Both share the same row shape so the
fusion layer doesn't need to special-case either one.

The embedding model is cached as a module-level singleton so repeated calls
in the same process reuse the loaded weights. Cold-load is ~3-5 s on CPU.
"""

from __future__ import annotations

import os
import re
from dataclasses import dataclass
from typing import Sequence

import psycopg
from pgvector.psycopg import register_vector
from sentence_transformers import SentenceTransformer

DEFAULT_EMBEDDING_MODEL = "pritamdeka/S-PubMedBert-MS-MARCO"
_TOKEN_RE = re.compile(r"[A-Za-z0-9]+")

# Tokens that are "long" by length but too generic to count as rare clinical
# entities — they appear in nearly every clinical/research chunk and would
# drown the lexical retriever in noise.
_GENERIC_LONG_TOKENS = frozenset({
    "patient", "patients", "clinical", "disorder", "disorders", "depression",
    "depressive", "anxiety", "criteria", "diagnosis", "treatment", "symptoms",
    "research", "adolescents", "adolescent", "generalized", "augmentation",
    "disease", "therapy", "results", "study", "studies", "moderate", "severe",
    "history", "currently", "recommend", "recommended", "negative", "positive",
    "psychiatric", "psychological", "medication", "medications",
})

_embedding_model: SentenceTransformer | None = None


@dataclass(frozen=True)
class Hit:
    """One retrieval hit — fields cover both retriever paths and rerank later."""
    chunk_id: int
    document_id: int
    source_type: str
    source_uri: str | None
    section: str | None
    title: str | None
    chunk_text: str
    score: float          # cosine similarity (vector) or ts_rank (bm25)


def get_embedding_model() -> SentenceTransformer:
    global _embedding_model
    if _embedding_model is None:
        name = os.environ.get("EMBEDDING_MODEL", DEFAULT_EMBEDDING_MODEL)
        _embedding_model = SentenceTransformer(name)
    return _embedding_model


def retrieve_vector(
    conn: psycopg.Connection,
    query: str,
    k: int = 50,
    source_types: Sequence[str] | None = None,
) -> list[Hit]:
    """Top-k by cosine similarity against the chunk embeddings.

    Uses the `<=>` cosine-distance operator backed by the HNSW index.
    Score returned is `1 - distance` so higher = better, matching the
    intuitive direction expected by the fusion layer.
    """
    register_vector(conn)
    embedding = get_embedding_model().encode(query, normalize_embeddings=True)
    sql, params = _build_vector_sql(embedding, k, source_types)
    with conn.cursor() as cur:
        cur.execute(sql, params)
        return [_row_to_hit(row) for row in cur.fetchall()]


def retrieve_bm25(
    conn: psycopg.Connection,
    query: str,
    k: int = 50,
    source_types: Sequence[str] | None = None,
) -> list[Hit]:
    """Top-k by Postgres `ts_rank` over the auto-populated `tsv` column.

    Tokens are extracted with a strict alphanumeric regex and joined with
    OR semantics — `plainto_tsquery`'s implicit AND is too brittle for
    natural-language clinical queries (e.g. "sertraline 50mg for MDD"
    requires every literal token in one chunk, which usually fails).
    OR keeps any token-overlap candidates flowing into RRF, which then
    ranks them. The regex also keeps user input safely outside the
    `to_tsquery` parser, which is strict about punctuation.
    """
    ts_query = _to_or_tsquery(query)
    if not ts_query:
        return []
    sql, params = _build_bm25_sql(ts_query, k, source_types)
    with conn.cursor() as cur:
        cur.execute(sql, params)
        return [_row_to_hit(row) for row in cur.fetchall()]


def _to_or_tsquery(query: str) -> str:
    tokens = {t.lower() for t in _TOKEN_RE.findall(query) if len(t) > 1}
    return " | ".join(sorted(tokens))


def retrieve_lexical(
    conn: psycopg.Connection,
    query: str,
    k: int = 50,
    source_types: Sequence[str] | None = None,
) -> list[Hit]:
    """Top-k by literal-substring matching on rare query tokens.

    Third RRF input alongside vector + BM25. Targets the failure mode where
    a chunk literally contains a rare clinical entity (drug name, ICD code,
    acronym) but the surrounding context buries it for both dense and
    `ts_rank` retrievers.

    Score = sum of matched-token lengths — gives longer/more-specific
    tokens proportionally more weight than short noisy ones like "50mg".
    Returns [] when the query has no tokens passing the rarity heuristic
    (the other two retrievers handle that case fine).
    """
    rare = rare_query_tokens(query)
    if not rare:
        return []
    patterns = [f"%{t}%" for t in rare]
    score_expr = " + ".join(
        f"(CASE WHEN chunk_text ILIKE %s THEN {len(t)} ELSE 0 END)" for t in rare
    )
    where_any = " OR ".join("chunk_text ILIKE %s" for _ in rare)
    src_clause, src_params = "", ()
    if source_types:
        src_clause = " AND source_type = ANY(%s)"
        src_params = (list(source_types),)
    sql = (
        "SELECT chunk_id, document_id, source_type, source_uri, section, "
        "       title, chunk_text, "
        f"       ({score_expr})::float AS score "
        "FROM chunks_with_source "
        f"WHERE ({where_any}){src_clause} "
        "ORDER BY score DESC, chunk_id ASC "
        "LIMIT %s"
    )
    with conn.cursor() as cur:
        cur.execute(sql, (*patterns, *patterns, *src_params, k))
        return [_row_to_hit(row) for row in cur.fetchall()]


def rare_query_tokens(query: str) -> list[str]:
    """Extract tokens worth literal-matching: long alphabetic, acronyms, codes.

    Three rules combined:
      - alphabetic and len > 7, not in the generic-medical stoplist
        (catches drug names like sertraline, paroxetine, fluoxetine)
      - all-uppercase and len >= 3 (catches acronyms: OCD, SSRI, MDD, TRD)
      - mixed letter+digit and len >= 3 (catches ICD codes like F41 / 6A20)
    """
    rare: list[str] = []
    seen: set[str] = set()
    for raw in _TOKEN_RE.findall(query):
        low = raw.lower()
        if low in seen:
            continue
        has_digit = any(c.isdigit() for c in raw)
        has_alpha = any(c.isalpha() for c in raw)
        is_upper = raw.isupper() and len(raw) >= 3 and not has_digit
        is_long = len(raw) > 7 and not has_digit and low not in _GENERIC_LONG_TOKENS
        is_codeish = has_digit and has_alpha and len(raw) >= 3
        if is_long or is_upper or is_codeish:
            rare.append(low)
            seen.add(low)
    return rare


def _build_vector_sql(
    embedding, k: int, source_types: Sequence[str] | None
) -> tuple[str, tuple]:
    where, params_pre = _source_filter(source_types)
    sql = (
        "SELECT chunk_id, document_id, source_type, source_uri, section, "
        "       title, chunk_text, 1 - (embedding <=> %s) AS score "
        "FROM chunks_with_source"
        f"{where} "
        "ORDER BY embedding <=> %s "
        "LIMIT %s"
    )
    # Placeholder order: SELECT embedding, optional WHERE source_type array,
    # ORDER BY embedding, LIMIT.
    return sql, (embedding, *params_pre, embedding, k)


def _build_bm25_sql(
    ts_query: str, k: int, source_types: Sequence[str] | None
) -> tuple[str, tuple]:
    where, params_pre = _source_filter(source_types, leading_where=False)
    base_where = "tsv @@ to_tsquery('english', %s)"
    full_where = f"WHERE {base_where}" + (f" AND {where}" if where else "")
    sql = (
        "SELECT chunk_id, document_id, source_type, source_uri, section, "
        "       title, chunk_text, ts_rank(tsv, to_tsquery('english', %s)) AS score "
        "FROM chunks_with_source "
        f"{full_where} "
        "ORDER BY ts_rank(tsv, to_tsquery('english', %s)) DESC "
        "LIMIT %s"
    )
    return sql, (ts_query, ts_query, *params_pre, ts_query, k)


def _source_filter(
    source_types: Sequence[str] | None, *, leading_where: bool = True
) -> tuple[str, tuple]:
    if not source_types:
        return ("", ())
    clause = "source_type = ANY(%s)"
    return (f" WHERE {clause}" if leading_where else clause, (list(source_types),))


def _row_to_hit(row) -> Hit:
    return Hit(
        chunk_id=row[0],
        document_id=row[1],
        source_type=row[2],
        source_uri=row[3],
        section=row[4],
        title=row[5],
        chunk_text=row[6],
        score=float(row[7]),
    )