from __future__ import annotations from typing import List, Dict, Any, Optional, Tuple import chromadb class ChromaAdapter: """Wrapper around Chroma client for per-user-notebook collections. Uses a collection named `_` for isolation. """ def __init__(self, persist_directory: Optional[str] = None): # Use new Chroma API (1.5+): EphemeralClient or PersistentClient if persist_directory: self.client = chromadb.PersistentClient(path=persist_directory) else: self.client = chromadb.EphemeralClient() def _collection_name(self, user_id: str, notebook_id: str) -> str: return f"{user_id}_{notebook_id}" def get_or_create_collection(self, user_id: str, notebook_id: str): name = self._collection_name(user_id, notebook_id) try: return self.client.get_collection(name) except Exception: return self.client.create_collection(name) def upsert_chunks( self, user_id: str, notebook_id: str, chunks: List[Dict[str, Any]], embeddings: Optional[List[List[float]]] = None, ): col = self.get_or_create_collection(user_id, notebook_id) ids = [c["chunk_id"] for c in chunks] documents = [c.get("text", "") for c in chunks] metadatas = [ { "user_id": user_id, "notebook_id": notebook_id, "source_id": str(c.get("source_id", "")), "source_title": str(c.get("source_title", "")), "chunk_index": int(c.get("chunk_index", -1)), "page": int(c.get("page", -1)) if c.get("page") is not None else -1, "char_start": int(c.get("char_start", -1)) if c.get("char_start") is not None else -1, "char_end": int(c.get("char_end", -1)) if c.get("char_end") is not None else -1, "text_preview": str(c.get("text_preview", "")), } for c in chunks ] if embeddings is None: col.upsert(ids=ids, metadatas=metadatas, documents=documents) else: col.upsert(ids=ids, embeddings=embeddings, metadatas=metadatas, documents=documents) def delete_source(self, user_id: str, notebook_id: str, source_id: str) -> None: col = self.get_or_create_collection(user_id, notebook_id) col.delete(where={"source_id": str(source_id)}) def query( self, user_id: str, notebook_id: str, query_text: str, top_k: int = 5, source_id: Optional[str] = None, query_embedding: Optional[List[float]] = None, ) -> List[Tuple[str, float, Dict[str, Any]]]: col = self.get_or_create_collection(user_id, notebook_id) where = {"source_id": str(source_id)} if source_id is not None else None if query_embedding is None: res = col.query(query_texts=[query_text], n_results=top_k, where=where) else: res = col.query(query_embeddings=[query_embedding], n_results=top_k, where=where) ids = res.get("ids", [[]])[0] if res.get("ids") else [] documents = res.get("documents", [[]])[0] if res.get("documents") else [] metadatas = res.get("metadatas", [[]])[0] if res.get("metadatas") else [] distances = res.get("distances", [[]])[0] if res.get("distances") else [] rows: List[Tuple[str, float, Dict[str, Any]]] = [] for idx, chunk_id in enumerate(ids): doc = documents[idx] if idx < len(documents) else "" metadata = metadatas[idx] if idx < len(metadatas) and isinstance(metadatas[idx], dict) else {} distance = float(distances[idx]) if idx < len(distances) else 0.0 rows.append( ( str(chunk_id), distance, {"document": doc, "metadata": metadata}, ) ) return rows