voiceAI / src /modules /rag_engine.py
ahanbose's picture
Update src/modules/rag_engine.py
ea38730 verified
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import List, Optional
from langchain_core.documents import Document
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
logger = logging.getLogger(__name__)
@dataclass
class RAGConfig:
chunk_size: int = 1000
chunk_overlap: int = 100
embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"
embedding_device: str = "cpu"
top_k: int = 4
normalize_embeddings: bool = True
@dataclass
class RetrievedContext:
query: str
chunks: List[Document]
combined_text: str
@property
def source_files(self) -> List[str]:
seen, out = set(), []
for doc in self.chunks:
s = doc.metadata.get("source", "unknown")
if s not in seen:
seen.add(s)
out.append(s)
return out
@property
def chunk_count(self) -> int:
return len(self.chunks)
class RAGEngine:
def __init__(self, config: Optional[RAGConfig] = None) -> None:
self.config = config or RAGConfig()
self._vector_store: Optional[FAISS] = None
self._embeddings: Optional[HuggingFaceEmbeddings] = None
self._splitter = RecursiveCharacterTextSplitter(
chunk_size=self.config.chunk_size,
chunk_overlap=self.config.chunk_overlap,
separators=["\n\n", "\n", ". ", " ", ""],
add_start_index=True,
)
def build_index(self, documents: List[Document]) -> int:
if not documents:
raise ValueError("Cannot build index from an empty document list.")
chunks = self._chunk(documents)
self._vector_store = FAISS.from_documents(chunks, self._get_embeddings())
return len(chunks)
def retrieve(self, query: str, top_k: Optional[int] = None) -> RetrievedContext:
if self._vector_store is None:
raise RuntimeError("Call build_index() before retrieve().")
docs = self._vector_store.similarity_search(query, k=top_k or self.config.top_k)
return RetrievedContext(query=query, chunks=docs, combined_text=self._format_context(docs))
def add_documents(self, documents: List[Document]) -> int:
if self._vector_store is None:
return self.build_index(documents)
chunks = self._chunk(documents)
self._vector_store.add_documents(chunks, embedding=self._get_embeddings())
return len(chunks)
def reset(self) -> None:
self._vector_store = None
def _chunk(self, documents: List[Document]) -> List[Document]:
chunks = self._splitter.split_documents(documents)
for idx, chunk in enumerate(chunks):
chunk.metadata["chunk_id"] = idx
return chunks
def _get_embeddings(self) -> HuggingFaceEmbeddings:
if self._embeddings is None:
self._embeddings = HuggingFaceEmbeddings(
model_name=self.config.embedding_model,
model_kwargs={"device": self.config.embedding_device},
encode_kwargs={"normalize_embeddings": self.config.normalize_embeddings},
)
return self._embeddings
@staticmethod
def _format_context(docs: List[Document]) -> str:
separator = "─" * 60
parts = []
for i, doc in enumerate(docs, start=1):
meta = doc.metadata
header = (
f"[CHUNK {i} | source: {meta.get('source', 'unknown')}, "
f"page: {meta.get('page', 'N/A')}, "
f"chunk_id: {meta.get('chunk_id', 'N/A')}]"
)
parts.append(f"{header}\n{doc.page_content.strip()}")
return f"\n{separator}\n".join(parts)