from __future__ import annotations from dataclasses import asdict, dataclass import chromadb from .config import Settings from .pdf_processing import ChunkRecord @dataclass class RetrievedChunk: chunk_id: str text: str chapter_number: int chapter_name: str topic: str page_number: int source_file: str distance: float class ChromaMathStore: def __init__(self, settings: Settings) -> None: self.settings = settings self.client = chromadb.PersistentClient(path=str(settings.chroma_dir)) self.collection = self.client.get_or_create_collection( name=settings.collection_name, metadata={"description": "Class 12 maths chapter-aware RAG store"}, ) def reset(self) -> None: try: self.client.delete_collection(self.settings.collection_name) except Exception: pass self.collection = self.client.get_or_create_collection( name=self.settings.collection_name, metadata={"description": "Class 12 maths chapter-aware RAG store"}, ) def add_chunks(self, chunks: list[ChunkRecord], embeddings: list[list[float]]) -> None: self.collection.add( ids=[chunk.chunk_id for chunk in chunks], documents=[chunk.text for chunk in chunks], embeddings=embeddings, metadatas=[ { "chapter_number": chunk.chapter_number, "chapter_name": chunk.chapter_name, "topic": chunk.topic, "page_number": chunk.page_number, "source_file": chunk.source_file, } for chunk in chunks ], ) def count(self) -> int: return self.collection.count() def query(self, query_embedding: list[float], top_k: int) -> list[RetrievedChunk]: result = self.collection.query( query_embeddings=[query_embedding], n_results=top_k, ) documents = result.get("documents", [[]])[0] ids = result.get("ids", [[]])[0] metadatas = result.get("metadatas", [[]])[0] distances = result.get("distances", [[]])[0] retrieved: list[RetrievedChunk] = [] for chunk_id, text, metadata, distance in zip(ids, documents, metadatas, distances): retrieved.append( RetrievedChunk( chunk_id=chunk_id, text=text, chapter_number=int(metadata["chapter_number"]), chapter_name=str(metadata["chapter_name"]), topic=str(metadata["topic"]), page_number=int(metadata["page_number"]), source_file=str(metadata["source_file"]), distance=float(distance), ) ) return retrieved