from __future__ import annotations import logging from dataclasses import dataclass from typing import List, Optional from langchain_core.documents import Document from langchain_huggingface import HuggingFaceEmbeddings from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_community.vectorstores import FAISS logger = logging.getLogger(__name__) @dataclass class RAGConfig: chunk_size: int = 1000 chunk_overlap: int = 100 embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2" embedding_device: str = "cpu" top_k: int = 4 normalize_embeddings: bool = True @dataclass class RetrievedContext: query: str chunks: List[Document] combined_text: str @property def source_files(self) -> List[str]: seen, out = set(), [] for doc in self.chunks: s = doc.metadata.get("source", "unknown") if s not in seen: seen.add(s) out.append(s) return out @property def chunk_count(self) -> int: return len(self.chunks) class RAGEngine: def __init__(self, config: Optional[RAGConfig] = None) -> None: self.config = config or RAGConfig() self._vector_store: Optional[FAISS] = None self._embeddings: Optional[HuggingFaceEmbeddings] = None self._splitter = RecursiveCharacterTextSplitter( chunk_size=self.config.chunk_size, chunk_overlap=self.config.chunk_overlap, separators=["\n\n", "\n", ". ", " ", ""], add_start_index=True, ) def build_index(self, documents: List[Document]) -> int: if not documents: raise ValueError("Cannot build index from an empty document list.") chunks = self._chunk(documents) self._vector_store = FAISS.from_documents(chunks, self._get_embeddings()) return len(chunks) def retrieve(self, query: str, top_k: Optional[int] = None) -> RetrievedContext: if self._vector_store is None: raise RuntimeError("Call build_index() before retrieve().") docs = self._vector_store.similarity_search(query, k=top_k or self.config.top_k) return RetrievedContext(query=query, chunks=docs, combined_text=self._format_context(docs)) def add_documents(self, documents: List[Document]) -> int: if self._vector_store is None: return self.build_index(documents) chunks = self._chunk(documents) self._vector_store.add_documents(chunks, embedding=self._get_embeddings()) return len(chunks) def reset(self) -> None: self._vector_store = None def _chunk(self, documents: List[Document]) -> List[Document]: chunks = self._splitter.split_documents(documents) for idx, chunk in enumerate(chunks): chunk.metadata["chunk_id"] = idx return chunks def _get_embeddings(self) -> HuggingFaceEmbeddings: if self._embeddings is None: self._embeddings = HuggingFaceEmbeddings( model_name=self.config.embedding_model, model_kwargs={"device": self.config.embedding_device}, encode_kwargs={"normalize_embeddings": self.config.normalize_embeddings}, ) return self._embeddings @staticmethod def _format_context(docs: List[Document]) -> str: separator = "─" * 60 parts = [] for i, doc in enumerate(docs, start=1): meta = doc.metadata header = ( f"[CHUNK {i} | source: {meta.get('source', 'unknown')}, " f"page: {meta.get('page', 'N/A')}, " f"chunk_id: {meta.get('chunk_id', 'N/A')}]" ) parts.append(f"{header}\n{doc.page_content.strip()}") return f"\n{separator}\n".join(parts)