File size: 5,719 Bytes
817d4c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f051f2e
817d4c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f051f2e
 
1e3e62c
817d4c6
 
 
 
 
 
1e3e62c
817d4c6
 
f051f2e
817d4c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f051f2e
817d4c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
"""
Session-level RAG with graceful FAISS fallback.

- If FAISS is installed, uses a FAISS L2 index over normalized embeddings.
- If FAISS is missing, falls back to pure NumPy cosine similarity.
- Designed to work with extract_text_from_files(...) outputs:
    * list[str]
    * list[dict] with keys like "text" or "content"
"""

from __future__ import annotations

import logging
import hashlib
from typing import Iterable, List, Optional, Tuple

import numpy as np
from sentence_transformers import SentenceTransformer

# ----- Optional FAISS -----
try:
    import faiss  # type: ignore
    _HAS_FAISS = True
except Exception:
    logging.warning(
        "FAISS not installed — session RAG will use a NumPy cosine-similarity fallback. "
        "Install faiss-cpu or faiss-gpu for faster retrieval."
    )
    faiss = None  # type: ignore
    _HAS_FAISS = False


def _normalize_rows(x: np.ndarray) -> np.ndarray:
    """L2 normalize row vectors; avoids division by zero."""
    norms = np.linalg.norm(x, axis=1, keepdims=True) + 1e-10
    return x / norms


def _hash_text(s: str) -> str:
    return hashlib.sha256(s.encode("utf-8")).hexdigest()


def _coerce_texts(items: Iterable) -> List[str]:
    """Accept str or dict items, pull text safely, drop empties, dedupe by hash."""
    out: List[str] = []
    seen: set = set()
    for it in items or []:
        if isinstance(it, str):
            txt = it.strip()
        elif isinstance(it, dict):
            txt = (it.get("text") or it.get("content") or "").strip()
        else:
            txt = ""
        if not txt:
            continue
        h = _hash_text(txt)
        if h in seen:
            continue
        seen.add(h)
        out.append(txt)
    return out


def _simple_chunk(text: str, max_chars: int = 1200, overlap: int = 150) -> List[str]:
    """Lightweight char-based chunking to improve recall on long docs."""
    if len(text) <= max_chars:
        return [text]
    chunks = []
    i = 0
    while i < len(text):
        chunk = text[i : i + max_chars]
        chunks.append(chunk)
        i += max_chars - overlap
    return chunks


class SessionRAG:
    """
    Ephemeral per-session retriever.

    Methods:
      - add_docs(items): add strings or dicts({"text"/"content": ...})
      - retrieve(query, k=5): returns list[str] of top-k chunks
      - clear(): drop index & memory
    """

    def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
        self.model = SentenceTransformer(model_name)
        self.texts: List[str] = []
        self.embeddings: Optional[np.ndarray] = None  # shape: (N, D)
        self.index = None  # FAISS index if available
        self.dim: Optional[int] = None

    # ---------- Private helpers ----------
    def _fit_faiss(self) -> None:
        if not _HAS_FAISS or self.embeddings is None:
            return
        # Use inner product on normalized vectors (cosine similarity)
        emb = _normalize_rows(self.embeddings.astype("float32"))
        self.dim = emb.shape[1]
        # Build IP index
        self.index = faiss.IndexFlatIP(self.dim)
        self.index.add(emb)

    def _ensure_embeddings(self) -> None:
        if not self.texts:
            self.embeddings = None
            self.index = None
            return
        # Compute embeddings
        embs = self.model.encode(self.texts, batch_size=64, show_progress_bar=False)
        self.embeddings = np.asarray(embs, dtype="float32")
        # Build FAISS if available
        if _HAS_FAISS:
            self._fit_faiss()
        else:
            self.index = None

    # ---------- Public API ----------
    def add_docs(self, items: Iterable) -> int:
        """
        Add a batch of texts or dicts with 'text'/'content'.
        Applies basic chunking and deduplication.
        Returns the number of chunks added.
        """
        raw_texts = _coerce_texts(items)
        if not raw_texts:
            return 0

        # Chunk each long text into manageable pieces
        chunks: List[str] = []
        for t in raw_texts:
            chunks.extend(_simple_chunk(t))

        # Deduplicate vs existing memory
        existing_hashes = { _hash_text(t) for t in self.texts }
        added = 0
        for c in chunks:
            h = _hash_text(c)
            if h in existing_hashes:
                continue
            self.texts.append(c)
            existing_hashes.add(h)
            added += 1

        # Recompute embeddings/index
        if added > 0:
            self._ensure_embeddings()

        return added

    def retrieve(self, query: str, k: int = 5) -> List[str]:
        """Return up to k most similar chunks for the query."""
        if not query or not self.texts:
            return []

        # Encode query, normalize
        q_emb = self.model.encode([query], show_progress_bar=False)
        q = _normalize_rows(np.asarray(q_emb, dtype="float32"))

        if self.embeddings is None:
            return []

        # FAISS path (inner product on normalized vectors)
        if _HAS_FAISS and self.index is not None:
            D, I = self.index.search(q, min(k, len(self.texts)))
            idxs = [i for i in I[0] if 0 <= i < len(self.texts)]
            return [self.texts[i] for i in idxs]

        # NumPy fallback: cosine similarity via dot product on normalized vectors
        docs = _normalize_rows(self.embeddings)
        sims = (q @ docs.T)[0]  # shape: (N,)
        top_idx = np.argsort(-sims)[: min(k, len(self.texts))]
        return [self.texts[i] for i in top_idx]

    def clear(self) -> None:
        """Drop all in-memory data for this session."""
        self.texts = []
        self.embeddings = None
        self.index = None
        self.dim = None