baya1116's picture
Add runtime/ scripts (SPChat + BGE RAG) for deployment
28178a3 verified
"""SP-evict bounded-memory multi-turn chat runtime.
Loads a frozen LLM + trained AttnPoolSP from a checkpoint and provides a
multi-turn chat interface that:
- keeps a system prompt as persistent KV cache,
- per turn, feeds <User>...<Assistant> tokens then samples response,
- maintains a bounded cumulative survivor set of distant tokens
(drop bottom-64 by attention mass each chunk),
- recomputes a 32-token SP each chunk from the survivors,
- keeps the LLM's per-chunk KV bounded at ~(mq + 32 + rw + chunk).
Validated decoding settings:
- conversational chat: rp=1.15, nr=4
- fact-faithful (RAG): rp=1.0, nr=0
EOS must be prepended between turns (handled automatically inside `turn()`).
CoT (`<think>...</think>`) is NOT stripped — stripping breaks the R1 distill's
chat-template expectation. Distant CoT gets compressed into the SP gist
naturally (in-distribution for the pooler, trained on dolphin-r1 CoT).
"""
from __future__ import annotations
import os
import sys
from typing import Dict, Any
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import DynamicCache
# Repo root must be on sys.path for AttnPoolSP / sinusoidal imports.
_HERE = os.path.dirname(os.path.abspath(__file__))
_ROOT = os.path.dirname(_HERE)
if _ROOT not in sys.path:
sys.path.insert(0, _ROOT)
from train_attnpool import AttnPoolSP, sinusoidal # noqa: E402
# ---------------------------------------------------------------------------
# Sampling helpers (repetition penalty + no-repeat n-gram, greedy argmax pick)
# ---------------------------------------------------------------------------
def _rep_pen(logits: torch.Tensor, gen: list, p: float) -> torch.Tensor:
if p == 1.0 or not gen:
return logits
ids = torch.tensor(gen, device=logits.device)
sel = logits.gather(-1, ids.unsqueeze(0).expand(logits.size(0), -1))
sel = torch.where(sel > 0, sel / p, sel * p)
return logits.scatter(-1, ids.unsqueeze(0).expand(logits.size(0), -1), sel)
def _banned_ngrams(gen: list, n: int) -> set:
if n <= 0 or len(gen) < n - 1:
return set()
pre = tuple(gen[-(n - 1):])
b = set()
for i in range(len(gen) - n + 1):
ng = tuple(gen[i:i + n])
if ng[:-1] == pre:
b.add(ng[-1])
return b
def _pick(logits: torch.Tensor, gen: list, rp: float, nr: int) -> int:
logits = _rep_pen(logits.clone(), gen, rp)
for x in _banned_ngrams(gen, nr):
logits[0, x] = -float("inf")
return int(logits[0].argmax().item())
# ---------------------------------------------------------------------------
# SPChat
# ---------------------------------------------------------------------------
class SPChat:
"""Bounded-memory chat with cumulative bottom-N eviction of distant tokens."""
DEFAULT_MODEL = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
DEFAULT_CKPT = "checkpoints/ap32_large.pt"
def __init__(
self,
model_id: str = DEFAULT_MODEL,
ckpt_path: str = DEFAULT_CKPT,
device: str = "cpu",
dtype: torch.dtype = torch.float32,
num_threads: int = 4,
rw: int = 512,
drop_per_chunk: int = 64,
chunk_size: int = 64,
):
if device == "cpu":
torch.set_num_threads(num_threads)
self.dev = torch.device(device)
self.dt = dtype
self.rw = rw
self.drop = drop_per_chunk
self.C = chunk_size
self.tok = AutoTokenizer.from_pretrained(model_id)
if self.tok.pad_token is None:
self.tok.pad_token = self.tok.eos_token
self.llm = AutoModelForCausalLM.from_pretrained(
model_id, dtype=dtype, attn_implementation="sdpa",
).to(self.dev)
self.llm.config.use_cache = True
for p in self.llm.parameters():
p.requires_grad_(False)
self.llm.eval()
self.embed = self.llm.get_input_embeddings()
self.H = self.llm.config.hidden_size
self.EOS = self.tok.eos_token_id
# checkpoint -> pooler
if not os.path.isabs(ckpt_path):
# Try relative to repo root first
cand = os.path.join(_ROOT, ckpt_path)
if os.path.exists(cand):
ckpt_path = cand
ck = torch.load(ckpt_path, map_location="cpu", weights_only=False)
self.S = ck["cfg"]["num_soft_tokens"]
self.pooler = AttnPoolSP(
self.H, self.S,
ck["args"]["heads"], ck["args"]["layers"], ck["args"]["ffn"],
ck["cfg"]["target_norm"],
).to(self.dev).eval()
self.pooler.load_state_dict(ck["pooler"])
# -----------------------------------------------------------------------
# Internal: attention-pooling SP from the kept survivors
# -----------------------------------------------------------------------
@torch.no_grad()
def _pool(self, past_pe: torch.Tensor, want_score: bool = False):
q = self.pooler.query.unsqueeze(0)
L = past_pe.size(1)
mass = torch.zeros(L)
for blk in self.pooler.blocks:
if L > 0:
qn = blk.lnq1(q)
kn = blk.lnk(past_pe)
a, w = blk.cross(qn, kn, kn, need_weights=want_score, average_attn_weights=True)
if want_score and w is not None:
mass += w[0].sum(0)
q = q + a
qn = blk.lnq2(q)
s, _ = blk.selfa(qn, qn, qn, need_weights=False)
q = q + s
q = q + blk.ffn(blk.lnq3(q))
sp = self.pooler.ln_out(q)
sp = sp / sp.norm(dim=-1, keepdim=True).clamp(min=1e-6) * self.pooler.out_scale.abs()
return sp, mass
# -----------------------------------------------------------------------
# Public API
# -----------------------------------------------------------------------
@torch.no_grad()
def start_session(self, system_prompt: str = "") -> Dict[str, Any]:
"""Prefill the persistent KV cache with a system prompt. Returns session state."""
ids = self.tok(
system_prompt, return_tensors="pt", add_special_tokens=True,
).input_ids.to(self.dev)
mq = ids.size(1)
cache = DynamicCache()
self.llm(
inputs_embeds=self.embed(ids).to(self.dt),
attention_mask=torch.ones(1, mq, device=self.dev),
past_key_values=cache, use_cache=True,
cache_position=torch.arange(mq, device=self.dev),
)
return {"cache": cache, "mq": mq, "gen": [], "kept": [], "prev_end": 0}
@torch.no_grad()
def turn(
self,
state: Dict[str, Any],
user_msg: str,
max_resp: int = 600,
rp: float = 1.15,
nr: int = 4,
) -> str:
"""Run one turn. Feeds <User>{user_msg}<Assistant> then samples up to
`max_resp` tokens. Returns the assistant's response text.
Decoding presets:
- conversational chat: rp=1.15, nr=4 (default)
- RAG/fact-faithful: rp=1.0, nr=0
"""
feed_text = f"<|User|>{user_msg}<|Assistant|>"
feed_ids = self.tok(feed_text, return_tensors="pt", add_special_tokens=False).input_ids[0].tolist()
if state["gen"]: # not the first turn -> close prev turn with EOS
feed_ids = [self.EOS] + feed_ids
cache = state["cache"]
mq = state["mq"]
gen = state["gen"]
kept = state["kept"]
prev_end = state["prev_end"]
PE = sinusoidal(len(gen) + len(feed_ids) + max_resp + 64, self.H, self.dev)
feed_idx = 0
sampled = 0
done = False
start_idx = len(gen)
while not done:
c0 = len(gen)
R = min(c0, self.rw)
end = c0 - R
# 1. extend kept with newly evicted-from-window positions
if end > prev_end:
kept.extend(range(prev_end, end))
prev_end = end
# 2. cumulative bottom-DROP eviction by attention mass
if len(kept) > self.drop:
ksp = torch.tensor(kept)
ke = self.embed(torch.tensor([[gen[p] for p in kept]], device=self.dev)).float()[0] + PE[ksp]
_, mass = self._pool(ke.unsqueeze(0), want_score=True)
keep_local = torch.topk(mass, len(kept) - self.drop, largest=True).indices
kept = ksp[keep_local].tolist()
# 3. compute SP from survivors
if kept:
ke = self.embed(torch.tensor([[gen[p] for p in kept]], device=self.dev)).float()[0] + PE[torch.tensor(kept)]
sp, _ = self._pool(ke.unsqueeze(0))
else:
sp, _ = self._pool(torch.zeros(1, 0, self.H, device=self.dev))
# 4. assemble [SP + raw window] prefix, forward
parts = [sp.to(self.dt)]
if R > 0:
parts.append(self.embed(torch.tensor([gen[c0 - R:c0]], device=self.dev)).to(self.dt))
seq = torch.cat(parts, 1)
n = self.S + R
pos = torch.arange(mq, mq + n, device=self.dev)
o = self.llm(
inputs_embeds=seq,
attention_mask=torch.ones(1, mq + n, device=self.dev),
past_key_values=cache,
position_ids=pos.unsqueeze(0),
use_cache=True,
cache_position=pos,
)
last_logits = o.logits[:, -1, :].float()
npos = mq + n
# 5. inner loop: feed or sample C tokens
for _ in range(self.C):
if feed_idx < len(feed_ids):
t = feed_ids[feed_idx]
feed_idx += 1
elif sampled < max_resp:
t = _pick(last_logits, gen, rp, nr)
sampled += 1
if t == self.EOS:
done = True
break
else:
done = True
break
gen.append(t)
e = self.embed(torch.tensor([[t]], device=self.dev)).to(self.dt)
o = self.llm(
inputs_embeds=e,
attention_mask=torch.ones(1, npos + 1, device=self.dev),
past_key_values=cache,
position_ids=torch.tensor([[npos]], device=self.dev),
use_cache=True,
cache_position=torch.tensor([npos], device=self.dev),
)
last_logits = o.logits[:, -1, :].float()
npos += 1
cache.crop(mq) # KV stays bounded across chunks
if done:
break
state["gen"] = gen
state["kept"] = kept
state["prev_end"] = prev_end
response_ids = gen[start_idx + len(feed_ids):]
return self.tok.decode(response_ids, skip_special_tokens=True)
def stats(self, state: Dict[str, Any]) -> str:
gl = len(state["gen"])
in_raw_start = max(0, gl - self.rw)
kv_bound = state["mq"] + self.S + self.rw + self.C
return (
f"gen={gl} raw=[{in_raw_start}..{gl}]({gl - in_raw_start}) "
f"distant=[0..{in_raw_start}]({in_raw_start}) "
f"survivors={len(state['kept'])} LLM_KV_bound≈{kv_bound}"
)