Spaces:
Running
Running
| """Hybrid retrieval over notebook-scoped indexed chunks. | |
| Spec references: | |
| - `specs/04_interfaces.md`: implements `retrieve()`. | |
| - `specs/05_rag_and_citations.md`: hybrid BM25 plus vector retrieval with merged candidates. | |
| - `specs/07_security.md`: notebook access remains isolated per user and notebook. | |
| - `specs/10_test_plan.md`: deterministic retrieval logic suitable for testing. | |
| - `specs/11_observability.md`: retrieval emits structured logging fields. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| import math | |
| import os | |
| from functools import lru_cache | |
| from pathlib import Path | |
| from time import perf_counter | |
| from typing import Any, TypedDict | |
| from ingestion.embedder import EmbedderDependencyError, EmbedderError, embed_texts | |
| from notebooklm_clone.notebooks import get_notebook | |
| from notebooklm_clone.storage import notebook_root, safe_join | |
| LOGGER = logging.getLogger(__name__) | |
| class RetrievalResult(TypedDict): | |
| """Returned retrieval record for one chunk candidate.""" | |
| chunk_id: str | |
| source_id: str | |
| source_name: str | |
| text: str | |
| score: float | |
| loc: Any | |
| class RetrievalError(Exception): | |
| """Base exception for retrieval failures.""" | |
| class RetrievalDependencyError(RetrievalError): | |
| """Raised when a required retrieval dependency is unavailable.""" | |
| class RetrievalValidationError(RetrievalError): | |
| """Raised when query inputs or indexed payloads are invalid.""" | |
| class RetrievalStorageError(RetrievalError): | |
| """Raised when notebook-local retrieval data cannot be opened.""" | |
| class _Candidate(TypedDict): | |
| """Internal merged candidate shape before final formatting.""" | |
| chunk_id: str | |
| source_id: str | |
| source_name: str | |
| text: str | |
| loc: Any | |
| bm25_score: float | |
| vector_score: float | |
| def _log_retrieval( | |
| username: str, | |
| notebook_id: str, | |
| status: str, | |
| started_at: float, | |
| ) -> None: | |
| """Emit an observability log record for retrieval operations.""" | |
| duration_ms: int = int((perf_counter() - started_at) * 1000) | |
| LOGGER.info( | |
| "retrieve", | |
| extra={ | |
| "user": username, | |
| "notebook_id": notebook_id, | |
| "action": "retrieve", | |
| "duration_ms": duration_ms, | |
| "status": status, | |
| }, | |
| ) | |
| def _tokenize(text: str) -> list[str]: | |
| """Tokenize text deterministically into lowercase alphanumeric terms.""" | |
| tokens: list[str] = [] | |
| current: list[str] = [] | |
| for character in text.lower(): | |
| if character.isalnum(): | |
| current.append(character) | |
| continue | |
| if current: | |
| tokens.append("".join(current)) | |
| current = [] | |
| if current: | |
| tokens.append("".join(current)) | |
| return tokens | |
| def _normalize_scores(scores: dict[str, float]) -> dict[str, float]: | |
| """Normalize positive scores to the `[0, 1]` interval deterministically.""" | |
| positive_scores: list[float] = [score for score in scores.values() if score > 0.0] | |
| if not positive_scores: | |
| return {chunk_id: 0.0 for chunk_id in scores} | |
| max_score: float = max(positive_scores) | |
| if max_score <= 0.0: | |
| return {chunk_id: 0.0 for chunk_id in scores} | |
| return { | |
| chunk_id: (score / max_score) if score > 0.0 else 0.0 | |
| for chunk_id, score in scores.items() | |
| } | |
| def _parse_loc(value: Any) -> Any: | |
| """Parse stored location metadata when it was serialized as JSON.""" | |
| if not isinstance(value, str): | |
| return value | |
| try: | |
| return json.loads(value) | |
| except json.JSONDecodeError: | |
| return value | |
| def _chroma_path(username: str, notebook_id: str) -> Path: | |
| """Return the notebook-scoped Chroma persistence directory.""" | |
| root: Path = notebook_root(username, notebook_id) | |
| chroma_root: Path = safe_join(root, "chroma") | |
| try: | |
| chroma_root.mkdir(parents=True, exist_ok=True) | |
| except OSError as exc: | |
| raise RetrievalStorageError(f"Failed to prepare Chroma path: {chroma_root}") from exc | |
| return chroma_root | |
| def _get_collection(username: str, notebook_id: str) -> Any: | |
| """Open the notebook-local Chroma collection.""" | |
| try: | |
| import chromadb | |
| except ImportError as exc: | |
| raise RetrievalDependencyError( | |
| "Retrieval requires the 'chromadb' package to be installed." | |
| ) from exc | |
| chroma_root: Path = _chroma_path(username, notebook_id) | |
| try: | |
| client = chromadb.PersistentClient(path=str(chroma_root)) | |
| return client.get_or_create_collection(name=notebook_id) | |
| except Exception as exc: | |
| raise RetrievalStorageError( | |
| f"Failed to open Chroma collection for notebook: {notebook_id}" | |
| ) from exc | |
| def _load_collection_documents(collection: Any) -> tuple[list[str], list[str], list[dict[str, Any]]]: | |
| """Load indexed notebook documents for BM25 scoring.""" | |
| try: | |
| payload: dict[str, Any] = collection.get(include=["documents", "metadatas"]) | |
| except Exception as exc: | |
| raise RetrievalStorageError("Failed to read notebook collection contents.") from exc | |
| ids: Any = payload.get("ids") | |
| documents: Any = payload.get("documents") | |
| metadatas: Any = payload.get("metadatas") | |
| if not isinstance(ids, list) or not isinstance(documents, list) or not isinstance(metadatas, list): | |
| raise RetrievalStorageError("Chroma collection returned invalid retrieval payloads.") | |
| if not (len(ids) == len(documents) == len(metadatas)): | |
| raise RetrievalStorageError("Chroma collection returned misaligned retrieval payloads.") | |
| validated_ids: list[str] = [] | |
| validated_documents: list[str] = [] | |
| validated_metadatas: list[dict[str, Any]] = [] | |
| for index, item_id in enumerate(ids): | |
| if not isinstance(item_id, str): | |
| raise RetrievalStorageError(f"Indexed chunk id at position {index} is invalid.") | |
| if not isinstance(documents[index], str): | |
| raise RetrievalStorageError(f"Indexed document at position {index} is invalid.") | |
| if not isinstance(metadatas[index], dict): | |
| raise RetrievalStorageError(f"Indexed metadata at position {index} is invalid.") | |
| validated_ids.append(item_id) | |
| validated_documents.append(documents[index]) | |
| validated_metadatas.append(metadatas[index]) | |
| return validated_ids, validated_documents, validated_metadatas | |
| def _bm25_scores(documents: dict[str, str], query: str) -> dict[str, float]: | |
| """Compute deterministic BM25 scores over `chunk_text` values.""" | |
| query_tokens: list[str] = _tokenize(query) | |
| if not query_tokens: | |
| return {chunk_id: 0.0 for chunk_id in documents} | |
| doc_tokens: dict[str, list[str]] = { | |
| chunk_id: _tokenize(text) for chunk_id, text in documents.items() | |
| } | |
| document_count: int = len(doc_tokens) | |
| if document_count == 0: | |
| return {} | |
| average_length: float = sum(len(tokens) for tokens in doc_tokens.values()) / document_count | |
| if average_length == 0.0: | |
| return {chunk_id: 0.0 for chunk_id in documents} | |
| document_frequency: dict[str, int] = {} | |
| term_frequencies: dict[str, dict[str, int]] = {} | |
| for chunk_id, tokens in doc_tokens.items(): | |
| counts: dict[str, int] = {} | |
| for token in tokens: | |
| counts[token] = counts.get(token, 0) + 1 | |
| term_frequencies[chunk_id] = counts | |
| for token in counts: | |
| document_frequency[token] = document_frequency.get(token, 0) + 1 | |
| k1: float = 1.5 | |
| b: float = 0.75 | |
| scores: dict[str, float] = {} | |
| for chunk_id, tokens in doc_tokens.items(): | |
| doc_length: int = len(tokens) | |
| score: float = 0.0 | |
| counts: dict[str, int] = term_frequencies[chunk_id] | |
| for token in query_tokens: | |
| frequency: int = counts.get(token, 0) | |
| if frequency == 0: | |
| continue | |
| df: int = document_frequency.get(token, 0) | |
| inverse_document_frequency: float = math.log( | |
| 1.0 + ((document_count - df + 0.5) / (df + 0.5)) | |
| ) | |
| denominator: float = frequency + k1 * ( | |
| 1.0 - b + b * (doc_length / average_length) | |
| ) | |
| score += inverse_document_frequency * ( | |
| (frequency * (k1 + 1.0)) / denominator | |
| ) | |
| scores[chunk_id] = score | |
| return scores | |
| def _vector_scores(collection: Any, query: str, limit: int) -> dict[str, float]: | |
| """Query vector similarity from the notebook-scoped Chroma collection.""" | |
| if limit <= 0: | |
| return {} | |
| try: | |
| query_embedding: list[float] = embed_texts([query])[0] | |
| except (EmbedderDependencyError, EmbedderError) as exc: | |
| raise RetrievalDependencyError("Failed to generate retrieval query embedding.") from exc | |
| try: | |
| payload: dict[str, Any] = collection.query( | |
| query_embeddings=[query_embedding], | |
| n_results=limit, | |
| include=["distances"], | |
| ) | |
| except Exception as exc: | |
| raise RetrievalStorageError("Failed to query notebook vector index.") from exc | |
| ids_nested: Any = payload.get("ids") | |
| distances_nested: Any = payload.get("distances") | |
| if not isinstance(ids_nested, list) or not ids_nested: | |
| return {} | |
| if not isinstance(distances_nested, list) or not distances_nested: | |
| raise RetrievalStorageError("Chroma query returned invalid distance payloads.") | |
| ids: Any = ids_nested[0] | |
| distances: Any = distances_nested[0] | |
| if not isinstance(ids, list) or not isinstance(distances, list): | |
| raise RetrievalStorageError("Chroma query returned invalid nested payloads.") | |
| if len(ids) != len(distances): | |
| raise RetrievalStorageError("Chroma query returned misaligned ids and distances.") | |
| scores: dict[str, float] = {} | |
| for index, chunk_id in enumerate(ids): | |
| distance: Any = distances[index] | |
| if not isinstance(chunk_id, str) or not isinstance(distance, (int, float)): | |
| raise RetrievalStorageError("Chroma query returned invalid vector results.") | |
| scores[chunk_id] = 1.0 / (1.0 + max(float(distance), 0.0)) | |
| return scores | |
| def _load_reranker() -> Any: | |
| """Load cross-encoder reranker model once per process.""" | |
| model_name: str = os.getenv( | |
| "NOTEBOOKLM_RERANKER_MODEL", "cross-encoder/ms-marco-MiniLM-L-6-v2" | |
| ).strip() | |
| try: | |
| from sentence_transformers import CrossEncoder | |
| except ImportError as exc: | |
| raise RetrievalDependencyError( | |
| "Reranking requires the 'sentence-transformers' package." | |
| ) from exc | |
| LOGGER.info("Loading reranker model: %s", model_name) | |
| return CrossEncoder(model_name) | |
| def _rerank( | |
| query: str, | |
| candidates: list[RetrievalResult], | |
| k: int, | |
| ) -> list[RetrievalResult]: | |
| """Re-score candidates with a cross-encoder and return top-k.""" | |
| if not candidates: | |
| return [] | |
| reranker = _load_reranker() | |
| pairs: list[list[str]] = [[query, c["text"]] for c in candidates] | |
| scores = reranker.predict(pairs) | |
| for i, candidate in enumerate(candidates): | |
| candidate["score"] = float(scores[i]) | |
| candidates.sort(key=lambda item: (-item["score"], item["chunk_id"])) | |
| return candidates[:k] | |
| def _expand_query(query: str) -> list[str]: | |
| """Generate alternative query phrasings using the LLM. | |
| Returns a list containing the original query plus expansions. | |
| Falls back to just [query] if LLM is unavailable or expansion is disabled. | |
| """ | |
| if os.getenv("NOTEBOOKLM_QUERY_EXPANSION", "on").strip().lower() == "off": | |
| return [query] | |
| api_key: str = os.getenv("OPENAI_API_KEY", "").strip() | |
| model: str = os.getenv("NOTEBOOKLM_CHAT_MODEL", "gpt-4o-mini").strip() | |
| if not api_key or not model: | |
| return [query] | |
| try: | |
| from openai import OpenAI | |
| client = OpenAI(api_key=api_key) | |
| response = client.chat.completions.create( | |
| model=model, | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": ( | |
| "Generate 2 alternative phrasings of the user's search query. " | |
| "Return only the alternative queries, one per line, no numbering." | |
| ), | |
| }, | |
| {"role": "user", "content": query}, | |
| ], | |
| temperature=0.7, | |
| max_tokens=150, | |
| ) | |
| content: str = (response.choices[0].message.content or "").strip() | |
| expansions: list[str] = [ | |
| line.strip() for line in content.splitlines() if line.strip() | |
| ] | |
| LOGGER.info("Query expanded: %s -> %s", query, expansions) | |
| return [query] + expansions[:2] | |
| except Exception as exc: | |
| LOGGER.warning("Query expansion failed, using original: %s", exc) | |
| return [query] | |
| def _multi_query_scores( | |
| chunk_documents: dict[str, str], | |
| collection: Any, | |
| queries: list[str], | |
| n_docs: int, | |
| ) -> tuple[dict[str, float], dict[str, float]]: | |
| """Run BM25 + vector for each query variant, merge with max-per-chunk.""" | |
| merged_bm25: dict[str, float] = {} | |
| merged_vector: dict[str, float] = {} | |
| for q in queries: | |
| bm25_raw = _bm25_scores(chunk_documents, q) | |
| vector_raw = _vector_scores(collection, q, n_docs) | |
| for cid, score in bm25_raw.items(): | |
| merged_bm25[cid] = max(merged_bm25.get(cid, 0.0), score) | |
| for cid, score in vector_raw.items(): | |
| merged_vector[cid] = max(merged_vector.get(cid, 0.0), score) | |
| return merged_bm25, merged_vector | |
| def retrieve( | |
| username: str, | |
| notebook_id: str, | |
| query: str, | |
| k: int, | |
| rag_mode: str = "Reasoning", | |
| ) -> list[RetrievalResult]: | |
| """Retrieve top notebook chunks with hybrid scoring, query expansion, and reranking. | |
| Spec references: | |
| - `specs/04_interfaces.md`: implements `retrieve()`. | |
| - `specs/05_rag_and_citations.md`: BM25 retrieval, vector retrieval, merge, dedupe, | |
| normalize, and return top-k sorted descending. | |
| - `specs/07_security.md`: retrieval is scoped to one notebook owned by one user. | |
| - `specs/11_observability.md`: logs `user`, `notebook_id`, `action`, `duration_ms`, and `status`. | |
| Raises: | |
| ValueError: If `query` is empty or `k` is not positive. | |
| RetrievalDependencyError: If retrieval dependencies are unavailable. | |
| RetrievalStorageError: If notebook-local retrieval data cannot be opened. | |
| RetrievalValidationError: If indexed metadata is malformed. | |
| """ | |
| started_at: float = perf_counter() | |
| try: | |
| if not isinstance(query, str) or not query.strip(): | |
| raise ValueError("query must be a non-empty string.") | |
| if k <= 0: | |
| raise ValueError("k must be greater than 0.") | |
| # Verifies notebook ownership and existence before any retrieval work. | |
| get_notebook(username, notebook_id) | |
| collection = _get_collection(username, notebook_id) | |
| ids, documents, metadatas = _load_collection_documents(collection) | |
| if not ids: | |
| _log_retrieval(username, notebook_id, "success", started_at) | |
| return [] | |
| chunk_documents: dict[str, str] = { | |
| chunk_id: document for chunk_id, document in zip(ids, documents) | |
| } | |
| chunk_metadata: dict[str, dict[str, Any]] = { | |
| chunk_id: metadata for chunk_id, metadata in zip(ids, metadatas) | |
| } | |
| # Query expansion: generate alt phrasings and merge scores | |
| queries: list[str] = _expand_query(query) if rag_mode == "Reasoning" else [query] | |
| bm25_raw, vector_raw = _multi_query_scores( | |
| chunk_documents, collection, queries, len(ids) | |
| ) | |
| bm25_normalized: dict[str, float] = _normalize_scores(bm25_raw) | |
| vector_normalized: dict[str, float] = _normalize_scores(vector_raw) | |
| merged_ids: list[str] = sorted(set(bm25_raw) | set(vector_raw)) | |
| candidates: list[_Candidate] = [] | |
| for chunk_id in merged_ids: | |
| metadata: dict[str, Any] | None = chunk_metadata.get(chunk_id) | |
| text: str | None = chunk_documents.get(chunk_id) | |
| if metadata is None or text is None: | |
| raise RetrievalStorageError(f"Missing indexed content for chunk: {chunk_id}") | |
| source_id: Any = metadata.get("source_id") | |
| source_name: Any = metadata.get("source_name") | |
| if not isinstance(source_id, str) or not source_id.strip(): | |
| raise RetrievalValidationError( | |
| f"Indexed metadata missing valid source_id for chunk: {chunk_id}" | |
| ) | |
| if not isinstance(source_name, str) or not source_name.strip(): | |
| raise RetrievalValidationError( | |
| f"Indexed metadata missing valid source_name for chunk: {chunk_id}" | |
| ) | |
| candidates.append( | |
| { | |
| "chunk_id": chunk_id, | |
| "source_id": source_id.strip(), | |
| "source_name": source_name.strip(), | |
| "text": text, | |
| "loc": _parse_loc(metadata.get("location_hints")), | |
| "bm25_score": bm25_normalized.get(chunk_id, 0.0), | |
| "vector_score": vector_normalized.get(chunk_id, 0.0), | |
| } | |
| ) | |
| ranked_results: list[RetrievalResult] = [] | |
| for candidate in candidates: | |
| combined_score: float = (candidate["bm25_score"] + candidate["vector_score"]) / 2.0 | |
| ranked_results.append( | |
| { | |
| "chunk_id": candidate["chunk_id"], | |
| "source_id": candidate["source_id"], | |
| "source_name": candidate["source_name"], | |
| "text": candidate["text"], | |
| "score": combined_score, | |
| "loc": candidate["loc"], | |
| } | |
| ) | |
| ranked_results.sort(key=lambda item: (-item["score"], item["chunk_id"])) | |
| if rag_mode == "Fast": | |
| result: list[RetrievalResult] = ranked_results[:k] | |
| else: | |
| # Rerank only top-N candidates to control latency (default: 10) | |
| _rerank_n: int = int(os.getenv("NOTEBOOKLM_RERANK_TOP_N", "10")) | |
| rerank_pool: list[RetrievalResult] = ranked_results[:_rerank_n] | |
| result: list[RetrievalResult] = _rerank(query, rerank_pool, k) | |
| _log_retrieval(username, notebook_id, "success", started_at) | |
| return result | |
| except Exception: | |
| _log_retrieval(username, notebook_id, "error", started_at) | |
| raise | |