| """BGE-small-based retrieval helper for injecting external knowledge into SPChat. |
| |
| Validated end-to-end: BGE-small picks the right chunk (sim ≈ 0.85+ for matched |
| topic, > 0.4 margin over distractors) and SPChat with rp=1.0/nr=0 faithfully |
| copies facts/numbers/names from the injected context. |
| |
| Usage: |
| from runtime.spchat import SPChat |
| from runtime.rag import BGERetriever |
| |
| chat = SPChat() |
| retriever = BGERetriever() |
| retriever.add_documents([...corpus chunks...]) |
| |
| hits = retriever.search("How many channels does Zephyric use?", top_k=2) |
| prompt = retriever.build_rag_prompt(hits, "How many channels does Zephyric use?") |
| state = chat.start_session("Use the Context to answer.") |
| answer = chat.turn(state, prompt, rp=1.0, nr=0) # ← fact-faithful decode |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import List, Tuple |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from transformers import AutoModel, AutoTokenizer |
|
|
|
|
| class BGERetriever: |
| """Thin wrapper around BGE-small (33M params, 384-dim) for cosine retrieval.""" |
|
|
| DEFAULT_MODEL = "BAAI/bge-small-en-v1.5" |
|
|
| def __init__(self, model_id: str = DEFAULT_MODEL, device: str = "cpu"): |
| self.dev = torch.device(device) |
| self.tok = AutoTokenizer.from_pretrained(model_id) |
| self.model = AutoModel.from_pretrained(model_id).to(self.dev).eval() |
| for p in self.model.parameters(): |
| p.requires_grad_(False) |
| self.documents: List[str] = [] |
| self.embeddings: np.ndarray = np.zeros((0, self._dim()), dtype=np.float32) |
|
|
| def _dim(self) -> int: |
| return int(self.model.config.hidden_size) |
|
|
| @torch.no_grad() |
| def _encode(self, texts: List[str], is_query: bool = False) -> np.ndarray: |
| if is_query: |
| texts = [f"Represent this sentence for searching relevant passages: {t}" for t in texts] |
| inputs = self.tok( |
| texts, padding=True, truncation=True, return_tensors="pt", max_length=512, |
| ).to(self.dev) |
| out = self.model(**inputs) |
| emb = out.last_hidden_state[:, 0] |
| emb = F.normalize(emb, p=2, dim=1) |
| return emb.cpu().numpy() |
|
|
| |
| def add_documents(self, texts: List[str]) -> None: |
| new_emb = self._encode(texts, is_query=False) |
| self.documents.extend(texts) |
| if len(self.embeddings) == 0: |
| self.embeddings = new_emb |
| else: |
| self.embeddings = np.concatenate([self.embeddings, new_emb], axis=0) |
|
|
| def search( |
| self, |
| query: str, |
| top_k: int = 2, |
| min_sim: float = 0.0, |
| ) -> List[Tuple[str, float]]: |
| """Return list of (text, cosine_similarity) sorted descending. Optionally |
| cut off below `min_sim` (use as a confidence gate for trigger logic).""" |
| if len(self.documents) == 0: |
| return [] |
| q = self._encode([query], is_query=True)[0] |
| sims = self.embeddings @ q |
| order = np.argsort(-sims) |
| out = [] |
| for idx in order[:top_k]: |
| if sims[idx] < min_sim: |
| break |
| out.append((self.documents[idx], float(sims[idx]))) |
| return out |
|
|
| @staticmethod |
| def build_rag_prompt(retrieved: List[Tuple[str, float]], query: str) -> str: |
| """Format retrieved chunks + query into a 'Context:\n...\n\nQuestion:' prompt.""" |
| if not retrieved: |
| return query |
| ctx = "\n---\n".join(text for text, _ in retrieved) |
| return f"Context:\n{ctx}\n\nQuestion: {query}" |
|
|