import os import re import numpy as np from pathlib import Path from sentence_transformers import SentenceTransformer import faiss import pickle KNOWLEDGE_BASE_PATH = Path(__file__).parent / "data" / "knowledge_base.md" INDEX_CACHE_PATH = Path(__file__).parent / "data" / "faiss_index.pkl" EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2" CHUNK_SIZE = 300 # characters per chunk CHUNK_OVERLAP = 50 # overlap between chunks TOP_K = 4 # number of chunks to retrieve def _chunk_text(text: str, size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> list[str]: """Split text into overlapping chunks, preserving paragraph boundaries where possible.""" paragraphs = [p.strip() for p in re.split(r"\n\n+", text) if p.strip()] chunks = [] current = "" for para in paragraphs: if len(current) + len(para) <= size: current = current + "\n\n" + para if current else para else: if current: chunks.append(current.strip()) # If single paragraph is longer than chunk size, split by sentences if len(para) > size: sentences = re.split(r"(?<=[.!?])\s+", para) buf = "" for sent in sentences: if len(buf) + len(sent) <= size: buf = buf + " " + sent if buf else sent else: if buf: chunks.append(buf.strip()) buf = sent if buf: current = buf else: current = para if current: chunks.append(current.strip()) return [c for c in chunks if len(c) > 20] class RAGPipeline: def __init__(self): print("[RAG] Loading embedding model...") self.model = SentenceTransformer(EMBED_MODEL) if INDEX_CACHE_PATH.exists(): print("[RAG] Loading cached FAISS index...") self._load_index() else: print("[RAG] Building FAISS index from knowledge base...") self._build_index() def _build_index(self): text = KNOWLEDGE_BASE_PATH.read_text(encoding="utf-8") self.chunks = _chunk_text(text) print(f"[RAG] Indexed {len(self.chunks)} chunks") embeddings = self.model.encode(self.chunks, show_progress_bar=False) embeddings = np.array(embeddings, dtype="float32") faiss.normalize_L2(embeddings) dim = embeddings.shape[1] self.index = faiss.IndexFlatIP(dim) # inner-product = cosine after L2 norm self.index.add(embeddings) # Cache index + chunks INDEX_CACHE_PATH.parent.mkdir(parents=True, exist_ok=True) with open(INDEX_CACHE_PATH, "wb") as f: pickle.dump({"chunks": self.chunks, "index": faiss.serialize_index(self.index)}, f) def _load_index(self): with open(INDEX_CACHE_PATH, "rb") as f: data = pickle.load(f) self.chunks = data["chunks"] self.index = faiss.deserialize_index(data["index"]) def retrieve(self, query: str, top_k: int = TOP_K) -> str: query_emb = self.model.encode([query], show_progress_bar=False) query_emb = np.array(query_emb, dtype="float32") faiss.normalize_L2(query_emb) scores, indices = self.index.search(query_emb, top_k) results = [self.chunks[i] for i in indices[0] if i >= 0] return "\n\n---\n\n".join(results)