#!/usr/bin/env python # -*- coding: utf-8 -*- import argparse import time from dataclasses import dataclass from typing import Dict, List, Optional, Tuple import glob import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoTokenizer import torch.utils.checkpoint as cp import os # ---------------------------------------------------------------------------- # mamba-ssm dependency # ---------------------------------------------------------------------------- try: from mamba_ssm import Mamba from mamba_ssm.utils.generation import InferenceParams _HAS_MAMBA = True except ImportError: _HAS_MAMBA = False InferenceParams = None print("=" * 80) print("[WARNING] mamba-ssm not installed. Mamba layers will not function.") print("Install with: pip install mamba-ssm") print("=" * 80) class Mamba(nn.Module): def __init__(self, *args, **kwargs): super().__init__() print("ERROR: Mamba placeholder. mamba-ssm not installed.") def forward(self, x, *args, **kwargs): print("ERROR: mamba-ssm not installed. Cannot run MambaBlock.") return x # ---------------------------------------------------------------------------- # Model # ---------------------------------------------------------------------------- @dataclass class AdaptiveRiverConfig: vocab_size: int = 50257 d_model: int = 1024 n_layers: int = 24 d_ff: int = 4096 dropout: float = 0.0 rope_theta: float = 10000.0 rotary_pct: float = 1.0 layer_norm_eps: float = 1e-5 rope_scaling_type: str | None = None rope_scaling_factor: float = 1.0 experts_per_layer: int = 4 top_k_ffn: int = 1 moe_dropout: float = 0.0 attn_n_experts: int = 6 attn_top_k: int = 6 attn_n_orig_heads: int = 16 mamba_d_state: int = 16 mamba_d_conv: int = 4 mamba_expand: int = 2 entropy_weight: float = 1e-4 head_entropy_weight: float = 1e-4 default_budget_ratio: float = 1.0 init_std: float = 0.02 tie_word_embeddings: bool = False # untied head (matches training) load_balance_weight: float = 0.01 router_z_weight: float = 0.001 gate_temperature: float = 0.7 checkpoint_attn_thresh: float = 0.35 checkpoint_ffn_thresh: float = 0.35 soak_dtype: str = "fp32" def _init_weights(module: nn.Module, std: float): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: nn.init.zeros_(module.bias) def topk_mask_ste(scores: torch.Tensor, k: int) -> torch.Tensor: s = scores.float() if k >= s.size(-1): return torch.ones_like(s) topk = torch.topk(s, k=k, dim=-1).indices one_hot = torch.zeros_like(s) one_hot.scatter_(dim=-1, index=topk, value=1.0) probs = F.softmax(s, dim=-1) return one_hot + probs - probs.detach() class RotaryEmbedding(nn.Module): def __init__(self, dim, base=10000.0, scaling_type: str | None = None, scaling_factor: float = 1.0): super().__init__() self.dim = dim self.base = float(base) self.scaling_type = scaling_type self.scaling_factor = float(scaling_factor) base = self._effective_base() inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.float32) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self._cos_sin_cache = None self._cos_sin_cache_device = None self._cos_sin_cache_dtype = None self._cos_sin_max_seq_len = -1 def _effective_base(self) -> float: if not self.scaling_type or self.scaling_factor == 1.0: return self.base if self.scaling_type in ("ntk", "linear", "yarn"): return self.base * self.scaling_factor return self.base def _get_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype): if (seq_len > self._cos_sin_max_seq_len or self._cos_sin_cache is None or self._cos_sin_cache_device != device or self._cos_sin_cache_dtype != dtype): self._cos_sin_max_seq_len = max(seq_len, 2048) t = torch.arange(self._cos_sin_max_seq_len, device=device, dtype=self.inv_freq.dtype) freqs = torch.einsum("i,j->ij", t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos().to(dtype) sin = emb.sin().to(dtype) self._cos_sin_cache = (cos, sin) self._cos_sin_cache_device = device self._cos_sin_cache_dtype = dtype return self._cos_sin_cache def forward(self, x, seq_len: int, offset: int | torch.Tensor = 0): device, dtype = x.device, x.dtype cos, sin = self._get_cos_sin_cache(seq_len + int(offset), device, dtype) if isinstance(offset, torch.Tensor): if offset.numel() > 1: t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype).float() freqs = torch.einsum("i,j->ij", t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) cos_val = emb.cos()[None, None, :, :].to(dtype) sin_val = emb.sin()[None, None, :, :].to(dtype) return cos_val, sin_val else: offset = int(offset.item()) cos = cos[offset:offset+seq_len].unsqueeze(0).unsqueeze(0) sin = sin[offset:offset+seq_len].unsqueeze(0).unsqueeze(0) return cos, sin def apply_rotary(x, cos, sin): x1, x2 = x[..., ::2], x[..., 1::2] x_rot = torch.stack((-x2, x1), dim=-1).flatten(-2) return x * cos + x_rot * sin class PTLayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-5): super().__init__() self.ln = nn.LayerNorm(hidden_size, eps=eps) def forward(self, x): return self.ln(x) class GlobalSDPAHead(nn.Module): def __init__(self, d_model, head_dim, dropout, rope_theta, rotary_pct, cfg): super().__init__() self.q_proj = nn.Linear(d_model, head_dim, bias=False) self.k_proj = nn.Linear(d_model, head_dim, bias=False) self.v_proj = nn.Linear(d_model, head_dim, bias=False) self.rotary_dim = int(head_dim * rotary_pct) self.dropout_p = dropout self.rope = None if self.rotary_dim > 0: self.rope = RotaryEmbedding( self.rotary_dim, base=rope_theta, scaling_type=cfg.rope_scaling_type, scaling_factor=cfg.rope_scaling_factor, ) def forward(self, x, position_offset): if isinstance(position_offset, torch.Tensor): position_offset = int(position_offset.view(-1)[0].item()) else: position_offset = int(position_offset) B, T, C = x.shape q, k, v = self.q_proj(x), self.k_proj(x), self.v_proj(x) if self.rotary_dim > 0: cos, sin = self.rope(q, seq_len=T, offset=position_offset) cos = cos.squeeze(1); sin = sin.squeeze(1) q_rot = apply_rotary(q[..., :self.rotary_dim], cos, sin) k_rot = apply_rotary(k[..., :self.rotary_dim], cos, sin) q = torch.cat([q_rot, q[..., self.rotary_dim:]], dim=-1) k = torch.cat([k_rot, k[..., self.rotary_dim:]], dim=-1) q, k, v = [t.unsqueeze(1) for t in (q, k, v)] dropout_p = self.dropout_p if self.training else 0.0 out = F.scaled_dot_product_attention(q, k, v, is_causal=True, dropout_p=dropout_p) return out.squeeze(1) class AttentionMoERouter(nn.Module): def __init__(self, d_model, num_experts, top_k): super().__init__() self.top_k = top_k self.num_experts = num_experts self.gate_proj = nn.Linear(d_model, num_experts, bias=False) nn.init.normal_(self.gate_proj.weight, mean=0.0, std=0.01) def forward(self, x, budget_ratio, temperature): seq_embed = x.mean(dim=1) logits = self.gate_proj(seq_embed) / max(1e-6, float(temperature)) logits = logits.clamp(min=-10.0, max=10.0) k_target = max(1, int(round(self.top_k * (0.25 + 0.75 * budget_ratio)))) k_target = min(k_target, logits.size(-1)) vals, idx = torch.topk(logits, k_target, dim=-1) weights = F.softmax(vals.to(torch.float32), dim=-1).to(x.dtype) mask = torch.zeros_like(logits, dtype=torch.bool) mask.scatter_(1, idx, True) with torch.no_grad(): p = F.softmax(logits, dim=-1) entropy = -(p * (p.clamp_min(1e-12)).log()).sum(dim=-1).mean() return mask, weights, idx, entropy, logits class MoEAttention(nn.Module): def __init__(self, cfg: AdaptiveRiverConfig): super().__init__() self.d_model = cfg.d_model self.n_experts = cfg.attn_n_experts self.cfg = cfg self.head_dim = cfg.d_model // cfg.attn_n_orig_heads self.rotary_dim = int(self.head_dim * cfg.rotary_pct) self.router = AttentionMoERouter(cfg.d_model, cfg.attn_n_experts, cfg.attn_top_k) self.q_proj = nn.Linear(cfg.d_model, self.n_experts * self.head_dim, bias=False) self.k_proj = nn.Linear(cfg.d_model, self.n_experts * self.head_dim, bias=False) self.v_proj = nn.Linear(cfg.d_model, self.n_experts * self.head_dim, bias=False) self.rope = None if self.rotary_dim > 0: self.rope = RotaryEmbedding( self.rotary_dim, base=cfg.rope_theta, scaling_type=cfg.rope_scaling_type, scaling_factor=cfg.rope_scaling_factor, ) self.o_proj = nn.Linear(cfg.attn_n_experts * self.head_dim, cfg.d_model, bias=False) def forward(self, x, position_offset, budget_ratio, temperature): B, T, C = x.shape E, H = self.n_experts, self.head_dim sel_mask, gate_w, gate_idx, entropy, gate_logits = self.router(x, budget_ratio, temperature) q = self.q_proj(x).view(B, T, E, H).permute(0, 2, 1, 3) k = self.k_proj(x).view(B, T, E, H).permute(0, 2, 1, 3) v = self.v_proj(x).view(B, T, E, H).permute(0, 2, 1, 3) if self.rope: if isinstance(position_offset, torch.Tensor): position_offset = int(position_offset.view(-1)[0].item()) else: position_offset = int(position_offset) cos, sin = self.rope(q, seq_len=T, offset=position_offset) cos = cos.squeeze(1); sin = sin.squeeze(1) q_rot = apply_rotary(q[..., :self.rotary_dim], cos, sin) k_rot = apply_rotary(k[..., :self.rotary_dim], cos, sin) q = torch.cat([q_rot, q[..., self.rotary_dim:]], dim=-1) k = torch.cat([k_rot, k[..., self.rotary_dim:]], dim=-1) q_b = q.reshape(B * E, T, H) k_b = k.reshape(B * E, T, H) v_b = v.reshape(B * E, T, H) dropout_p = self.cfg.dropout if self.training else 0.0 out_b = F.scaled_dot_product_attention(q_b, k_b, v_b, is_causal=True, dropout_p=dropout_p) out = out_b.view(B, E, T, H).permute(0, 2, 1, 3) W = torch.zeros(B, E, device=x.device, dtype=out.dtype) W.scatter_(1, gate_idx, gate_w.to(out.dtype)) weighted_out = torch.einsum('b t e h, b e -> b t e h', out, W) y = weighted_out.reshape(B, T, E * H).to(self.o_proj.weight.dtype) y = self.o_proj(y) with torch.no_grad(): usage = sel_mask.float().mean(dim=0) expected = sel_mask.float().sum(dim=-1).mean() den = torch.clamp(expected, min=1e-6) usage_norm = usage / den uniform = 1.0 / self.n_experts attn_lb = ((usage_norm - uniform) ** 2).sum() * self.n_experts / self.n_experts attn_rz = (gate_logits ** 2).mean() head_keep = sel_mask.float().mean() return y, { "head_entropy": entropy, "head_keep_frac": head_keep, "attn_load_balance_loss": attn_lb, "attn_router_z_loss": attn_rz, } class ExpertFFN(nn.Module): def __init__(self, d_model: int, d_ff: int, dropout: float): super().__init__() self.w1 = nn.Linear(d_model, d_ff, bias=False) self.w2 = nn.Linear(d_ff, d_model, bias=False) self.dropout_p = dropout def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.w1(x) x = F.gelu(x, approximate="tanh") x = F.dropout(x, p=self.dropout_p, training=self.training) x = self.w2(x) return x class MoEFFN(nn.Module): def __init__(self, d_model: int, d_ff: int, n_experts: int, top_k: int, dropout: float, cfg: AdaptiveRiverConfig): super().__init__() self.n_experts = n_experts self.base_top_k = top_k self.cfg = cfg self.router = nn.Linear(d_model, n_experts, bias=False) self.w1_stacked = nn.Parameter(torch.empty(n_experts, d_ff, d_model)) self.w2_stacked = nn.Parameter(torch.empty(n_experts, d_model, d_ff)) std = cfg.init_std nn.init.normal_(self.router.weight, mean=0.0, std=std) nn.init.normal_(self.w1_stacked, mean=0.0, std=std) nn.init.normal_(self.w2_stacked, mean=0.0, std=std) def forward(self, x: torch.Tensor, budget_ratio: float): B, T, C = x.shape N = B * T X = x.reshape(N, C) k_target = max(1, int(round(self.base_top_k * (0.5 + budget_ratio / 2.0)))) k_target = min(k_target, self.n_experts) scores = self.router(X).to(torch.float32).clamp(min=-10.0, max=10.0) probs = F.softmax(scores, dim=-1).to(X.dtype) mask = topk_mask_ste(scores, k=k_target).to(X.dtype) gate = (mask * probs) gate = gate / gate.sum(dim=-1, keepdim=True).clamp_min(1e-6) x_ff = torch.einsum('n c, e d c -> n e d', X, self.w1_stacked) x_act = F.gelu(x_ff, approximate="tanh") y_experts = torch.einsum('n e d, e c d -> n e c', x_act, self.w2_stacked) y = torch.einsum('n e, n e c -> n c', gate, y_experts).view(B, T, C).to(x.dtype) with torch.no_grad(): entropy = (-probs * probs.clamp_min(1e-12).log()).sum(dim=-1).mean() router_z = (scores ** 2).mean().clamp(max=10.0) frac = mask.mean(dim=0) uniform = 1.0 / self.n_experts lb = ((frac - uniform) ** 2).sum() * self.n_experts / self.n_experts return y, { "router_entropy": entropy, "ffn_expert_usage": frac.detach(), "ffn_load_balance_loss": lb, "ffn_router_z_loss": router_z, } class MambaBlock(nn.Module): def __init__(self, cfg: AdaptiveRiverConfig, enhanced: bool = False, layer_idx: int | None = None): super().__init__() if not _HAS_MAMBA: print(f"MambaBlock Layer {layer_idx} disabled: mamba-ssm not installed.") self.mamba = None return self.cfg = cfg self.ln1 = PTLayerNorm(cfg.d_model, eps=cfg.layer_norm_eps) self.mamba = Mamba( d_model=cfg.d_model, d_state=cfg.mamba_d_state, d_conv=cfg.mamba_d_conv, expand=cfg.mamba_expand * (2 if enhanced else 1), layer_idx=layer_idx, ) self.ln2 = PTLayerNorm(cfg.d_model, eps=cfg.layer_norm_eps) self.ffn = nn.Sequential( nn.Linear(cfg.d_model, cfg.d_ff * (2 if enhanced else 1), bias=False), nn.GELU(approximate="tanh"), nn.Linear(cfg.d_ff * (2 if enhanced else 1), cfg.d_model, bias=False), ) def forward( self, x, attn_mask=None, position_offset: int | torch.Tensor = 0, past_kv=None, budget_ratio: float = 1.0, use_cache: bool = False, mamba_state: Optional[InferenceParams] = None, ): if not _HAS_MAMBA or self.mamba is None: stats = {"head_entropy": torch.tensor(0.0, device=x.device), "head_keep_frac": torch.tensor(1.0, device=x.device), "mamba_out_l2": torch.tensor(0.0, device=x.device)} return x, stats, (None, None) h = self.ln1(x) x_m = self.mamba(h) # stateless path m_out_l2 = x_m.float().pow(2).mean() x = x + x_m h2 = self.ln2(x) x = x + self.ffn(h2) stats = { "head_entropy": torch.tensor(0.0, device=x.device), "head_keep_frac": torch.tensor(1.0, device=x.device), "mamba_out_l2": m_out_l2.detach(), } return x, stats, (None, None) class RoutedBlock(nn.Module): def __init__(self, cfg: AdaptiveRiverConfig): super().__init__() self.cfg = cfg self.ln1 = PTLayerNorm(cfg.d_model, eps=cfg.layer_norm_eps) self.ln2 = PTLayerNorm(cfg.d_model, eps=cfg.layer_norm_eps) self.attn = MoEAttention(cfg) self.ffn = MoEFFN(cfg.d_model, cfg.d_ff, cfg.experts_per_layer, cfg.top_k_ffn, cfg.moe_dropout, cfg) def _attn_forward(self, h: torch.Tensor, position_offset: int, budget_ratio: float): if isinstance(position_offset, torch.Tensor): position_offset = int(position_offset.view(-1)[0].item()) else: position_offset = int(position_offset) return self.attn(h, position_offset, budget_ratio, self.cfg.gate_temperature) def forward( self, x, attn_mask=None, position_offset: int | torch.Tensor = 0, past_kv=None, budget_ratio: float = 1.0, use_cache: bool = False, mamba_state: Optional[InferenceParams] = None, ): h = self.ln1(x) attn_out, attn_stats = self._attn_forward(h, position_offset, budget_ratio) x = x + attn_out h2 = self.ln2(x) ffn_out, moe_stats = self.ffn(h2, budget_ratio=budget_ratio) x = x + ffn_out stats = {**attn_stats, **moe_stats} return x, stats, (None, None) class AdaptiveRiverLM(nn.Module): def __init__(self, cfg: AdaptiveRiverConfig): super().__init__() self.cfg = cfg self.embed = nn.Embedding(cfg.vocab_size, cfg.d_model) self.blocks = nn.ModuleList() mamba_layer_counter = 0 for i in range(cfg.n_layers): if i < 2: print(f"[model] Layer {i}: Mamba") self.blocks.append(MambaBlock(cfg, enhanced=False, layer_idx=mamba_layer_counter)); mamba_layer_counter += 1 elif i >= (cfg.n_layers - 2): print(f"[model] Layer {i}: Mamba (enhanced)") self.blocks.append(MambaBlock(cfg, enhanced=True, layer_idx=mamba_layer_counter)); mamba_layer_counter += 1 else: if i == 2: print(f"[model] Layers {i}-{cfg.n_layers-3}: MoE Attention + MoE FFN") self.blocks.append(RoutedBlock(cfg)) self.ln_f = PTLayerNorm(cfg.d_model, eps=cfg.layer_norm_eps) self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) if cfg.tie_word_embeddings: self.lm_head.weight = self.embed.weight self.apply(lambda m: _init_weights(m, cfg.init_std) if isinstance(m, nn.Linear) else None) def forward( self, input_ids: torch.Tensor, budget_ratio: Optional[float] = None, mamba_states: Optional[List] = None, past_kvs: Optional[List] = None, position_offset: int | torch.Tensor = 0, return_expert_stats: bool = False, use_cache: bool = False, ): x = self.embed(input_ids) b = float(self.cfg.default_budget_ratio if budget_ratio is None else budget_ratio) all_stats: Dict[str, List[torch.Tensor]] = {} for block in self.blocks: x, stats, _ = block( x, position_offset=position_offset, past_kv=None, budget_ratio=b, use_cache=False, mamba_state=None, ) for k, v in stats.items(): all_stats.setdefault(k, []).append(torch.as_tensor(v.detach() if isinstance(v, torch.Tensor) else v)) _ = {k: torch.stack(v).mean() for k, v in all_stats.items() if len(v) > 0} x = self.ln_f(x) logits = self.lm_head(x) return logits, _ def estimate_1b_config() -> AdaptiveRiverConfig: return AdaptiveRiverConfig( vocab_size=50257, d_model=1024, n_layers=24, d_ff=4096, experts_per_layer=4, top_k_ffn=1, default_budget_ratio=1.0, attn_n_experts=6, attn_top_k=6, attn_n_orig_heads=16, mamba_d_state=16, mamba_d_conv=4, mamba_expand=2, gate_temperature=0.7, head_entropy_weight=1e-4, checkpoint_attn_thresh=0.35, checkpoint_ffn_thresh=0.35, load_balance_weight=0.01, router_z_weight=0.001, tie_word_embeddings=False, ) # ---------------------------------------------------------------------------- # Inference (stateless) with proper end-of-turn handling # ---------------------------------------------------------------------------- class FastInferenceTester: def __init__(self, model, tokenizer, device, im_start_id, im_end_id, eos_id, pad_id): self.model = model self.tokenizer = tokenizer self.device = device self.im_start_id = im_start_id self.im_end_id = im_end_id self.eos_id = eos_id self.pad_id = pad_id self.model.eval() torch.set_grad_enabled(False) print("Using model's native precision") if hasattr(torch, 'compile') and _HAS_MAMBA: print("Skipping torch.compile due to mamba-ssm kernels.") else: try: print("Compiling model with torch.compile...") self.model = torch.compile(self.model, mode="reduce-overhead") print("Model compiled successfully") except Exception as e: print(f"Could not compile model: {e}") print("Running without compilation") def _format_to_training_chat(self, prompt: str) -> torch.Tensor: messages = [{"role": "user", "content": prompt}] formatted = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) input_ids = self.tokenizer.encode( formatted, add_special_tokens=False, return_tensors="pt" ).to(self.device) return input_ids def _postprocess_like_training(self, text: str) -> str: if "<|im_start|>assistant" in text: return text.split("<|im_start|>assistant")[-1].split("<|im_end|>")[0].strip() if "assistant\n" in text: return text.split("assistant\n")[-1].split("<|im_end|>")[0].strip() return text.split("<|im_end|>")[0].strip() def _reset_mamba_states(self): if not _HAS_MAMBA: return for block in self.model.blocks: if isinstance(block, MambaBlock) and hasattr(block, "mamba"): for attr in ("inference_params", "conv_state", "ssm_state"): if hasattr(block.mamba, attr): setattr(block.mamba, attr, None) def generate_once( self, prompt: str, max_tokens: int = 2000, temperature: float = 0.8, top_p: float = 1.0, top_k: int = 0, budget_ratio: float = 1.0, show_tokens: bool = False, min_new_tokens: int = 3, ) -> Dict: self._reset_mamba_states() print(f"\n{'='*80}") print("FAST GENERATION (no cache)") print(f"{'='*80}") print(f"Prompt: {prompt}") print("─" * 80) input_ids = self._format_to_training_chat(prompt) generated_tokens: List[int] = [] token_times: List[float] = [] stop_ids = set(t for t in [self.im_end_id, self.eos_id] if t is not None) ban_initial_ids = set(t for t in [self.im_end_id, self.eos_id, self.im_start_id, self.pad_id] if t is not None) start_time = time.time() with torch.inference_mode(): # Prefill over full prompt logits, _ = self.model( input_ids, budget_ratio=budget_ratio, position_offset=0, use_cache=False ) next_token_logits = logits[:, -1, :] # [1, vocab] vocab_size = next_token_logits.size(-1) print("Generating...", end=" ", flush=True) is_cuda = torch.cuda.is_available() buffer = [] # small output buffer for streaming for _ in range(max_tokens): if is_cuda: torch.cuda.synchronize() t0 = time.time() # 1D view for sampling/masking logits_for_sampling = next_token_logits.squeeze(0).clone() / max(1e-6, temperature) vocab_size = logits_for_sampling.size(0) # Ban structural tokens at the very start if len(generated_tokens) < min_new_tokens and min_new_tokens > 0: for tid in ban_initial_ids: if tid is not None and 0 <= tid < vocab_size: logits_for_sampling[tid] = float("-inf") # Top-k if top_k and top_k > 0: kth = torch.topk(logits_for_sampling, top_k)[0][-1] logits_for_sampling[logits_for_sampling < kth] = float("-inf") # Top-p if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits_for_sampling, descending=True) cumulative_probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone() sorted_indices_to_remove[0] = False remove_idx = sorted_indices[sorted_indices_to_remove] logits_for_sampling[remove_idx] = float("-inf") # Sample probs = F.softmax(logits_for_sampling, dim=-1) next_token_id = torch.multinomial(probs, num_samples=1).item() generated_tokens.append(next_token_id) # Decode + buffered print if show_tokens: tok_text = self.tokenizer.decode([next_token_id], skip_special_tokens=False) buffer.append(tok_text) if len(buffer) >= 16: print("".join(buffer), end="", flush=True) buffer.clear() # Stop on EOT/EOS after min_new_tokens if (next_token_id in stop_ids) and (len(generated_tokens) >= max(1, min_new_tokens)): if buffer: print("".join(buffer), end="", flush=True) buffer.clear() if show_tokens: print(" [EOT]", flush=True) break # Stateless decode: append token and re-run forward input_ids = torch.cat( [input_ids, torch.tensor([[next_token_id]], device=self.device)], dim=1 ) logits, _ = self.model( input_ids, budget_ratio=budget_ratio, position_offset=0, use_cache=False ) next_token_logits = logits[:, -1, :] if is_cuda: torch.cuda.synchronize() token_times.append(time.time() - t0) # Flush any remaining buffered tokens if buffer: print("".join(buffer), end="", flush=True) buffer.clear() total_time = time.time() - start_time text = self.tokenizer.decode(generated_tokens, skip_special_tokens=False) text = self._postprocess_like_training(text) if show_tokens and (not generated_tokens or (generated_tokens[-1] not in stop_ids)): print() num_gen = len(generated_tokens) if num_gen == 0: print("\nNo tokens generated.") return {'output': '', 'tokens_per_sec': 0, 'decode_tps': 0, 'total_time': total_time, 'num_tokens': 0} decode_time = sum(token_times) toks_per_sec = num_gen / total_time if total_time > 0 else 0 decode_tps = num_gen / decode_time if decode_time > 0 else 0 print("\n" + "─" * 80) print("STATISTICS") print("─" * 80) print(f"Tokens: {num_gen}") print(f"Total time: {total_time:.2f}s") print(f"Overall speed: {toks_per_sec:.1f} tok/s (includes prompt)") print(f"Decode speed: {decode_tps:.1f} tok/s (generation only)") print(f"Time/token: {(decode_time/num_gen)*1000:.1f}ms") print("─" * 80) print(f"Output: {text[:100]}{'...' if len(text) > 100 else ''}") print("=" * 80 + "\n") self._reset_mamba_states() return { 'output': text, 'tokens_per_sec': toks_per_sec, 'decode_tps': decode_tps, 'total_time': total_time, 'num_tokens': num_gen, } def interactive_mode(self): print("\n" + "=" * 80) print("INTERACTIVE MODE (no cache, stateless)") print("Type 'quit' or your prompt") print("=" * 80 + "\n") while True: try: prompt = input("\nYou: ") except (EOFError, KeyboardInterrupt): print("\nBye.") break if prompt.lower() in ["quit", "exit", "q"]: break if not prompt.strip(): continue print("\nAssistant: ", end="", flush=True) self.generate_once(prompt, max_tokens=2000, temperature=0.8, show_tokens=True) def _cast_layernorm_fp32(module: nn.Module): for m in module.modules(): if isinstance(m, nn.LayerNorm): m.float() def load_model_and_tokenizer(model_dir: str): """ Load AdaptiveRiverLM model and tokenizer from a folder layout like: model_dir/ checkpoint.pt (or any .pt file) tokenizer/ tokenizer.json special_tokens_map.json ... Automatically finds the .pt file if not explicitly named. """ print(f"Searching for model checkpoint in: {model_dir}") ckpts = glob.glob(os.path.join(model_dir, "*.pt")) if not ckpts: raise FileNotFoundError(f"No .pt checkpoint found in {model_dir}") if len(ckpts) > 1: print(f"[Warning] Multiple .pt files found, using: {ckpts[0]}") checkpoint_path = ckpts[0] tokenizer_path = os.path.join(model_dir, "tokenizer") if not os.path.isdir(tokenizer_path): raise FileNotFoundError(f"Missing tokenizer directory: {tokenizer_path}") print(f"Loading tokenizer from: {tokenizer_path}") tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True, trust_remote_code=True) if tokenizer.pad_token is None: print("Tokenizer missing pad_token. Assigning eos_token as pad_token.") tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token_id = tokenizer.eos_token_id print("Building model (AdaptiveRiverLM)...") cfg = estimate_1b_config() cfg.vocab_size = len(tokenizer) cfg.tie_word_embeddings = False model = AdaptiveRiverLM(cfg) print(f"Loading checkpoint: {checkpoint_path}") state = torch.load(checkpoint_path, map_location="cpu") model_state_dict = model.state_dict() converted_state = {} for k, param in model_state_dict.items(): if k in state and state[k].shape == param.shape: converted_state[k] = state[k] print("Loading weights...") load_result = model.load_state_dict(converted_state, strict=False) if load_result.missing_keys: print("\n--- Missing Keys ---") for k in load_result.missing_keys: print(" ", k) if load_result.unexpected_keys: print("\n--- Unexpected Keys ---") for k in load_result.unexpected_keys: print(" ", k) device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) if device == "cuda" and torch.cuda.is_bf16_supported(): _cast_layernorm_fp32(model) model = model.to(torch.bfloat16) else: model = model.to(torch.float32) model.eval() print(f"Model and tokenizer loaded successfully from {model_dir} on {device}") return model, tokenizer, device def main(): parser = argparse.ArgumentParser(description="Stateless inference for AdaptiveRiverLM (no KV cache), proper EOT handling") parser.add_argument("--model_dir", type=str, required=True, help="Path to model folder (with checkpoint.pt and tokenizer/)") parser.add_argument("--prompt", type=str, default="Hello, my name is") parser.add_argument("--max_tokens", type=int, default=2000) parser.add_argument("--temperature", type=float, default=0.8) parser.add_argument("--top_p", type=float, default=1.0) parser.add_argument("--top_k", type=int, default=0) parser.add_argument("--min_new_tokens", type=int, default=3) parser.add_argument("--interactive", action="store_true", help="Interactive mode (stateless)") args = parser.parse_args() model, tokenizer, device = load_model_and_tokenizer(args.model_dir) # Resolve special token IDs for end-of-turn handling im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>") im_start_id = tokenizer.convert_tokens_to_ids("<|im_start|>") eos_id = tokenizer.eos_token_id pad_id = tokenizer.pad_token_id stop_ids = set(t for t in [im_end_id, eos_id] if t is not None) ban_initial_ids = set(t for t in [im_end_id, eos_id, im_start_id, pad_id] if t is not None) tester = FastInferenceTester(model, tokenizer, device, im_start_id, im_end_id, eos_id, pad_id) if args.interactive: tester.interactive_mode() else: tester.generate_once( args.prompt, max_tokens=args.max_tokens, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k, show_tokens=True, min_new_tokens=args.min_new_tokens, ) if __name__ == "__main__": main()