Lulu750M-Instruct-LOCALNOQANT / luluv2_optimized_engine.py
TheOpenMachine's picture
Upload 13 files
edf2c2e verified
Raw
History Blame Contribute Delete
51.9 kB
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
LULUV2 optimized local inference engine.
Goals:
- load LULU2/LULUV2 checkpoints through the existing LULUV2 model file
- no AutoModelForCausalLM.from_pretrained and no external model weights
- vectorized prompt prefill into explicit KV caches
- persistent session KV cache across turns when prompt tokens extend prior prompt
- modes: fast(pass1/base), vwm(pass1+pass2), deep(pass1+pass2 long context)
- safe fallback to slow full-prefix forward if cached path fails
This is intentionally Python-first and debuggable. It is a bridge toward
kernel/CUDA-graph optimization, not the final kernel path.
"""
from __future__ import annotations
import importlib.util
import json
import math
import os
import platform
import time
import traceback
from contextlib import nullcontext
from dataclasses import dataclass, asdict
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict, Generator, List, Optional, Tuple
import torch
import torch.nn.functional as F
try:
import psutil
except Exception:
psutil = None
try:
import pynvml
except Exception:
pynvml = None
STOP_STRINGS = [
"<|im_start|>", "<|im_end|>", "<|user|>", "<|system|>", "<|assistant|>",
"User:", "Assistant:", "\nuser:", "\nassistant:",
]
def setup_torch() -> None:
if torch.cuda.is_available():
try:
# Old API still works on current wheels; warnings are harmless.
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
except Exception:
pass
try:
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(False)
except Exception:
pass
if hasattr(torch, "set_float32_matmul_precision"):
try:
torch.set_float32_matmul_precision("high")
except Exception:
pass
def human_bytes(num: float) -> str:
num = float(num)
for unit in ["B", "KB", "MB", "GB", "TB"]:
if abs(num) < 1024.0:
return f"{num:.2f} {unit}"
num /= 1024.0
return f"{num:.2f} PB"
def value_to_text(value: Any) -> str:
if value is None:
return ""
if isinstance(value, str):
return value
if isinstance(value, dict):
for key in ("text", "content", "value"):
if key in value:
return value_to_text(value.get(key))
return "\n".join(value_to_text(v) for v in value.values() if value_to_text(v))
if isinstance(value, (list, tuple)):
return "\n".join(value_to_text(v) for v in value if value_to_text(v))
return str(value)
def clean_text(text: Any) -> str:
text = value_to_text(text).replace("\\n", "\n")
cut_points = [text.find(s) for s in STOP_STRINGS if s in text and text.find(s) > 0]
if cut_points:
text = text[: min(cut_points)]
for s in STOP_STRINGS:
text = text.replace(s, "")
text = text.strip()
for prefix in ("Assistant:", "assistant:", "Lulu:", "lulu:"):
if text.startswith(prefix):
text = text[len(prefix):].strip()
lines = [ln.rstrip() for ln in text.splitlines()]
# collapse excessive vertical whitespace without destroying code blocks too much
out: List[str] = []
blank = 0
for ln in lines:
if not ln.strip():
blank += 1
if blank <= 2:
out.append("")
else:
blank = 0
out.append(ln)
return "\n".join(out).strip()
def normalize_history(history: Any) -> List[Dict[str, str]]:
out: List[Dict[str, str]] = []
if not history:
return out
for item in history:
if isinstance(item, dict):
role = item.get("role", "")
content = clean_text(item.get("content", ""))
if role in {"user", "assistant"} and content:
out.append({"role": role, "content": content})
elif isinstance(item, (tuple, list)) and len(item) >= 2:
u = clean_text(item[0])
a = clean_text(item[1])
if u:
out.append({"role": "user", "content": u})
if a:
out.append({"role": "assistant", "content": a})
return out
def resolve_model_py(model_py: Optional[str]) -> str:
candidates: List[str] = []
if model_py:
candidates.append(model_py)
candidates.extend(["luluv2_inference_runtime.py"])
for c in candidates:
p = Path(c)
if p.exists():
return str(p.resolve())
raise FileNotFoundError("Could not find LULUV2 model file. Pass --model-py.")
def import_model_py(model_py: Optional[str]):
path = resolve_model_py(model_py)
spec = importlib.util.spec_from_file_location("luluv2_runtime_module", path)
if spec is None or spec.loader is None:
raise RuntimeError(f"Could not import model file: {path}")
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod, path
@dataclass
class GenerationConfig:
max_new_tokens: int = 512
temperature: float = 0.65
top_k: int = 40
top_p: float = 0.90
min_p: float = 0.03
repetition_penalty: float = 1.10
frequency_penalty: float = 0.02
greedy: bool = False
no_repeat_ngram: int = 4
stream_every: int = 1
max_context_tokens: int = 4096
mode: str = "vwm" # fast, vwm, deep, slow
return_pass_metrics: bool = True
use_cache: bool = True
vectorized_prefill: bool = True
persistent_cache: bool = True
compile_step: bool = False
@dataclass
class GenerationStats:
prompt_tokens: int = 0
prompt_total_tokens: int = 0
prompt_kept_tokens: int = 0
prompt_dropped_tokens: int = 0
generated_tokens: int = 0
elapsed_sec: float = 0.0
tokens_per_sec: float = 0.0
prefill_sec: float = 0.0
prefill_tps: float = 0.0
cache_hit: bool = False
cache_reused_tokens: int = 0
cache_new_prefill_tokens: int = 0
mode: str = "vwm"
backend: str = "none"
last_token: str = ""
last_token_id: int = -1
last_token_prob: float = 0.0
last_entropy: float = 0.0
finish_reason: str = "none"
pass1_pass2_kl: Optional[float] = None
pass1_pass2_logit_cosine: Optional[float] = None
class KVLayerCache:
def __init__(self):
self.k: Optional[torch.Tensor] = None # [B, H, T, Dh]
self.v: Optional[torch.Tensor] = None
@property
def length(self) -> int:
if self.k is None:
return 0
return int(self.k.shape[2])
def set(self, k: torch.Tensor, v: torch.Tensor, max_len: int) -> None:
if k.shape[2] > max_len:
k = k[:, :, -max_len:, :]
v = v[:, :, -max_len:, :]
self.k = k.detach().contiguous()
self.v = v.detach().contiguous()
def append(self, k: torch.Tensor, v: torch.Tensor, max_len: int) -> None:
if self.k is None:
self.set(k, v, max_len)
return
self.k = torch.cat([self.k, k.detach()], dim=2)
self.v = torch.cat([self.v, v.detach()], dim=2)
if self.k.shape[2] > max_len:
self.k = self.k[:, :, -max_len:, :].contiguous()
self.v = self.v[:, :, -max_len:, :].contiguous()
class DecoderKVCache:
def __init__(self, n_layers: int):
self.layers = [KVLayerCache() for _ in range(int(n_layers))]
def clear(self):
for layer in self.layers:
layer.k = None
layer.v = None
@property
def length(self) -> int:
if not self.layers:
return 0
return self.layers[0].length
class LULUV2OptimizedEngine:
def __init__(
self,
ckpt_path: str,
model_py: Optional[str] = None,
tokenizer_dir: Optional[str] = None,
device: Optional[str] = None,
dtype: str = "bf16",
local_files_only: bool = True,
no_config_download: bool = True,
force_base_only: bool = False,
):
setup_torch()
self.ckpt_path = str(ckpt_path)
self.ckpt_dir = Path(self.ckpt_path).resolve().parent
self.device = self._select_device(device)
self.dtype = self._dtype_from_name(dtype)
self.local_files_only = bool(local_files_only)
self.no_config_download = bool(no_config_download)
self.force_base_only = bool(force_base_only)
self.last_stats = GenerationStats()
self.recent_tokens: List[Dict[str, Any]] = []
self.last_prompt_total_tokens: int = 0
self.last_prompt_kept_tokens: int = 0
self.last_prompt_dropped_tokens: int = 0
self.cache_ids: Optional[torch.Tensor] = None
self.cache_mode: str = ""
self.cache_max_context: int = 0
self.pass1_cache: Optional[DecoderKVCache] = None
self.pass2_cache: Optional[DecoderKVCache] = None
self.cached_logits: Optional[torch.Tensor] = None
self.cached_pass1_logits: Optional[torch.Tensor] = None
self.cached_pass2_logits: Optional[torch.Tensor] = None
self.cache_backend: str = "cold"
self.goku, self.model_py_path = import_model_py(model_py)
self.args = SimpleNamespace(
checkpoint=self.ckpt_path,
tokenizer=tokenizer_dir or "",
model_id="",
no_config_download=self.no_config_download,
local_files_only=self.local_files_only,
)
print("[guard] LULUV2 cockpit: no AutoModelForCausalLM.from_pretrained call and no external model weights loaded.")
print(f"[load] checkpoint={self.ckpt_path}")
self.base_ckpt, base = self.goku.load_lulu2_base(self.args, self.device, self.dtype)
self.tokenizer = self._load_tokenizer(tokenizer_dir)
self.model, self.has_pass2 = self._maybe_wrap_pass2(base)
self.base = self.model.base if self.has_pass2 else self.model
self.n_layers = int(self.base.config.num_hidden_layers)
self.model.eval()
self.base.eval()
self.model_info = self._build_model_info()
self._compiled = False
def _select_device(self, device: Optional[str]):
if device:
return torch.device(device)
if torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")
def _dtype_from_name(self, name: str):
name = (name or "bf16").lower()
if name in {"bf16", "bfloat16"}:
return torch.bfloat16
if name in {"fp16", "float16", "half"}:
return torch.float16
return torch.float32
def _load_tokenizer(self, tokenizer_dir: Optional[str]):
if tokenizer_dir:
self.args.tokenizer = tokenizer_dir
else:
sibling = self.ckpt_dir / "tokenizer"
if sibling.is_dir():
self.args.tokenizer = str(sibling)
tok = self.goku.load_tokenizer(self.args, self.base_ckpt)
if getattr(tok, "pad_token_id", None) is None and getattr(tok, "eos_token_id", None) is not None:
try:
tok.pad_token = tok.eos_token
except Exception:
pass
# Long-prompt safety: for chat/RAG prompts, the latest user turn and final
# instruction are normally at the end. Right-side truncation silently drops
# exactly the part the model must answer, so force left truncation where the
# tokenizer supports it. encode() below also performs manual left truncation
# and records how many tokens were dropped.
try:
tok.truncation_side = "left"
except Exception:
pass
try:
tok.model_max_length = 10**9
except Exception:
pass
return tok
def _maybe_wrap_pass2(self, base):
ckpt = self.base_ckpt
if self.force_base_only or "pass2_state" not in ckpt:
print("[pass2] no pass2_state loaded; running base LULUV2 forward")
return base.to(self.device).eval(), False
cfg_dict = dict(ckpt.get("pass2_config") or {})
Pass2Config = self.goku.Pass2Config
fields = getattr(Pass2Config, "__dataclass_fields__", {})
pass2_cfg = Pass2Config(**{k: v for k, v in cfg_dict.items() if k in fields})
model = self.goku.Lulu2TwoPassForCausalLM(base, pass2_cfg)
missing, unexpected = model.load_state_dict(ckpt["pass2_state"], strict=False)
print(f"[pass2] loaded pass2_state missing={len(missing)} unexpected={len(unexpected)}")
model.to(device=self.device, dtype=self.dtype).eval()
return model, True
def _build_model_info(self) -> Dict[str, Any]:
total_params = sum(p.numel() for p in self.model.parameters())
c_codes = [(n, p.numel()) for n, p in self.model.named_parameters() if n.endswith(".c")]
gate_mean = None
adapter_gate_mean = None
if self.has_pass2:
with torch.no_grad():
gate_mean = float(torch.sigmoid(self.model.layer_gates.float()).mean().item())
vals = [float(torch.sigmoid(ad.gate.float()).item()) for ad in self.model.adapters]
adapter_gate_mean = sum(vals) / max(1, len(vals))
ckpt_size = Path(self.ckpt_path).stat().st_size if Path(self.ckpt_path).exists() else 0
cfg = getattr(self.base, "config", None)
return {
"checkpoint": self.ckpt_path,
"checkpoint_size": human_bytes(ckpt_size),
"model_py": self.model_py_path,
"device": str(self.device),
"dtype": str(self.dtype).replace("torch.", ""),
"has_pass2": self.has_pass2,
"total_params": total_params,
"vwm_c_modules": len(c_codes),
"vwm_c_params": sum(n for _, n in c_codes),
"pass2_layer_gate_mean": gate_mean,
"pass2_adapter_gate_mean": adapter_gate_mean,
"hidden_size": getattr(cfg, "hidden_size", None),
"layers": getattr(cfg, "num_hidden_layers", None),
"heads": getattr(cfg, "num_attention_heads", None),
"kv_heads": getattr(cfg, "num_key_value_heads", None),
"max_position_embeddings": getattr(cfg, "max_position_embeddings", None),
}
def amp_context(self):
if self.device.type == "cuda" and self.dtype in (torch.bfloat16, torch.float16):
return torch.autocast("cuda", dtype=self.dtype)
return nullcontext()
def build_chat_prompt(
self,
message: str,
history: Any,
system_prompt: str,
memory_notes: str = "",
history_turns: int = 4,
extra_context: str = "",
) -> str:
history = normalize_history(history)
recent = history[-max(0, int(history_turns)) * 2:] if history_turns else []
system_chunks: List[str] = []
if system_prompt.strip():
system_chunks.append(system_prompt.strip())
if memory_notes.strip():
system_chunks.append("Useful memory notes:\n" + memory_notes.strip())
if extra_context.strip():
system_chunks.append("Relevant local context:\n" + extra_context.strip())
system = "\n\n".join(system_chunks)
messages: List[Dict[str, str]] = []
if system:
messages.append({"role": "system", "content": system})
messages.extend(recent)
messages.append({"role": "user", "content": clean_text(message)})
try:
return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
except Exception:
parts: List[str] = []
if system:
parts.append(f"<|im_start|>system\n{system}<|im_end|>")
for item in recent:
parts.append(f"<|im_start|>{item['role']}\n{item['content']}<|im_end|>")
parts.append(f"<|im_start|>user\n{clean_text(message)}<|im_end|>")
parts.append("<|im_start|>assistant\n")
return "\n".join(parts)
def encode(self, text: str, max_context_tokens: int) -> torch.Tensor:
"""Encode prompt with explicit left-truncation and accounting.
This avoids a common long-context failure mode: many tokenizers default to
right-side truncation, which keeps the beginning of a huge prompt and drops
the final user instruction. For chat, we almost always want the opposite.
"""
max_context = max(1, int(max_context_tokens))
try:
self.tokenizer.truncation_side = "left"
except Exception:
pass
# Tokenize without tokenizer-side truncation so we know exactly whether the
# prompt was clipped. The prompt already contains chat special tokens.
try:
enc = self.tokenizer(
text,
return_tensors="pt",
truncation=False,
add_special_tokens=False,
)
except TypeError:
enc = self.tokenizer(text, return_tensors="pt", truncation=False)
ids = enc.input_ids
total = int(ids.shape[1])
dropped = max(0, total - max_context)
if dropped > 0:
ids = ids[:, -max_context:].contiguous()
# Do not reuse an older conversation cache after a hard context trim;
# the logical prefix changed and reuse can make long prompts feel like
# they are "forgetting" pieces.
self.pass1_cache = None
self.pass2_cache = None
self.cache_ids = None
self.cached_logits = None
self.cached_pass1_logits = None
self.cached_pass2_logits = None
self.cache_backend = "truncated-rebuild"
self.last_prompt_total_tokens = total
self.last_prompt_kept_tokens = int(ids.shape[1])
self.last_prompt_dropped_tokens = dropped
return ids.to(self.device)
def _position_ids(self, T: int, offset: int = 0) -> torch.Tensor:
return torch.arange(offset, offset + T, device=self.device, dtype=torch.long).unsqueeze(0)
def _attn_prefill(self, attn, hidden_states: torch.Tensor, position_ids: torch.Tensor, cache: KVLayerCache, max_context: int) -> torch.Tensor:
bsz, q_len, _ = hidden_states.size()
query_states = attn.q_proj(hidden_states)
key_states = attn.k_proj(hidden_states)
value_states = attn.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, attn.num_heads, attn.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, attn.num_key_value_heads, attn.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, attn.num_key_value_heads, attn.head_dim).transpose(1, 2)
cos, sin = attn.rotary_emb(value_states, position_ids)
query_states, key_states = self.goku.apply_rotary_pos_emb(query_states, key_states, cos, sin)
key_states = self.goku.repeat_kv(key_states, attn.num_key_value_groups)
value_states = self.goku.repeat_kv(value_states, attn.num_key_value_groups)
cache.set(key_states, value_states, max_context)
attn_output = F.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=None, dropout_p=0.0, is_causal=True, scale=attn.scaling
)
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, attn.hidden_size)
return attn.o_proj(attn_output)
def _attn_step(self, attn, hidden_states: torch.Tensor, pos: int, cache: KVLayerCache, max_context: int) -> torch.Tensor:
bsz, q_len, _ = hidden_states.size()
assert q_len == 1
query_states = attn.q_proj(hidden_states)
key_states = attn.k_proj(hidden_states)
value_states = attn.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, attn.num_heads, attn.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, attn.num_key_value_heads, attn.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, attn.num_key_value_heads, attn.head_dim).transpose(1, 2)
position_ids = self._position_ids(1, pos)
cos, sin = attn.rotary_emb(value_states, position_ids)
query_states, key_states = self.goku.apply_rotary_pos_emb(query_states, key_states, cos, sin)
key_states = self.goku.repeat_kv(key_states, attn.num_key_value_groups)
value_states = self.goku.repeat_kv(value_states, attn.num_key_value_groups)
cache.append(key_states, value_states, max_context)
if cache.k is None or cache.v is None:
raise RuntimeError("KV cache append failed")
attn_output = F.scaled_dot_product_attention(
query_states, cache.k, cache.v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=attn.scaling
)
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, attn.hidden_size)
return attn.o_proj(attn_output)
def _layer_prefill(self, layer, hidden_states: torch.Tensor, position_ids: torch.Tensor, cache: KVLayerCache, max_context: int) -> torch.Tensor:
residual = hidden_states
x = layer.input_layernorm(hidden_states)
x = self._attn_prefill(layer.self_attn, x, position_ids, cache, max_context)
hidden_states = residual + x
residual = hidden_states
x = layer.post_attention_layernorm(hidden_states)
x = layer.mlp(x)
return residual + x
def _layer_step(self, layer, hidden_states: torch.Tensor, pos: int, cache: KVLayerCache, max_context: int) -> torch.Tensor:
residual = hidden_states
x = layer.input_layernorm(hidden_states)
x = self._attn_step(layer.self_attn, x, pos, cache, max_context)
hidden_states = residual + x
residual = hidden_states
x = layer.post_attention_layernorm(hidden_states)
x = layer.mlp(x)
return residual + x
@torch.no_grad()
def _prefill_pass1(self, input_ids: torch.Tensor, max_context: int, use_pass_embed: bool) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor, torch.Tensor]:
T = int(input_ids.shape[1])
position_ids = self._position_ids(T, 0)
cache = DecoderKVCache(self.n_layers)
h = self.base.model.embed_tokens(input_ids)
if use_pass_embed and self.has_pass2:
h = h + self.model.pass_embed[0].to(dtype=h.dtype, device=h.device).view(1, 1, -1)
layer_states: List[torch.Tensor] = []
for i, layer in enumerate(self.base.model.layers):
h = self._layer_prefill(layer, h, position_ids, cache.layers[i], max_context)
layer_states.append(h)
normed = self.base.model.norm(h)
logits = self.base.lm_head(normed)
self.pass1_cache = cache
return h, layer_states, position_ids, logits
@torch.no_grad()
def _prefill_pass2(self, h1_resid: torch.Tensor, pass1_states: List[torch.Tensor], position_ids: torch.Tensor, max_context: int) -> torch.Tensor:
if not self.has_pass2:
raise RuntimeError("pass2 requested but checkpoint has no pass2_state")
cache = DecoderKVCache(self.n_layers)
h2 = h1_resid + self.model.pass_embed[1].to(dtype=h1_resid.dtype, device=h1_resid.device).view(1, 1, -1)
for i, layer in enumerate(self.base.model.layers):
before = h2
layer_out = self._layer_prefill(layer, h2, position_ids, cache.layers[i], max_context)
layer_delta = layer_out - before
gate = torch.sigmoid(self.model.layer_gates[i]).to(dtype=h2.dtype, device=h2.device)
adapter_delta = self.model.adapters[i](h2, pass1_states[i])
h2 = before + gate * layer_delta + adapter_delta
normed = self.base.model.norm(h2)
logits = self.base.lm_head(normed)
self.pass2_cache = cache
return logits
@torch.no_grad()
def _step_pass1(self, token_id: torch.Tensor, pos: int, max_context: int, use_pass_embed: bool) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]:
if self.pass1_cache is None:
self.pass1_cache = DecoderKVCache(self.n_layers)
h = self.base.model.embed_tokens(token_id)
if use_pass_embed and self.has_pass2:
h = h + self.model.pass_embed[0].to(dtype=h.dtype, device=h.device).view(1, 1, -1)
states: List[torch.Tensor] = []
for i, layer in enumerate(self.base.model.layers):
h = self._layer_step(layer, h, pos, self.pass1_cache.layers[i], max_context)
states.append(h)
logits = self.base.lm_head(self.base.model.norm(h))
return h, states, logits
@torch.no_grad()
def _step_pass2(self, h1_resid: torch.Tensor, pass1_states: List[torch.Tensor], pos: int, max_context: int) -> torch.Tensor:
if not self.has_pass2:
raise RuntimeError("pass2 step requested but unavailable")
if self.pass2_cache is None:
self.pass2_cache = DecoderKVCache(self.n_layers)
h2 = h1_resid + self.model.pass_embed[1].to(dtype=h1_resid.dtype, device=h1_resid.device).view(1, 1, -1)
for i, layer in enumerate(self.base.model.layers):
before = h2
layer_out = self._layer_step(layer, h2, pos, self.pass2_cache.layers[i], max_context)
layer_delta = layer_out - before
gate = torch.sigmoid(self.model.layer_gates[i]).to(dtype=h2.dtype, device=h2.device)
adapter_delta = self.model.adapters[i](h2, pass1_states[i])
h2 = before + gate * layer_delta + adapter_delta
return self.base.lm_head(self.base.model.norm(h2))
def _ids_prefix_len(self, old: torch.Tensor, new: torch.Tensor) -> int:
if old is None or old.numel() == 0 or new.numel() == 0:
return 0
old1 = old[0]
new1 = new[0]
max_n = min(int(old1.numel()), int(new1.numel()))
if max_n == 0:
return 0
# Fast path: old is exact prefix of new.
if int(old1.numel()) <= int(new1.numel()) and torch.equal(old1, new1[: old1.numel()]):
return int(old1.numel())
# Conservative fallback, scan from max down; prompts are usually exact-prefix or reset.
for n in range(max_n, 0, -1):
if torch.equal(old1[:n], new1[:n]):
return n
return 0
@torch.no_grad()
def _token_prefill_context(self, input_ids: torch.Tensor, cfg: GenerationConfig, use_pass2: bool, use_pass_embed: bool, max_context: int) -> None:
"""
Conservative cache builder.
It fills the same pass1/pass2 KV caches by walking the prompt one token at a time.
This is slower than vectorized prefill but much safer across checkpoint/runtime variants,
and it still gives a valid decode cache + persistent cache for the generated tokens.
"""
self.pass1_cache = DecoderKVCache(self.n_layers)
self.pass2_cache = DecoderKVCache(self.n_layers) if use_pass2 else None
self.cached_logits = None
self.cached_pass1_logits = None
self.cached_pass2_logits = None
T = int(input_ids.shape[1])
for pos in range(T):
tok = input_ids[:, pos:pos + 1]
h1, states, logits1 = self._step_pass1(tok, pos, max_context, use_pass_embed=use_pass_embed)
if use_pass2:
logits2 = self._step_pass2(h1, states, pos, max_context)
self.cached_logits = logits2
self.cached_pass1_logits = logits1
self.cached_pass2_logits = logits2
else:
self.cached_logits = logits1
self.cached_pass1_logits = logits1
self.cached_pass2_logits = None
@torch.no_grad()
def _prepare_cached_context(self, input_ids: torch.Tensor, cfg: GenerationConfig) -> Tuple[torch.Tensor, bool, int, int, str]:
mode = self._effective_mode(cfg.mode)
max_context = int(cfg.max_context_tokens)
use_pass2 = mode in {"vwm", "deep"} and self.has_pass2
use_pass_embed = bool(use_pass2)
T = int(input_ids.shape[1])
if T > max_context:
input_ids = input_ids[:, -max_context:]
T = max_context
# If mode/context changed, persistent cache is invalid.
cache_ok = (
cfg.persistent_cache
and self.cache_ids is not None
and self.cache_mode == mode
and self.cache_max_context == max_context
and self.pass1_cache is not None
)
prefix = self._ids_prefix_len(self.cache_ids, input_ids) if cache_ok else 0
cache_hit = bool(cache_ok and prefix == int(self.cache_ids.shape[1]) and prefix <= T and prefix > 0)
t0 = time.time()
if cache_hit:
# Process only suffix between prior cached prompt and new prompt.
suffix = input_ids[:, prefix:]
for j in range(int(suffix.shape[1])):
tok = suffix[:, j : j + 1]
pos = prefix + j
h1, states, logits1 = self._step_pass1(tok, pos, max_context, use_pass_embed=use_pass_embed)
if use_pass2:
logits2 = self._step_pass2(h1, states, pos, max_context)
self.cached_logits = logits2
self.cached_pass1_logits = logits1
self.cached_pass2_logits = logits2
else:
self.cached_logits = logits1
self.cached_pass1_logits = logits1
self.cached_pass2_logits = None
self.cache_ids = input_ids.detach().clone()
self.cache_backend = "persistent-kv-suffix" if suffix.numel() else "persistent-kv-hit"
return input_ids, True, prefix, int(suffix.shape[1]), self.cache_backend
# Reset and prefill. Prefer vectorized prefill, but fall back to conservative
# token prefill if the runtime variant does not support our vectorized cache path.
self.pass1_cache = None
self.pass2_cache = None
backend = "vectorized-prefill"
if bool(cfg.vectorized_prefill):
try:
h1, states, pos_ids, logits1 = self._prefill_pass1(input_ids, max_context, use_pass_embed=use_pass_embed)
if use_pass2:
logits2 = self._prefill_pass2(h1, states, pos_ids, max_context)
self.cached_logits = logits2
self.cached_pass1_logits = logits1
self.cached_pass2_logits = logits2
else:
self.cached_logits = logits1
self.cached_pass1_logits = logits1
self.cached_pass2_logits = None
except Exception as exc:
if os.getenv("LULUV2_CACHE_DEBUG", "0").strip().lower() in {"1", "true", "yes", "on"}:
print("[cache] vectorized prefill failed; using token-prefill cache.")
traceback.print_exc()
self._token_prefill_context(input_ids, cfg, use_pass2=use_pass2, use_pass_embed=use_pass_embed, max_context=max_context)
backend = "token-prefill-cache"
else:
self._token_prefill_context(input_ids, cfg, use_pass2=use_pass2, use_pass_embed=use_pass_embed, max_context=max_context)
backend = "token-prefill-cache"
self.cache_ids = input_ids.detach().clone()
self.cache_mode = mode
self.cache_max_context = max_context
self.cache_backend = backend
return input_ids, False, 0, T, self.cache_backend
def _effective_mode(self, mode: str) -> str:
mode = (mode or "vwm").lower()
if mode in {"fast", "base", "pass1"}:
return "fast"
if mode in {"deep", "32k", "long"}:
return "deep"
if mode in {"slow", "full"}:
return "slow"
return "vwm"
@torch.no_grad()
def pass_metrics_from_logits(self, logits1: Optional[torch.Tensor], logits2: Optional[torch.Tensor]) -> Tuple[Optional[float], Optional[float]]:
if logits1 is None or logits2 is None:
return None, None
try:
l1 = logits1[:, -1, :].float()
l2 = logits2[:, -1, :].float()
kl = F.kl_div(F.log_softmax(l2, dim=-1), F.softmax(l1, dim=-1), reduction="batchmean")
cos = F.cosine_similarity(l1, l2, dim=-1).mean()
return float(kl.item()), float(cos.item())
except Exception:
return None, None
def _apply_penalties(self, logits: torch.Tensor, generated: torch.Tensor, cfg: GenerationConfig) -> torch.Tensor:
if generated.numel() == 0:
return logits
out = logits.clone()
uniq, counts = torch.unique(generated.view(-1), return_counts=True)
if cfg.repetition_penalty != 1.0:
selected = out[:, uniq]
selected = torch.where(selected > 0, selected / float(cfg.repetition_penalty), selected * float(cfg.repetition_penalty))
out[:, uniq] = selected
if cfg.frequency_penalty:
out[:, uniq] -= float(cfg.frequency_penalty) * counts.to(out.dtype).unsqueeze(0)
n = int(cfg.no_repeat_ngram)
if n > 1 and generated.size(1) >= n - 1:
seq = generated[0].tolist()
prefix = tuple(seq[-(n - 1):])
banned = []
for i in range(len(seq) - n + 1):
if tuple(seq[i:i + n - 1]) == prefix:
banned.append(seq[i + n - 1])
if banned:
out[:, list(set(banned))] = -float("inf")
return out
@torch.no_grad()
def _sample_next(self, logits: torch.Tensor, generated: torch.Tensor, cfg: GenerationConfig) -> Tuple[torch.Tensor, Dict[str, float]]:
work = self._apply_penalties(logits.float(), generated, cfg)
if cfg.greedy or cfg.temperature <= 0:
probs = torch.softmax(work, dim=-1)
next_id = torch.argmax(work, dim=-1, keepdim=True)
else:
work = work / max(float(cfg.temperature), 1e-6)
if cfg.top_k > 0:
k = min(int(cfg.top_k), work.size(-1))
thresh = torch.topk(work, k, dim=-1).values[..., -1, None]
work = torch.where(work >= thresh, work, torch.full_like(work, -float("inf")))
if 0.0 < cfg.top_p < 1.0:
sorted_logits, sorted_idx = torch.sort(work, descending=True, dim=-1)
sorted_probs = torch.softmax(sorted_logits, dim=-1)
cumprobs = torch.cumsum(sorted_probs, dim=-1)
remove = cumprobs > float(cfg.top_p)
shifted = remove.clone()
shifted[..., 1:] = remove[..., :-1]
shifted[..., 0] = False
sorted_logits = sorted_logits.masked_fill(shifted, -float("inf"))
work = torch.full_like(work, -float("inf")).scatter(1, sorted_idx, sorted_logits)
if 0.0 < cfg.min_p < 1.0:
probs_for_minp = torch.softmax(work, dim=-1)
max_prob = probs_for_minp.max(dim=-1, keepdim=True).values
keep = probs_for_minp >= float(cfg.min_p) * max_prob
work = work.masked_fill(~keep, -float("inf"))
probs = torch.softmax(work, dim=-1)
if torch.isnan(probs).any() or not torch.isfinite(probs.sum()) or float(probs.sum()) <= 0:
next_id = torch.argmax(logits, dim=-1, keepdim=True)
probs = torch.softmax(logits.float(), dim=-1)
else:
next_id = torch.multinomial(probs, 1)
prob = float(probs.gather(1, next_id).item()) if probs.numel() else 0.0
entropy = float((-(probs * torch.log(probs.clamp_min(1e-12))).sum(dim=-1)).mean().item()) if probs.numel() else 0.0
return next_id, {"prob": prob, "entropy": entropy}
@torch.no_grad()
def _slow_generate(self, ids: torch.Tensor, prompt_len: int, cfg: GenerationConfig) -> Generator[str, None, None]:
# Compatibility path: full prefix recompute every token.
eos_id = getattr(self.tokenizer, "eos_token_id", None)
last_text = ""
t0 = time.time()
for step in range(int(cfg.max_new_tokens)):
ctx = ids[:, -int(cfg.max_context_tokens):]
with self.amp_context():
out = self.model(ctx) if self._effective_mode(cfg.mode) != "fast" else self.base(ctx)
logits = out.logits[:, -1, :].float()
generated = ids[:, prompt_len:]
next_id, tok_stats = self._sample_next(logits, generated, cfg)
ids = torch.cat([ids, next_id.to(ids.device)], dim=-1)
token_id = int(next_id.item())
token_text = self.tokenizer.decode([token_id], skip_special_tokens=False)
self._record_token(step + 1, token_id, token_text, tok_stats)
if eos_id is not None and token_id == int(eos_id):
break
if (step + 1) % int(cfg.stream_every) == 0 or step == 0:
raw = self.tokenizer.decode(ids[0, prompt_len:], skip_special_tokens=True)
if any(s in raw for s in STOP_STRINGS):
break
text = clean_text(raw)
if text and text != last_text:
elapsed = time.time() - t0
gen = int(ids.shape[1]) - prompt_len
self.last_stats = GenerationStats(prompt_tokens=prompt_len, prompt_total_tokens=self.last_prompt_total_tokens, prompt_kept_tokens=self.last_prompt_kept_tokens, prompt_dropped_tokens=self.last_prompt_dropped_tokens, generated_tokens=gen, elapsed_sec=elapsed, tokens_per_sec=gen / max(elapsed, 1e-9), mode=cfg.mode, backend="slow-full-prefix", last_token=token_text, last_token_id=token_id, last_token_prob=tok_stats["prob"], last_entropy=tok_stats["entropy"], finish_reason="streaming")
last_text = text
yield text
final = clean_text(self.tokenizer.decode(ids[0, prompt_len:], skip_special_tokens=True))
if final:
yield final
def _record_token(self, i: int, token_id: int, token_text: str, tok_stats: Dict[str, float]) -> None:
self.recent_tokens.append({"i": i, "id": token_id, "text": token_text, "prob": tok_stats.get("prob", 0.0), "entropy": tok_stats.get("entropy", 0.0)})
self.recent_tokens = self.recent_tokens[-64:]
@torch.no_grad()
def generate(self, prompt: str, cfg: GenerationConfig) -> Generator[str, None, None]:
self.model.eval()
self.base.eval()
self.recent_tokens = []
mode = self._effective_mode(cfg.mode)
if mode == "deep":
cfg.max_context_tokens = max(int(cfg.max_context_tokens), 16384)
ids = self.encode(prompt, max_context_tokens=int(cfg.max_context_tokens))
prompt_len = int(ids.shape[1])
if self.last_prompt_dropped_tokens > 0:
print(f"[context] prompt clipped: kept={self.last_prompt_kept_tokens} total={self.last_prompt_total_tokens} dropped={self.last_prompt_dropped_tokens}")
t_start = time.time()
prefill_sec = 0.0
cache_hit = False
reused = 0
new_prefill = prompt_len
backend = ""
pass_kl = None
pass_cos = None
if (not cfg.use_cache) or mode == "slow":
yield from self._slow_generate(ids, prompt_len, cfg)
return
try:
with self.amp_context():
t_pref = time.time()
ids, cache_hit, reused, new_prefill, backend = self._prepare_cached_context(ids, cfg)
prefill_sec = time.time() - t_pref
pass_kl, pass_cos = self.pass_metrics_from_logits(self.cached_pass1_logits, self.cached_pass2_logits) if cfg.return_pass_metrics else (None, None)
except Exception as exc:
print(f"[cache] cached path failed; falling back to slow full-prefix: {type(exc).__name__}: {exc}")
if os.getenv("LULUV2_CACHE_DEBUG", "0").strip().lower() in {"1", "true", "yes", "on"}:
traceback.print_exc()
self.pass1_cache = None
self.pass2_cache = None
self.cache_ids = None
yield from self._slow_generate(ids, prompt_len, cfg)
return
eos_id = getattr(self.tokenizer, "eos_token_id", None)
last_text = ""
finish_reason = "length"
use_pass2 = mode in {"vwm", "deep"} and self.has_pass2
use_pass_embed = bool(use_pass2)
for step in range(int(cfg.max_new_tokens)):
logits = self.cached_logits[:, -1, :].float() if self.cached_logits is not None and self.cached_logits.dim() == 3 else self.cached_logits.float()
generated = ids[:, prompt_len:]
next_id, tok_stats = self._sample_next(logits, generated, cfg)
token_id = int(next_id.item())
token_text = self.tokenizer.decode([token_id], skip_special_tokens=False)
self._record_token(step + 1, token_id, token_text, tok_stats)
ids = torch.cat([ids, next_id.to(ids.device)], dim=-1)
if eos_id is not None and token_id == int(eos_id):
finish_reason = "eos"
break
pos = int(ids.shape[1]) - 1
try:
with self.amp_context():
h1, states, logits1 = self._step_pass1(next_id.to(self.device), pos, int(cfg.max_context_tokens), use_pass_embed=use_pass_embed)
if use_pass2:
logits2 = self._step_pass2(h1, states, pos, int(cfg.max_context_tokens))
self.cached_logits = logits2
self.cached_pass1_logits = logits1
self.cached_pass2_logits = logits2
else:
self.cached_logits = logits1
self.cached_pass1_logits = logits1
self.cached_pass2_logits = None
if self.cache_ids is not None:
self.cache_ids = torch.cat([self.cache_ids, next_id.detach().to(self.cache_ids.device)], dim=-1)
if self.cache_ids.shape[1] > int(cfg.max_context_tokens):
self.cache_ids = self.cache_ids[:, -int(cfg.max_context_tokens):]
except Exception as exc:
print(f"[decode-cache] step failed; falling back for this request: {type(exc).__name__}: {exc}")
# Finish with slow path from current ids; do not pretend cache is valid.
self.cache_ids = None
yield from self._slow_generate(ids, prompt_len, cfg)
return
if (step + 1) % int(cfg.stream_every) == 0 or step == 0:
raw = self.tokenizer.decode(ids[0, prompt_len:], skip_special_tokens=True)
if any(s in raw for s in STOP_STRINGS):
finish_reason = "stop_string"
break
text = clean_text(raw)
if text and text != last_text:
elapsed = time.time() - t_start
gen = int(ids.shape[1]) - prompt_len
self.last_stats = GenerationStats(
prompt_tokens=prompt_len,
prompt_total_tokens=self.last_prompt_total_tokens,
prompt_kept_tokens=self.last_prompt_kept_tokens,
prompt_dropped_tokens=self.last_prompt_dropped_tokens,
generated_tokens=gen,
elapsed_sec=elapsed,
tokens_per_sec=gen / max(elapsed - prefill_sec, 1e-9),
prefill_sec=prefill_sec,
prefill_tps=(new_prefill / max(prefill_sec, 1e-9)),
cache_hit=cache_hit,
cache_reused_tokens=reused,
cache_new_prefill_tokens=new_prefill,
mode=mode,
backend=backend,
last_token=token_text,
last_token_id=token_id,
last_token_prob=tok_stats["prob"],
last_entropy=tok_stats["entropy"],
finish_reason="streaming",
pass1_pass2_kl=pass_kl,
pass1_pass2_logit_cosine=pass_cos,
)
last_text = text
yield text
raw = self.tokenizer.decode(ids[0, prompt_len:], skip_special_tokens=True)
final = clean_text(raw)
elapsed = time.time() - t_start
gen = int(ids.shape[1]) - prompt_len
self.last_stats = GenerationStats(
prompt_tokens=prompt_len,
prompt_total_tokens=self.last_prompt_total_tokens,
prompt_kept_tokens=self.last_prompt_kept_tokens,
prompt_dropped_tokens=self.last_prompt_dropped_tokens,
generated_tokens=gen,
elapsed_sec=elapsed,
tokens_per_sec=gen / max(elapsed - prefill_sec, 1e-9),
prefill_sec=prefill_sec,
prefill_tps=(new_prefill / max(prefill_sec, 1e-9)),
cache_hit=cache_hit,
cache_reused_tokens=reused,
cache_new_prefill_tokens=new_prefill,
mode=mode,
backend=backend,
last_token=self.recent_tokens[-1]["text"] if self.recent_tokens else "",
last_token_id=self.recent_tokens[-1]["id"] if self.recent_tokens else -1,
last_token_prob=self.recent_tokens[-1]["prob"] if self.recent_tokens else 0.0,
last_entropy=self.recent_tokens[-1]["entropy"] if self.recent_tokens else 0.0,
finish_reason=finish_reason,
pass1_pass2_kl=pass_kl,
pass1_pass2_logit_cosine=pass_cos,
)
if final:
yield final
def clear_session_cache(self) -> None:
self.pass1_cache = None
self.pass2_cache = None
self.cache_ids = None
self.cached_logits = None
self.cached_pass1_logits = None
self.cached_pass2_logits = None
self.cache_backend = "cleared"
def stats_dict(self) -> Dict[str, Any]:
return {"generation": asdict(self.last_stats), "model": self.model_info, "system": system_snapshot(self)}
def stats_text(self) -> str:
s = self.last_stats
lines = [
f"Mode: {s.mode} | backend={s.backend}",
f"Prompt tokens: {s.prompt_tokens} kept / {getattr(s, 'prompt_total_tokens', s.prompt_tokens)} total / {getattr(s, 'prompt_dropped_tokens', 0)} dropped",
f"Generated tokens: {s.generated_tokens}",
f"Elapsed: {s.elapsed_sec:.2f}s | prefill={s.prefill_sec:.2f}s ({s.prefill_tps:.1f} tok/s)",
f"Decode speed: {s.tokens_per_sec:.2f} tok/s",
f"Cache: hit={s.cache_hit} reused={s.cache_reused_tokens} new_prefill={s.cache_new_prefill_tokens}",
f"Finish reason: {s.finish_reason}",
f"Last token: {s.last_token!r} id={s.last_token_id} p={s.last_token_prob:.4f} H={s.last_entropy:.2f}",
]
if s.pass1_pass2_kl is not None:
lines.append(f"Pass1→Pass2 KL: {s.pass1_pass2_kl:.6f}")
if s.pass1_pass2_logit_cosine is not None:
lines.append(f"Pass1/Pass2 cosine: {s.pass1_pass2_logit_cosine:.6f}")
lines.extend([
"",
f"Checkpoint: {self.model_info['checkpoint']}",
f"Checkpoint size: {self.model_info['checkpoint_size']}",
f"Device: {self.model_info['device']} dtype={self.model_info['dtype']}",
f"Pass2 active: {self.model_info['has_pass2']}",
f"Params: {self.model_info['total_params']:,}",
f"VWM c modules: {self.model_info['vwm_c_modules']} ({self.model_info['vwm_c_params']:,} c params)",
])
return "\n".join(lines)
def token_trace_text(self) -> str:
if not self.recent_tokens:
return "No tokens generated yet."
rows = []
for t in self.recent_tokens[-48:]:
safe = repr(t["text"])[1:-1]
rows.append(f"{t['i']:04d} id={t['id']:<7} p={t['prob']:.4f} H={t['entropy']:.2f} {safe}")
return "\n".join(rows)
def system_snapshot(engine: Optional[LULUV2OptimizedEngine] = None) -> Dict[str, Any]:
snap: Dict[str, Any] = {
"python_ram": "n/a", "system_ram": "n/a", "system_ram_percent": 0.0,
"cpu_percent": 0.0, "gpu_name": "CUDA unavailable", "vram_allocated": "n/a",
"vram_reserved": "n/a", "vram_used": "n/a", "vram_total": "n/a",
"vram_percent": 0.0, "gpu_util_percent": None, "gpu_temp_c": None,
}
if psutil is not None:
try:
proc = psutil.Process(os.getpid())
vm = psutil.virtual_memory()
snap.update({
"python_ram": human_bytes(proc.memory_info().rss),
"system_ram": f"{human_bytes(vm.used)} / {human_bytes(vm.total)}",
"system_ram_percent": float(vm.percent),
"cpu_percent": float(psutil.cpu_percent(interval=0.0)),
})
except Exception:
pass
if torch.cuda.is_available():
try:
idx = torch.cuda.current_device()
props = torch.cuda.get_device_properties(idx)
allocated = int(torch.cuda.memory_allocated(idx))
reserved = int(torch.cuda.memory_reserved(idx))
total = int(props.total_memory)
snap.update({
"gpu_name": props.name,
"vram_allocated": human_bytes(allocated),
"vram_reserved": human_bytes(reserved),
"vram_used": human_bytes(allocated),
"vram_total": human_bytes(total),
"vram_percent": 100.0 * allocated / max(total, 1),
})
if pynvml is not None:
try:
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(idx)
util = pynvml.nvmlDeviceGetUtilizationRates(handle)
mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
temp = pynvml.nvmlDeviceGetTemperature(handle, pynvml.NVML_TEMPERATURE_GPU)
snap.update({
"gpu_util_percent": int(util.gpu),
"vram_used": human_bytes(int(mem.used)),
"vram_total": human_bytes(int(mem.total)),
"vram_percent": 100.0 * float(mem.used) / max(float(mem.total), 1.0),
"gpu_temp_c": int(temp),
})
except Exception:
pass
except Exception:
pass
return snap
def system_usage(engine: Optional[LULUV2OptimizedEngine] = None) -> str:
snap = system_snapshot(engine)
lines = [
f"OS: {platform.system()} {platform.release()}",
f"Python RAM: {snap['python_ram']}",
f"System RAM: {snap['system_ram']} ({snap['system_ram_percent']:.1f}%)",
f"CPU: {snap['cpu_percent']:.1f}%",
"",
f"GPU: {snap['gpu_name']}",
f"VRAM used: {snap['vram_used']} / {snap['vram_total']} ({snap['vram_percent']:.1f}%)",
f"VRAM allocated: {snap['vram_allocated']}",
f"VRAM reserved: {snap['vram_reserved']}",
]
if snap.get("gpu_util_percent") is not None:
lines.append(f"GPU util: {snap['gpu_util_percent']}%")
if snap.get("gpu_temp_c") is not None:
lines.append(f"GPU temp: {snap['gpu_temp_c']} C")
if engine is not None:
lines.extend(["", engine.stats_text()])
return "\n".join(lines)