File size: 3,509 Bytes
74b04c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import re
import numpy as np
from pathlib import Path
from sentence_transformers import SentenceTransformer
import faiss
import pickle

KNOWLEDGE_BASE_PATH = Path(__file__).parent / "data" / "knowledge_base.md"
INDEX_CACHE_PATH = Path(__file__).parent / "data" / "faiss_index.pkl"

EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
CHUNK_SIZE = 300       # characters per chunk
CHUNK_OVERLAP = 50     # overlap between chunks
TOP_K = 4              # number of chunks to retrieve


def _chunk_text(text: str, size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> list[str]:
    """Split text into overlapping chunks, preserving paragraph boundaries where possible."""
    paragraphs = [p.strip() for p in re.split(r"\n\n+", text) if p.strip()]
    
    chunks = []
    current = ""
    
    for para in paragraphs:
        if len(current) + len(para) <= size:
            current = current + "\n\n" + para if current else para
        else:
            if current:
                chunks.append(current.strip())
            # If single paragraph is longer than chunk size, split by sentences
            if len(para) > size:
                sentences = re.split(r"(?<=[.!?])\s+", para)
                buf = ""
                for sent in sentences:
                    if len(buf) + len(sent) <= size:
                        buf = buf + " " + sent if buf else sent
                    else:
                        if buf:
                            chunks.append(buf.strip())
                        buf = sent
                if buf:
                    current = buf
            else:
                current = para

    if current:
        chunks.append(current.strip())

    return [c for c in chunks if len(c) > 20]


class RAGPipeline:
    def __init__(self):
        print("[RAG] Loading embedding model...")
        self.model = SentenceTransformer(EMBED_MODEL)

        if INDEX_CACHE_PATH.exists():
            print("[RAG] Loading cached FAISS index...")
            self._load_index()
        else:
            print("[RAG] Building FAISS index from knowledge base...")
            self._build_index()

    def _build_index(self):
        text = KNOWLEDGE_BASE_PATH.read_text(encoding="utf-8")
        self.chunks = _chunk_text(text)
        print(f"[RAG] Indexed {len(self.chunks)} chunks")

        embeddings = self.model.encode(self.chunks, show_progress_bar=False)
        embeddings = np.array(embeddings, dtype="float32")
        faiss.normalize_L2(embeddings)

        dim = embeddings.shape[1]
        self.index = faiss.IndexFlatIP(dim)   # inner-product = cosine after L2 norm
        self.index.add(embeddings)

        # Cache index + chunks
        INDEX_CACHE_PATH.parent.mkdir(parents=True, exist_ok=True)
        with open(INDEX_CACHE_PATH, "wb") as f:
            pickle.dump({"chunks": self.chunks, "index": faiss.serialize_index(self.index)}, f)

    def _load_index(self):
        with open(INDEX_CACHE_PATH, "rb") as f:
            data = pickle.load(f)
        self.chunks = data["chunks"]
        self.index = faiss.deserialize_index(data["index"])

    def retrieve(self, query: str, top_k: int = TOP_K) -> str:
        query_emb = self.model.encode([query], show_progress_bar=False)
        query_emb = np.array(query_emb, dtype="float32")
        faiss.normalize_L2(query_emb)

        scores, indices = self.index.search(query_emb, top_k)
        results = [self.chunks[i] for i in indices[0] if i >= 0]

        return "\n\n---\n\n".join(results)