genAI-Project / src /retrieval /vector_store.py
OGB2000's picture
Initial clean deployment
bf77be6
Raw
History Blame Contribute Delete
2.86 kB
"""
src/retrieval/vector_store.py
------------------------------
Extended ChromaStore with metadata filtering for scientific papers.
"""
from pathlib import Path
from typing import Optional
class ScientificChromaStore:
def __init__(self, persist_dir="data/chroma_scientific", collection="scientific_corpus"):
import chromadb
Path(persist_dir).mkdir(parents=True, exist_ok=True)
self.client = chromadb.PersistentClient(path=persist_dir)
self.col = self.client.get_or_create_collection(
name=collection, metadata={"hnsw:space": "cosine"}
)
def index_chunks(self, chunks, embeddings):
ids = [f"chunk_{i}" for i in range(len(chunks))]
documents = [c.text for c in chunks]
metadatas = [c.metadata for c in chunks]
emb_list = embeddings.tolist() if hasattr(embeddings, "tolist") else embeddings
batch_size = 500
for start in range(0, len(ids), batch_size):
end = start + batch_size
self.col.upsert(
ids=ids[start:end], documents=documents[start:end],
embeddings=emb_list[start:end], metadatas=metadatas[start:end],
)
print(f"✅ Indexed {len(ids)} chunks. Total: {self.col.count()}")
def query(self, query_embedding, k=5, where=None):
kwargs = dict(query_embeddings=[query_embedding], n_results=min(k, self.col.count()))
if where:
kwargs["where"] = where
res = self.col.query(**kwargs)
return {"documents": res["documents"][0], "metadatas": res["metadatas"][0], "distances": res["distances"][0]}
def query_with_metadata_filter(self, query_embedding, k=5, year_min=None, sections=None, paper_ids=None):
where_clauses = []
if year_min:
where_clauses.append({"year": {"$gte": year_min}})
if sections:
where_clauses.append({"section": {"$in": sections}})
if paper_ids:
where_clauses.append({"paper_id": {"$in": paper_ids}})
where = None
if len(where_clauses) == 1:
where = where_clauses[0]
elif len(where_clauses) > 1:
where = {"$and": where_clauses}
return self.query(query_embedding, k=k, where=where)
def count(self):
return self.col.count()
def list_papers(self):
all_meta = self.col.get(include=["metadatas"])["metadatas"]
seen = set()
papers = []
for m in all_meta:
pid = m.get("paper_id", "")
if pid not in seen:
seen.add(pid)
papers.append({"paper_id": pid, "title": m.get("title", ""),
"authors_str": m.get("authors_str", ""), "year": m.get("year", "")})
return sorted(papers, key=lambda x: x["year"], reverse=True)
ChromaStore = ScientificChromaStore