#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ LULUV2 optimized local inference engine. Goals: - load LULU2/LULUV2 checkpoints through the existing LULUV2 model file - no AutoModelForCausalLM.from_pretrained and no external model weights - vectorized prompt prefill into explicit KV caches - persistent session KV cache across turns when prompt tokens extend prior prompt - modes: fast(pass1/base), vwm(pass1+pass2), deep(pass1+pass2 long context) - safe fallback to slow full-prefix forward if cached path fails This is intentionally Python-first and debuggable. It is a bridge toward kernel/CUDA-graph optimization, not the final kernel path. """ from __future__ import annotations import importlib.util import json import math import os import platform import time import traceback from contextlib import nullcontext from dataclasses import dataclass, asdict from pathlib import Path from types import SimpleNamespace from typing import Any, Dict, Generator, List, Optional, Tuple import torch import torch.nn.functional as F try: import psutil except Exception: psutil = None try: import pynvml except Exception: pynvml = None STOP_STRINGS = [ "<|im_start|>", "<|im_end|>", "<|user|>", "<|system|>", "<|assistant|>", "User:", "Assistant:", "\nuser:", "\nassistant:", ] def setup_torch() -> None: if torch.cuda.is_available(): try: # Old API still works on current wheels; warnings are harmless. torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True except Exception: pass try: torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(True) torch.backends.cuda.enable_math_sdp(False) except Exception: pass if hasattr(torch, "set_float32_matmul_precision"): try: torch.set_float32_matmul_precision("high") except Exception: pass def human_bytes(num: float) -> str: num = float(num) for unit in ["B", "KB", "MB", "GB", "TB"]: if abs(num) < 1024.0: return f"{num:.2f} {unit}" num /= 1024.0 return f"{num:.2f} PB" def value_to_text(value: Any) -> str: if value is None: return "" if isinstance(value, str): return value if isinstance(value, dict): for key in ("text", "content", "value"): if key in value: return value_to_text(value.get(key)) return "\n".join(value_to_text(v) for v in value.values() if value_to_text(v)) if isinstance(value, (list, tuple)): return "\n".join(value_to_text(v) for v in value if value_to_text(v)) return str(value) def clean_text(text: Any) -> str: text = value_to_text(text).replace("\\n", "\n") cut_points = [text.find(s) for s in STOP_STRINGS if s in text and text.find(s) > 0] if cut_points: text = text[: min(cut_points)] for s in STOP_STRINGS: text = text.replace(s, "") text = text.strip() for prefix in ("Assistant:", "assistant:", "Lulu:", "lulu:"): if text.startswith(prefix): text = text[len(prefix):].strip() lines = [ln.rstrip() for ln in text.splitlines()] # collapse excessive vertical whitespace without destroying code blocks too much out: List[str] = [] blank = 0 for ln in lines: if not ln.strip(): blank += 1 if blank <= 2: out.append("") else: blank = 0 out.append(ln) return "\n".join(out).strip() def normalize_history(history: Any) -> List[Dict[str, str]]: out: List[Dict[str, str]] = [] if not history: return out for item in history: if isinstance(item, dict): role = item.get("role", "") content = clean_text(item.get("content", "")) if role in {"user", "assistant"} and content: out.append({"role": role, "content": content}) elif isinstance(item, (tuple, list)) and len(item) >= 2: u = clean_text(item[0]) a = clean_text(item[1]) if u: out.append({"role": "user", "content": u}) if a: out.append({"role": "assistant", "content": a}) return out def resolve_model_py(model_py: Optional[str]) -> str: candidates: List[str] = [] if model_py: candidates.append(model_py) candidates.extend(["luluv2_inference_runtime.py"]) for c in candidates: p = Path(c) if p.exists(): return str(p.resolve()) raise FileNotFoundError("Could not find LULUV2 model file. Pass --model-py.") def import_model_py(model_py: Optional[str]): path = resolve_model_py(model_py) spec = importlib.util.spec_from_file_location("luluv2_runtime_module", path) if spec is None or spec.loader is None: raise RuntimeError(f"Could not import model file: {path}") mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) return mod, path @dataclass class GenerationConfig: max_new_tokens: int = 512 temperature: float = 0.65 top_k: int = 40 top_p: float = 0.90 min_p: float = 0.03 repetition_penalty: float = 1.10 frequency_penalty: float = 0.02 greedy: bool = False no_repeat_ngram: int = 4 stream_every: int = 1 max_context_tokens: int = 4096 mode: str = "vwm" # fast, vwm, deep, slow return_pass_metrics: bool = True use_cache: bool = True vectorized_prefill: bool = True persistent_cache: bool = True compile_step: bool = False @dataclass class GenerationStats: prompt_tokens: int = 0 prompt_total_tokens: int = 0 prompt_kept_tokens: int = 0 prompt_dropped_tokens: int = 0 generated_tokens: int = 0 elapsed_sec: float = 0.0 tokens_per_sec: float = 0.0 prefill_sec: float = 0.0 prefill_tps: float = 0.0 cache_hit: bool = False cache_reused_tokens: int = 0 cache_new_prefill_tokens: int = 0 mode: str = "vwm" backend: str = "none" last_token: str = "" last_token_id: int = -1 last_token_prob: float = 0.0 last_entropy: float = 0.0 finish_reason: str = "none" pass1_pass2_kl: Optional[float] = None pass1_pass2_logit_cosine: Optional[float] = None class KVLayerCache: def __init__(self): self.k: Optional[torch.Tensor] = None # [B, H, T, Dh] self.v: Optional[torch.Tensor] = None @property def length(self) -> int: if self.k is None: return 0 return int(self.k.shape[2]) def set(self, k: torch.Tensor, v: torch.Tensor, max_len: int) -> None: if k.shape[2] > max_len: k = k[:, :, -max_len:, :] v = v[:, :, -max_len:, :] self.k = k.detach().contiguous() self.v = v.detach().contiguous() def append(self, k: torch.Tensor, v: torch.Tensor, max_len: int) -> None: if self.k is None: self.set(k, v, max_len) return self.k = torch.cat([self.k, k.detach()], dim=2) self.v = torch.cat([self.v, v.detach()], dim=2) if self.k.shape[2] > max_len: self.k = self.k[:, :, -max_len:, :].contiguous() self.v = self.v[:, :, -max_len:, :].contiguous() class DecoderKVCache: def __init__(self, n_layers: int): self.layers = [KVLayerCache() for _ in range(int(n_layers))] def clear(self): for layer in self.layers: layer.k = None layer.v = None @property def length(self) -> int: if not self.layers: return 0 return self.layers[0].length class LULUV2OptimizedEngine: def __init__( self, ckpt_path: str, model_py: Optional[str] = None, tokenizer_dir: Optional[str] = None, device: Optional[str] = None, dtype: str = "bf16", local_files_only: bool = True, no_config_download: bool = True, force_base_only: bool = False, ): setup_torch() self.ckpt_path = str(ckpt_path) self.ckpt_dir = Path(self.ckpt_path).resolve().parent self.device = self._select_device(device) self.dtype = self._dtype_from_name(dtype) self.local_files_only = bool(local_files_only) self.no_config_download = bool(no_config_download) self.force_base_only = bool(force_base_only) self.last_stats = GenerationStats() self.recent_tokens: List[Dict[str, Any]] = [] self.last_prompt_total_tokens: int = 0 self.last_prompt_kept_tokens: int = 0 self.last_prompt_dropped_tokens: int = 0 self.cache_ids: Optional[torch.Tensor] = None self.cache_mode: str = "" self.cache_max_context: int = 0 self.pass1_cache: Optional[DecoderKVCache] = None self.pass2_cache: Optional[DecoderKVCache] = None self.cached_logits: Optional[torch.Tensor] = None self.cached_pass1_logits: Optional[torch.Tensor] = None self.cached_pass2_logits: Optional[torch.Tensor] = None self.cache_backend: str = "cold" self.goku, self.model_py_path = import_model_py(model_py) self.args = SimpleNamespace( checkpoint=self.ckpt_path, tokenizer=tokenizer_dir or "", model_id="", no_config_download=self.no_config_download, local_files_only=self.local_files_only, ) print("[guard] LULUV2 cockpit: no AutoModelForCausalLM.from_pretrained call and no external model weights loaded.") print(f"[load] checkpoint={self.ckpt_path}") self.base_ckpt, base = self.goku.load_lulu2_base(self.args, self.device, self.dtype) self.tokenizer = self._load_tokenizer(tokenizer_dir) self.model, self.has_pass2 = self._maybe_wrap_pass2(base) self.base = self.model.base if self.has_pass2 else self.model self.n_layers = int(self.base.config.num_hidden_layers) self.model.eval() self.base.eval() self.model_info = self._build_model_info() self._compiled = False def _select_device(self, device: Optional[str]): if device: return torch.device(device) if torch.cuda.is_available(): return torch.device("cuda") return torch.device("cpu") def _dtype_from_name(self, name: str): name = (name or "bf16").lower() if name in {"bf16", "bfloat16"}: return torch.bfloat16 if name in {"fp16", "float16", "half"}: return torch.float16 return torch.float32 def _load_tokenizer(self, tokenizer_dir: Optional[str]): if tokenizer_dir: self.args.tokenizer = tokenizer_dir else: sibling = self.ckpt_dir / "tokenizer" if sibling.is_dir(): self.args.tokenizer = str(sibling) tok = self.goku.load_tokenizer(self.args, self.base_ckpt) if getattr(tok, "pad_token_id", None) is None and getattr(tok, "eos_token_id", None) is not None: try: tok.pad_token = tok.eos_token except Exception: pass # Long-prompt safety: for chat/RAG prompts, the latest user turn and final # instruction are normally at the end. Right-side truncation silently drops # exactly the part the model must answer, so force left truncation where the # tokenizer supports it. encode() below also performs manual left truncation # and records how many tokens were dropped. try: tok.truncation_side = "left" except Exception: pass try: tok.model_max_length = 10**9 except Exception: pass return tok def _maybe_wrap_pass2(self, base): ckpt = self.base_ckpt if self.force_base_only or "pass2_state" not in ckpt: print("[pass2] no pass2_state loaded; running base LULUV2 forward") return base.to(self.device).eval(), False cfg_dict = dict(ckpt.get("pass2_config") or {}) Pass2Config = self.goku.Pass2Config fields = getattr(Pass2Config, "__dataclass_fields__", {}) pass2_cfg = Pass2Config(**{k: v for k, v in cfg_dict.items() if k in fields}) model = self.goku.Lulu2TwoPassForCausalLM(base, pass2_cfg) missing, unexpected = model.load_state_dict(ckpt["pass2_state"], strict=False) print(f"[pass2] loaded pass2_state missing={len(missing)} unexpected={len(unexpected)}") model.to(device=self.device, dtype=self.dtype).eval() return model, True def _build_model_info(self) -> Dict[str, Any]: total_params = sum(p.numel() for p in self.model.parameters()) c_codes = [(n, p.numel()) for n, p in self.model.named_parameters() if n.endswith(".c")] gate_mean = None adapter_gate_mean = None if self.has_pass2: with torch.no_grad(): gate_mean = float(torch.sigmoid(self.model.layer_gates.float()).mean().item()) vals = [float(torch.sigmoid(ad.gate.float()).item()) for ad in self.model.adapters] adapter_gate_mean = sum(vals) / max(1, len(vals)) ckpt_size = Path(self.ckpt_path).stat().st_size if Path(self.ckpt_path).exists() else 0 cfg = getattr(self.base, "config", None) return { "checkpoint": self.ckpt_path, "checkpoint_size": human_bytes(ckpt_size), "model_py": self.model_py_path, "device": str(self.device), "dtype": str(self.dtype).replace("torch.", ""), "has_pass2": self.has_pass2, "total_params": total_params, "vwm_c_modules": len(c_codes), "vwm_c_params": sum(n for _, n in c_codes), "pass2_layer_gate_mean": gate_mean, "pass2_adapter_gate_mean": adapter_gate_mean, "hidden_size": getattr(cfg, "hidden_size", None), "layers": getattr(cfg, "num_hidden_layers", None), "heads": getattr(cfg, "num_attention_heads", None), "kv_heads": getattr(cfg, "num_key_value_heads", None), "max_position_embeddings": getattr(cfg, "max_position_embeddings", None), } def amp_context(self): if self.device.type == "cuda" and self.dtype in (torch.bfloat16, torch.float16): return torch.autocast("cuda", dtype=self.dtype) return nullcontext() def build_chat_prompt( self, message: str, history: Any, system_prompt: str, memory_notes: str = "", history_turns: int = 4, extra_context: str = "", ) -> str: history = normalize_history(history) recent = history[-max(0, int(history_turns)) * 2:] if history_turns else [] system_chunks: List[str] = [] if system_prompt.strip(): system_chunks.append(system_prompt.strip()) if memory_notes.strip(): system_chunks.append("Useful memory notes:\n" + memory_notes.strip()) if extra_context.strip(): system_chunks.append("Relevant local context:\n" + extra_context.strip()) system = "\n\n".join(system_chunks) messages: List[Dict[str, str]] = [] if system: messages.append({"role": "system", "content": system}) messages.extend(recent) messages.append({"role": "user", "content": clean_text(message)}) try: return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) except Exception: parts: List[str] = [] if system: parts.append(f"<|im_start|>system\n{system}<|im_end|>") for item in recent: parts.append(f"<|im_start|>{item['role']}\n{item['content']}<|im_end|>") parts.append(f"<|im_start|>user\n{clean_text(message)}<|im_end|>") parts.append("<|im_start|>assistant\n") return "\n".join(parts) def encode(self, text: str, max_context_tokens: int) -> torch.Tensor: """Encode prompt with explicit left-truncation and accounting. This avoids a common long-context failure mode: many tokenizers default to right-side truncation, which keeps the beginning of a huge prompt and drops the final user instruction. For chat, we almost always want the opposite. """ max_context = max(1, int(max_context_tokens)) try: self.tokenizer.truncation_side = "left" except Exception: pass # Tokenize without tokenizer-side truncation so we know exactly whether the # prompt was clipped. The prompt already contains chat special tokens. try: enc = self.tokenizer( text, return_tensors="pt", truncation=False, add_special_tokens=False, ) except TypeError: enc = self.tokenizer(text, return_tensors="pt", truncation=False) ids = enc.input_ids total = int(ids.shape[1]) dropped = max(0, total - max_context) if dropped > 0: ids = ids[:, -max_context:].contiguous() # Do not reuse an older conversation cache after a hard context trim; # the logical prefix changed and reuse can make long prompts feel like # they are "forgetting" pieces. self.pass1_cache = None self.pass2_cache = None self.cache_ids = None self.cached_logits = None self.cached_pass1_logits = None self.cached_pass2_logits = None self.cache_backend = "truncated-rebuild" self.last_prompt_total_tokens = total self.last_prompt_kept_tokens = int(ids.shape[1]) self.last_prompt_dropped_tokens = dropped return ids.to(self.device) def _position_ids(self, T: int, offset: int = 0) -> torch.Tensor: return torch.arange(offset, offset + T, device=self.device, dtype=torch.long).unsqueeze(0) def _attn_prefill(self, attn, hidden_states: torch.Tensor, position_ids: torch.Tensor, cache: KVLayerCache, max_context: int) -> torch.Tensor: bsz, q_len, _ = hidden_states.size() query_states = attn.q_proj(hidden_states) key_states = attn.k_proj(hidden_states) value_states = attn.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, attn.num_heads, attn.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, attn.num_key_value_heads, attn.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, attn.num_key_value_heads, attn.head_dim).transpose(1, 2) cos, sin = attn.rotary_emb(value_states, position_ids) query_states, key_states = self.goku.apply_rotary_pos_emb(query_states, key_states, cos, sin) key_states = self.goku.repeat_kv(key_states, attn.num_key_value_groups) value_states = self.goku.repeat_kv(value_states, attn.num_key_value_groups) cache.set(key_states, value_states, max_context) attn_output = F.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=None, dropout_p=0.0, is_causal=True, scale=attn.scaling ) attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, attn.hidden_size) return attn.o_proj(attn_output) def _attn_step(self, attn, hidden_states: torch.Tensor, pos: int, cache: KVLayerCache, max_context: int) -> torch.Tensor: bsz, q_len, _ = hidden_states.size() assert q_len == 1 query_states = attn.q_proj(hidden_states) key_states = attn.k_proj(hidden_states) value_states = attn.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, attn.num_heads, attn.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, attn.num_key_value_heads, attn.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, attn.num_key_value_heads, attn.head_dim).transpose(1, 2) position_ids = self._position_ids(1, pos) cos, sin = attn.rotary_emb(value_states, position_ids) query_states, key_states = self.goku.apply_rotary_pos_emb(query_states, key_states, cos, sin) key_states = self.goku.repeat_kv(key_states, attn.num_key_value_groups) value_states = self.goku.repeat_kv(value_states, attn.num_key_value_groups) cache.append(key_states, value_states, max_context) if cache.k is None or cache.v is None: raise RuntimeError("KV cache append failed") attn_output = F.scaled_dot_product_attention( query_states, cache.k, cache.v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=attn.scaling ) attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, attn.hidden_size) return attn.o_proj(attn_output) def _layer_prefill(self, layer, hidden_states: torch.Tensor, position_ids: torch.Tensor, cache: KVLayerCache, max_context: int) -> torch.Tensor: residual = hidden_states x = layer.input_layernorm(hidden_states) x = self._attn_prefill(layer.self_attn, x, position_ids, cache, max_context) hidden_states = residual + x residual = hidden_states x = layer.post_attention_layernorm(hidden_states) x = layer.mlp(x) return residual + x def _layer_step(self, layer, hidden_states: torch.Tensor, pos: int, cache: KVLayerCache, max_context: int) -> torch.Tensor: residual = hidden_states x = layer.input_layernorm(hidden_states) x = self._attn_step(layer.self_attn, x, pos, cache, max_context) hidden_states = residual + x residual = hidden_states x = layer.post_attention_layernorm(hidden_states) x = layer.mlp(x) return residual + x @torch.no_grad() def _prefill_pass1(self, input_ids: torch.Tensor, max_context: int, use_pass_embed: bool) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor, torch.Tensor]: T = int(input_ids.shape[1]) position_ids = self._position_ids(T, 0) cache = DecoderKVCache(self.n_layers) h = self.base.model.embed_tokens(input_ids) if use_pass_embed and self.has_pass2: h = h + self.model.pass_embed[0].to(dtype=h.dtype, device=h.device).view(1, 1, -1) layer_states: List[torch.Tensor] = [] for i, layer in enumerate(self.base.model.layers): h = self._layer_prefill(layer, h, position_ids, cache.layers[i], max_context) layer_states.append(h) normed = self.base.model.norm(h) logits = self.base.lm_head(normed) self.pass1_cache = cache return h, layer_states, position_ids, logits @torch.no_grad() def _prefill_pass2(self, h1_resid: torch.Tensor, pass1_states: List[torch.Tensor], position_ids: torch.Tensor, max_context: int) -> torch.Tensor: if not self.has_pass2: raise RuntimeError("pass2 requested but checkpoint has no pass2_state") cache = DecoderKVCache(self.n_layers) h2 = h1_resid + self.model.pass_embed[1].to(dtype=h1_resid.dtype, device=h1_resid.device).view(1, 1, -1) for i, layer in enumerate(self.base.model.layers): before = h2 layer_out = self._layer_prefill(layer, h2, position_ids, cache.layers[i], max_context) layer_delta = layer_out - before gate = torch.sigmoid(self.model.layer_gates[i]).to(dtype=h2.dtype, device=h2.device) adapter_delta = self.model.adapters[i](h2, pass1_states[i]) h2 = before + gate * layer_delta + adapter_delta normed = self.base.model.norm(h2) logits = self.base.lm_head(normed) self.pass2_cache = cache return logits @torch.no_grad() def _step_pass1(self, token_id: torch.Tensor, pos: int, max_context: int, use_pass_embed: bool) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]: if self.pass1_cache is None: self.pass1_cache = DecoderKVCache(self.n_layers) h = self.base.model.embed_tokens(token_id) if use_pass_embed and self.has_pass2: h = h + self.model.pass_embed[0].to(dtype=h.dtype, device=h.device).view(1, 1, -1) states: List[torch.Tensor] = [] for i, layer in enumerate(self.base.model.layers): h = self._layer_step(layer, h, pos, self.pass1_cache.layers[i], max_context) states.append(h) logits = self.base.lm_head(self.base.model.norm(h)) return h, states, logits @torch.no_grad() def _step_pass2(self, h1_resid: torch.Tensor, pass1_states: List[torch.Tensor], pos: int, max_context: int) -> torch.Tensor: if not self.has_pass2: raise RuntimeError("pass2 step requested but unavailable") if self.pass2_cache is None: self.pass2_cache = DecoderKVCache(self.n_layers) h2 = h1_resid + self.model.pass_embed[1].to(dtype=h1_resid.dtype, device=h1_resid.device).view(1, 1, -1) for i, layer in enumerate(self.base.model.layers): before = h2 layer_out = self._layer_step(layer, h2, pos, self.pass2_cache.layers[i], max_context) layer_delta = layer_out - before gate = torch.sigmoid(self.model.layer_gates[i]).to(dtype=h2.dtype, device=h2.device) adapter_delta = self.model.adapters[i](h2, pass1_states[i]) h2 = before + gate * layer_delta + adapter_delta return self.base.lm_head(self.base.model.norm(h2)) def _ids_prefix_len(self, old: torch.Tensor, new: torch.Tensor) -> int: if old is None or old.numel() == 0 or new.numel() == 0: return 0 old1 = old[0] new1 = new[0] max_n = min(int(old1.numel()), int(new1.numel())) if max_n == 0: return 0 # Fast path: old is exact prefix of new. if int(old1.numel()) <= int(new1.numel()) and torch.equal(old1, new1[: old1.numel()]): return int(old1.numel()) # Conservative fallback, scan from max down; prompts are usually exact-prefix or reset. for n in range(max_n, 0, -1): if torch.equal(old1[:n], new1[:n]): return n return 0 @torch.no_grad() def _token_prefill_context(self, input_ids: torch.Tensor, cfg: GenerationConfig, use_pass2: bool, use_pass_embed: bool, max_context: int) -> None: """ Conservative cache builder. It fills the same pass1/pass2 KV caches by walking the prompt one token at a time. This is slower than vectorized prefill but much safer across checkpoint/runtime variants, and it still gives a valid decode cache + persistent cache for the generated tokens. """ self.pass1_cache = DecoderKVCache(self.n_layers) self.pass2_cache = DecoderKVCache(self.n_layers) if use_pass2 else None self.cached_logits = None self.cached_pass1_logits = None self.cached_pass2_logits = None T = int(input_ids.shape[1]) for pos in range(T): tok = input_ids[:, pos:pos + 1] h1, states, logits1 = self._step_pass1(tok, pos, max_context, use_pass_embed=use_pass_embed) if use_pass2: logits2 = self._step_pass2(h1, states, pos, max_context) self.cached_logits = logits2 self.cached_pass1_logits = logits1 self.cached_pass2_logits = logits2 else: self.cached_logits = logits1 self.cached_pass1_logits = logits1 self.cached_pass2_logits = None @torch.no_grad() def _prepare_cached_context(self, input_ids: torch.Tensor, cfg: GenerationConfig) -> Tuple[torch.Tensor, bool, int, int, str]: mode = self._effective_mode(cfg.mode) max_context = int(cfg.max_context_tokens) use_pass2 = mode in {"vwm", "deep"} and self.has_pass2 use_pass_embed = bool(use_pass2) T = int(input_ids.shape[1]) if T > max_context: input_ids = input_ids[:, -max_context:] T = max_context # If mode/context changed, persistent cache is invalid. cache_ok = ( cfg.persistent_cache and self.cache_ids is not None and self.cache_mode == mode and self.cache_max_context == max_context and self.pass1_cache is not None ) prefix = self._ids_prefix_len(self.cache_ids, input_ids) if cache_ok else 0 cache_hit = bool(cache_ok and prefix == int(self.cache_ids.shape[1]) and prefix <= T and prefix > 0) t0 = time.time() if cache_hit: # Process only suffix between prior cached prompt and new prompt. suffix = input_ids[:, prefix:] for j in range(int(suffix.shape[1])): tok = suffix[:, j : j + 1] pos = prefix + j h1, states, logits1 = self._step_pass1(tok, pos, max_context, use_pass_embed=use_pass_embed) if use_pass2: logits2 = self._step_pass2(h1, states, pos, max_context) self.cached_logits = logits2 self.cached_pass1_logits = logits1 self.cached_pass2_logits = logits2 else: self.cached_logits = logits1 self.cached_pass1_logits = logits1 self.cached_pass2_logits = None self.cache_ids = input_ids.detach().clone() self.cache_backend = "persistent-kv-suffix" if suffix.numel() else "persistent-kv-hit" return input_ids, True, prefix, int(suffix.shape[1]), self.cache_backend # Reset and prefill. Prefer vectorized prefill, but fall back to conservative # token prefill if the runtime variant does not support our vectorized cache path. self.pass1_cache = None self.pass2_cache = None backend = "vectorized-prefill" if bool(cfg.vectorized_prefill): try: h1, states, pos_ids, logits1 = self._prefill_pass1(input_ids, max_context, use_pass_embed=use_pass_embed) if use_pass2: logits2 = self._prefill_pass2(h1, states, pos_ids, max_context) self.cached_logits = logits2 self.cached_pass1_logits = logits1 self.cached_pass2_logits = logits2 else: self.cached_logits = logits1 self.cached_pass1_logits = logits1 self.cached_pass2_logits = None except Exception as exc: if os.getenv("LULUV2_CACHE_DEBUG", "0").strip().lower() in {"1", "true", "yes", "on"}: print("[cache] vectorized prefill failed; using token-prefill cache.") traceback.print_exc() self._token_prefill_context(input_ids, cfg, use_pass2=use_pass2, use_pass_embed=use_pass_embed, max_context=max_context) backend = "token-prefill-cache" else: self._token_prefill_context(input_ids, cfg, use_pass2=use_pass2, use_pass_embed=use_pass_embed, max_context=max_context) backend = "token-prefill-cache" self.cache_ids = input_ids.detach().clone() self.cache_mode = mode self.cache_max_context = max_context self.cache_backend = backend return input_ids, False, 0, T, self.cache_backend def _effective_mode(self, mode: str) -> str: mode = (mode or "vwm").lower() if mode in {"fast", "base", "pass1"}: return "fast" if mode in {"deep", "32k", "long"}: return "deep" if mode in {"slow", "full"}: return "slow" return "vwm" @torch.no_grad() def pass_metrics_from_logits(self, logits1: Optional[torch.Tensor], logits2: Optional[torch.Tensor]) -> Tuple[Optional[float], Optional[float]]: if logits1 is None or logits2 is None: return None, None try: l1 = logits1[:, -1, :].float() l2 = logits2[:, -1, :].float() kl = F.kl_div(F.log_softmax(l2, dim=-1), F.softmax(l1, dim=-1), reduction="batchmean") cos = F.cosine_similarity(l1, l2, dim=-1).mean() return float(kl.item()), float(cos.item()) except Exception: return None, None def _apply_penalties(self, logits: torch.Tensor, generated: torch.Tensor, cfg: GenerationConfig) -> torch.Tensor: if generated.numel() == 0: return logits out = logits.clone() uniq, counts = torch.unique(generated.view(-1), return_counts=True) if cfg.repetition_penalty != 1.0: selected = out[:, uniq] selected = torch.where(selected > 0, selected / float(cfg.repetition_penalty), selected * float(cfg.repetition_penalty)) out[:, uniq] = selected if cfg.frequency_penalty: out[:, uniq] -= float(cfg.frequency_penalty) * counts.to(out.dtype).unsqueeze(0) n = int(cfg.no_repeat_ngram) if n > 1 and generated.size(1) >= n - 1: seq = generated[0].tolist() prefix = tuple(seq[-(n - 1):]) banned = [] for i in range(len(seq) - n + 1): if tuple(seq[i:i + n - 1]) == prefix: banned.append(seq[i + n - 1]) if banned: out[:, list(set(banned))] = -float("inf") return out @torch.no_grad() def _sample_next(self, logits: torch.Tensor, generated: torch.Tensor, cfg: GenerationConfig) -> Tuple[torch.Tensor, Dict[str, float]]: work = self._apply_penalties(logits.float(), generated, cfg) if cfg.greedy or cfg.temperature <= 0: probs = torch.softmax(work, dim=-1) next_id = torch.argmax(work, dim=-1, keepdim=True) else: work = work / max(float(cfg.temperature), 1e-6) if cfg.top_k > 0: k = min(int(cfg.top_k), work.size(-1)) thresh = torch.topk(work, k, dim=-1).values[..., -1, None] work = torch.where(work >= thresh, work, torch.full_like(work, -float("inf"))) if 0.0 < cfg.top_p < 1.0: sorted_logits, sorted_idx = torch.sort(work, descending=True, dim=-1) sorted_probs = torch.softmax(sorted_logits, dim=-1) cumprobs = torch.cumsum(sorted_probs, dim=-1) remove = cumprobs > float(cfg.top_p) shifted = remove.clone() shifted[..., 1:] = remove[..., :-1] shifted[..., 0] = False sorted_logits = sorted_logits.masked_fill(shifted, -float("inf")) work = torch.full_like(work, -float("inf")).scatter(1, sorted_idx, sorted_logits) if 0.0 < cfg.min_p < 1.0: probs_for_minp = torch.softmax(work, dim=-1) max_prob = probs_for_minp.max(dim=-1, keepdim=True).values keep = probs_for_minp >= float(cfg.min_p) * max_prob work = work.masked_fill(~keep, -float("inf")) probs = torch.softmax(work, dim=-1) if torch.isnan(probs).any() or not torch.isfinite(probs.sum()) or float(probs.sum()) <= 0: next_id = torch.argmax(logits, dim=-1, keepdim=True) probs = torch.softmax(logits.float(), dim=-1) else: next_id = torch.multinomial(probs, 1) prob = float(probs.gather(1, next_id).item()) if probs.numel() else 0.0 entropy = float((-(probs * torch.log(probs.clamp_min(1e-12))).sum(dim=-1)).mean().item()) if probs.numel() else 0.0 return next_id, {"prob": prob, "entropy": entropy} @torch.no_grad() def _slow_generate(self, ids: torch.Tensor, prompt_len: int, cfg: GenerationConfig) -> Generator[str, None, None]: # Compatibility path: full prefix recompute every token. eos_id = getattr(self.tokenizer, "eos_token_id", None) last_text = "" t0 = time.time() for step in range(int(cfg.max_new_tokens)): ctx = ids[:, -int(cfg.max_context_tokens):] with self.amp_context(): out = self.model(ctx) if self._effective_mode(cfg.mode) != "fast" else self.base(ctx) logits = out.logits[:, -1, :].float() generated = ids[:, prompt_len:] next_id, tok_stats = self._sample_next(logits, generated, cfg) ids = torch.cat([ids, next_id.to(ids.device)], dim=-1) token_id = int(next_id.item()) token_text = self.tokenizer.decode([token_id], skip_special_tokens=False) self._record_token(step + 1, token_id, token_text, tok_stats) if eos_id is not None and token_id == int(eos_id): break if (step + 1) % int(cfg.stream_every) == 0 or step == 0: raw = self.tokenizer.decode(ids[0, prompt_len:], skip_special_tokens=True) if any(s in raw for s in STOP_STRINGS): break text = clean_text(raw) if text and text != last_text: elapsed = time.time() - t0 gen = int(ids.shape[1]) - prompt_len self.last_stats = GenerationStats(prompt_tokens=prompt_len, prompt_total_tokens=self.last_prompt_total_tokens, prompt_kept_tokens=self.last_prompt_kept_tokens, prompt_dropped_tokens=self.last_prompt_dropped_tokens, generated_tokens=gen, elapsed_sec=elapsed, tokens_per_sec=gen / max(elapsed, 1e-9), mode=cfg.mode, backend="slow-full-prefix", last_token=token_text, last_token_id=token_id, last_token_prob=tok_stats["prob"], last_entropy=tok_stats["entropy"], finish_reason="streaming") last_text = text yield text final = clean_text(self.tokenizer.decode(ids[0, prompt_len:], skip_special_tokens=True)) if final: yield final def _record_token(self, i: int, token_id: int, token_text: str, tok_stats: Dict[str, float]) -> None: self.recent_tokens.append({"i": i, "id": token_id, "text": token_text, "prob": tok_stats.get("prob", 0.0), "entropy": tok_stats.get("entropy", 0.0)}) self.recent_tokens = self.recent_tokens[-64:] @torch.no_grad() def generate(self, prompt: str, cfg: GenerationConfig) -> Generator[str, None, None]: self.model.eval() self.base.eval() self.recent_tokens = [] mode = self._effective_mode(cfg.mode) if mode == "deep": cfg.max_context_tokens = max(int(cfg.max_context_tokens), 16384) ids = self.encode(prompt, max_context_tokens=int(cfg.max_context_tokens)) prompt_len = int(ids.shape[1]) if self.last_prompt_dropped_tokens > 0: print(f"[context] prompt clipped: kept={self.last_prompt_kept_tokens} total={self.last_prompt_total_tokens} dropped={self.last_prompt_dropped_tokens}") t_start = time.time() prefill_sec = 0.0 cache_hit = False reused = 0 new_prefill = prompt_len backend = "" pass_kl = None pass_cos = None if (not cfg.use_cache) or mode == "slow": yield from self._slow_generate(ids, prompt_len, cfg) return try: with self.amp_context(): t_pref = time.time() ids, cache_hit, reused, new_prefill, backend = self._prepare_cached_context(ids, cfg) prefill_sec = time.time() - t_pref pass_kl, pass_cos = self.pass_metrics_from_logits(self.cached_pass1_logits, self.cached_pass2_logits) if cfg.return_pass_metrics else (None, None) except Exception as exc: print(f"[cache] cached path failed; falling back to slow full-prefix: {type(exc).__name__}: {exc}") if os.getenv("LULUV2_CACHE_DEBUG", "0").strip().lower() in {"1", "true", "yes", "on"}: traceback.print_exc() self.pass1_cache = None self.pass2_cache = None self.cache_ids = None yield from self._slow_generate(ids, prompt_len, cfg) return eos_id = getattr(self.tokenizer, "eos_token_id", None) last_text = "" finish_reason = "length" use_pass2 = mode in {"vwm", "deep"} and self.has_pass2 use_pass_embed = bool(use_pass2) for step in range(int(cfg.max_new_tokens)): logits = self.cached_logits[:, -1, :].float() if self.cached_logits is not None and self.cached_logits.dim() == 3 else self.cached_logits.float() generated = ids[:, prompt_len:] next_id, tok_stats = self._sample_next(logits, generated, cfg) token_id = int(next_id.item()) token_text = self.tokenizer.decode([token_id], skip_special_tokens=False) self._record_token(step + 1, token_id, token_text, tok_stats) ids = torch.cat([ids, next_id.to(ids.device)], dim=-1) if eos_id is not None and token_id == int(eos_id): finish_reason = "eos" break pos = int(ids.shape[1]) - 1 try: with self.amp_context(): h1, states, logits1 = self._step_pass1(next_id.to(self.device), pos, int(cfg.max_context_tokens), use_pass_embed=use_pass_embed) if use_pass2: logits2 = self._step_pass2(h1, states, pos, int(cfg.max_context_tokens)) self.cached_logits = logits2 self.cached_pass1_logits = logits1 self.cached_pass2_logits = logits2 else: self.cached_logits = logits1 self.cached_pass1_logits = logits1 self.cached_pass2_logits = None if self.cache_ids is not None: self.cache_ids = torch.cat([self.cache_ids, next_id.detach().to(self.cache_ids.device)], dim=-1) if self.cache_ids.shape[1] > int(cfg.max_context_tokens): self.cache_ids = self.cache_ids[:, -int(cfg.max_context_tokens):] except Exception as exc: print(f"[decode-cache] step failed; falling back for this request: {type(exc).__name__}: {exc}") # Finish with slow path from current ids; do not pretend cache is valid. self.cache_ids = None yield from self._slow_generate(ids, prompt_len, cfg) return if (step + 1) % int(cfg.stream_every) == 0 or step == 0: raw = self.tokenizer.decode(ids[0, prompt_len:], skip_special_tokens=True) if any(s in raw for s in STOP_STRINGS): finish_reason = "stop_string" break text = clean_text(raw) if text and text != last_text: elapsed = time.time() - t_start gen = int(ids.shape[1]) - prompt_len self.last_stats = GenerationStats( prompt_tokens=prompt_len, prompt_total_tokens=self.last_prompt_total_tokens, prompt_kept_tokens=self.last_prompt_kept_tokens, prompt_dropped_tokens=self.last_prompt_dropped_tokens, generated_tokens=gen, elapsed_sec=elapsed, tokens_per_sec=gen / max(elapsed - prefill_sec, 1e-9), prefill_sec=prefill_sec, prefill_tps=(new_prefill / max(prefill_sec, 1e-9)), cache_hit=cache_hit, cache_reused_tokens=reused, cache_new_prefill_tokens=new_prefill, mode=mode, backend=backend, last_token=token_text, last_token_id=token_id, last_token_prob=tok_stats["prob"], last_entropy=tok_stats["entropy"], finish_reason="streaming", pass1_pass2_kl=pass_kl, pass1_pass2_logit_cosine=pass_cos, ) last_text = text yield text raw = self.tokenizer.decode(ids[0, prompt_len:], skip_special_tokens=True) final = clean_text(raw) elapsed = time.time() - t_start gen = int(ids.shape[1]) - prompt_len self.last_stats = GenerationStats( prompt_tokens=prompt_len, prompt_total_tokens=self.last_prompt_total_tokens, prompt_kept_tokens=self.last_prompt_kept_tokens, prompt_dropped_tokens=self.last_prompt_dropped_tokens, generated_tokens=gen, elapsed_sec=elapsed, tokens_per_sec=gen / max(elapsed - prefill_sec, 1e-9), prefill_sec=prefill_sec, prefill_tps=(new_prefill / max(prefill_sec, 1e-9)), cache_hit=cache_hit, cache_reused_tokens=reused, cache_new_prefill_tokens=new_prefill, mode=mode, backend=backend, last_token=self.recent_tokens[-1]["text"] if self.recent_tokens else "", last_token_id=self.recent_tokens[-1]["id"] if self.recent_tokens else -1, last_token_prob=self.recent_tokens[-1]["prob"] if self.recent_tokens else 0.0, last_entropy=self.recent_tokens[-1]["entropy"] if self.recent_tokens else 0.0, finish_reason=finish_reason, pass1_pass2_kl=pass_kl, pass1_pass2_logit_cosine=pass_cos, ) if final: yield final def clear_session_cache(self) -> None: self.pass1_cache = None self.pass2_cache = None self.cache_ids = None self.cached_logits = None self.cached_pass1_logits = None self.cached_pass2_logits = None self.cache_backend = "cleared" def stats_dict(self) -> Dict[str, Any]: return {"generation": asdict(self.last_stats), "model": self.model_info, "system": system_snapshot(self)} def stats_text(self) -> str: s = self.last_stats lines = [ f"Mode: {s.mode} | backend={s.backend}", f"Prompt tokens: {s.prompt_tokens} kept / {getattr(s, 'prompt_total_tokens', s.prompt_tokens)} total / {getattr(s, 'prompt_dropped_tokens', 0)} dropped", f"Generated tokens: {s.generated_tokens}", f"Elapsed: {s.elapsed_sec:.2f}s | prefill={s.prefill_sec:.2f}s ({s.prefill_tps:.1f} tok/s)", f"Decode speed: {s.tokens_per_sec:.2f} tok/s", f"Cache: hit={s.cache_hit} reused={s.cache_reused_tokens} new_prefill={s.cache_new_prefill_tokens}", f"Finish reason: {s.finish_reason}", f"Last token: {s.last_token!r} id={s.last_token_id} p={s.last_token_prob:.4f} H={s.last_entropy:.2f}", ] if s.pass1_pass2_kl is not None: lines.append(f"Pass1→Pass2 KL: {s.pass1_pass2_kl:.6f}") if s.pass1_pass2_logit_cosine is not None: lines.append(f"Pass1/Pass2 cosine: {s.pass1_pass2_logit_cosine:.6f}") lines.extend([ "", f"Checkpoint: {self.model_info['checkpoint']}", f"Checkpoint size: {self.model_info['checkpoint_size']}", f"Device: {self.model_info['device']} dtype={self.model_info['dtype']}", f"Pass2 active: {self.model_info['has_pass2']}", f"Params: {self.model_info['total_params']:,}", f"VWM c modules: {self.model_info['vwm_c_modules']} ({self.model_info['vwm_c_params']:,} c params)", ]) return "\n".join(lines) def token_trace_text(self) -> str: if not self.recent_tokens: return "No tokens generated yet." rows = [] for t in self.recent_tokens[-48:]: safe = repr(t["text"])[1:-1] rows.append(f"{t['i']:04d} id={t['id']:<7} p={t['prob']:.4f} H={t['entropy']:.2f} {safe}") return "\n".join(rows) def system_snapshot(engine: Optional[LULUV2OptimizedEngine] = None) -> Dict[str, Any]: snap: Dict[str, Any] = { "python_ram": "n/a", "system_ram": "n/a", "system_ram_percent": 0.0, "cpu_percent": 0.0, "gpu_name": "CUDA unavailable", "vram_allocated": "n/a", "vram_reserved": "n/a", "vram_used": "n/a", "vram_total": "n/a", "vram_percent": 0.0, "gpu_util_percent": None, "gpu_temp_c": None, } if psutil is not None: try: proc = psutil.Process(os.getpid()) vm = psutil.virtual_memory() snap.update({ "python_ram": human_bytes(proc.memory_info().rss), "system_ram": f"{human_bytes(vm.used)} / {human_bytes(vm.total)}", "system_ram_percent": float(vm.percent), "cpu_percent": float(psutil.cpu_percent(interval=0.0)), }) except Exception: pass if torch.cuda.is_available(): try: idx = torch.cuda.current_device() props = torch.cuda.get_device_properties(idx) allocated = int(torch.cuda.memory_allocated(idx)) reserved = int(torch.cuda.memory_reserved(idx)) total = int(props.total_memory) snap.update({ "gpu_name": props.name, "vram_allocated": human_bytes(allocated), "vram_reserved": human_bytes(reserved), "vram_used": human_bytes(allocated), "vram_total": human_bytes(total), "vram_percent": 100.0 * allocated / max(total, 1), }) if pynvml is not None: try: pynvml.nvmlInit() handle = pynvml.nvmlDeviceGetHandleByIndex(idx) util = pynvml.nvmlDeviceGetUtilizationRates(handle) mem = pynvml.nvmlDeviceGetMemoryInfo(handle) temp = pynvml.nvmlDeviceGetTemperature(handle, pynvml.NVML_TEMPERATURE_GPU) snap.update({ "gpu_util_percent": int(util.gpu), "vram_used": human_bytes(int(mem.used)), "vram_total": human_bytes(int(mem.total)), "vram_percent": 100.0 * float(mem.used) / max(float(mem.total), 1.0), "gpu_temp_c": int(temp), }) except Exception: pass except Exception: pass return snap def system_usage(engine: Optional[LULUV2OptimizedEngine] = None) -> str: snap = system_snapshot(engine) lines = [ f"OS: {platform.system()} {platform.release()}", f"Python RAM: {snap['python_ram']}", f"System RAM: {snap['system_ram']} ({snap['system_ram_percent']:.1f}%)", f"CPU: {snap['cpu_percent']:.1f}%", "", f"GPU: {snap['gpu_name']}", f"VRAM used: {snap['vram_used']} / {snap['vram_total']} ({snap['vram_percent']:.1f}%)", f"VRAM allocated: {snap['vram_allocated']}", f"VRAM reserved: {snap['vram_reserved']}", ] if snap.get("gpu_util_percent") is not None: lines.append(f"GPU util: {snap['gpu_util_percent']}%") if snap.get("gpu_temp_c") is not None: lines.append(f"GPU temp: {snap['gpu_temp_c']} C") if engine is not None: lines.extend(["", engine.stats_text()]) return "\n".join(lines)