MyChatbot / app /rag.py
Parsa2025AI's picture
rag system
8d3b124 verified
import os
import re
import numpy as np
from pathlib import Path
from sentence_transformers import SentenceTransformer
import faiss
import pickle
KNOWLEDGE_BASE_PATH = Path(__file__).parent / "data" / "knowledge_base.md"
INDEX_CACHE_PATH = Path(__file__).parent / "data" / "faiss_index.pkl"
EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
CHUNK_SIZE = 300 # characters per chunk
CHUNK_OVERLAP = 50 # overlap between chunks
TOP_K = 4 # number of chunks to retrieve
def _chunk_text(text: str, size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> list[str]:
"""Split text into overlapping chunks, preserving paragraph boundaries where possible."""
paragraphs = [p.strip() for p in re.split(r"\n\n+", text) if p.strip()]
chunks = []
current = ""
for para in paragraphs:
if len(current) + len(para) <= size:
current = current + "\n\n" + para if current else para
else:
if current:
chunks.append(current.strip())
# If single paragraph is longer than chunk size, split by sentences
if len(para) > size:
sentences = re.split(r"(?<=[.!?])\s+", para)
buf = ""
for sent in sentences:
if len(buf) + len(sent) <= size:
buf = buf + " " + sent if buf else sent
else:
if buf:
chunks.append(buf.strip())
buf = sent
if buf:
current = buf
else:
current = para
if current:
chunks.append(current.strip())
return [c for c in chunks if len(c) > 20]
class RAGPipeline:
def __init__(self):
print("[RAG] Loading embedding model...")
self.model = SentenceTransformer(EMBED_MODEL)
if INDEX_CACHE_PATH.exists():
print("[RAG] Loading cached FAISS index...")
self._load_index()
else:
print("[RAG] Building FAISS index from knowledge base...")
self._build_index()
def _build_index(self):
text = KNOWLEDGE_BASE_PATH.read_text(encoding="utf-8")
self.chunks = _chunk_text(text)
print(f"[RAG] Indexed {len(self.chunks)} chunks")
embeddings = self.model.encode(self.chunks, show_progress_bar=False)
embeddings = np.array(embeddings, dtype="float32")
faiss.normalize_L2(embeddings)
dim = embeddings.shape[1]
self.index = faiss.IndexFlatIP(dim) # inner-product = cosine after L2 norm
self.index.add(embeddings)
# Cache index + chunks
INDEX_CACHE_PATH.parent.mkdir(parents=True, exist_ok=True)
with open(INDEX_CACHE_PATH, "wb") as f:
pickle.dump({"chunks": self.chunks, "index": faiss.serialize_index(self.index)}, f)
def _load_index(self):
with open(INDEX_CACHE_PATH, "rb") as f:
data = pickle.load(f)
self.chunks = data["chunks"]
self.index = faiss.deserialize_index(data["index"])
def retrieve(self, query: str, top_k: int = TOP_K) -> str:
query_emb = self.model.encode([query], show_progress_bar=False)
query_emb = np.array(query_emb, dtype="float32")
faiss.normalize_L2(query_emb)
scores, indices = self.index.search(query_emb, top_k)
results = [self.chunks[i] for i in indices[0] if i >= 0]
return "\n\n---\n\n".join(results)