Spaces:
Sleeping
Sleeping
| """ | |
| src/retrieval/vector_store.py | |
| ------------------------------ | |
| Extended ChromaStore with metadata filtering for scientific papers. | |
| """ | |
| from pathlib import Path | |
| from typing import Optional | |
| class ScientificChromaStore: | |
| def __init__(self, persist_dir="data/chroma_scientific", collection="scientific_corpus"): | |
| import chromadb | |
| Path(persist_dir).mkdir(parents=True, exist_ok=True) | |
| self.client = chromadb.PersistentClient(path=persist_dir) | |
| self.col = self.client.get_or_create_collection( | |
| name=collection, metadata={"hnsw:space": "cosine"} | |
| ) | |
| def index_chunks(self, chunks, embeddings): | |
| ids = [f"chunk_{i}" for i in range(len(chunks))] | |
| documents = [c.text for c in chunks] | |
| metadatas = [c.metadata for c in chunks] | |
| emb_list = embeddings.tolist() if hasattr(embeddings, "tolist") else embeddings | |
| batch_size = 500 | |
| for start in range(0, len(ids), batch_size): | |
| end = start + batch_size | |
| self.col.upsert( | |
| ids=ids[start:end], documents=documents[start:end], | |
| embeddings=emb_list[start:end], metadatas=metadatas[start:end], | |
| ) | |
| print(f"✅ Indexed {len(ids)} chunks. Total: {self.col.count()}") | |
| def query(self, query_embedding, k=5, where=None): | |
| kwargs = dict(query_embeddings=[query_embedding], n_results=min(k, self.col.count())) | |
| if where: | |
| kwargs["where"] = where | |
| res = self.col.query(**kwargs) | |
| return {"documents": res["documents"][0], "metadatas": res["metadatas"][0], "distances": res["distances"][0]} | |
| def query_with_metadata_filter(self, query_embedding, k=5, year_min=None, sections=None, paper_ids=None): | |
| where_clauses = [] | |
| if year_min: | |
| where_clauses.append({"year": {"$gte": year_min}}) | |
| if sections: | |
| where_clauses.append({"section": {"$in": sections}}) | |
| if paper_ids: | |
| where_clauses.append({"paper_id": {"$in": paper_ids}}) | |
| where = None | |
| if len(where_clauses) == 1: | |
| where = where_clauses[0] | |
| elif len(where_clauses) > 1: | |
| where = {"$and": where_clauses} | |
| return self.query(query_embedding, k=k, where=where) | |
| def count(self): | |
| return self.col.count() | |
| def list_papers(self): | |
| all_meta = self.col.get(include=["metadatas"])["metadatas"] | |
| seen = set() | |
| papers = [] | |
| for m in all_meta: | |
| pid = m.get("paper_id", "") | |
| if pid not in seen: | |
| seen.add(pid) | |
| papers.append({"paper_id": pid, "title": m.get("title", ""), | |
| "authors_str": m.get("authors_str", ""), "year": m.get("year", "")}) | |
| return sorted(papers, key=lambda x: x["year"], reverse=True) | |
| ChromaStore = ScientificChromaStore | |