""" src/retrieval/vector_store.py ------------------------------ Extended ChromaStore with metadata filtering for scientific papers. """ from pathlib import Path from typing import Optional class ScientificChromaStore: def __init__(self, persist_dir="data/chroma_scientific", collection="scientific_corpus"): import chromadb Path(persist_dir).mkdir(parents=True, exist_ok=True) self.client = chromadb.PersistentClient(path=persist_dir) self.col = self.client.get_or_create_collection( name=collection, metadata={"hnsw:space": "cosine"} ) def index_chunks(self, chunks, embeddings): ids = [f"chunk_{i}" for i in range(len(chunks))] documents = [c.text for c in chunks] metadatas = [c.metadata for c in chunks] emb_list = embeddings.tolist() if hasattr(embeddings, "tolist") else embeddings batch_size = 500 for start in range(0, len(ids), batch_size): end = start + batch_size self.col.upsert( ids=ids[start:end], documents=documents[start:end], embeddings=emb_list[start:end], metadatas=metadatas[start:end], ) print(f"✅ Indexed {len(ids)} chunks. Total: {self.col.count()}") def query(self, query_embedding, k=5, where=None): kwargs = dict(query_embeddings=[query_embedding], n_results=min(k, self.col.count())) if where: kwargs["where"] = where res = self.col.query(**kwargs) return {"documents": res["documents"][0], "metadatas": res["metadatas"][0], "distances": res["distances"][0]} def query_with_metadata_filter(self, query_embedding, k=5, year_min=None, sections=None, paper_ids=None): where_clauses = [] if year_min: where_clauses.append({"year": {"$gte": year_min}}) if sections: where_clauses.append({"section": {"$in": sections}}) if paper_ids: where_clauses.append({"paper_id": {"$in": paper_ids}}) where = None if len(where_clauses) == 1: where = where_clauses[0] elif len(where_clauses) > 1: where = {"$and": where_clauses} return self.query(query_embedding, k=k, where=where) def count(self): return self.col.count() def list_papers(self): all_meta = self.col.get(include=["metadatas"])["metadatas"] seen = set() papers = [] for m in all_meta: pid = m.get("paper_id", "") if pid not in seen: seen.add(pid) papers.append({"paper_id": pid, "title": m.get("title", ""), "authors_str": m.get("authors_str", ""), "year": m.get("year", "")}) return sorted(papers, key=lambda x: x["year"], reverse=True) ChromaStore = ScientificChromaStore