Custom-LLM-Chat / rag /embedder.py
Bhaskar Ram
fix: model singleton cache, dedup guard, Gradio type=messages
9edd318
"""
embedder.py
Chunks raw text documents and builds an in-memory FAISS vector index.
"""
from __future__ import annotations
import re as _re
import numpy as np
from dataclasses import dataclass, field
from typing import Optional
CHUNK_SIZE = 512 # characters — max chars per chunk
CHUNK_OVERLAP = 64 # characters — approx overlap between consecutive chunks
EMBEDDING_MODEL = "BAAI/bge-small-en-v1.5" # State-of-the-art small retrieval model
# Regex: split on sentence-ending punctuation followed by whitespace + capital letter,
# or on paragraph / line breaks.
_SENTENCE_SPLIT = _re.compile(r'(?<=[.!?])\s+(?=[A-Z])|(?<=\n)\s*\n+')
# ── Model singleton ───────────────────────────────────────────────────────────
# SentenceTransformer takes 5–15s to load from disk. We load it exactly once
# per process and reuse across all build_index / add_to_index calls.
_MODEL: Optional[object] = None
def _get_model():
"""Return the cached SentenceTransformer, loading it on first call only."""
global _MODEL
if _MODEL is None:
from sentence_transformers import SentenceTransformer
_MODEL = SentenceTransformer(EMBEDDING_MODEL)
return _MODEL
# ─────────────────────────────────────────────────────────────────────────────
@dataclass
class VectorIndex:
"""Holds chunks, their embeddings, and the FAISS index."""
chunks: list[dict] = field(default_factory=list) # {"source", "text"}
index: object = None # faiss.IndexFlatIP
embedder: object = None # SentenceTransformer
def _chunk_text(source: str, text: str) -> list[dict]:
"""
Split text into overlapping chunks that respect sentence boundaries.
Instead of slicing at a fixed character offset (which cuts mid-sentence),
we:
1. Split the document into sentences / paragraphs.
2. Greedily accumulate sentences until CHUNK_SIZE is reached.
3. For the next chunk, back up by ~CHUNK_OVERLAP chars worth of sentences
so consecutive chunks share context at their boundaries.
"""
# Normalise excessive whitespace while preserving paragraph breaks
text = _re.sub(r'[ \t]+', ' ', text).strip()
sentences = [s.strip() for s in _SENTENCE_SPLIT.split(text) if s.strip()]
chunks: list[dict] = []
i = 0
while i < len(sentences):
# Accumulate sentences until we hit the size limit
parts: list[str] = []
total = 0
j = i
while j < len(sentences):
slen = len(sentences[j])
if total + slen > CHUNK_SIZE and parts:
break
parts.append(sentences[j])
total += slen + 1 # +1 for the space we'll join with
j += 1
chunk_text = " ".join(parts)
if chunk_text.strip():
chunks.append({"source": source, "text": chunk_text})
if j == i:
# Single sentence longer than CHUNK_SIZE — hard-split it
sent = sentences[i]
for k in range(0, len(sent), CHUNK_SIZE - CHUNK_OVERLAP):
part = sent[k: k + CHUNK_SIZE]
if part.strip():
chunks.append({"source": source, "text": part})
i += 1
continue
# Slide forward, but overlap by backtracking ~CHUNK_OVERLAP chars
overlap_chars = 0
next_i = j
for k in range(j - 1, i, -1):
overlap_chars += len(sentences[k]) + 1
if overlap_chars >= CHUNK_OVERLAP:
next_i = k
break
i = max(i + 1, next_i) # always advance at least one sentence
return chunks
def build_index(docs: list[dict]) -> VectorIndex:
"""
Takes list of {"source", "text"} dicts.
Returns a VectorIndex with embeddings stored in FAISS.
"""
import faiss
model = _get_model() # reuse cached singleton — no reload cost
# Chunk all documents
all_chunks = []
for doc in docs:
all_chunks.extend(_chunk_text(doc["source"], doc["text"]))
if not all_chunks:
raise ValueError("No text chunks could be extracted from the uploaded files.")
print(f"[Embedder] Embedding {len(all_chunks)} chunks...")
texts = [c["text"] for c in all_chunks]
embeddings = model.encode(texts, show_progress_bar=False, batch_size=32)
embeddings = np.array(embeddings, dtype="float32")
dim = embeddings.shape[1]
# Use Inner Product index (cosine similarity after L2 normalisation)
faiss.normalize_L2(embeddings)
index = faiss.IndexFlatIP(dim)
index.add(embeddings)
print(f"[Embedder] Index built: {index.ntotal} vectors, dim={dim}")
return VectorIndex(chunks=all_chunks, index=index, embedder=model)
def add_to_index(vector_index: VectorIndex, docs: list[dict]) -> VectorIndex:
"""Incrementally add new docs to an existing index."""
import faiss
# numpy already imported at module level — no duplicate import needed
new_chunks = []
for doc in docs:
new_chunks.extend(_chunk_text(doc["source"], doc["text"]))
texts = [c["text"] for c in new_chunks]
embeddings = vector_index.embedder.encode(texts, show_progress_bar=False, batch_size=32)
embeddings = np.array(embeddings, dtype="float32")
faiss.normalize_L2(embeddings) # Keep consistent with cosine index
vector_index.index.add(embeddings)
vector_index.chunks.extend(new_chunks)
return vector_index