| """SP-evict bounded-memory multi-turn chat runtime. |
| |
| Loads a frozen LLM + trained AttnPoolSP from a checkpoint and provides a |
| multi-turn chat interface that: |
| - keeps a system prompt as persistent KV cache, |
| - per turn, feeds <User>...<Assistant> tokens then samples response, |
| - maintains a bounded cumulative survivor set of distant tokens |
| (drop bottom-64 by attention mass each chunk), |
| - recomputes a 32-token SP each chunk from the survivors, |
| - keeps the LLM's per-chunk KV bounded at ~(mq + 32 + rw + chunk). |
| |
| Validated decoding settings: |
| - conversational chat: rp=1.15, nr=4 |
| - fact-faithful (RAG): rp=1.0, nr=0 |
| EOS must be prepended between turns (handled automatically inside `turn()`). |
| |
| CoT (`<think>...</think>`) is NOT stripped — stripping breaks the R1 distill's |
| chat-template expectation. Distant CoT gets compressed into the SP gist |
| naturally (in-distribution for the pooler, trained on dolphin-r1 CoT). |
| """ |
|
|
| from __future__ import annotations |
|
|
| import os |
| import sys |
| from typing import Dict, Any |
|
|
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from transformers.cache_utils import DynamicCache |
|
|
| |
| _HERE = os.path.dirname(os.path.abspath(__file__)) |
| _ROOT = os.path.dirname(_HERE) |
| if _ROOT not in sys.path: |
| sys.path.insert(0, _ROOT) |
|
|
| from train_attnpool import AttnPoolSP, sinusoidal |
|
|
|
|
| |
| |
| |
| def _rep_pen(logits: torch.Tensor, gen: list, p: float) -> torch.Tensor: |
| if p == 1.0 or not gen: |
| return logits |
| ids = torch.tensor(gen, device=logits.device) |
| sel = logits.gather(-1, ids.unsqueeze(0).expand(logits.size(0), -1)) |
| sel = torch.where(sel > 0, sel / p, sel * p) |
| return logits.scatter(-1, ids.unsqueeze(0).expand(logits.size(0), -1), sel) |
|
|
|
|
| def _banned_ngrams(gen: list, n: int) -> set: |
| if n <= 0 or len(gen) < n - 1: |
| return set() |
| pre = tuple(gen[-(n - 1):]) |
| b = set() |
| for i in range(len(gen) - n + 1): |
| ng = tuple(gen[i:i + n]) |
| if ng[:-1] == pre: |
| b.add(ng[-1]) |
| return b |
|
|
|
|
| def _pick(logits: torch.Tensor, gen: list, rp: float, nr: int) -> int: |
| logits = _rep_pen(logits.clone(), gen, rp) |
| for x in _banned_ngrams(gen, nr): |
| logits[0, x] = -float("inf") |
| return int(logits[0].argmax().item()) |
|
|
|
|
| |
| |
| |
| class SPChat: |
| """Bounded-memory chat with cumulative bottom-N eviction of distant tokens.""" |
|
|
| DEFAULT_MODEL = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" |
| DEFAULT_CKPT = "checkpoints/ap32_large.pt" |
|
|
| def __init__( |
| self, |
| model_id: str = DEFAULT_MODEL, |
| ckpt_path: str = DEFAULT_CKPT, |
| device: str = "cpu", |
| dtype: torch.dtype = torch.float32, |
| num_threads: int = 4, |
| rw: int = 512, |
| drop_per_chunk: int = 64, |
| chunk_size: int = 64, |
| ): |
| if device == "cpu": |
| torch.set_num_threads(num_threads) |
| self.dev = torch.device(device) |
| self.dt = dtype |
| self.rw = rw |
| self.drop = drop_per_chunk |
| self.C = chunk_size |
|
|
| self.tok = AutoTokenizer.from_pretrained(model_id) |
| if self.tok.pad_token is None: |
| self.tok.pad_token = self.tok.eos_token |
|
|
| self.llm = AutoModelForCausalLM.from_pretrained( |
| model_id, dtype=dtype, attn_implementation="sdpa", |
| ).to(self.dev) |
| self.llm.config.use_cache = True |
| for p in self.llm.parameters(): |
| p.requires_grad_(False) |
| self.llm.eval() |
| self.embed = self.llm.get_input_embeddings() |
| self.H = self.llm.config.hidden_size |
| self.EOS = self.tok.eos_token_id |
|
|
| |
| if not os.path.isabs(ckpt_path): |
| |
| cand = os.path.join(_ROOT, ckpt_path) |
| if os.path.exists(cand): |
| ckpt_path = cand |
| ck = torch.load(ckpt_path, map_location="cpu", weights_only=False) |
| self.S = ck["cfg"]["num_soft_tokens"] |
| self.pooler = AttnPoolSP( |
| self.H, self.S, |
| ck["args"]["heads"], ck["args"]["layers"], ck["args"]["ffn"], |
| ck["cfg"]["target_norm"], |
| ).to(self.dev).eval() |
| self.pooler.load_state_dict(ck["pooler"]) |
|
|
| |
| |
| |
| @torch.no_grad() |
| def _pool(self, past_pe: torch.Tensor, want_score: bool = False): |
| q = self.pooler.query.unsqueeze(0) |
| L = past_pe.size(1) |
| mass = torch.zeros(L) |
| for blk in self.pooler.blocks: |
| if L > 0: |
| qn = blk.lnq1(q) |
| kn = blk.lnk(past_pe) |
| a, w = blk.cross(qn, kn, kn, need_weights=want_score, average_attn_weights=True) |
| if want_score and w is not None: |
| mass += w[0].sum(0) |
| q = q + a |
| qn = blk.lnq2(q) |
| s, _ = blk.selfa(qn, qn, qn, need_weights=False) |
| q = q + s |
| q = q + blk.ffn(blk.lnq3(q)) |
| sp = self.pooler.ln_out(q) |
| sp = sp / sp.norm(dim=-1, keepdim=True).clamp(min=1e-6) * self.pooler.out_scale.abs() |
| return sp, mass |
|
|
| |
| |
| |
| @torch.no_grad() |
| def start_session(self, system_prompt: str = "") -> Dict[str, Any]: |
| """Prefill the persistent KV cache with a system prompt. Returns session state.""" |
| ids = self.tok( |
| system_prompt, return_tensors="pt", add_special_tokens=True, |
| ).input_ids.to(self.dev) |
| mq = ids.size(1) |
| cache = DynamicCache() |
| self.llm( |
| inputs_embeds=self.embed(ids).to(self.dt), |
| attention_mask=torch.ones(1, mq, device=self.dev), |
| past_key_values=cache, use_cache=True, |
| cache_position=torch.arange(mq, device=self.dev), |
| ) |
| return {"cache": cache, "mq": mq, "gen": [], "kept": [], "prev_end": 0} |
|
|
| @torch.no_grad() |
| def turn( |
| self, |
| state: Dict[str, Any], |
| user_msg: str, |
| max_resp: int = 600, |
| rp: float = 1.15, |
| nr: int = 4, |
| ) -> str: |
| """Run one turn. Feeds <User>{user_msg}<Assistant> then samples up to |
| `max_resp` tokens. Returns the assistant's response text. |
| |
| Decoding presets: |
| - conversational chat: rp=1.15, nr=4 (default) |
| - RAG/fact-faithful: rp=1.0, nr=0 |
| """ |
| feed_text = f"<|User|>{user_msg}<|Assistant|>" |
| feed_ids = self.tok(feed_text, return_tensors="pt", add_special_tokens=False).input_ids[0].tolist() |
| if state["gen"]: |
| feed_ids = [self.EOS] + feed_ids |
|
|
| cache = state["cache"] |
| mq = state["mq"] |
| gen = state["gen"] |
| kept = state["kept"] |
| prev_end = state["prev_end"] |
| PE = sinusoidal(len(gen) + len(feed_ids) + max_resp + 64, self.H, self.dev) |
|
|
| feed_idx = 0 |
| sampled = 0 |
| done = False |
| start_idx = len(gen) |
|
|
| while not done: |
| c0 = len(gen) |
| R = min(c0, self.rw) |
| end = c0 - R |
|
|
| |
| if end > prev_end: |
| kept.extend(range(prev_end, end)) |
| prev_end = end |
|
|
| |
| if len(kept) > self.drop: |
| ksp = torch.tensor(kept) |
| ke = self.embed(torch.tensor([[gen[p] for p in kept]], device=self.dev)).float()[0] + PE[ksp] |
| _, mass = self._pool(ke.unsqueeze(0), want_score=True) |
| keep_local = torch.topk(mass, len(kept) - self.drop, largest=True).indices |
| kept = ksp[keep_local].tolist() |
|
|
| |
| if kept: |
| ke = self.embed(torch.tensor([[gen[p] for p in kept]], device=self.dev)).float()[0] + PE[torch.tensor(kept)] |
| sp, _ = self._pool(ke.unsqueeze(0)) |
| else: |
| sp, _ = self._pool(torch.zeros(1, 0, self.H, device=self.dev)) |
|
|
| |
| parts = [sp.to(self.dt)] |
| if R > 0: |
| parts.append(self.embed(torch.tensor([gen[c0 - R:c0]], device=self.dev)).to(self.dt)) |
| seq = torch.cat(parts, 1) |
| n = self.S + R |
| pos = torch.arange(mq, mq + n, device=self.dev) |
| o = self.llm( |
| inputs_embeds=seq, |
| attention_mask=torch.ones(1, mq + n, device=self.dev), |
| past_key_values=cache, |
| position_ids=pos.unsqueeze(0), |
| use_cache=True, |
| cache_position=pos, |
| ) |
| last_logits = o.logits[:, -1, :].float() |
| npos = mq + n |
|
|
| |
| for _ in range(self.C): |
| if feed_idx < len(feed_ids): |
| t = feed_ids[feed_idx] |
| feed_idx += 1 |
| elif sampled < max_resp: |
| t = _pick(last_logits, gen, rp, nr) |
| sampled += 1 |
| if t == self.EOS: |
| done = True |
| break |
| else: |
| done = True |
| break |
| gen.append(t) |
| e = self.embed(torch.tensor([[t]], device=self.dev)).to(self.dt) |
| o = self.llm( |
| inputs_embeds=e, |
| attention_mask=torch.ones(1, npos + 1, device=self.dev), |
| past_key_values=cache, |
| position_ids=torch.tensor([[npos]], device=self.dev), |
| use_cache=True, |
| cache_position=torch.tensor([npos], device=self.dev), |
| ) |
| last_logits = o.logits[:, -1, :].float() |
| npos += 1 |
|
|
| cache.crop(mq) |
| if done: |
| break |
|
|
| state["gen"] = gen |
| state["kept"] = kept |
| state["prev_end"] = prev_end |
| response_ids = gen[start_idx + len(feed_ids):] |
| return self.tok.decode(response_ids, skip_special_tokens=True) |
|
|
| def stats(self, state: Dict[str, Any]) -> str: |
| gl = len(state["gen"]) |
| in_raw_start = max(0, gl - self.rw) |
| kv_bound = state["mq"] + self.S + self.rw + self.C |
| return ( |
| f"gen={gl} raw=[{in_raw_start}..{gl}]({gl - in_raw_start}) " |
| f"distant=[0..{in_raw_start}]({in_raw_start}) " |
| f"survivors={len(state['kept'])} LLM_KV_bound≈{kv_bound}" |
| ) |
|
|