File size: 5,749 Bytes
55953aa
 
 
 
 
 
9edd318
55953aa
 
9edd318
55953aa
2623b17
 
 
 
 
 
 
55953aa
9edd318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55953aa
 
 
 
 
2623b17
55953aa
 
 
 
2623b17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55953aa
 
2623b17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55953aa
 
 
 
 
 
 
 
 
9edd318
 
55953aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a465955
 
 
55953aa
 
 
 
 
 
 
 
 
9edd318
55953aa
 
 
 
 
 
 
 
a465955
55953aa
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""
embedder.py
Chunks raw text documents and builds an in-memory FAISS vector index.
"""

from __future__ import annotations
import re as _re
import numpy as np
from dataclasses import dataclass, field
from typing import Optional

CHUNK_SIZE = 512        # characters — max chars per chunk
CHUNK_OVERLAP = 64      # characters — approx overlap between consecutive chunks
EMBEDDING_MODEL = "BAAI/bge-small-en-v1.5"  # State-of-the-art small retrieval model

# Regex: split on sentence-ending punctuation followed by whitespace + capital letter,
# or on paragraph / line breaks.
_SENTENCE_SPLIT = _re.compile(r'(?<=[.!?])\s+(?=[A-Z])|(?<=\n)\s*\n+')

# ── Model singleton ───────────────────────────────────────────────────────────
# SentenceTransformer takes 5–15s to load from disk. We load it exactly once
# per process and reuse across all build_index / add_to_index calls.
_MODEL: Optional[object] = None


def _get_model():
    """Return the cached SentenceTransformer, loading it on first call only."""
    global _MODEL
    if _MODEL is None:
        from sentence_transformers import SentenceTransformer
        _MODEL = SentenceTransformer(EMBEDDING_MODEL)
    return _MODEL
# ─────────────────────────────────────────────────────────────────────────────


@dataclass
class VectorIndex:
    """Holds chunks, their embeddings, and the FAISS index."""
    chunks: list[dict] = field(default_factory=list)   # {"source", "text"}
    index: object = None                                # faiss.IndexFlatIP
    embedder: object = None                             # SentenceTransformer


def _chunk_text(source: str, text: str) -> list[dict]:
    """
    Split text into overlapping chunks that respect sentence boundaries.

    Instead of slicing at a fixed character offset (which cuts mid-sentence),
    we:
      1. Split the document into sentences / paragraphs.
      2. Greedily accumulate sentences until CHUNK_SIZE is reached.
      3. For the next chunk, back up by ~CHUNK_OVERLAP chars worth of sentences
         so consecutive chunks share context at their boundaries.
    """
    # Normalise excessive whitespace while preserving paragraph breaks
    text = _re.sub(r'[ \t]+', ' ', text).strip()
    sentences = [s.strip() for s in _SENTENCE_SPLIT.split(text) if s.strip()]

    chunks: list[dict] = []
    i = 0

    while i < len(sentences):
        # Accumulate sentences until we hit the size limit
        parts: list[str] = []
        total = 0
        j = i
        while j < len(sentences):
            slen = len(sentences[j])
            if total + slen > CHUNK_SIZE and parts:
                break
            parts.append(sentences[j])
            total += slen + 1   # +1 for the space we'll join with
            j += 1

        chunk_text = " ".join(parts)
        if chunk_text.strip():
            chunks.append({"source": source, "text": chunk_text})

        if j == i:
            # Single sentence longer than CHUNK_SIZE — hard-split it
            sent = sentences[i]
            for k in range(0, len(sent), CHUNK_SIZE - CHUNK_OVERLAP):
                part = sent[k: k + CHUNK_SIZE]
                if part.strip():
                    chunks.append({"source": source, "text": part})
            i += 1
            continue

        # Slide forward, but overlap by backtracking ~CHUNK_OVERLAP chars
        overlap_chars = 0
        next_i = j
        for k in range(j - 1, i, -1):
            overlap_chars += len(sentences[k]) + 1
            if overlap_chars >= CHUNK_OVERLAP:
                next_i = k
                break

        i = max(i + 1, next_i)   # always advance at least one sentence

    return chunks


def build_index(docs: list[dict]) -> VectorIndex:
    """
    Takes list of {"source", "text"} dicts.
    Returns a VectorIndex with embeddings stored in FAISS.
    """
    import faiss

    model = _get_model()  # reuse cached singleton — no reload cost

    # Chunk all documents
    all_chunks = []
    for doc in docs:
        all_chunks.extend(_chunk_text(doc["source"], doc["text"]))

    if not all_chunks:
        raise ValueError("No text chunks could be extracted from the uploaded files.")

    print(f"[Embedder] Embedding {len(all_chunks)} chunks...")
    texts = [c["text"] for c in all_chunks]
    embeddings = model.encode(texts, show_progress_bar=False, batch_size=32)
    embeddings = np.array(embeddings, dtype="float32")

    dim = embeddings.shape[1]
    # Use Inner Product index (cosine similarity after L2 normalisation)
    faiss.normalize_L2(embeddings)
    index = faiss.IndexFlatIP(dim)
    index.add(embeddings)

    print(f"[Embedder] Index built: {index.ntotal} vectors, dim={dim}")
    return VectorIndex(chunks=all_chunks, index=index, embedder=model)


def add_to_index(vector_index: VectorIndex, docs: list[dict]) -> VectorIndex:
    """Incrementally add new docs to an existing index."""
    import faiss
    # numpy already imported at module level — no duplicate import needed

    new_chunks = []
    for doc in docs:
        new_chunks.extend(_chunk_text(doc["source"], doc["text"]))

    texts = [c["text"] for c in new_chunks]
    embeddings = vector_index.embedder.encode(texts, show_progress_bar=False, batch_size=32)
    embeddings = np.array(embeddings, dtype="float32")
    faiss.normalize_L2(embeddings)  # Keep consistent with cosine index

    vector_index.index.add(embeddings)
    vector_index.chunks.extend(new_chunks)
    return vector_index