"""SP-evict bounded-memory multi-turn chat runtime. Loads a frozen LLM + trained AttnPoolSP from a checkpoint and provides a multi-turn chat interface that: - keeps a system prompt as persistent KV cache, - per turn, feeds ... tokens then samples response, - maintains a bounded cumulative survivor set of distant tokens (drop bottom-64 by attention mass each chunk), - recomputes a 32-token SP each chunk from the survivors, - keeps the LLM's per-chunk KV bounded at ~(mq + 32 + rw + chunk). Validated decoding settings: - conversational chat: rp=1.15, nr=4 - fact-faithful (RAG): rp=1.0, nr=0 EOS must be prepended between turns (handled automatically inside `turn()`). CoT (`...`) is NOT stripped — stripping breaks the R1 distill's chat-template expectation. Distant CoT gets compressed into the SP gist naturally (in-distribution for the pooler, trained on dolphin-r1 CoT). """ from __future__ import annotations import os import sys from typing import Dict, Any import torch from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.cache_utils import DynamicCache # Repo root must be on sys.path for AttnPoolSP / sinusoidal imports. _HERE = os.path.dirname(os.path.abspath(__file__)) _ROOT = os.path.dirname(_HERE) if _ROOT not in sys.path: sys.path.insert(0, _ROOT) from train_attnpool import AttnPoolSP, sinusoidal # noqa: E402 # --------------------------------------------------------------------------- # Sampling helpers (repetition penalty + no-repeat n-gram, greedy argmax pick) # --------------------------------------------------------------------------- def _rep_pen(logits: torch.Tensor, gen: list, p: float) -> torch.Tensor: if p == 1.0 or not gen: return logits ids = torch.tensor(gen, device=logits.device) sel = logits.gather(-1, ids.unsqueeze(0).expand(logits.size(0), -1)) sel = torch.where(sel > 0, sel / p, sel * p) return logits.scatter(-1, ids.unsqueeze(0).expand(logits.size(0), -1), sel) def _banned_ngrams(gen: list, n: int) -> set: if n <= 0 or len(gen) < n - 1: return set() pre = tuple(gen[-(n - 1):]) b = set() for i in range(len(gen) - n + 1): ng = tuple(gen[i:i + n]) if ng[:-1] == pre: b.add(ng[-1]) return b def _pick(logits: torch.Tensor, gen: list, rp: float, nr: int) -> int: logits = _rep_pen(logits.clone(), gen, rp) for x in _banned_ngrams(gen, nr): logits[0, x] = -float("inf") return int(logits[0].argmax().item()) # --------------------------------------------------------------------------- # SPChat # --------------------------------------------------------------------------- class SPChat: """Bounded-memory chat with cumulative bottom-N eviction of distant tokens.""" DEFAULT_MODEL = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" DEFAULT_CKPT = "checkpoints/ap32_large.pt" def __init__( self, model_id: str = DEFAULT_MODEL, ckpt_path: str = DEFAULT_CKPT, device: str = "cpu", dtype: torch.dtype = torch.float32, num_threads: int = 4, rw: int = 512, drop_per_chunk: int = 64, chunk_size: int = 64, ): if device == "cpu": torch.set_num_threads(num_threads) self.dev = torch.device(device) self.dt = dtype self.rw = rw self.drop = drop_per_chunk self.C = chunk_size self.tok = AutoTokenizer.from_pretrained(model_id) if self.tok.pad_token is None: self.tok.pad_token = self.tok.eos_token self.llm = AutoModelForCausalLM.from_pretrained( model_id, dtype=dtype, attn_implementation="sdpa", ).to(self.dev) self.llm.config.use_cache = True for p in self.llm.parameters(): p.requires_grad_(False) self.llm.eval() self.embed = self.llm.get_input_embeddings() self.H = self.llm.config.hidden_size self.EOS = self.tok.eos_token_id # checkpoint -> pooler if not os.path.isabs(ckpt_path): # Try relative to repo root first cand = os.path.join(_ROOT, ckpt_path) if os.path.exists(cand): ckpt_path = cand ck = torch.load(ckpt_path, map_location="cpu", weights_only=False) self.S = ck["cfg"]["num_soft_tokens"] self.pooler = AttnPoolSP( self.H, self.S, ck["args"]["heads"], ck["args"]["layers"], ck["args"]["ffn"], ck["cfg"]["target_norm"], ).to(self.dev).eval() self.pooler.load_state_dict(ck["pooler"]) # ----------------------------------------------------------------------- # Internal: attention-pooling SP from the kept survivors # ----------------------------------------------------------------------- @torch.no_grad() def _pool(self, past_pe: torch.Tensor, want_score: bool = False): q = self.pooler.query.unsqueeze(0) L = past_pe.size(1) mass = torch.zeros(L) for blk in self.pooler.blocks: if L > 0: qn = blk.lnq1(q) kn = blk.lnk(past_pe) a, w = blk.cross(qn, kn, kn, need_weights=want_score, average_attn_weights=True) if want_score and w is not None: mass += w[0].sum(0) q = q + a qn = blk.lnq2(q) s, _ = blk.selfa(qn, qn, qn, need_weights=False) q = q + s q = q + blk.ffn(blk.lnq3(q)) sp = self.pooler.ln_out(q) sp = sp / sp.norm(dim=-1, keepdim=True).clamp(min=1e-6) * self.pooler.out_scale.abs() return sp, mass # ----------------------------------------------------------------------- # Public API # ----------------------------------------------------------------------- @torch.no_grad() def start_session(self, system_prompt: str = "") -> Dict[str, Any]: """Prefill the persistent KV cache with a system prompt. Returns session state.""" ids = self.tok( system_prompt, return_tensors="pt", add_special_tokens=True, ).input_ids.to(self.dev) mq = ids.size(1) cache = DynamicCache() self.llm( inputs_embeds=self.embed(ids).to(self.dt), attention_mask=torch.ones(1, mq, device=self.dev), past_key_values=cache, use_cache=True, cache_position=torch.arange(mq, device=self.dev), ) return {"cache": cache, "mq": mq, "gen": [], "kept": [], "prev_end": 0} @torch.no_grad() def turn( self, state: Dict[str, Any], user_msg: str, max_resp: int = 600, rp: float = 1.15, nr: int = 4, ) -> str: """Run one turn. Feeds {user_msg} then samples up to `max_resp` tokens. Returns the assistant's response text. Decoding presets: - conversational chat: rp=1.15, nr=4 (default) - RAG/fact-faithful: rp=1.0, nr=0 """ feed_text = f"<|User|>{user_msg}<|Assistant|>" feed_ids = self.tok(feed_text, return_tensors="pt", add_special_tokens=False).input_ids[0].tolist() if state["gen"]: # not the first turn -> close prev turn with EOS feed_ids = [self.EOS] + feed_ids cache = state["cache"] mq = state["mq"] gen = state["gen"] kept = state["kept"] prev_end = state["prev_end"] PE = sinusoidal(len(gen) + len(feed_ids) + max_resp + 64, self.H, self.dev) feed_idx = 0 sampled = 0 done = False start_idx = len(gen) while not done: c0 = len(gen) R = min(c0, self.rw) end = c0 - R # 1. extend kept with newly evicted-from-window positions if end > prev_end: kept.extend(range(prev_end, end)) prev_end = end # 2. cumulative bottom-DROP eviction by attention mass if len(kept) > self.drop: ksp = torch.tensor(kept) ke = self.embed(torch.tensor([[gen[p] for p in kept]], device=self.dev)).float()[0] + PE[ksp] _, mass = self._pool(ke.unsqueeze(0), want_score=True) keep_local = torch.topk(mass, len(kept) - self.drop, largest=True).indices kept = ksp[keep_local].tolist() # 3. compute SP from survivors if kept: ke = self.embed(torch.tensor([[gen[p] for p in kept]], device=self.dev)).float()[0] + PE[torch.tensor(kept)] sp, _ = self._pool(ke.unsqueeze(0)) else: sp, _ = self._pool(torch.zeros(1, 0, self.H, device=self.dev)) # 4. assemble [SP + raw window] prefix, forward parts = [sp.to(self.dt)] if R > 0: parts.append(self.embed(torch.tensor([gen[c0 - R:c0]], device=self.dev)).to(self.dt)) seq = torch.cat(parts, 1) n = self.S + R pos = torch.arange(mq, mq + n, device=self.dev) o = self.llm( inputs_embeds=seq, attention_mask=torch.ones(1, mq + n, device=self.dev), past_key_values=cache, position_ids=pos.unsqueeze(0), use_cache=True, cache_position=pos, ) last_logits = o.logits[:, -1, :].float() npos = mq + n # 5. inner loop: feed or sample C tokens for _ in range(self.C): if feed_idx < len(feed_ids): t = feed_ids[feed_idx] feed_idx += 1 elif sampled < max_resp: t = _pick(last_logits, gen, rp, nr) sampled += 1 if t == self.EOS: done = True break else: done = True break gen.append(t) e = self.embed(torch.tensor([[t]], device=self.dev)).to(self.dt) o = self.llm( inputs_embeds=e, attention_mask=torch.ones(1, npos + 1, device=self.dev), past_key_values=cache, position_ids=torch.tensor([[npos]], device=self.dev), use_cache=True, cache_position=torch.tensor([npos], device=self.dev), ) last_logits = o.logits[:, -1, :].float() npos += 1 cache.crop(mq) # KV stays bounded across chunks if done: break state["gen"] = gen state["kept"] = kept state["prev_end"] = prev_end response_ids = gen[start_idx + len(feed_ids):] return self.tok.decode(response_ids, skip_special_tokens=True) def stats(self, state: Dict[str, Any]) -> str: gl = len(state["gen"]) in_raw_start = max(0, gl - self.rw) kv_bound = state["mq"] + self.S + self.rw + self.C return ( f"gen={gl} raw=[{in_raw_start}..{gl}]({gl - in_raw_start}) " f"distant=[0..{in_raw_start}]({in_raw_start}) " f"survivors={len(state['kept'])} LLM_KV_bound≈{kv_bound}" )