| | from __future__ import annotations |
| |
|
| | import logging |
| | from dataclasses import dataclass |
| | from typing import List, Optional |
| |
|
| | from langchain_core.documents import Document |
| | from langchain_huggingface import HuggingFaceEmbeddings |
| | from langchain_text_splitters import RecursiveCharacterTextSplitter |
| | from langchain_community.vectorstores import FAISS |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | @dataclass |
| | class RAGConfig: |
| | chunk_size: int = 1000 |
| | chunk_overlap: int = 100 |
| | embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2" |
| | embedding_device: str = "cpu" |
| | top_k: int = 4 |
| | normalize_embeddings: bool = True |
| |
|
| |
|
| | @dataclass |
| | class RetrievedContext: |
| | query: str |
| | chunks: List[Document] |
| | combined_text: str |
| |
|
| | @property |
| | def source_files(self) -> List[str]: |
| | seen, out = set(), [] |
| | for doc in self.chunks: |
| | s = doc.metadata.get("source", "unknown") |
| | if s not in seen: |
| | seen.add(s) |
| | out.append(s) |
| | return out |
| |
|
| | @property |
| | def chunk_count(self) -> int: |
| | return len(self.chunks) |
| |
|
| |
|
| | class RAGEngine: |
| | def __init__(self, config: Optional[RAGConfig] = None) -> None: |
| | self.config = config or RAGConfig() |
| | self._vector_store: Optional[FAISS] = None |
| | self._embeddings: Optional[HuggingFaceEmbeddings] = None |
| | self._splitter = RecursiveCharacterTextSplitter( |
| | chunk_size=self.config.chunk_size, |
| | chunk_overlap=self.config.chunk_overlap, |
| | separators=["\n\n", "\n", ". ", " ", ""], |
| | add_start_index=True, |
| | ) |
| |
|
| | def build_index(self, documents: List[Document]) -> int: |
| | if not documents: |
| | raise ValueError("Cannot build index from an empty document list.") |
| | chunks = self._chunk(documents) |
| | self._vector_store = FAISS.from_documents(chunks, self._get_embeddings()) |
| | return len(chunks) |
| |
|
| | def retrieve(self, query: str, top_k: Optional[int] = None) -> RetrievedContext: |
| | if self._vector_store is None: |
| | raise RuntimeError("Call build_index() before retrieve().") |
| | docs = self._vector_store.similarity_search(query, k=top_k or self.config.top_k) |
| | return RetrievedContext(query=query, chunks=docs, combined_text=self._format_context(docs)) |
| |
|
| | def add_documents(self, documents: List[Document]) -> int: |
| | if self._vector_store is None: |
| | return self.build_index(documents) |
| | chunks = self._chunk(documents) |
| | self._vector_store.add_documents(chunks, embedding=self._get_embeddings()) |
| | return len(chunks) |
| |
|
| | def reset(self) -> None: |
| | self._vector_store = None |
| |
|
| | def _chunk(self, documents: List[Document]) -> List[Document]: |
| | chunks = self._splitter.split_documents(documents) |
| | for idx, chunk in enumerate(chunks): |
| | chunk.metadata["chunk_id"] = idx |
| | return chunks |
| |
|
| | def _get_embeddings(self) -> HuggingFaceEmbeddings: |
| | if self._embeddings is None: |
| | self._embeddings = HuggingFaceEmbeddings( |
| | model_name=self.config.embedding_model, |
| | model_kwargs={"device": self.config.embedding_device}, |
| | encode_kwargs={"normalize_embeddings": self.config.normalize_embeddings}, |
| | ) |
| | return self._embeddings |
| |
|
| | @staticmethod |
| | def _format_context(docs: List[Document]) -> str: |
| | separator = "─" * 60 |
| | parts = [] |
| | for i, doc in enumerate(docs, start=1): |
| | meta = doc.metadata |
| | header = ( |
| | f"[CHUNK {i} | source: {meta.get('source', 'unknown')}, " |
| | f"page: {meta.get('page', 'N/A')}, " |
| | f"chunk_id: {meta.get('chunk_id', 'N/A')}]" |
| | ) |
| | parts.append(f"{header}\n{doc.page_content.strip()}") |
| | return f"\n{separator}\n".join(parts) |