Spaces:
Paused
Paused
| import os | |
| import re | |
| import numpy as np | |
| from pathlib import Path | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| import pickle | |
| KNOWLEDGE_BASE_PATH = Path(__file__).parent / "data" / "knowledge_base.md" | |
| INDEX_CACHE_PATH = Path(__file__).parent / "data" / "faiss_index.pkl" | |
| EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2" | |
| CHUNK_SIZE = 300 # characters per chunk | |
| CHUNK_OVERLAP = 50 # overlap between chunks | |
| TOP_K = 4 # number of chunks to retrieve | |
| def _chunk_text(text: str, size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> list[str]: | |
| """Split text into overlapping chunks, preserving paragraph boundaries where possible.""" | |
| paragraphs = [p.strip() for p in re.split(r"\n\n+", text) if p.strip()] | |
| chunks = [] | |
| current = "" | |
| for para in paragraphs: | |
| if len(current) + len(para) <= size: | |
| current = current + "\n\n" + para if current else para | |
| else: | |
| if current: | |
| chunks.append(current.strip()) | |
| # If single paragraph is longer than chunk size, split by sentences | |
| if len(para) > size: | |
| sentences = re.split(r"(?<=[.!?])\s+", para) | |
| buf = "" | |
| for sent in sentences: | |
| if len(buf) + len(sent) <= size: | |
| buf = buf + " " + sent if buf else sent | |
| else: | |
| if buf: | |
| chunks.append(buf.strip()) | |
| buf = sent | |
| if buf: | |
| current = buf | |
| else: | |
| current = para | |
| if current: | |
| chunks.append(current.strip()) | |
| return [c for c in chunks if len(c) > 20] | |
| class RAGPipeline: | |
| def __init__(self): | |
| print("[RAG] Loading embedding model...") | |
| self.model = SentenceTransformer(EMBED_MODEL) | |
| if INDEX_CACHE_PATH.exists(): | |
| print("[RAG] Loading cached FAISS index...") | |
| self._load_index() | |
| else: | |
| print("[RAG] Building FAISS index from knowledge base...") | |
| self._build_index() | |
| def _build_index(self): | |
| text = KNOWLEDGE_BASE_PATH.read_text(encoding="utf-8") | |
| self.chunks = _chunk_text(text) | |
| print(f"[RAG] Indexed {len(self.chunks)} chunks") | |
| embeddings = self.model.encode(self.chunks, show_progress_bar=False) | |
| embeddings = np.array(embeddings, dtype="float32") | |
| faiss.normalize_L2(embeddings) | |
| dim = embeddings.shape[1] | |
| self.index = faiss.IndexFlatIP(dim) # inner-product = cosine after L2 norm | |
| self.index.add(embeddings) | |
| # Cache index + chunks | |
| INDEX_CACHE_PATH.parent.mkdir(parents=True, exist_ok=True) | |
| with open(INDEX_CACHE_PATH, "wb") as f: | |
| pickle.dump({"chunks": self.chunks, "index": faiss.serialize_index(self.index)}, f) | |
| def _load_index(self): | |
| with open(INDEX_CACHE_PATH, "rb") as f: | |
| data = pickle.load(f) | |
| self.chunks = data["chunks"] | |
| self.index = faiss.deserialize_index(data["index"]) | |
| def retrieve(self, query: str, top_k: int = TOP_K) -> str: | |
| query_emb = self.model.encode([query], show_progress_bar=False) | |
| query_emb = np.array(query_emb, dtype="float32") | |
| faiss.normalize_L2(query_emb) | |
| scores, indices = self.index.search(query_emb, top_k) | |
| results = [self.chunks[i] for i in indices[0] if i >= 0] | |
| return "\n\n---\n\n".join(results) |