explainer-env / research /retrieval.py
kgdrathan's picture
Upload folder using huggingface_hub
8fa7af1 verified
"""Small retrieval helpers: tokenization, chunking, and embedding ranking."""
from __future__ import annotations
import math
import re
from pathlib import Path
from .types import ResearchChunk
SECTION_MAX_CHARS = 900
MAX_RETURNED_CHUNKS = 5
EMBEDDING_MODEL_NAME = "BAAI/bge-small-en-v1.5"
EMBEDDING_CACHE_DIR = Path(__file__).resolve().parents[1] / ".cache" / "fastembed"
_EMBEDDING_MODEL = None
_STOP_WORDS = frozenset({
"the",
"a",
"an",
"is",
"are",
"was",
"were",
"be",
"been",
"being",
"have",
"has",
"had",
"do",
"does",
"did",
"will",
"would",
"could",
"should",
"to",
"of",
"in",
"for",
"on",
"with",
"at",
"by",
"from",
"as",
"and",
"but",
"or",
"this",
"that",
"these",
"those",
"it",
"its",
})
def tokenize(text: str) -> list[str]:
"""Lowercase alphanumeric tokenization, stop words removed."""
return [
word
for word in re.findall(r"\w+", text.lower())
if word not in _STOP_WORDS and len(word) > 1
]
def trim_text(text: str, max_chars: int = SECTION_MAX_CHARS) -> str:
text = re.sub(r"\s+", " ", text).strip()
return text[:max_chars].strip()
def chunk_markdown(text: str, fallback_title: str) -> list[tuple[str, str]]:
"""Split markdown-ish text into titled chunks."""
chunks: list[tuple[str, str]] = []
heading = fallback_title
lines: list[str] = []
for line in text.splitlines():
if line.startswith("#"):
body = "\n".join(lines).strip()
if body:
chunks.append((heading, body))
heading = line.lstrip("#").strip() or fallback_title
lines = []
else:
lines.append(line)
body = "\n".join(lines).strip()
if body:
chunks.append((heading, body))
return chunks
def rank_chunks_for_query(
query: str,
intent: str,
chunks: list[ResearchChunk],
top_k: int = MAX_RETURNED_CHUNKS,
embedding_model=None,
) -> list[ResearchChunk]:
"""Return the final top chunks for query+intent.
The pipeline is: source results -> text chunks -> embedding similarity
against query+intent -> final top-k chunks.
"""
if not chunks:
return []
query_text = f"{query} {intent}".strip()
if not query_text:
return _assign_ranks(chunks[:top_k])
model = embedding_model or _get_embedding_model()
texts = [query_text] + [_chunk_embedding_text(chunk) for chunk in chunks]
vectors = list(model.embed(texts))
if len(vectors) != len(texts):
raise RuntimeError("Embedding model returned an unexpected number of vectors")
query_vec = vectors[0]
scored: list[ResearchChunk] = []
for chunk, vec in zip(chunks, vectors[1:]):
chunk.score = _cosine(query_vec, vec)
scored.append(chunk)
scored.sort(key=lambda chunk: chunk.score, reverse=True)
return _assign_ranks(scored[:top_k])
def preload_embedding_model() -> None:
"""Download/cache and initialize the embedding model before serving traffic."""
model = _get_embedding_model()
# Force model files and runtime session to be ready, not just configured.
list(model.embed(["startup warmup"]))
def _get_embedding_model():
global _EMBEDDING_MODEL
if _EMBEDDING_MODEL is None:
from fastembed import TextEmbedding
EMBEDDING_CACHE_DIR.mkdir(parents=True, exist_ok=True)
_EMBEDDING_MODEL = TextEmbedding(
model_name=EMBEDDING_MODEL_NAME,
cache_dir=str(EMBEDDING_CACHE_DIR),
)
return _EMBEDDING_MODEL
def _chunk_embedding_text(chunk: ResearchChunk) -> str:
return f"{chunk.title}\n{chunk.text}".strip()
def _assign_ranks(chunks: list[ResearchChunk]) -> list[ResearchChunk]:
for idx, chunk in enumerate(chunks, start=1):
chunk.rank = idx
return chunks
def _cosine(a, b) -> float:
numerator = sum(float(x) * float(y) for x, y in zip(a, b))
a_norm = math.sqrt(sum(float(x) * float(x) for x in a))
b_norm = math.sqrt(sum(float(y) * float(y) for y in b))
if a_norm == 0 or b_norm == 0:
return 0.0
return numerator / (a_norm * b_norm)