NotebookLMClone / src /ingestion /vectorstore.py
github-actions[bot]
Sync from GitHub e48aa5f27523b35a22c1a01acbb2b835cdc28984
aacd162
from __future__ import annotations
from typing import List, Dict, Any, Optional, Tuple
import chromadb
class ChromaAdapter:
"""Wrapper around Chroma client for per-user-notebook collections.
Uses a collection named `<user_id>_<notebook_id>` for isolation.
"""
def __init__(self, persist_directory: Optional[str] = None):
# Use new Chroma API (1.5+): EphemeralClient or PersistentClient
if persist_directory:
self.client = chromadb.PersistentClient(path=persist_directory)
else:
self.client = chromadb.EphemeralClient()
def _collection_name(self, user_id: str, notebook_id: str) -> str:
return f"{user_id}_{notebook_id}"
def get_or_create_collection(self, user_id: str, notebook_id: str):
name = self._collection_name(user_id, notebook_id)
try:
return self.client.get_collection(name)
except Exception:
return self.client.create_collection(name)
def upsert_chunks(
self,
user_id: str,
notebook_id: str,
chunks: List[Dict[str, Any]],
embeddings: Optional[List[List[float]]] = None,
):
col = self.get_or_create_collection(user_id, notebook_id)
ids = [c["chunk_id"] for c in chunks]
documents = [c.get("text", "") for c in chunks]
metadatas = [
{
"user_id": user_id,
"notebook_id": notebook_id,
"source_id": str(c.get("source_id", "")),
"source_title": str(c.get("source_title", "")),
"chunk_index": int(c.get("chunk_index", -1)),
"page": int(c.get("page", -1)) if c.get("page") is not None else -1,
"char_start": int(c.get("char_start", -1)) if c.get("char_start") is not None else -1,
"char_end": int(c.get("char_end", -1)) if c.get("char_end") is not None else -1,
"text_preview": str(c.get("text_preview", "")),
}
for c in chunks
]
if embeddings is None:
col.upsert(ids=ids, metadatas=metadatas, documents=documents)
else:
col.upsert(ids=ids, embeddings=embeddings, metadatas=metadatas, documents=documents)
def delete_source(self, user_id: str, notebook_id: str, source_id: str) -> None:
col = self.get_or_create_collection(user_id, notebook_id)
col.delete(where={"source_id": str(source_id)})
def query(
self,
user_id: str,
notebook_id: str,
query_text: str,
top_k: int = 5,
source_id: Optional[str] = None,
query_embedding: Optional[List[float]] = None,
) -> List[Tuple[str, float, Dict[str, Any]]]:
col = self.get_or_create_collection(user_id, notebook_id)
where = {"source_id": str(source_id)} if source_id is not None else None
if query_embedding is None:
res = col.query(query_texts=[query_text], n_results=top_k, where=where)
else:
res = col.query(query_embeddings=[query_embedding], n_results=top_k, where=where)
ids = res.get("ids", [[]])[0] if res.get("ids") else []
documents = res.get("documents", [[]])[0] if res.get("documents") else []
metadatas = res.get("metadatas", [[]])[0] if res.get("metadatas") else []
distances = res.get("distances", [[]])[0] if res.get("distances") else []
rows: List[Tuple[str, float, Dict[str, Any]]] = []
for idx, chunk_id in enumerate(ids):
doc = documents[idx] if idx < len(documents) else ""
metadata = metadatas[idx] if idx < len(metadatas) and isinstance(metadatas[idx], dict) else {}
distance = float(distances[idx]) if idx < len(distances) else 0.0
rows.append(
(
str(chunk_id),
distance,
{"document": doc, "metadata": metadata},
)
)
return rows