Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from dataclasses import asdict, dataclass | |
| import chromadb | |
| from .config import Settings | |
| from .pdf_processing import ChunkRecord | |
| class RetrievedChunk: | |
| chunk_id: str | |
| text: str | |
| chapter_number: int | |
| chapter_name: str | |
| topic: str | |
| page_number: int | |
| source_file: str | |
| distance: float | |
| class ChromaMathStore: | |
| def __init__(self, settings: Settings) -> None: | |
| self.settings = settings | |
| self.client = chromadb.PersistentClient(path=str(settings.chroma_dir)) | |
| self.collection = self.client.get_or_create_collection( | |
| name=settings.collection_name, | |
| metadata={"description": "Class 12 maths chapter-aware RAG store"}, | |
| ) | |
| def reset(self) -> None: | |
| try: | |
| self.client.delete_collection(self.settings.collection_name) | |
| except Exception: | |
| pass | |
| self.collection = self.client.get_or_create_collection( | |
| name=self.settings.collection_name, | |
| metadata={"description": "Class 12 maths chapter-aware RAG store"}, | |
| ) | |
| def add_chunks(self, chunks: list[ChunkRecord], embeddings: list[list[float]]) -> None: | |
| self.collection.add( | |
| ids=[chunk.chunk_id for chunk in chunks], | |
| documents=[chunk.text for chunk in chunks], | |
| embeddings=embeddings, | |
| metadatas=[ | |
| { | |
| "chapter_number": chunk.chapter_number, | |
| "chapter_name": chunk.chapter_name, | |
| "topic": chunk.topic, | |
| "page_number": chunk.page_number, | |
| "source_file": chunk.source_file, | |
| } | |
| for chunk in chunks | |
| ], | |
| ) | |
| def count(self) -> int: | |
| return self.collection.count() | |
| def query(self, query_embedding: list[float], top_k: int) -> list[RetrievedChunk]: | |
| result = self.collection.query( | |
| query_embeddings=[query_embedding], | |
| n_results=top_k, | |
| ) | |
| documents = result.get("documents", [[]])[0] | |
| ids = result.get("ids", [[]])[0] | |
| metadatas = result.get("metadatas", [[]])[0] | |
| distances = result.get("distances", [[]])[0] | |
| retrieved: list[RetrievedChunk] = [] | |
| for chunk_id, text, metadata, distance in zip(ids, documents, metadatas, distances): | |
| retrieved.append( | |
| RetrievedChunk( | |
| chunk_id=chunk_id, | |
| text=text, | |
| chapter_number=int(metadata["chapter_number"]), | |
| chapter_name=str(metadata["chapter_name"]), | |
| topic=str(metadata["topic"]), | |
| page_number=int(metadata["page_number"]), | |
| source_file=str(metadata["source_file"]), | |
| distance=float(distance), | |
| ) | |
| ) | |
| return retrieved | |