|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
import os, time, math |
|
|
from typing import List, Dict, Any |
|
|
from dataclasses import dataclass |
|
|
import numpy as np |
|
|
import faiss |
|
|
import tiktoken |
|
|
from openai import OpenAI |
|
|
import re |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def l2_normalize(mat: np.ndarray) -> np.ndarray: |
|
|
"""Row-wise L2 normalize for cosine similarity via inner product.""" |
|
|
norm = np.linalg.norm(mat, axis=1, keepdims=True) + 1e-12 |
|
|
return mat / norm |
|
|
|
|
|
def batch(iterable, n=128): |
|
|
"""Yield lists of size n from an iterable (last one may be shorter).""" |
|
|
buf = [] |
|
|
for x in iterable: |
|
|
buf.append(x) |
|
|
if len(buf) >= n: |
|
|
yield buf |
|
|
buf = [] |
|
|
if buf: |
|
|
yield buf |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Chunk: |
|
|
"""A single chunk of the document, with token offsets for traceability.""" |
|
|
id: int |
|
|
text: str |
|
|
start_token: int |
|
|
end_token: int |
|
|
|
|
|
class OpenAIEmbedRAG: |
|
|
""" |
|
|
Retrieval module using OpenAI Embeddings + FAISS (IP over L2-normalized vectors = cosine). |
|
|
Design notes: |
|
|
- Single-pass tokenization for the whole document (no repeated encode/decode). |
|
|
- Chunk.text is ALWAYS a string (never None) to avoid downstream NoneType errors. |
|
|
- Graceful degradation: empty input => no index; search() returns []. |
|
|
- Optional MMR re-ranking (diversity) via mmr_search(). |
|
|
""" |
|
|
def __init__(self, |
|
|
model: str = "text-embedding-3-small", |
|
|
chunk_size_tokens: int = 800, |
|
|
overlap_tokens: int = 100, |
|
|
batch_size: int = 256, |
|
|
openai_key=None): |
|
|
self.client = OpenAI(api_key=openai_key) |
|
|
self.model = model |
|
|
self.batch_size = batch_size |
|
|
self.enc = tiktoken.get_encoding("cl100k_base") |
|
|
self.chunk_size = max(1, int(chunk_size_tokens)) |
|
|
self.overlap = max(0, int(overlap_tokens)) |
|
|
if self.overlap >= self.chunk_size: |
|
|
|
|
|
self.overlap = max(0, self.chunk_size // 4) |
|
|
|
|
|
self._doc_token_ids: List[int] | None = None |
|
|
self.chunks: List[Chunk] = [] |
|
|
self.index: faiss.IndexFlatIP | None = None |
|
|
self._emb_dim: int | None = None |
|
|
self._emb_matrix: np.ndarray | None = None |
|
|
|
|
|
|
|
|
def _clean_text(self, text: str) -> str: |
|
|
""" |
|
|
Light normalization: |
|
|
- Collapse consecutive whitespace to a single space. |
|
|
- Remove non-printable control chars (keep \n and \t). |
|
|
- Trim leading/trailing spaces. |
|
|
""" |
|
|
text = re.sub(r"\s+", " ", text or "") |
|
|
text = "".join(ch for ch in text if ch.isprintable() or ch in "\n\t") |
|
|
return text.strip() |
|
|
|
|
|
|
|
|
def _tokenize(self, text: str) -> List[int]: |
|
|
return self.enc.encode(text) |
|
|
|
|
|
def _detokenize(self, ids: List[int]) -> str: |
|
|
return self.enc.decode(ids) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def chunk_text(self, text: str) -> List[Chunk]: |
|
|
""" |
|
|
Tokenize once and create overlapping windows of token ids. |
|
|
Each Chunk stores its decoded text and token offsets. |
|
|
""" |
|
|
self._doc_token_ids = self._tokenize(text) |
|
|
total = len(self._doc_token_ids) |
|
|
chunks: List[Chunk] = [] |
|
|
if total == 0: |
|
|
return chunks |
|
|
|
|
|
print(f"[RAG] Total tokens: {total}. Chunk size: {self.chunk_size}, overlap: {self.overlap}") |
|
|
|
|
|
stride = self.chunk_size - self.overlap |
|
|
i, cid = 0, 0 |
|
|
while i < total: |
|
|
j = min(i + self.chunk_size, total) |
|
|
ids_slice = self._doc_token_ids[i:j] |
|
|
txt = self._detokenize(ids_slice) |
|
|
chunks.append(Chunk(id=cid, text=txt, start_token=i, end_token=j)) |
|
|
cid += 1 |
|
|
if j == total: |
|
|
break |
|
|
i += stride |
|
|
return chunks |
|
|
|
|
|
|
|
|
def _embed_texts(self, texts: List[str], max_retries=3) -> np.ndarray: |
|
|
""" |
|
|
Call OpenAI Embeddings with encoding_format='float'. |
|
|
Returns a float32 matrix with rows aligned to input order. |
|
|
""" |
|
|
for attempt in range(max_retries): |
|
|
try: |
|
|
resp = self.client.embeddings.create( |
|
|
model=self.model, |
|
|
input=texts, |
|
|
encoding_format="float", |
|
|
) |
|
|
vecs = [None] * len(resp.data) |
|
|
for item in resp.data: |
|
|
vecs[item.index] = np.array(item.embedding, dtype=np.float32) |
|
|
arr = np.vstack(vecs) |
|
|
if self._emb_dim is None: |
|
|
self._emb_dim = arr.shape[1] |
|
|
return arr |
|
|
except Exception as e: |
|
|
if attempt == max_retries - 1: |
|
|
raise |
|
|
|
|
|
time.sleep(0.8 * (attempt + 1)) |
|
|
|
|
|
|
|
|
def build(self, text: str): |
|
|
""" |
|
|
Clean -> chunk -> embed -> build an IP index on normalized vectors. |
|
|
Graceful if text is empty: index remains None and chunks empty. |
|
|
""" |
|
|
text = self._clean_text(text) |
|
|
self.chunks = self.chunk_text(text) |
|
|
if not self.chunks: |
|
|
self.index = None |
|
|
self._emb_matrix = None |
|
|
return |
|
|
|
|
|
all_vecs = [] |
|
|
|
|
|
for chunk_batch in batch([c.text for c in self.chunks], n=self.batch_size): |
|
|
arr = self._embed_texts(chunk_batch) |
|
|
all_vecs.append(arr) |
|
|
|
|
|
mat = np.vstack(all_vecs).astype(np.float32) |
|
|
mat = l2_normalize(mat) |
|
|
self._emb_matrix = mat |
|
|
|
|
|
self.index = faiss.IndexFlatIP(mat.shape[1]) |
|
|
self.index.add(mat) |
|
|
|
|
|
|
|
|
def search(self, query: str, topk: int = 6) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Return top-k chunks by cosine similarity (via IP on normalized vectors). |
|
|
If the index hasn't been built or the doc is empty, returns []. |
|
|
""" |
|
|
if not self.index or not self.chunks: |
|
|
return [] |
|
|
|
|
|
q = self._clean_text(query) |
|
|
if not q: |
|
|
return [] |
|
|
|
|
|
qv = self._embed_texts([q]) |
|
|
qv = l2_normalize(qv) |
|
|
D, I = self.index.search(qv.astype(np.float32), max(1, int(topk))) |
|
|
results = [] |
|
|
for rank, idx in enumerate(I[0]): |
|
|
if idx == -1: |
|
|
continue |
|
|
ch = self.chunks[int(idx)] |
|
|
results.append({ |
|
|
"id": ch.id, |
|
|
"score": float(D[0][rank]), |
|
|
"text": ch.text, |
|
|
"start_token": ch.start_token, |
|
|
"end_token": ch.end_token |
|
|
}) |
|
|
return results |
|
|
|
|
|
|
|
|
def mmr_search(self, query: str, topk: int = 6, fetch_k: int | None = None, lambda_mult: float = 0.5) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Maximal Marginal Relevance. |
|
|
- fetch_k: number of initial candidates to consider (defaults to 4*topk). |
|
|
- lambda_mult in [0,1]: 1 emphasizes relevance; 0 emphasizes diversity. |
|
|
""" |
|
|
if self._emb_matrix is None or not self.chunks: |
|
|
return [] |
|
|
|
|
|
q = self._clean_text(query) |
|
|
if not q: |
|
|
return [] |
|
|
|
|
|
qv = l2_normalize(self._embed_texts([q]))[0] |
|
|
|
|
|
rel = self._emb_matrix @ qv |
|
|
|
|
|
N = len(self.chunks) |
|
|
k = max(1, int(topk)) |
|
|
m = min(N, int(fetch_k) if fetch_k else min(N, 4 * k)) |
|
|
|
|
|
|
|
|
cand_idx = np.argpartition(-rel, m-1)[:m] |
|
|
cand_idx = cand_idx[np.argsort(-rel[cand_idx])] |
|
|
|
|
|
selected: List[int] = [] |
|
|
selected_set = set() |
|
|
|
|
|
for _ in range(min(k, m)): |
|
|
if not selected: |
|
|
best = int(cand_idx[0]) |
|
|
selected.append(best) |
|
|
selected_set.add(best) |
|
|
continue |
|
|
|
|
|
|
|
|
S = self._emb_matrix[selected] |
|
|
|
|
|
|
|
|
cand_vecs = self._emb_matrix[cand_idx] |
|
|
sims = cand_vecs @ S.T |
|
|
max_sims = sims.max(axis=1) |
|
|
|
|
|
|
|
|
scores = lambda_mult * rel[cand_idx] - (1 - lambda_mult) * max_sims |
|
|
|
|
|
order = np.argsort(-scores) |
|
|
for j in order: |
|
|
idx_j = int(cand_idx[j]) |
|
|
if idx_j not in selected_set: |
|
|
selected.append(idx_j) |
|
|
selected_set.add(idx_j) |
|
|
break |
|
|
|
|
|
|
|
|
out = [] |
|
|
for idx in selected: |
|
|
ch = self.chunks[idx] |
|
|
out.append({ |
|
|
"id": ch.id, |
|
|
"score": float(rel[idx]), |
|
|
"text": ch.text, |
|
|
"start_token": ch.start_token, |
|
|
"end_token": ch.end_token |
|
|
}) |
|
|
return out |
|
|
|
|
|
|