math-chatbot-v2 / src /edurag_math_bot /vector_store.py
pranshu dhiman
Deploy MathSutra Space
7fab45b
Raw
History Blame Contribute Delete
2.92 kB
from __future__ import annotations
from dataclasses import asdict, dataclass
import chromadb
from .config import Settings
from .pdf_processing import ChunkRecord
@dataclass
class RetrievedChunk:
chunk_id: str
text: str
chapter_number: int
chapter_name: str
topic: str
page_number: int
source_file: str
distance: float
class ChromaMathStore:
def __init__(self, settings: Settings) -> None:
self.settings = settings
self.client = chromadb.PersistentClient(path=str(settings.chroma_dir))
self.collection = self.client.get_or_create_collection(
name=settings.collection_name,
metadata={"description": "Class 12 maths chapter-aware RAG store"},
)
def reset(self) -> None:
try:
self.client.delete_collection(self.settings.collection_name)
except Exception:
pass
self.collection = self.client.get_or_create_collection(
name=self.settings.collection_name,
metadata={"description": "Class 12 maths chapter-aware RAG store"},
)
def add_chunks(self, chunks: list[ChunkRecord], embeddings: list[list[float]]) -> None:
self.collection.add(
ids=[chunk.chunk_id for chunk in chunks],
documents=[chunk.text for chunk in chunks],
embeddings=embeddings,
metadatas=[
{
"chapter_number": chunk.chapter_number,
"chapter_name": chunk.chapter_name,
"topic": chunk.topic,
"page_number": chunk.page_number,
"source_file": chunk.source_file,
}
for chunk in chunks
],
)
def count(self) -> int:
return self.collection.count()
def query(self, query_embedding: list[float], top_k: int) -> list[RetrievedChunk]:
result = self.collection.query(
query_embeddings=[query_embedding],
n_results=top_k,
)
documents = result.get("documents", [[]])[0]
ids = result.get("ids", [[]])[0]
metadatas = result.get("metadatas", [[]])[0]
distances = result.get("distances", [[]])[0]
retrieved: list[RetrievedChunk] = []
for chunk_id, text, metadata, distance in zip(ids, documents, metadatas, distances):
retrieved.append(
RetrievedChunk(
chunk_id=chunk_id,
text=text,
chapter_number=int(metadata["chapter_number"]),
chapter_name=str(metadata["chapter_name"]),
topic=str(metadata["topic"]),
page_number=int(metadata["page_number"]),
source_file=str(metadata["source_file"]),
distance=float(distance),
)
)
return retrieved