baya1116's picture
Add runtime/ scripts (SPChat + BGE RAG) for deployment
28178a3 verified
"""BGE-small-based retrieval helper for injecting external knowledge into SPChat.
Validated end-to-end: BGE-small picks the right chunk (sim ≈ 0.85+ for matched
topic, > 0.4 margin over distractors) and SPChat with rp=1.0/nr=0 faithfully
copies facts/numbers/names from the injected context.
Usage:
from runtime.spchat import SPChat
from runtime.rag import BGERetriever
chat = SPChat()
retriever = BGERetriever()
retriever.add_documents([...corpus chunks...])
hits = retriever.search("How many channels does Zephyric use?", top_k=2)
prompt = retriever.build_rag_prompt(hits, "How many channels does Zephyric use?")
state = chat.start_session("Use the Context to answer.")
answer = chat.turn(state, prompt, rp=1.0, nr=0) # ← fact-faithful decode
"""
from __future__ import annotations
from typing import List, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
class BGERetriever:
"""Thin wrapper around BGE-small (33M params, 384-dim) for cosine retrieval."""
DEFAULT_MODEL = "BAAI/bge-small-en-v1.5"
def __init__(self, model_id: str = DEFAULT_MODEL, device: str = "cpu"):
self.dev = torch.device(device)
self.tok = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModel.from_pretrained(model_id).to(self.dev).eval()
for p in self.model.parameters():
p.requires_grad_(False)
self.documents: List[str] = []
self.embeddings: np.ndarray = np.zeros((0, self._dim()), dtype=np.float32)
def _dim(self) -> int:
return int(self.model.config.hidden_size)
@torch.no_grad()
def _encode(self, texts: List[str], is_query: bool = False) -> np.ndarray:
if is_query:
texts = [f"Represent this sentence for searching relevant passages: {t}" for t in texts]
inputs = self.tok(
texts, padding=True, truncation=True, return_tensors="pt", max_length=512,
).to(self.dev)
out = self.model(**inputs)
emb = out.last_hidden_state[:, 0] # CLS pooling per BGE recommendation
emb = F.normalize(emb, p=2, dim=1)
return emb.cpu().numpy()
# -----------------------------------------------------------------------
def add_documents(self, texts: List[str]) -> None:
new_emb = self._encode(texts, is_query=False)
self.documents.extend(texts)
if len(self.embeddings) == 0:
self.embeddings = new_emb
else:
self.embeddings = np.concatenate([self.embeddings, new_emb], axis=0)
def search(
self,
query: str,
top_k: int = 2,
min_sim: float = 0.0,
) -> List[Tuple[str, float]]:
"""Return list of (text, cosine_similarity) sorted descending. Optionally
cut off below `min_sim` (use as a confidence gate for trigger logic)."""
if len(self.documents) == 0:
return []
q = self._encode([query], is_query=True)[0]
sims = self.embeddings @ q
order = np.argsort(-sims)
out = []
for idx in order[:top_k]:
if sims[idx] < min_sim:
break
out.append((self.documents[idx], float(sims[idx])))
return out
@staticmethod
def build_rag_prompt(retrieved: List[Tuple[str, float]], query: str) -> str:
"""Format retrieved chunks + query into a 'Context:\n...\n\nQuestion:' prompt."""
if not retrieved:
return query
ctx = "\n---\n".join(text for text, _ in retrieved)
return f"Context:\n{ctx}\n\nQuestion: {query}"