Abhinav Biju
fast/thinking toggle
cc2dc62
"""Hybrid retrieval over notebook-scoped indexed chunks.
Spec references:
- `specs/04_interfaces.md`: implements `retrieve()`.
- `specs/05_rag_and_citations.md`: hybrid BM25 plus vector retrieval with merged candidates.
- `specs/07_security.md`: notebook access remains isolated per user and notebook.
- `specs/10_test_plan.md`: deterministic retrieval logic suitable for testing.
- `specs/11_observability.md`: retrieval emits structured logging fields.
"""
from __future__ import annotations
import json
import logging
import math
import os
from functools import lru_cache
from pathlib import Path
from time import perf_counter
from typing import Any, TypedDict
from ingestion.embedder import EmbedderDependencyError, EmbedderError, embed_texts
from notebooklm_clone.notebooks import get_notebook
from notebooklm_clone.storage import notebook_root, safe_join
LOGGER = logging.getLogger(__name__)
class RetrievalResult(TypedDict):
"""Returned retrieval record for one chunk candidate."""
chunk_id: str
source_id: str
source_name: str
text: str
score: float
loc: Any
class RetrievalError(Exception):
"""Base exception for retrieval failures."""
class RetrievalDependencyError(RetrievalError):
"""Raised when a required retrieval dependency is unavailable."""
class RetrievalValidationError(RetrievalError):
"""Raised when query inputs or indexed payloads are invalid."""
class RetrievalStorageError(RetrievalError):
"""Raised when notebook-local retrieval data cannot be opened."""
class _Candidate(TypedDict):
"""Internal merged candidate shape before final formatting."""
chunk_id: str
source_id: str
source_name: str
text: str
loc: Any
bm25_score: float
vector_score: float
def _log_retrieval(
username: str,
notebook_id: str,
status: str,
started_at: float,
) -> None:
"""Emit an observability log record for retrieval operations."""
duration_ms: int = int((perf_counter() - started_at) * 1000)
LOGGER.info(
"retrieve",
extra={
"user": username,
"notebook_id": notebook_id,
"action": "retrieve",
"duration_ms": duration_ms,
"status": status,
},
)
def _tokenize(text: str) -> list[str]:
"""Tokenize text deterministically into lowercase alphanumeric terms."""
tokens: list[str] = []
current: list[str] = []
for character in text.lower():
if character.isalnum():
current.append(character)
continue
if current:
tokens.append("".join(current))
current = []
if current:
tokens.append("".join(current))
return tokens
def _normalize_scores(scores: dict[str, float]) -> dict[str, float]:
"""Normalize positive scores to the `[0, 1]` interval deterministically."""
positive_scores: list[float] = [score for score in scores.values() if score > 0.0]
if not positive_scores:
return {chunk_id: 0.0 for chunk_id in scores}
max_score: float = max(positive_scores)
if max_score <= 0.0:
return {chunk_id: 0.0 for chunk_id in scores}
return {
chunk_id: (score / max_score) if score > 0.0 else 0.0
for chunk_id, score in scores.items()
}
def _parse_loc(value: Any) -> Any:
"""Parse stored location metadata when it was serialized as JSON."""
if not isinstance(value, str):
return value
try:
return json.loads(value)
except json.JSONDecodeError:
return value
def _chroma_path(username: str, notebook_id: str) -> Path:
"""Return the notebook-scoped Chroma persistence directory."""
root: Path = notebook_root(username, notebook_id)
chroma_root: Path = safe_join(root, "chroma")
try:
chroma_root.mkdir(parents=True, exist_ok=True)
except OSError as exc:
raise RetrievalStorageError(f"Failed to prepare Chroma path: {chroma_root}") from exc
return chroma_root
def _get_collection(username: str, notebook_id: str) -> Any:
"""Open the notebook-local Chroma collection."""
try:
import chromadb
except ImportError as exc:
raise RetrievalDependencyError(
"Retrieval requires the 'chromadb' package to be installed."
) from exc
chroma_root: Path = _chroma_path(username, notebook_id)
try:
client = chromadb.PersistentClient(path=str(chroma_root))
return client.get_or_create_collection(name=notebook_id)
except Exception as exc:
raise RetrievalStorageError(
f"Failed to open Chroma collection for notebook: {notebook_id}"
) from exc
def _load_collection_documents(collection: Any) -> tuple[list[str], list[str], list[dict[str, Any]]]:
"""Load indexed notebook documents for BM25 scoring."""
try:
payload: dict[str, Any] = collection.get(include=["documents", "metadatas"])
except Exception as exc:
raise RetrievalStorageError("Failed to read notebook collection contents.") from exc
ids: Any = payload.get("ids")
documents: Any = payload.get("documents")
metadatas: Any = payload.get("metadatas")
if not isinstance(ids, list) or not isinstance(documents, list) or not isinstance(metadatas, list):
raise RetrievalStorageError("Chroma collection returned invalid retrieval payloads.")
if not (len(ids) == len(documents) == len(metadatas)):
raise RetrievalStorageError("Chroma collection returned misaligned retrieval payloads.")
validated_ids: list[str] = []
validated_documents: list[str] = []
validated_metadatas: list[dict[str, Any]] = []
for index, item_id in enumerate(ids):
if not isinstance(item_id, str):
raise RetrievalStorageError(f"Indexed chunk id at position {index} is invalid.")
if not isinstance(documents[index], str):
raise RetrievalStorageError(f"Indexed document at position {index} is invalid.")
if not isinstance(metadatas[index], dict):
raise RetrievalStorageError(f"Indexed metadata at position {index} is invalid.")
validated_ids.append(item_id)
validated_documents.append(documents[index])
validated_metadatas.append(metadatas[index])
return validated_ids, validated_documents, validated_metadatas
def _bm25_scores(documents: dict[str, str], query: str) -> dict[str, float]:
"""Compute deterministic BM25 scores over `chunk_text` values."""
query_tokens: list[str] = _tokenize(query)
if not query_tokens:
return {chunk_id: 0.0 for chunk_id in documents}
doc_tokens: dict[str, list[str]] = {
chunk_id: _tokenize(text) for chunk_id, text in documents.items()
}
document_count: int = len(doc_tokens)
if document_count == 0:
return {}
average_length: float = sum(len(tokens) for tokens in doc_tokens.values()) / document_count
if average_length == 0.0:
return {chunk_id: 0.0 for chunk_id in documents}
document_frequency: dict[str, int] = {}
term_frequencies: dict[str, dict[str, int]] = {}
for chunk_id, tokens in doc_tokens.items():
counts: dict[str, int] = {}
for token in tokens:
counts[token] = counts.get(token, 0) + 1
term_frequencies[chunk_id] = counts
for token in counts:
document_frequency[token] = document_frequency.get(token, 0) + 1
k1: float = 1.5
b: float = 0.75
scores: dict[str, float] = {}
for chunk_id, tokens in doc_tokens.items():
doc_length: int = len(tokens)
score: float = 0.0
counts: dict[str, int] = term_frequencies[chunk_id]
for token in query_tokens:
frequency: int = counts.get(token, 0)
if frequency == 0:
continue
df: int = document_frequency.get(token, 0)
inverse_document_frequency: float = math.log(
1.0 + ((document_count - df + 0.5) / (df + 0.5))
)
denominator: float = frequency + k1 * (
1.0 - b + b * (doc_length / average_length)
)
score += inverse_document_frequency * (
(frequency * (k1 + 1.0)) / denominator
)
scores[chunk_id] = score
return scores
def _vector_scores(collection: Any, query: str, limit: int) -> dict[str, float]:
"""Query vector similarity from the notebook-scoped Chroma collection."""
if limit <= 0:
return {}
try:
query_embedding: list[float] = embed_texts([query])[0]
except (EmbedderDependencyError, EmbedderError) as exc:
raise RetrievalDependencyError("Failed to generate retrieval query embedding.") from exc
try:
payload: dict[str, Any] = collection.query(
query_embeddings=[query_embedding],
n_results=limit,
include=["distances"],
)
except Exception as exc:
raise RetrievalStorageError("Failed to query notebook vector index.") from exc
ids_nested: Any = payload.get("ids")
distances_nested: Any = payload.get("distances")
if not isinstance(ids_nested, list) or not ids_nested:
return {}
if not isinstance(distances_nested, list) or not distances_nested:
raise RetrievalStorageError("Chroma query returned invalid distance payloads.")
ids: Any = ids_nested[0]
distances: Any = distances_nested[0]
if not isinstance(ids, list) or not isinstance(distances, list):
raise RetrievalStorageError("Chroma query returned invalid nested payloads.")
if len(ids) != len(distances):
raise RetrievalStorageError("Chroma query returned misaligned ids and distances.")
scores: dict[str, float] = {}
for index, chunk_id in enumerate(ids):
distance: Any = distances[index]
if not isinstance(chunk_id, str) or not isinstance(distance, (int, float)):
raise RetrievalStorageError("Chroma query returned invalid vector results.")
scores[chunk_id] = 1.0 / (1.0 + max(float(distance), 0.0))
return scores
@lru_cache(maxsize=1)
def _load_reranker() -> Any:
"""Load cross-encoder reranker model once per process."""
model_name: str = os.getenv(
"NOTEBOOKLM_RERANKER_MODEL", "cross-encoder/ms-marco-MiniLM-L-6-v2"
).strip()
try:
from sentence_transformers import CrossEncoder
except ImportError as exc:
raise RetrievalDependencyError(
"Reranking requires the 'sentence-transformers' package."
) from exc
LOGGER.info("Loading reranker model: %s", model_name)
return CrossEncoder(model_name)
def _rerank(
query: str,
candidates: list[RetrievalResult],
k: int,
) -> list[RetrievalResult]:
"""Re-score candidates with a cross-encoder and return top-k."""
if not candidates:
return []
reranker = _load_reranker()
pairs: list[list[str]] = [[query, c["text"]] for c in candidates]
scores = reranker.predict(pairs)
for i, candidate in enumerate(candidates):
candidate["score"] = float(scores[i])
candidates.sort(key=lambda item: (-item["score"], item["chunk_id"]))
return candidates[:k]
def _expand_query(query: str) -> list[str]:
"""Generate alternative query phrasings using the LLM.
Returns a list containing the original query plus expansions.
Falls back to just [query] if LLM is unavailable or expansion is disabled.
"""
if os.getenv("NOTEBOOKLM_QUERY_EXPANSION", "on").strip().lower() == "off":
return [query]
api_key: str = os.getenv("OPENAI_API_KEY", "").strip()
model: str = os.getenv("NOTEBOOKLM_CHAT_MODEL", "gpt-4o-mini").strip()
if not api_key or not model:
return [query]
try:
from openai import OpenAI
client = OpenAI(api_key=api_key)
response = client.chat.completions.create(
model=model,
messages=[
{
"role": "system",
"content": (
"Generate 2 alternative phrasings of the user's search query. "
"Return only the alternative queries, one per line, no numbering."
),
},
{"role": "user", "content": query},
],
temperature=0.7,
max_tokens=150,
)
content: str = (response.choices[0].message.content or "").strip()
expansions: list[str] = [
line.strip() for line in content.splitlines() if line.strip()
]
LOGGER.info("Query expanded: %s -> %s", query, expansions)
return [query] + expansions[:2]
except Exception as exc:
LOGGER.warning("Query expansion failed, using original: %s", exc)
return [query]
def _multi_query_scores(
chunk_documents: dict[str, str],
collection: Any,
queries: list[str],
n_docs: int,
) -> tuple[dict[str, float], dict[str, float]]:
"""Run BM25 + vector for each query variant, merge with max-per-chunk."""
merged_bm25: dict[str, float] = {}
merged_vector: dict[str, float] = {}
for q in queries:
bm25_raw = _bm25_scores(chunk_documents, q)
vector_raw = _vector_scores(collection, q, n_docs)
for cid, score in bm25_raw.items():
merged_bm25[cid] = max(merged_bm25.get(cid, 0.0), score)
for cid, score in vector_raw.items():
merged_vector[cid] = max(merged_vector.get(cid, 0.0), score)
return merged_bm25, merged_vector
def retrieve(
username: str,
notebook_id: str,
query: str,
k: int,
rag_mode: str = "Reasoning",
) -> list[RetrievalResult]:
"""Retrieve top notebook chunks with hybrid scoring, query expansion, and reranking.
Spec references:
- `specs/04_interfaces.md`: implements `retrieve()`.
- `specs/05_rag_and_citations.md`: BM25 retrieval, vector retrieval, merge, dedupe,
normalize, and return top-k sorted descending.
- `specs/07_security.md`: retrieval is scoped to one notebook owned by one user.
- `specs/11_observability.md`: logs `user`, `notebook_id`, `action`, `duration_ms`, and `status`.
Raises:
ValueError: If `query` is empty or `k` is not positive.
RetrievalDependencyError: If retrieval dependencies are unavailable.
RetrievalStorageError: If notebook-local retrieval data cannot be opened.
RetrievalValidationError: If indexed metadata is malformed.
"""
started_at: float = perf_counter()
try:
if not isinstance(query, str) or not query.strip():
raise ValueError("query must be a non-empty string.")
if k <= 0:
raise ValueError("k must be greater than 0.")
# Verifies notebook ownership and existence before any retrieval work.
get_notebook(username, notebook_id)
collection = _get_collection(username, notebook_id)
ids, documents, metadatas = _load_collection_documents(collection)
if not ids:
_log_retrieval(username, notebook_id, "success", started_at)
return []
chunk_documents: dict[str, str] = {
chunk_id: document for chunk_id, document in zip(ids, documents)
}
chunk_metadata: dict[str, dict[str, Any]] = {
chunk_id: metadata for chunk_id, metadata in zip(ids, metadatas)
}
# Query expansion: generate alt phrasings and merge scores
queries: list[str] = _expand_query(query) if rag_mode == "Reasoning" else [query]
bm25_raw, vector_raw = _multi_query_scores(
chunk_documents, collection, queries, len(ids)
)
bm25_normalized: dict[str, float] = _normalize_scores(bm25_raw)
vector_normalized: dict[str, float] = _normalize_scores(vector_raw)
merged_ids: list[str] = sorted(set(bm25_raw) | set(vector_raw))
candidates: list[_Candidate] = []
for chunk_id in merged_ids:
metadata: dict[str, Any] | None = chunk_metadata.get(chunk_id)
text: str | None = chunk_documents.get(chunk_id)
if metadata is None or text is None:
raise RetrievalStorageError(f"Missing indexed content for chunk: {chunk_id}")
source_id: Any = metadata.get("source_id")
source_name: Any = metadata.get("source_name")
if not isinstance(source_id, str) or not source_id.strip():
raise RetrievalValidationError(
f"Indexed metadata missing valid source_id for chunk: {chunk_id}"
)
if not isinstance(source_name, str) or not source_name.strip():
raise RetrievalValidationError(
f"Indexed metadata missing valid source_name for chunk: {chunk_id}"
)
candidates.append(
{
"chunk_id": chunk_id,
"source_id": source_id.strip(),
"source_name": source_name.strip(),
"text": text,
"loc": _parse_loc(metadata.get("location_hints")),
"bm25_score": bm25_normalized.get(chunk_id, 0.0),
"vector_score": vector_normalized.get(chunk_id, 0.0),
}
)
ranked_results: list[RetrievalResult] = []
for candidate in candidates:
combined_score: float = (candidate["bm25_score"] + candidate["vector_score"]) / 2.0
ranked_results.append(
{
"chunk_id": candidate["chunk_id"],
"source_id": candidate["source_id"],
"source_name": candidate["source_name"],
"text": candidate["text"],
"score": combined_score,
"loc": candidate["loc"],
}
)
ranked_results.sort(key=lambda item: (-item["score"], item["chunk_id"]))
if rag_mode == "Fast":
result: list[RetrievalResult] = ranked_results[:k]
else:
# Rerank only top-N candidates to control latency (default: 10)
_rerank_n: int = int(os.getenv("NOTEBOOKLM_RERANK_TOP_N", "10"))
rerank_pool: list[RetrievalResult] = ranked_results[:_rerank_n]
result: list[RetrievalResult] = _rerank(query, rerank_pool, k)
_log_retrieval(username, notebook_id, "success", started_at)
return result
except Exception:
_log_retrieval(username, notebook_id, "error", started_at)
raise