"""Hybrid retrieval over notebook-scoped indexed chunks. Spec references: - `specs/04_interfaces.md`: implements `retrieve()`. - `specs/05_rag_and_citations.md`: hybrid BM25 plus vector retrieval with merged candidates. - `specs/07_security.md`: notebook access remains isolated per user and notebook. - `specs/10_test_plan.md`: deterministic retrieval logic suitable for testing. - `specs/11_observability.md`: retrieval emits structured logging fields. """ from __future__ import annotations import json import logging import math import os from functools import lru_cache from pathlib import Path from time import perf_counter from typing import Any, TypedDict from ingestion.embedder import EmbedderDependencyError, EmbedderError, embed_texts from notebooklm_clone.notebooks import get_notebook from notebooklm_clone.storage import notebook_root, safe_join LOGGER = logging.getLogger(__name__) class RetrievalResult(TypedDict): """Returned retrieval record for one chunk candidate.""" chunk_id: str source_id: str source_name: str text: str score: float loc: Any class RetrievalError(Exception): """Base exception for retrieval failures.""" class RetrievalDependencyError(RetrievalError): """Raised when a required retrieval dependency is unavailable.""" class RetrievalValidationError(RetrievalError): """Raised when query inputs or indexed payloads are invalid.""" class RetrievalStorageError(RetrievalError): """Raised when notebook-local retrieval data cannot be opened.""" class _Candidate(TypedDict): """Internal merged candidate shape before final formatting.""" chunk_id: str source_id: str source_name: str text: str loc: Any bm25_score: float vector_score: float def _log_retrieval( username: str, notebook_id: str, status: str, started_at: float, ) -> None: """Emit an observability log record for retrieval operations.""" duration_ms: int = int((perf_counter() - started_at) * 1000) LOGGER.info( "retrieve", extra={ "user": username, "notebook_id": notebook_id, "action": "retrieve", "duration_ms": duration_ms, "status": status, }, ) def _tokenize(text: str) -> list[str]: """Tokenize text deterministically into lowercase alphanumeric terms.""" tokens: list[str] = [] current: list[str] = [] for character in text.lower(): if character.isalnum(): current.append(character) continue if current: tokens.append("".join(current)) current = [] if current: tokens.append("".join(current)) return tokens def _normalize_scores(scores: dict[str, float]) -> dict[str, float]: """Normalize positive scores to the `[0, 1]` interval deterministically.""" positive_scores: list[float] = [score for score in scores.values() if score > 0.0] if not positive_scores: return {chunk_id: 0.0 for chunk_id in scores} max_score: float = max(positive_scores) if max_score <= 0.0: return {chunk_id: 0.0 for chunk_id in scores} return { chunk_id: (score / max_score) if score > 0.0 else 0.0 for chunk_id, score in scores.items() } def _parse_loc(value: Any) -> Any: """Parse stored location metadata when it was serialized as JSON.""" if not isinstance(value, str): return value try: return json.loads(value) except json.JSONDecodeError: return value def _chroma_path(username: str, notebook_id: str) -> Path: """Return the notebook-scoped Chroma persistence directory.""" root: Path = notebook_root(username, notebook_id) chroma_root: Path = safe_join(root, "chroma") try: chroma_root.mkdir(parents=True, exist_ok=True) except OSError as exc: raise RetrievalStorageError(f"Failed to prepare Chroma path: {chroma_root}") from exc return chroma_root def _get_collection(username: str, notebook_id: str) -> Any: """Open the notebook-local Chroma collection.""" try: import chromadb except ImportError as exc: raise RetrievalDependencyError( "Retrieval requires the 'chromadb' package to be installed." ) from exc chroma_root: Path = _chroma_path(username, notebook_id) try: client = chromadb.PersistentClient(path=str(chroma_root)) return client.get_or_create_collection(name=notebook_id) except Exception as exc: raise RetrievalStorageError( f"Failed to open Chroma collection for notebook: {notebook_id}" ) from exc def _load_collection_documents(collection: Any) -> tuple[list[str], list[str], list[dict[str, Any]]]: """Load indexed notebook documents for BM25 scoring.""" try: payload: dict[str, Any] = collection.get(include=["documents", "metadatas"]) except Exception as exc: raise RetrievalStorageError("Failed to read notebook collection contents.") from exc ids: Any = payload.get("ids") documents: Any = payload.get("documents") metadatas: Any = payload.get("metadatas") if not isinstance(ids, list) or not isinstance(documents, list) or not isinstance(metadatas, list): raise RetrievalStorageError("Chroma collection returned invalid retrieval payloads.") if not (len(ids) == len(documents) == len(metadatas)): raise RetrievalStorageError("Chroma collection returned misaligned retrieval payloads.") validated_ids: list[str] = [] validated_documents: list[str] = [] validated_metadatas: list[dict[str, Any]] = [] for index, item_id in enumerate(ids): if not isinstance(item_id, str): raise RetrievalStorageError(f"Indexed chunk id at position {index} is invalid.") if not isinstance(documents[index], str): raise RetrievalStorageError(f"Indexed document at position {index} is invalid.") if not isinstance(metadatas[index], dict): raise RetrievalStorageError(f"Indexed metadata at position {index} is invalid.") validated_ids.append(item_id) validated_documents.append(documents[index]) validated_metadatas.append(metadatas[index]) return validated_ids, validated_documents, validated_metadatas def _bm25_scores(documents: dict[str, str], query: str) -> dict[str, float]: """Compute deterministic BM25 scores over `chunk_text` values.""" query_tokens: list[str] = _tokenize(query) if not query_tokens: return {chunk_id: 0.0 for chunk_id in documents} doc_tokens: dict[str, list[str]] = { chunk_id: _tokenize(text) for chunk_id, text in documents.items() } document_count: int = len(doc_tokens) if document_count == 0: return {} average_length: float = sum(len(tokens) for tokens in doc_tokens.values()) / document_count if average_length == 0.0: return {chunk_id: 0.0 for chunk_id in documents} document_frequency: dict[str, int] = {} term_frequencies: dict[str, dict[str, int]] = {} for chunk_id, tokens in doc_tokens.items(): counts: dict[str, int] = {} for token in tokens: counts[token] = counts.get(token, 0) + 1 term_frequencies[chunk_id] = counts for token in counts: document_frequency[token] = document_frequency.get(token, 0) + 1 k1: float = 1.5 b: float = 0.75 scores: dict[str, float] = {} for chunk_id, tokens in doc_tokens.items(): doc_length: int = len(tokens) score: float = 0.0 counts: dict[str, int] = term_frequencies[chunk_id] for token in query_tokens: frequency: int = counts.get(token, 0) if frequency == 0: continue df: int = document_frequency.get(token, 0) inverse_document_frequency: float = math.log( 1.0 + ((document_count - df + 0.5) / (df + 0.5)) ) denominator: float = frequency + k1 * ( 1.0 - b + b * (doc_length / average_length) ) score += inverse_document_frequency * ( (frequency * (k1 + 1.0)) / denominator ) scores[chunk_id] = score return scores def _vector_scores(collection: Any, query: str, limit: int) -> dict[str, float]: """Query vector similarity from the notebook-scoped Chroma collection.""" if limit <= 0: return {} try: query_embedding: list[float] = embed_texts([query])[0] except (EmbedderDependencyError, EmbedderError) as exc: raise RetrievalDependencyError("Failed to generate retrieval query embedding.") from exc try: payload: dict[str, Any] = collection.query( query_embeddings=[query_embedding], n_results=limit, include=["distances"], ) except Exception as exc: raise RetrievalStorageError("Failed to query notebook vector index.") from exc ids_nested: Any = payload.get("ids") distances_nested: Any = payload.get("distances") if not isinstance(ids_nested, list) or not ids_nested: return {} if not isinstance(distances_nested, list) or not distances_nested: raise RetrievalStorageError("Chroma query returned invalid distance payloads.") ids: Any = ids_nested[0] distances: Any = distances_nested[0] if not isinstance(ids, list) or not isinstance(distances, list): raise RetrievalStorageError("Chroma query returned invalid nested payloads.") if len(ids) != len(distances): raise RetrievalStorageError("Chroma query returned misaligned ids and distances.") scores: dict[str, float] = {} for index, chunk_id in enumerate(ids): distance: Any = distances[index] if not isinstance(chunk_id, str) or not isinstance(distance, (int, float)): raise RetrievalStorageError("Chroma query returned invalid vector results.") scores[chunk_id] = 1.0 / (1.0 + max(float(distance), 0.0)) return scores @lru_cache(maxsize=1) def _load_reranker() -> Any: """Load cross-encoder reranker model once per process.""" model_name: str = os.getenv( "NOTEBOOKLM_RERANKER_MODEL", "cross-encoder/ms-marco-MiniLM-L-6-v2" ).strip() try: from sentence_transformers import CrossEncoder except ImportError as exc: raise RetrievalDependencyError( "Reranking requires the 'sentence-transformers' package." ) from exc LOGGER.info("Loading reranker model: %s", model_name) return CrossEncoder(model_name) def _rerank( query: str, candidates: list[RetrievalResult], k: int, ) -> list[RetrievalResult]: """Re-score candidates with a cross-encoder and return top-k.""" if not candidates: return [] reranker = _load_reranker() pairs: list[list[str]] = [[query, c["text"]] for c in candidates] scores = reranker.predict(pairs) for i, candidate in enumerate(candidates): candidate["score"] = float(scores[i]) candidates.sort(key=lambda item: (-item["score"], item["chunk_id"])) return candidates[:k] def _expand_query(query: str) -> list[str]: """Generate alternative query phrasings using the LLM. Returns a list containing the original query plus expansions. Falls back to just [query] if LLM is unavailable or expansion is disabled. """ if os.getenv("NOTEBOOKLM_QUERY_EXPANSION", "on").strip().lower() == "off": return [query] api_key: str = os.getenv("OPENAI_API_KEY", "").strip() model: str = os.getenv("NOTEBOOKLM_CHAT_MODEL", "gpt-4o-mini").strip() if not api_key or not model: return [query] try: from openai import OpenAI client = OpenAI(api_key=api_key) response = client.chat.completions.create( model=model, messages=[ { "role": "system", "content": ( "Generate 2 alternative phrasings of the user's search query. " "Return only the alternative queries, one per line, no numbering." ), }, {"role": "user", "content": query}, ], temperature=0.7, max_tokens=150, ) content: str = (response.choices[0].message.content or "").strip() expansions: list[str] = [ line.strip() for line in content.splitlines() if line.strip() ] LOGGER.info("Query expanded: %s -> %s", query, expansions) return [query] + expansions[:2] except Exception as exc: LOGGER.warning("Query expansion failed, using original: %s", exc) return [query] def _multi_query_scores( chunk_documents: dict[str, str], collection: Any, queries: list[str], n_docs: int, ) -> tuple[dict[str, float], dict[str, float]]: """Run BM25 + vector for each query variant, merge with max-per-chunk.""" merged_bm25: dict[str, float] = {} merged_vector: dict[str, float] = {} for q in queries: bm25_raw = _bm25_scores(chunk_documents, q) vector_raw = _vector_scores(collection, q, n_docs) for cid, score in bm25_raw.items(): merged_bm25[cid] = max(merged_bm25.get(cid, 0.0), score) for cid, score in vector_raw.items(): merged_vector[cid] = max(merged_vector.get(cid, 0.0), score) return merged_bm25, merged_vector def retrieve( username: str, notebook_id: str, query: str, k: int, rag_mode: str = "Reasoning", ) -> list[RetrievalResult]: """Retrieve top notebook chunks with hybrid scoring, query expansion, and reranking. Spec references: - `specs/04_interfaces.md`: implements `retrieve()`. - `specs/05_rag_and_citations.md`: BM25 retrieval, vector retrieval, merge, dedupe, normalize, and return top-k sorted descending. - `specs/07_security.md`: retrieval is scoped to one notebook owned by one user. - `specs/11_observability.md`: logs `user`, `notebook_id`, `action`, `duration_ms`, and `status`. Raises: ValueError: If `query` is empty or `k` is not positive. RetrievalDependencyError: If retrieval dependencies are unavailable. RetrievalStorageError: If notebook-local retrieval data cannot be opened. RetrievalValidationError: If indexed metadata is malformed. """ started_at: float = perf_counter() try: if not isinstance(query, str) or not query.strip(): raise ValueError("query must be a non-empty string.") if k <= 0: raise ValueError("k must be greater than 0.") # Verifies notebook ownership and existence before any retrieval work. get_notebook(username, notebook_id) collection = _get_collection(username, notebook_id) ids, documents, metadatas = _load_collection_documents(collection) if not ids: _log_retrieval(username, notebook_id, "success", started_at) return [] chunk_documents: dict[str, str] = { chunk_id: document for chunk_id, document in zip(ids, documents) } chunk_metadata: dict[str, dict[str, Any]] = { chunk_id: metadata for chunk_id, metadata in zip(ids, metadatas) } # Query expansion: generate alt phrasings and merge scores queries: list[str] = _expand_query(query) if rag_mode == "Reasoning" else [query] bm25_raw, vector_raw = _multi_query_scores( chunk_documents, collection, queries, len(ids) ) bm25_normalized: dict[str, float] = _normalize_scores(bm25_raw) vector_normalized: dict[str, float] = _normalize_scores(vector_raw) merged_ids: list[str] = sorted(set(bm25_raw) | set(vector_raw)) candidates: list[_Candidate] = [] for chunk_id in merged_ids: metadata: dict[str, Any] | None = chunk_metadata.get(chunk_id) text: str | None = chunk_documents.get(chunk_id) if metadata is None or text is None: raise RetrievalStorageError(f"Missing indexed content for chunk: {chunk_id}") source_id: Any = metadata.get("source_id") source_name: Any = metadata.get("source_name") if not isinstance(source_id, str) or not source_id.strip(): raise RetrievalValidationError( f"Indexed metadata missing valid source_id for chunk: {chunk_id}" ) if not isinstance(source_name, str) or not source_name.strip(): raise RetrievalValidationError( f"Indexed metadata missing valid source_name for chunk: {chunk_id}" ) candidates.append( { "chunk_id": chunk_id, "source_id": source_id.strip(), "source_name": source_name.strip(), "text": text, "loc": _parse_loc(metadata.get("location_hints")), "bm25_score": bm25_normalized.get(chunk_id, 0.0), "vector_score": vector_normalized.get(chunk_id, 0.0), } ) ranked_results: list[RetrievalResult] = [] for candidate in candidates: combined_score: float = (candidate["bm25_score"] + candidate["vector_score"]) / 2.0 ranked_results.append( { "chunk_id": candidate["chunk_id"], "source_id": candidate["source_id"], "source_name": candidate["source_name"], "text": candidate["text"], "score": combined_score, "loc": candidate["loc"], } ) ranked_results.sort(key=lambda item: (-item["score"], item["chunk_id"])) if rag_mode == "Fast": result: list[RetrievalResult] = ranked_results[:k] else: # Rerank only top-N candidates to control latency (default: 10) _rerank_n: int = int(os.getenv("NOTEBOOKLM_RERANK_TOP_N", "10")) rerank_pool: list[RetrievalResult] = ranked_results[:_rerank_n] result: list[RetrievalResult] = _rerank(query, rerank_pool, k) _log_retrieval(username, notebook_id, "success", started_at) return result except Exception: _log_retrieval(username, notebook_id, "error", started_at) raise