ROE_EDU_BASE_Undercooked / inference_tester.py
Alienanthony's picture
Upload of model inferencing and svg
9cd89a6 verified
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import argparse
import time
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer
import torch.utils.checkpoint as cp
import os
# ----------------------------------------------------------------------------
# mamba-ssm dependency
# ----------------------------------------------------------------------------
try:
from mamba_ssm import Mamba
from mamba_ssm.utils.generation import InferenceParams
_HAS_MAMBA = True
except ImportError:
_HAS_MAMBA = False
InferenceParams = None
print("=" * 80)
print("[WARNING] mamba-ssm not installed. Mamba layers will not function.")
print("Install with: pip install mamba-ssm")
print("=" * 80)
class Mamba(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
print("ERROR: Mamba placeholder. mamba-ssm not installed.")
def forward(self, x, *args, **kwargs):
print("ERROR: mamba-ssm not installed. Cannot run MambaBlock.")
return x
# ----------------------------------------------------------------------------
# Model
# ----------------------------------------------------------------------------
@dataclass
class AdaptiveRiverConfig:
vocab_size: int = 50257
d_model: int = 1024
n_layers: int = 24
d_ff: int = 4096
dropout: float = 0.0
rope_theta: float = 10000.0
rotary_pct: float = 1.0
layer_norm_eps: float = 1e-5
rope_scaling_type: str | None = None
rope_scaling_factor: float = 1.0
experts_per_layer: int = 4
top_k_ffn: int = 1
moe_dropout: float = 0.0
attn_n_experts: int = 6
attn_top_k: int = 6
attn_n_orig_heads: int = 16
mamba_d_state: int = 16
mamba_d_conv: int = 4
mamba_expand: int = 2
entropy_weight: float = 1e-4
head_entropy_weight: float = 1e-4
default_budget_ratio: float = 1.0
init_std: float = 0.02
tie_word_embeddings: bool = False # untied head (matches training)
load_balance_weight: float = 0.01
router_z_weight: float = 0.001
gate_temperature: float = 0.7
checkpoint_attn_thresh: float = 0.35
checkpoint_ffn_thresh: float = 0.35
soak_dtype: str = "fp32"
def _init_weights(module: nn.Module, std: float):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
nn.init.zeros_(module.bias)
def topk_mask_ste(scores: torch.Tensor, k: int) -> torch.Tensor:
s = scores.float()
if k >= s.size(-1):
return torch.ones_like(s)
topk = torch.topk(s, k=k, dim=-1).indices
one_hot = torch.zeros_like(s)
one_hot.scatter_(dim=-1, index=topk, value=1.0)
probs = F.softmax(s, dim=-1)
return one_hot + probs - probs.detach()
class RotaryEmbedding(nn.Module):
def __init__(self, dim, base=10000.0, scaling_type: str | None = None, scaling_factor: float = 1.0):
super().__init__()
self.dim = dim
self.base = float(base)
self.scaling_type = scaling_type
self.scaling_factor = float(scaling_factor)
base = self._effective_base()
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.float32) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._cos_sin_cache = None
self._cos_sin_cache_device = None
self._cos_sin_cache_dtype = None
self._cos_sin_max_seq_len = -1
def _effective_base(self) -> float:
if not self.scaling_type or self.scaling_factor == 1.0:
return self.base
if self.scaling_type in ("ntk", "linear", "yarn"):
return self.base * self.scaling_factor
return self.base
def _get_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype):
if (seq_len > self._cos_sin_max_seq_len or self._cos_sin_cache is None
or self._cos_sin_cache_device != device or self._cos_sin_cache_dtype != dtype):
self._cos_sin_max_seq_len = max(seq_len, 2048)
t = torch.arange(self._cos_sin_max_seq_len, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().to(dtype)
sin = emb.sin().to(dtype)
self._cos_sin_cache = (cos, sin)
self._cos_sin_cache_device = device
self._cos_sin_cache_dtype = dtype
return self._cos_sin_cache
def forward(self, x, seq_len: int, offset: int | torch.Tensor = 0):
device, dtype = x.device, x.dtype
cos, sin = self._get_cos_sin_cache(seq_len + int(offset), device, dtype)
if isinstance(offset, torch.Tensor):
if offset.numel() > 1:
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype).float()
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
cos_val = emb.cos()[None, None, :, :].to(dtype)
sin_val = emb.sin()[None, None, :, :].to(dtype)
return cos_val, sin_val
else:
offset = int(offset.item())
cos = cos[offset:offset+seq_len].unsqueeze(0).unsqueeze(0)
sin = sin[offset:offset+seq_len].unsqueeze(0).unsqueeze(0)
return cos, sin
def apply_rotary(x, cos, sin):
x1, x2 = x[..., ::2], x[..., 1::2]
x_rot = torch.stack((-x2, x1), dim=-1).flatten(-2)
return x * cos + x_rot * sin
class PTLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-5):
super().__init__()
self.ln = nn.LayerNorm(hidden_size, eps=eps)
def forward(self, x):
return self.ln(x)
class GlobalSDPAHead(nn.Module):
def __init__(self, d_model, head_dim, dropout, rope_theta, rotary_pct, cfg):
super().__init__()
self.q_proj = nn.Linear(d_model, head_dim, bias=False)
self.k_proj = nn.Linear(d_model, head_dim, bias=False)
self.v_proj = nn.Linear(d_model, head_dim, bias=False)
self.rotary_dim = int(head_dim * rotary_pct)
self.dropout_p = dropout
self.rope = None
if self.rotary_dim > 0:
self.rope = RotaryEmbedding(
self.rotary_dim, base=rope_theta,
scaling_type=cfg.rope_scaling_type,
scaling_factor=cfg.rope_scaling_factor,
)
def forward(self, x, position_offset):
if isinstance(position_offset, torch.Tensor):
position_offset = int(position_offset.view(-1)[0].item())
else:
position_offset = int(position_offset)
B, T, C = x.shape
q, k, v = self.q_proj(x), self.k_proj(x), self.v_proj(x)
if self.rotary_dim > 0:
cos, sin = self.rope(q, seq_len=T, offset=position_offset)
cos = cos.squeeze(1); sin = sin.squeeze(1)
q_rot = apply_rotary(q[..., :self.rotary_dim], cos, sin)
k_rot = apply_rotary(k[..., :self.rotary_dim], cos, sin)
q = torch.cat([q_rot, q[..., self.rotary_dim:]], dim=-1)
k = torch.cat([k_rot, k[..., self.rotary_dim:]], dim=-1)
q, k, v = [t.unsqueeze(1) for t in (q, k, v)]
dropout_p = self.dropout_p if self.training else 0.0
out = F.scaled_dot_product_attention(q, k, v, is_causal=True, dropout_p=dropout_p)
return out.squeeze(1)
class AttentionMoERouter(nn.Module):
def __init__(self, d_model, num_experts, top_k):
super().__init__()
self.top_k = top_k
self.num_experts = num_experts
self.gate_proj = nn.Linear(d_model, num_experts, bias=False)
nn.init.normal_(self.gate_proj.weight, mean=0.0, std=0.01)
def forward(self, x, budget_ratio, temperature):
seq_embed = x.mean(dim=1)
logits = self.gate_proj(seq_embed) / max(1e-6, float(temperature))
logits = logits.clamp(min=-10.0, max=10.0)
k_target = max(1, int(round(self.top_k * (0.25 + 0.75 * budget_ratio))))
k_target = min(k_target, logits.size(-1))
vals, idx = torch.topk(logits, k_target, dim=-1)
weights = F.softmax(vals.to(torch.float32), dim=-1).to(x.dtype)
mask = torch.zeros_like(logits, dtype=torch.bool)
mask.scatter_(1, idx, True)
with torch.no_grad():
p = F.softmax(logits, dim=-1)
entropy = -(p * (p.clamp_min(1e-12)).log()).sum(dim=-1).mean()
return mask, weights, idx, entropy, logits
class MoEAttention(nn.Module):
def __init__(self, cfg: AdaptiveRiverConfig):
super().__init__()
self.d_model = cfg.d_model
self.n_experts = cfg.attn_n_experts
self.cfg = cfg
self.head_dim = cfg.d_model // cfg.attn_n_orig_heads
self.rotary_dim = int(self.head_dim * cfg.rotary_pct)
self.router = AttentionMoERouter(cfg.d_model, cfg.attn_n_experts, cfg.attn_top_k)
self.q_proj = nn.Linear(cfg.d_model, self.n_experts * self.head_dim, bias=False)
self.k_proj = nn.Linear(cfg.d_model, self.n_experts * self.head_dim, bias=False)
self.v_proj = nn.Linear(cfg.d_model, self.n_experts * self.head_dim, bias=False)
self.rope = None
if self.rotary_dim > 0:
self.rope = RotaryEmbedding(
self.rotary_dim, base=cfg.rope_theta,
scaling_type=cfg.rope_scaling_type,
scaling_factor=cfg.rope_scaling_factor,
)
self.o_proj = nn.Linear(cfg.attn_n_experts * self.head_dim, cfg.d_model, bias=False)
def forward(self, x, position_offset, budget_ratio, temperature):
B, T, C = x.shape
E, H = self.n_experts, self.head_dim
sel_mask, gate_w, gate_idx, entropy, gate_logits = self.router(x, budget_ratio, temperature)
q = self.q_proj(x).view(B, T, E, H).permute(0, 2, 1, 3)
k = self.k_proj(x).view(B, T, E, H).permute(0, 2, 1, 3)
v = self.v_proj(x).view(B, T, E, H).permute(0, 2, 1, 3)
if self.rope:
if isinstance(position_offset, torch.Tensor):
position_offset = int(position_offset.view(-1)[0].item())
else:
position_offset = int(position_offset)
cos, sin = self.rope(q, seq_len=T, offset=position_offset)
cos = cos.squeeze(1); sin = sin.squeeze(1)
q_rot = apply_rotary(q[..., :self.rotary_dim], cos, sin)
k_rot = apply_rotary(k[..., :self.rotary_dim], cos, sin)
q = torch.cat([q_rot, q[..., self.rotary_dim:]], dim=-1)
k = torch.cat([k_rot, k[..., self.rotary_dim:]], dim=-1)
q_b = q.reshape(B * E, T, H)
k_b = k.reshape(B * E, T, H)
v_b = v.reshape(B * E, T, H)
dropout_p = self.cfg.dropout if self.training else 0.0
out_b = F.scaled_dot_product_attention(q_b, k_b, v_b, is_causal=True, dropout_p=dropout_p)
out = out_b.view(B, E, T, H).permute(0, 2, 1, 3)
W = torch.zeros(B, E, device=x.device, dtype=out.dtype)
W.scatter_(1, gate_idx, gate_w.to(out.dtype))
weighted_out = torch.einsum('b t e h, b e -> b t e h', out, W)
y = weighted_out.reshape(B, T, E * H).to(self.o_proj.weight.dtype)
y = self.o_proj(y)
with torch.no_grad():
usage = sel_mask.float().mean(dim=0)
expected = sel_mask.float().sum(dim=-1).mean()
den = torch.clamp(expected, min=1e-6)
usage_norm = usage / den
uniform = 1.0 / self.n_experts
attn_lb = ((usage_norm - uniform) ** 2).sum() * self.n_experts / self.n_experts
attn_rz = (gate_logits ** 2).mean()
head_keep = sel_mask.float().mean()
return y, {
"head_entropy": entropy,
"head_keep_frac": head_keep,
"attn_load_balance_loss": attn_lb,
"attn_router_z_loss": attn_rz,
}
class ExpertFFN(nn.Module):
def __init__(self, d_model: int, d_ff: int, dropout: float):
super().__init__()
self.w1 = nn.Linear(d_model, d_ff, bias=False)
self.w2 = nn.Linear(d_ff, d_model, bias=False)
self.dropout_p = dropout
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.w1(x)
x = F.gelu(x, approximate="tanh")
x = F.dropout(x, p=self.dropout_p, training=self.training)
x = self.w2(x)
return x
class MoEFFN(nn.Module):
def __init__(self, d_model: int, d_ff: int, n_experts: int, top_k: int, dropout: float, cfg: AdaptiveRiverConfig):
super().__init__()
self.n_experts = n_experts
self.base_top_k = top_k
self.cfg = cfg
self.router = nn.Linear(d_model, n_experts, bias=False)
self.w1_stacked = nn.Parameter(torch.empty(n_experts, d_ff, d_model))
self.w2_stacked = nn.Parameter(torch.empty(n_experts, d_model, d_ff))
std = cfg.init_std
nn.init.normal_(self.router.weight, mean=0.0, std=std)
nn.init.normal_(self.w1_stacked, mean=0.0, std=std)
nn.init.normal_(self.w2_stacked, mean=0.0, std=std)
def forward(self, x: torch.Tensor, budget_ratio: float):
B, T, C = x.shape
N = B * T
X = x.reshape(N, C)
k_target = max(1, int(round(self.base_top_k * (0.5 + budget_ratio / 2.0))))
k_target = min(k_target, self.n_experts)
scores = self.router(X).to(torch.float32).clamp(min=-10.0, max=10.0)
probs = F.softmax(scores, dim=-1).to(X.dtype)
mask = topk_mask_ste(scores, k=k_target).to(X.dtype)
gate = (mask * probs)
gate = gate / gate.sum(dim=-1, keepdim=True).clamp_min(1e-6)
x_ff = torch.einsum('n c, e d c -> n e d', X, self.w1_stacked)
x_act = F.gelu(x_ff, approximate="tanh")
y_experts = torch.einsum('n e d, e c d -> n e c', x_act, self.w2_stacked)
y = torch.einsum('n e, n e c -> n c', gate, y_experts).view(B, T, C).to(x.dtype)
with torch.no_grad():
entropy = (-probs * probs.clamp_min(1e-12).log()).sum(dim=-1).mean()
router_z = (scores ** 2).mean().clamp(max=10.0)
frac = mask.mean(dim=0)
uniform = 1.0 / self.n_experts
lb = ((frac - uniform) ** 2).sum() * self.n_experts / self.n_experts
return y, {
"router_entropy": entropy,
"ffn_expert_usage": frac.detach(),
"ffn_load_balance_loss": lb,
"ffn_router_z_loss": router_z,
}
class MambaBlock(nn.Module):
def __init__(self, cfg: AdaptiveRiverConfig, enhanced: bool = False, layer_idx: int | None = None):
super().__init__()
if not _HAS_MAMBA:
print(f"MambaBlock Layer {layer_idx} disabled: mamba-ssm not installed.")
self.mamba = None
return
self.cfg = cfg
self.ln1 = PTLayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
self.mamba = Mamba(
d_model=cfg.d_model,
d_state=cfg.mamba_d_state,
d_conv=cfg.mamba_d_conv,
expand=cfg.mamba_expand * (2 if enhanced else 1),
layer_idx=layer_idx,
)
self.ln2 = PTLayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
self.ffn = nn.Sequential(
nn.Linear(cfg.d_model, cfg.d_ff * (2 if enhanced else 1), bias=False),
nn.GELU(approximate="tanh"),
nn.Linear(cfg.d_ff * (2 if enhanced else 1), cfg.d_model, bias=False),
)
def forward(
self,
x,
attn_mask=None,
position_offset: int | torch.Tensor = 0,
past_kv=None,
budget_ratio: float = 1.0,
use_cache: bool = False,
mamba_state: Optional[InferenceParams] = None,
):
if not _HAS_MAMBA or self.mamba is None:
stats = {"head_entropy": torch.tensor(0.0, device=x.device),
"head_keep_frac": torch.tensor(1.0, device=x.device),
"mamba_out_l2": torch.tensor(0.0, device=x.device)}
return x, stats, (None, None)
h = self.ln1(x)
x_m = self.mamba(h) # stateless path
m_out_l2 = x_m.float().pow(2).mean()
x = x + x_m
h2 = self.ln2(x)
x = x + self.ffn(h2)
stats = {
"head_entropy": torch.tensor(0.0, device=x.device),
"head_keep_frac": torch.tensor(1.0, device=x.device),
"mamba_out_l2": m_out_l2.detach(),
}
return x, stats, (None, None)
class RoutedBlock(nn.Module):
def __init__(self, cfg: AdaptiveRiverConfig):
super().__init__()
self.cfg = cfg
self.ln1 = PTLayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
self.ln2 = PTLayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
self.attn = MoEAttention(cfg)
self.ffn = MoEFFN(cfg.d_model, cfg.d_ff, cfg.experts_per_layer, cfg.top_k_ffn, cfg.moe_dropout, cfg)
def _attn_forward(self, h: torch.Tensor, position_offset: int, budget_ratio: float):
if isinstance(position_offset, torch.Tensor):
position_offset = int(position_offset.view(-1)[0].item())
else:
position_offset = int(position_offset)
return self.attn(h, position_offset, budget_ratio, self.cfg.gate_temperature)
def forward(
self,
x,
attn_mask=None,
position_offset: int | torch.Tensor = 0,
past_kv=None,
budget_ratio: float = 1.0,
use_cache: bool = False,
mamba_state: Optional[InferenceParams] = None,
):
h = self.ln1(x)
attn_out, attn_stats = self._attn_forward(h, position_offset, budget_ratio)
x = x + attn_out
h2 = self.ln2(x)
ffn_out, moe_stats = self.ffn(h2, budget_ratio=budget_ratio)
x = x + ffn_out
stats = {**attn_stats, **moe_stats}
return x, stats, (None, None)
class AdaptiveRiverLM(nn.Module):
def __init__(self, cfg: AdaptiveRiverConfig):
super().__init__()
self.cfg = cfg
self.embed = nn.Embedding(cfg.vocab_size, cfg.d_model)
self.blocks = nn.ModuleList()
mamba_layer_counter = 0
for i in range(cfg.n_layers):
if i < 2:
print(f"[model] Layer {i}: Mamba")
self.blocks.append(MambaBlock(cfg, enhanced=False, layer_idx=mamba_layer_counter)); mamba_layer_counter += 1
elif i >= (cfg.n_layers - 2):
print(f"[model] Layer {i}: Mamba (enhanced)")
self.blocks.append(MambaBlock(cfg, enhanced=True, layer_idx=mamba_layer_counter)); mamba_layer_counter += 1
else:
if i == 2:
print(f"[model] Layers {i}-{cfg.n_layers-3}: MoE Attention + MoE FFN")
self.blocks.append(RoutedBlock(cfg))
self.ln_f = PTLayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
if cfg.tie_word_embeddings:
self.lm_head.weight = self.embed.weight
self.apply(lambda m: _init_weights(m, cfg.init_std) if isinstance(m, nn.Linear) else None)
def forward(
self,
input_ids: torch.Tensor,
budget_ratio: Optional[float] = None,
mamba_states: Optional[List] = None,
past_kvs: Optional[List] = None,
position_offset: int | torch.Tensor = 0,
return_expert_stats: bool = False,
use_cache: bool = False,
):
x = self.embed(input_ids)
b = float(self.cfg.default_budget_ratio if budget_ratio is None else budget_ratio)
all_stats: Dict[str, List[torch.Tensor]] = {}
for block in self.blocks:
x, stats, _ = block(
x,
position_offset=position_offset,
past_kv=None,
budget_ratio=b,
use_cache=False,
mamba_state=None,
)
for k, v in stats.items():
all_stats.setdefault(k, []).append(torch.as_tensor(v.detach() if isinstance(v, torch.Tensor) else v))
_ = {k: torch.stack(v).mean() for k, v in all_stats.items() if len(v) > 0}
x = self.ln_f(x)
logits = self.lm_head(x)
return logits, _
def estimate_1b_config() -> AdaptiveRiverConfig:
return AdaptiveRiverConfig(
vocab_size=50257,
d_model=1024,
n_layers=24,
d_ff=4096,
experts_per_layer=4,
top_k_ffn=1,
default_budget_ratio=1.0,
attn_n_experts=6,
attn_top_k=6,
attn_n_orig_heads=16,
mamba_d_state=16,
mamba_d_conv=4,
mamba_expand=2,
gate_temperature=0.7,
head_entropy_weight=1e-4,
checkpoint_attn_thresh=0.35,
checkpoint_ffn_thresh=0.35,
load_balance_weight=0.01,
router_z_weight=0.001,
tie_word_embeddings=False,
)
# ----------------------------------------------------------------------------
# Inference (stateless) with proper end-of-turn handling
# ----------------------------------------------------------------------------
class FastInferenceTester:
def __init__(self, model, tokenizer, device, im_start_id, im_end_id, eos_id, pad_id):
self.model = model
self.tokenizer = tokenizer
self.device = device
self.im_start_id = im_start_id
self.im_end_id = im_end_id
self.eos_id = eos_id
self.pad_id = pad_id
self.model.eval()
torch.set_grad_enabled(False)
print("Using model's native precision")
if hasattr(torch, 'compile') and _HAS_MAMBA:
print("Skipping torch.compile due to mamba-ssm kernels.")
else:
try:
print("Compiling model with torch.compile...")
self.model = torch.compile(self.model, mode="reduce-overhead")
print("Model compiled successfully")
except Exception as e:
print(f"Could not compile model: {e}")
print("Running without compilation")
def _format_to_training_chat(self, prompt: str) -> torch.Tensor:
messages = [{"role": "user", "content": prompt}]
formatted = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
input_ids = self.tokenizer.encode(
formatted, add_special_tokens=False, return_tensors="pt"
).to(self.device)
return input_ids
def _postprocess_like_training(self, text: str) -> str:
if "<|im_start|>assistant" in text:
return text.split("<|im_start|>assistant")[-1].split("<|im_end|>")[0].strip()
if "assistant\n" in text:
return text.split("assistant\n")[-1].split("<|im_end|>")[0].strip()
return text.split("<|im_end|>")[0].strip()
def _reset_mamba_states(self):
if not _HAS_MAMBA:
return
for block in self.model.blocks:
if isinstance(block, MambaBlock) and hasattr(block, "mamba"):
for attr in ("inference_params", "conv_state", "ssm_state"):
if hasattr(block.mamba, attr):
setattr(block.mamba, attr, None)
def generate_once(
self,
prompt: str,
max_tokens: int = 2000,
temperature: float = 0.8,
top_p: float = 1.0,
top_k: int = 0,
budget_ratio: float = 1.0,
show_tokens: bool = False,
min_new_tokens: int = 3,
) -> Dict:
self._reset_mamba_states()
print(f"\n{'='*80}")
print("FAST GENERATION (no cache)")
print(f"{'='*80}")
print(f"Prompt: {prompt}")
print("─" * 80)
input_ids = self._format_to_training_chat(prompt)
generated_tokens: List[int] = []
token_times: List[float] = []
stop_ids = set(t for t in [self.im_end_id, self.eos_id] if t is not None)
ban_initial_ids = set(t for t in [self.im_end_id, self.eos_id, self.im_start_id, self.pad_id] if t is not None)
start_time = time.time()
with torch.inference_mode():
# Prefill over full prompt
logits, _ = self.model(
input_ids,
budget_ratio=budget_ratio,
position_offset=0,
use_cache=False
)
next_token_logits = logits[:, -1, :] # [1, vocab]
vocab_size = next_token_logits.size(-1)
print("Generating...", end=" ", flush=True)
is_cuda = torch.cuda.is_available()
buffer = [] # small output buffer for streaming
for _ in range(max_tokens):
if is_cuda:
torch.cuda.synchronize()
t0 = time.time()
# 1D view for sampling/masking
logits_for_sampling = next_token_logits.squeeze(0).clone() / max(1e-6, temperature)
vocab_size = logits_for_sampling.size(0)
# Ban structural tokens at the very start
if len(generated_tokens) < min_new_tokens and min_new_tokens > 0:
for tid in ban_initial_ids:
if tid is not None and 0 <= tid < vocab_size:
logits_for_sampling[tid] = float("-inf")
# Top-k
if top_k and top_k > 0:
kth = torch.topk(logits_for_sampling, top_k)[0][-1]
logits_for_sampling[logits_for_sampling < kth] = float("-inf")
# Top-p
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits_for_sampling, descending=True)
cumulative_probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
sorted_indices_to_remove[0] = False
remove_idx = sorted_indices[sorted_indices_to_remove]
logits_for_sampling[remove_idx] = float("-inf")
# Sample
probs = F.softmax(logits_for_sampling, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1).item()
generated_tokens.append(next_token_id)
# Decode + buffered print
if show_tokens:
tok_text = self.tokenizer.decode([next_token_id], skip_special_tokens=False)
buffer.append(tok_text)
if len(buffer) >= 16:
print("".join(buffer), end="", flush=True)
buffer.clear()
# Stop on EOT/EOS after min_new_tokens
if (next_token_id in stop_ids) and (len(generated_tokens) >= max(1, min_new_tokens)):
if buffer:
print("".join(buffer), end="", flush=True)
buffer.clear()
if show_tokens:
print(" [EOT]", flush=True)
break
# Stateless decode: append token and re-run forward
input_ids = torch.cat(
[input_ids, torch.tensor([[next_token_id]], device=self.device)],
dim=1
)
logits, _ = self.model(
input_ids,
budget_ratio=budget_ratio,
position_offset=0,
use_cache=False
)
next_token_logits = logits[:, -1, :]
if is_cuda:
torch.cuda.synchronize()
token_times.append(time.time() - t0)
# Flush any remaining buffered tokens
if buffer:
print("".join(buffer), end="", flush=True)
buffer.clear()
total_time = time.time() - start_time
text = self.tokenizer.decode(generated_tokens, skip_special_tokens=False)
text = self._postprocess_like_training(text)
if show_tokens and (not generated_tokens or (generated_tokens[-1] not in stop_ids)):
print()
num_gen = len(generated_tokens)
if num_gen == 0:
print("\nNo tokens generated.")
return {'output': '', 'tokens_per_sec': 0, 'decode_tps': 0, 'total_time': total_time, 'num_tokens': 0}
decode_time = sum(token_times)
toks_per_sec = num_gen / total_time if total_time > 0 else 0
decode_tps = num_gen / decode_time if decode_time > 0 else 0
print("\n" + "─" * 80)
print("STATISTICS")
print("─" * 80)
print(f"Tokens: {num_gen}")
print(f"Total time: {total_time:.2f}s")
print(f"Overall speed: {toks_per_sec:.1f} tok/s (includes prompt)")
print(f"Decode speed: {decode_tps:.1f} tok/s (generation only)")
print(f"Time/token: {(decode_time/num_gen)*1000:.1f}ms")
print("─" * 80)
print(f"Output: {text[:100]}{'...' if len(text) > 100 else ''}")
print("=" * 80 + "\n")
self._reset_mamba_states()
return {
'output': text,
'tokens_per_sec': toks_per_sec,
'decode_tps': decode_tps,
'total_time': total_time,
'num_tokens': num_gen,
}
def interactive_mode(self):
print("\n" + "=" * 80)
print("INTERACTIVE MODE (no cache, stateless)")
print("Type 'quit' or your prompt")
print("=" * 80 + "\n")
while True:
try:
prompt = input("\nYou: ")
except (EOFError, KeyboardInterrupt):
print("\nBye.")
break
if prompt.lower() in ["quit", "exit", "q"]:
break
if not prompt.strip():
continue
print("\nAssistant: ", end="", flush=True)
self.generate_once(prompt, max_tokens=2000, temperature=0.8, show_tokens=True)
def _cast_layernorm_fp32(module: nn.Module):
for m in module.modules():
if isinstance(m, nn.LayerNorm):
m.float()
def load_model_and_tokenizer(model_dir: str):
"""
Load AdaptiveRiverLM model and tokenizer from a folder layout like:
model_dir/
checkpoint.pt (or any .pt file)
tokenizer/
tokenizer.json
special_tokens_map.json
...
Automatically finds the .pt file if not explicitly named.
"""
print(f"Searching for model checkpoint in: {model_dir}")
ckpts = glob.glob(os.path.join(model_dir, "*.pt"))
if not ckpts:
raise FileNotFoundError(f"No .pt checkpoint found in {model_dir}")
if len(ckpts) > 1:
print(f"[Warning] Multiple .pt files found, using: {ckpts[0]}")
checkpoint_path = ckpts[0]
tokenizer_path = os.path.join(model_dir, "tokenizer")
if not os.path.isdir(tokenizer_path):
raise FileNotFoundError(f"Missing tokenizer directory: {tokenizer_path}")
print(f"Loading tokenizer from: {tokenizer_path}")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True, trust_remote_code=True)
if tokenizer.pad_token is None:
print("Tokenizer missing pad_token. Assigning eos_token as pad_token.")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
print("Building model (AdaptiveRiverLM)...")
cfg = estimate_1b_config()
cfg.vocab_size = len(tokenizer)
cfg.tie_word_embeddings = False
model = AdaptiveRiverLM(cfg)
print(f"Loading checkpoint: {checkpoint_path}")
state = torch.load(checkpoint_path, map_location="cpu")
model_state_dict = model.state_dict()
converted_state = {}
for k, param in model_state_dict.items():
if k in state and state[k].shape == param.shape:
converted_state[k] = state[k]
print("Loading weights...")
load_result = model.load_state_dict(converted_state, strict=False)
if load_result.missing_keys:
print("\n--- Missing Keys ---")
for k in load_result.missing_keys:
print(" ", k)
if load_result.unexpected_keys:
print("\n--- Unexpected Keys ---")
for k in load_result.unexpected_keys:
print(" ", k)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
if device == "cuda" and torch.cuda.is_bf16_supported():
_cast_layernorm_fp32(model)
model = model.to(torch.bfloat16)
else:
model = model.to(torch.float32)
model.eval()
print(f"Model and tokenizer loaded successfully from {model_dir} on {device}")
return model, tokenizer, device
def main():
parser = argparse.ArgumentParser(description="Stateless inference for AdaptiveRiverLM (no KV cache), proper EOT handling")
parser.add_argument("--model_dir", type=str, required=True, help="Path to model folder (with checkpoint.pt and tokenizer/)")
parser.add_argument("--prompt", type=str, default="Hello, my name is")
parser.add_argument("--max_tokens", type=int, default=2000)
parser.add_argument("--temperature", type=float, default=0.8)
parser.add_argument("--top_p", type=float, default=1.0)
parser.add_argument("--top_k", type=int, default=0)
parser.add_argument("--min_new_tokens", type=int, default=3)
parser.add_argument("--interactive", action="store_true", help="Interactive mode (stateless)")
args = parser.parse_args()
model, tokenizer, device = load_model_and_tokenizer(args.model_dir)
# Resolve special token IDs for end-of-turn handling
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
im_start_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
eos_id = tokenizer.eos_token_id
pad_id = tokenizer.pad_token_id
stop_ids = set(t for t in [im_end_id, eos_id] if t is not None)
ban_initial_ids = set(t for t in [im_end_id, eos_id, im_start_id, pad_id] if t is not None)
tester = FastInferenceTester(model, tokenizer, device, im_start_id, im_end_id, eos_id, pad_id)
if args.interactive:
tester.interactive_mode()
else:
tester.generate_once(
args.prompt,
max_tokens=args.max_tokens,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
show_tokens=True,
min_new_tokens=args.min_new_tokens,
)
if __name__ == "__main__":
main()