|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
import time |
|
|
from dataclasses import dataclass |
|
|
from typing import Dict, List, Optional, Tuple |
|
|
import glob |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers import AutoTokenizer |
|
|
import torch.utils.checkpoint as cp |
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
from mamba_ssm import Mamba |
|
|
from mamba_ssm.utils.generation import InferenceParams |
|
|
_HAS_MAMBA = True |
|
|
except ImportError: |
|
|
_HAS_MAMBA = False |
|
|
InferenceParams = None |
|
|
print("=" * 80) |
|
|
print("[WARNING] mamba-ssm not installed. Mamba layers will not function.") |
|
|
print("Install with: pip install mamba-ssm") |
|
|
print("=" * 80) |
|
|
|
|
|
class Mamba(nn.Module): |
|
|
def __init__(self, *args, **kwargs): |
|
|
super().__init__() |
|
|
print("ERROR: Mamba placeholder. mamba-ssm not installed.") |
|
|
def forward(self, x, *args, **kwargs): |
|
|
print("ERROR: mamba-ssm not installed. Cannot run MambaBlock.") |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class AdaptiveRiverConfig: |
|
|
vocab_size: int = 50257 |
|
|
d_model: int = 1024 |
|
|
n_layers: int = 24 |
|
|
d_ff: int = 4096 |
|
|
dropout: float = 0.0 |
|
|
rope_theta: float = 10000.0 |
|
|
rotary_pct: float = 1.0 |
|
|
layer_norm_eps: float = 1e-5 |
|
|
rope_scaling_type: str | None = None |
|
|
rope_scaling_factor: float = 1.0 |
|
|
experts_per_layer: int = 4 |
|
|
top_k_ffn: int = 1 |
|
|
moe_dropout: float = 0.0 |
|
|
attn_n_experts: int = 6 |
|
|
attn_top_k: int = 6 |
|
|
attn_n_orig_heads: int = 16 |
|
|
mamba_d_state: int = 16 |
|
|
mamba_d_conv: int = 4 |
|
|
mamba_expand: int = 2 |
|
|
entropy_weight: float = 1e-4 |
|
|
head_entropy_weight: float = 1e-4 |
|
|
default_budget_ratio: float = 1.0 |
|
|
init_std: float = 0.02 |
|
|
tie_word_embeddings: bool = False |
|
|
load_balance_weight: float = 0.01 |
|
|
router_z_weight: float = 0.001 |
|
|
gate_temperature: float = 0.7 |
|
|
checkpoint_attn_thresh: float = 0.35 |
|
|
checkpoint_ffn_thresh: float = 0.35 |
|
|
soak_dtype: str = "fp32" |
|
|
|
|
|
def _init_weights(module: nn.Module, std: float): |
|
|
if isinstance(module, nn.Linear): |
|
|
nn.init.normal_(module.weight, mean=0.0, std=std) |
|
|
if module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
|
|
|
def topk_mask_ste(scores: torch.Tensor, k: int) -> torch.Tensor: |
|
|
s = scores.float() |
|
|
if k >= s.size(-1): |
|
|
return torch.ones_like(s) |
|
|
topk = torch.topk(s, k=k, dim=-1).indices |
|
|
one_hot = torch.zeros_like(s) |
|
|
one_hot.scatter_(dim=-1, index=topk, value=1.0) |
|
|
probs = F.softmax(s, dim=-1) |
|
|
return one_hot + probs - probs.detach() |
|
|
|
|
|
class RotaryEmbedding(nn.Module): |
|
|
def __init__(self, dim, base=10000.0, scaling_type: str | None = None, scaling_factor: float = 1.0): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
self.base = float(base) |
|
|
self.scaling_type = scaling_type |
|
|
self.scaling_factor = float(scaling_factor) |
|
|
base = self._effective_base() |
|
|
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.float32) / self.dim)) |
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
self._cos_sin_cache = None |
|
|
self._cos_sin_cache_device = None |
|
|
self._cos_sin_cache_dtype = None |
|
|
self._cos_sin_max_seq_len = -1 |
|
|
def _effective_base(self) -> float: |
|
|
if not self.scaling_type or self.scaling_factor == 1.0: |
|
|
return self.base |
|
|
if self.scaling_type in ("ntk", "linear", "yarn"): |
|
|
return self.base * self.scaling_factor |
|
|
return self.base |
|
|
def _get_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype): |
|
|
if (seq_len > self._cos_sin_max_seq_len or self._cos_sin_cache is None |
|
|
or self._cos_sin_cache_device != device or self._cos_sin_cache_dtype != dtype): |
|
|
self._cos_sin_max_seq_len = max(seq_len, 2048) |
|
|
t = torch.arange(self._cos_sin_max_seq_len, device=device, dtype=self.inv_freq.dtype) |
|
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
|
cos = emb.cos().to(dtype) |
|
|
sin = emb.sin().to(dtype) |
|
|
self._cos_sin_cache = (cos, sin) |
|
|
self._cos_sin_cache_device = device |
|
|
self._cos_sin_cache_dtype = dtype |
|
|
return self._cos_sin_cache |
|
|
def forward(self, x, seq_len: int, offset: int | torch.Tensor = 0): |
|
|
device, dtype = x.device, x.dtype |
|
|
cos, sin = self._get_cos_sin_cache(seq_len + int(offset), device, dtype) |
|
|
if isinstance(offset, torch.Tensor): |
|
|
if offset.numel() > 1: |
|
|
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype).float() |
|
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
|
cos_val = emb.cos()[None, None, :, :].to(dtype) |
|
|
sin_val = emb.sin()[None, None, :, :].to(dtype) |
|
|
return cos_val, sin_val |
|
|
else: |
|
|
offset = int(offset.item()) |
|
|
cos = cos[offset:offset+seq_len].unsqueeze(0).unsqueeze(0) |
|
|
sin = sin[offset:offset+seq_len].unsqueeze(0).unsqueeze(0) |
|
|
return cos, sin |
|
|
|
|
|
def apply_rotary(x, cos, sin): |
|
|
x1, x2 = x[..., ::2], x[..., 1::2] |
|
|
x_rot = torch.stack((-x2, x1), dim=-1).flatten(-2) |
|
|
return x * cos + x_rot * sin |
|
|
|
|
|
class PTLayerNorm(nn.Module): |
|
|
def __init__(self, hidden_size, eps=1e-5): |
|
|
super().__init__() |
|
|
self.ln = nn.LayerNorm(hidden_size, eps=eps) |
|
|
def forward(self, x): |
|
|
return self.ln(x) |
|
|
|
|
|
class GlobalSDPAHead(nn.Module): |
|
|
def __init__(self, d_model, head_dim, dropout, rope_theta, rotary_pct, cfg): |
|
|
super().__init__() |
|
|
self.q_proj = nn.Linear(d_model, head_dim, bias=False) |
|
|
self.k_proj = nn.Linear(d_model, head_dim, bias=False) |
|
|
self.v_proj = nn.Linear(d_model, head_dim, bias=False) |
|
|
self.rotary_dim = int(head_dim * rotary_pct) |
|
|
self.dropout_p = dropout |
|
|
self.rope = None |
|
|
if self.rotary_dim > 0: |
|
|
self.rope = RotaryEmbedding( |
|
|
self.rotary_dim, base=rope_theta, |
|
|
scaling_type=cfg.rope_scaling_type, |
|
|
scaling_factor=cfg.rope_scaling_factor, |
|
|
) |
|
|
def forward(self, x, position_offset): |
|
|
if isinstance(position_offset, torch.Tensor): |
|
|
position_offset = int(position_offset.view(-1)[0].item()) |
|
|
else: |
|
|
position_offset = int(position_offset) |
|
|
B, T, C = x.shape |
|
|
q, k, v = self.q_proj(x), self.k_proj(x), self.v_proj(x) |
|
|
if self.rotary_dim > 0: |
|
|
cos, sin = self.rope(q, seq_len=T, offset=position_offset) |
|
|
cos = cos.squeeze(1); sin = sin.squeeze(1) |
|
|
q_rot = apply_rotary(q[..., :self.rotary_dim], cos, sin) |
|
|
k_rot = apply_rotary(k[..., :self.rotary_dim], cos, sin) |
|
|
q = torch.cat([q_rot, q[..., self.rotary_dim:]], dim=-1) |
|
|
k = torch.cat([k_rot, k[..., self.rotary_dim:]], dim=-1) |
|
|
q, k, v = [t.unsqueeze(1) for t in (q, k, v)] |
|
|
dropout_p = self.dropout_p if self.training else 0.0 |
|
|
out = F.scaled_dot_product_attention(q, k, v, is_causal=True, dropout_p=dropout_p) |
|
|
return out.squeeze(1) |
|
|
|
|
|
class AttentionMoERouter(nn.Module): |
|
|
def __init__(self, d_model, num_experts, top_k): |
|
|
super().__init__() |
|
|
self.top_k = top_k |
|
|
self.num_experts = num_experts |
|
|
self.gate_proj = nn.Linear(d_model, num_experts, bias=False) |
|
|
nn.init.normal_(self.gate_proj.weight, mean=0.0, std=0.01) |
|
|
def forward(self, x, budget_ratio, temperature): |
|
|
seq_embed = x.mean(dim=1) |
|
|
logits = self.gate_proj(seq_embed) / max(1e-6, float(temperature)) |
|
|
logits = logits.clamp(min=-10.0, max=10.0) |
|
|
k_target = max(1, int(round(self.top_k * (0.25 + 0.75 * budget_ratio)))) |
|
|
k_target = min(k_target, logits.size(-1)) |
|
|
vals, idx = torch.topk(logits, k_target, dim=-1) |
|
|
weights = F.softmax(vals.to(torch.float32), dim=-1).to(x.dtype) |
|
|
mask = torch.zeros_like(logits, dtype=torch.bool) |
|
|
mask.scatter_(1, idx, True) |
|
|
with torch.no_grad(): |
|
|
p = F.softmax(logits, dim=-1) |
|
|
entropy = -(p * (p.clamp_min(1e-12)).log()).sum(dim=-1).mean() |
|
|
return mask, weights, idx, entropy, logits |
|
|
|
|
|
class MoEAttention(nn.Module): |
|
|
def __init__(self, cfg: AdaptiveRiverConfig): |
|
|
super().__init__() |
|
|
self.d_model = cfg.d_model |
|
|
self.n_experts = cfg.attn_n_experts |
|
|
self.cfg = cfg |
|
|
self.head_dim = cfg.d_model // cfg.attn_n_orig_heads |
|
|
self.rotary_dim = int(self.head_dim * cfg.rotary_pct) |
|
|
self.router = AttentionMoERouter(cfg.d_model, cfg.attn_n_experts, cfg.attn_top_k) |
|
|
self.q_proj = nn.Linear(cfg.d_model, self.n_experts * self.head_dim, bias=False) |
|
|
self.k_proj = nn.Linear(cfg.d_model, self.n_experts * self.head_dim, bias=False) |
|
|
self.v_proj = nn.Linear(cfg.d_model, self.n_experts * self.head_dim, bias=False) |
|
|
self.rope = None |
|
|
if self.rotary_dim > 0: |
|
|
self.rope = RotaryEmbedding( |
|
|
self.rotary_dim, base=cfg.rope_theta, |
|
|
scaling_type=cfg.rope_scaling_type, |
|
|
scaling_factor=cfg.rope_scaling_factor, |
|
|
) |
|
|
self.o_proj = nn.Linear(cfg.attn_n_experts * self.head_dim, cfg.d_model, bias=False) |
|
|
def forward(self, x, position_offset, budget_ratio, temperature): |
|
|
B, T, C = x.shape |
|
|
E, H = self.n_experts, self.head_dim |
|
|
sel_mask, gate_w, gate_idx, entropy, gate_logits = self.router(x, budget_ratio, temperature) |
|
|
q = self.q_proj(x).view(B, T, E, H).permute(0, 2, 1, 3) |
|
|
k = self.k_proj(x).view(B, T, E, H).permute(0, 2, 1, 3) |
|
|
v = self.v_proj(x).view(B, T, E, H).permute(0, 2, 1, 3) |
|
|
if self.rope: |
|
|
if isinstance(position_offset, torch.Tensor): |
|
|
position_offset = int(position_offset.view(-1)[0].item()) |
|
|
else: |
|
|
position_offset = int(position_offset) |
|
|
cos, sin = self.rope(q, seq_len=T, offset=position_offset) |
|
|
cos = cos.squeeze(1); sin = sin.squeeze(1) |
|
|
q_rot = apply_rotary(q[..., :self.rotary_dim], cos, sin) |
|
|
k_rot = apply_rotary(k[..., :self.rotary_dim], cos, sin) |
|
|
q = torch.cat([q_rot, q[..., self.rotary_dim:]], dim=-1) |
|
|
k = torch.cat([k_rot, k[..., self.rotary_dim:]], dim=-1) |
|
|
q_b = q.reshape(B * E, T, H) |
|
|
k_b = k.reshape(B * E, T, H) |
|
|
v_b = v.reshape(B * E, T, H) |
|
|
dropout_p = self.cfg.dropout if self.training else 0.0 |
|
|
out_b = F.scaled_dot_product_attention(q_b, k_b, v_b, is_causal=True, dropout_p=dropout_p) |
|
|
out = out_b.view(B, E, T, H).permute(0, 2, 1, 3) |
|
|
W = torch.zeros(B, E, device=x.device, dtype=out.dtype) |
|
|
W.scatter_(1, gate_idx, gate_w.to(out.dtype)) |
|
|
weighted_out = torch.einsum('b t e h, b e -> b t e h', out, W) |
|
|
y = weighted_out.reshape(B, T, E * H).to(self.o_proj.weight.dtype) |
|
|
y = self.o_proj(y) |
|
|
with torch.no_grad(): |
|
|
usage = sel_mask.float().mean(dim=0) |
|
|
expected = sel_mask.float().sum(dim=-1).mean() |
|
|
den = torch.clamp(expected, min=1e-6) |
|
|
usage_norm = usage / den |
|
|
uniform = 1.0 / self.n_experts |
|
|
attn_lb = ((usage_norm - uniform) ** 2).sum() * self.n_experts / self.n_experts |
|
|
attn_rz = (gate_logits ** 2).mean() |
|
|
head_keep = sel_mask.float().mean() |
|
|
return y, { |
|
|
"head_entropy": entropy, |
|
|
"head_keep_frac": head_keep, |
|
|
"attn_load_balance_loss": attn_lb, |
|
|
"attn_router_z_loss": attn_rz, |
|
|
} |
|
|
|
|
|
class ExpertFFN(nn.Module): |
|
|
def __init__(self, d_model: int, d_ff: int, dropout: float): |
|
|
super().__init__() |
|
|
self.w1 = nn.Linear(d_model, d_ff, bias=False) |
|
|
self.w2 = nn.Linear(d_ff, d_model, bias=False) |
|
|
self.dropout_p = dropout |
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x = self.w1(x) |
|
|
x = F.gelu(x, approximate="tanh") |
|
|
x = F.dropout(x, p=self.dropout_p, training=self.training) |
|
|
x = self.w2(x) |
|
|
return x |
|
|
|
|
|
class MoEFFN(nn.Module): |
|
|
def __init__(self, d_model: int, d_ff: int, n_experts: int, top_k: int, dropout: float, cfg: AdaptiveRiverConfig): |
|
|
super().__init__() |
|
|
self.n_experts = n_experts |
|
|
self.base_top_k = top_k |
|
|
self.cfg = cfg |
|
|
self.router = nn.Linear(d_model, n_experts, bias=False) |
|
|
self.w1_stacked = nn.Parameter(torch.empty(n_experts, d_ff, d_model)) |
|
|
self.w2_stacked = nn.Parameter(torch.empty(n_experts, d_model, d_ff)) |
|
|
std = cfg.init_std |
|
|
nn.init.normal_(self.router.weight, mean=0.0, std=std) |
|
|
nn.init.normal_(self.w1_stacked, mean=0.0, std=std) |
|
|
nn.init.normal_(self.w2_stacked, mean=0.0, std=std) |
|
|
def forward(self, x: torch.Tensor, budget_ratio: float): |
|
|
B, T, C = x.shape |
|
|
N = B * T |
|
|
X = x.reshape(N, C) |
|
|
k_target = max(1, int(round(self.base_top_k * (0.5 + budget_ratio / 2.0)))) |
|
|
k_target = min(k_target, self.n_experts) |
|
|
scores = self.router(X).to(torch.float32).clamp(min=-10.0, max=10.0) |
|
|
probs = F.softmax(scores, dim=-1).to(X.dtype) |
|
|
mask = topk_mask_ste(scores, k=k_target).to(X.dtype) |
|
|
gate = (mask * probs) |
|
|
gate = gate / gate.sum(dim=-1, keepdim=True).clamp_min(1e-6) |
|
|
x_ff = torch.einsum('n c, e d c -> n e d', X, self.w1_stacked) |
|
|
x_act = F.gelu(x_ff, approximate="tanh") |
|
|
y_experts = torch.einsum('n e d, e c d -> n e c', x_act, self.w2_stacked) |
|
|
y = torch.einsum('n e, n e c -> n c', gate, y_experts).view(B, T, C).to(x.dtype) |
|
|
with torch.no_grad(): |
|
|
entropy = (-probs * probs.clamp_min(1e-12).log()).sum(dim=-1).mean() |
|
|
router_z = (scores ** 2).mean().clamp(max=10.0) |
|
|
frac = mask.mean(dim=0) |
|
|
uniform = 1.0 / self.n_experts |
|
|
lb = ((frac - uniform) ** 2).sum() * self.n_experts / self.n_experts |
|
|
return y, { |
|
|
"router_entropy": entropy, |
|
|
"ffn_expert_usage": frac.detach(), |
|
|
"ffn_load_balance_loss": lb, |
|
|
"ffn_router_z_loss": router_z, |
|
|
} |
|
|
|
|
|
class MambaBlock(nn.Module): |
|
|
def __init__(self, cfg: AdaptiveRiverConfig, enhanced: bool = False, layer_idx: int | None = None): |
|
|
super().__init__() |
|
|
if not _HAS_MAMBA: |
|
|
print(f"MambaBlock Layer {layer_idx} disabled: mamba-ssm not installed.") |
|
|
self.mamba = None |
|
|
return |
|
|
self.cfg = cfg |
|
|
self.ln1 = PTLayerNorm(cfg.d_model, eps=cfg.layer_norm_eps) |
|
|
self.mamba = Mamba( |
|
|
d_model=cfg.d_model, |
|
|
d_state=cfg.mamba_d_state, |
|
|
d_conv=cfg.mamba_d_conv, |
|
|
expand=cfg.mamba_expand * (2 if enhanced else 1), |
|
|
layer_idx=layer_idx, |
|
|
) |
|
|
self.ln2 = PTLayerNorm(cfg.d_model, eps=cfg.layer_norm_eps) |
|
|
self.ffn = nn.Sequential( |
|
|
nn.Linear(cfg.d_model, cfg.d_ff * (2 if enhanced else 1), bias=False), |
|
|
nn.GELU(approximate="tanh"), |
|
|
nn.Linear(cfg.d_ff * (2 if enhanced else 1), cfg.d_model, bias=False), |
|
|
) |
|
|
def forward( |
|
|
self, |
|
|
x, |
|
|
attn_mask=None, |
|
|
position_offset: int | torch.Tensor = 0, |
|
|
past_kv=None, |
|
|
budget_ratio: float = 1.0, |
|
|
use_cache: bool = False, |
|
|
mamba_state: Optional[InferenceParams] = None, |
|
|
): |
|
|
if not _HAS_MAMBA or self.mamba is None: |
|
|
stats = {"head_entropy": torch.tensor(0.0, device=x.device), |
|
|
"head_keep_frac": torch.tensor(1.0, device=x.device), |
|
|
"mamba_out_l2": torch.tensor(0.0, device=x.device)} |
|
|
return x, stats, (None, None) |
|
|
h = self.ln1(x) |
|
|
x_m = self.mamba(h) |
|
|
m_out_l2 = x_m.float().pow(2).mean() |
|
|
x = x + x_m |
|
|
h2 = self.ln2(x) |
|
|
x = x + self.ffn(h2) |
|
|
stats = { |
|
|
"head_entropy": torch.tensor(0.0, device=x.device), |
|
|
"head_keep_frac": torch.tensor(1.0, device=x.device), |
|
|
"mamba_out_l2": m_out_l2.detach(), |
|
|
} |
|
|
return x, stats, (None, None) |
|
|
|
|
|
class RoutedBlock(nn.Module): |
|
|
def __init__(self, cfg: AdaptiveRiverConfig): |
|
|
super().__init__() |
|
|
self.cfg = cfg |
|
|
self.ln1 = PTLayerNorm(cfg.d_model, eps=cfg.layer_norm_eps) |
|
|
self.ln2 = PTLayerNorm(cfg.d_model, eps=cfg.layer_norm_eps) |
|
|
self.attn = MoEAttention(cfg) |
|
|
self.ffn = MoEFFN(cfg.d_model, cfg.d_ff, cfg.experts_per_layer, cfg.top_k_ffn, cfg.moe_dropout, cfg) |
|
|
def _attn_forward(self, h: torch.Tensor, position_offset: int, budget_ratio: float): |
|
|
if isinstance(position_offset, torch.Tensor): |
|
|
position_offset = int(position_offset.view(-1)[0].item()) |
|
|
else: |
|
|
position_offset = int(position_offset) |
|
|
return self.attn(h, position_offset, budget_ratio, self.cfg.gate_temperature) |
|
|
def forward( |
|
|
self, |
|
|
x, |
|
|
attn_mask=None, |
|
|
position_offset: int | torch.Tensor = 0, |
|
|
past_kv=None, |
|
|
budget_ratio: float = 1.0, |
|
|
use_cache: bool = False, |
|
|
mamba_state: Optional[InferenceParams] = None, |
|
|
): |
|
|
h = self.ln1(x) |
|
|
attn_out, attn_stats = self._attn_forward(h, position_offset, budget_ratio) |
|
|
x = x + attn_out |
|
|
h2 = self.ln2(x) |
|
|
ffn_out, moe_stats = self.ffn(h2, budget_ratio=budget_ratio) |
|
|
x = x + ffn_out |
|
|
stats = {**attn_stats, **moe_stats} |
|
|
return x, stats, (None, None) |
|
|
|
|
|
class AdaptiveRiverLM(nn.Module): |
|
|
def __init__(self, cfg: AdaptiveRiverConfig): |
|
|
super().__init__() |
|
|
self.cfg = cfg |
|
|
self.embed = nn.Embedding(cfg.vocab_size, cfg.d_model) |
|
|
self.blocks = nn.ModuleList() |
|
|
mamba_layer_counter = 0 |
|
|
for i in range(cfg.n_layers): |
|
|
if i < 2: |
|
|
print(f"[model] Layer {i}: Mamba") |
|
|
self.blocks.append(MambaBlock(cfg, enhanced=False, layer_idx=mamba_layer_counter)); mamba_layer_counter += 1 |
|
|
elif i >= (cfg.n_layers - 2): |
|
|
print(f"[model] Layer {i}: Mamba (enhanced)") |
|
|
self.blocks.append(MambaBlock(cfg, enhanced=True, layer_idx=mamba_layer_counter)); mamba_layer_counter += 1 |
|
|
else: |
|
|
if i == 2: |
|
|
print(f"[model] Layers {i}-{cfg.n_layers-3}: MoE Attention + MoE FFN") |
|
|
self.blocks.append(RoutedBlock(cfg)) |
|
|
self.ln_f = PTLayerNorm(cfg.d_model, eps=cfg.layer_norm_eps) |
|
|
self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) |
|
|
if cfg.tie_word_embeddings: |
|
|
self.lm_head.weight = self.embed.weight |
|
|
self.apply(lambda m: _init_weights(m, cfg.init_std) if isinstance(m, nn.Linear) else None) |
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
budget_ratio: Optional[float] = None, |
|
|
mamba_states: Optional[List] = None, |
|
|
past_kvs: Optional[List] = None, |
|
|
position_offset: int | torch.Tensor = 0, |
|
|
return_expert_stats: bool = False, |
|
|
use_cache: bool = False, |
|
|
): |
|
|
x = self.embed(input_ids) |
|
|
b = float(self.cfg.default_budget_ratio if budget_ratio is None else budget_ratio) |
|
|
all_stats: Dict[str, List[torch.Tensor]] = {} |
|
|
for block in self.blocks: |
|
|
x, stats, _ = block( |
|
|
x, |
|
|
position_offset=position_offset, |
|
|
past_kv=None, |
|
|
budget_ratio=b, |
|
|
use_cache=False, |
|
|
mamba_state=None, |
|
|
) |
|
|
for k, v in stats.items(): |
|
|
all_stats.setdefault(k, []).append(torch.as_tensor(v.detach() if isinstance(v, torch.Tensor) else v)) |
|
|
_ = {k: torch.stack(v).mean() for k, v in all_stats.items() if len(v) > 0} |
|
|
x = self.ln_f(x) |
|
|
logits = self.lm_head(x) |
|
|
return logits, _ |
|
|
|
|
|
def estimate_1b_config() -> AdaptiveRiverConfig: |
|
|
return AdaptiveRiverConfig( |
|
|
vocab_size=50257, |
|
|
d_model=1024, |
|
|
n_layers=24, |
|
|
d_ff=4096, |
|
|
experts_per_layer=4, |
|
|
top_k_ffn=1, |
|
|
default_budget_ratio=1.0, |
|
|
attn_n_experts=6, |
|
|
attn_top_k=6, |
|
|
attn_n_orig_heads=16, |
|
|
mamba_d_state=16, |
|
|
mamba_d_conv=4, |
|
|
mamba_expand=2, |
|
|
gate_temperature=0.7, |
|
|
head_entropy_weight=1e-4, |
|
|
checkpoint_attn_thresh=0.35, |
|
|
checkpoint_ffn_thresh=0.35, |
|
|
load_balance_weight=0.01, |
|
|
router_z_weight=0.001, |
|
|
tie_word_embeddings=False, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FastInferenceTester: |
|
|
def __init__(self, model, tokenizer, device, im_start_id, im_end_id, eos_id, pad_id): |
|
|
self.model = model |
|
|
self.tokenizer = tokenizer |
|
|
self.device = device |
|
|
self.im_start_id = im_start_id |
|
|
self.im_end_id = im_end_id |
|
|
self.eos_id = eos_id |
|
|
self.pad_id = pad_id |
|
|
|
|
|
self.model.eval() |
|
|
torch.set_grad_enabled(False) |
|
|
print("Using model's native precision") |
|
|
|
|
|
if hasattr(torch, 'compile') and _HAS_MAMBA: |
|
|
print("Skipping torch.compile due to mamba-ssm kernels.") |
|
|
else: |
|
|
try: |
|
|
print("Compiling model with torch.compile...") |
|
|
self.model = torch.compile(self.model, mode="reduce-overhead") |
|
|
print("Model compiled successfully") |
|
|
except Exception as e: |
|
|
print(f"Could not compile model: {e}") |
|
|
print("Running without compilation") |
|
|
|
|
|
def _format_to_training_chat(self, prompt: str) -> torch.Tensor: |
|
|
messages = [{"role": "user", "content": prompt}] |
|
|
formatted = self.tokenizer.apply_chat_template( |
|
|
messages, tokenize=False, add_generation_prompt=True |
|
|
) |
|
|
input_ids = self.tokenizer.encode( |
|
|
formatted, add_special_tokens=False, return_tensors="pt" |
|
|
).to(self.device) |
|
|
return input_ids |
|
|
|
|
|
def _postprocess_like_training(self, text: str) -> str: |
|
|
if "<|im_start|>assistant" in text: |
|
|
return text.split("<|im_start|>assistant")[-1].split("<|im_end|>")[0].strip() |
|
|
if "assistant\n" in text: |
|
|
return text.split("assistant\n")[-1].split("<|im_end|>")[0].strip() |
|
|
return text.split("<|im_end|>")[0].strip() |
|
|
|
|
|
def _reset_mamba_states(self): |
|
|
if not _HAS_MAMBA: |
|
|
return |
|
|
for block in self.model.blocks: |
|
|
if isinstance(block, MambaBlock) and hasattr(block, "mamba"): |
|
|
for attr in ("inference_params", "conv_state", "ssm_state"): |
|
|
if hasattr(block.mamba, attr): |
|
|
setattr(block.mamba, attr, None) |
|
|
|
|
|
def generate_once( |
|
|
self, |
|
|
prompt: str, |
|
|
max_tokens: int = 2000, |
|
|
temperature: float = 0.8, |
|
|
top_p: float = 1.0, |
|
|
top_k: int = 0, |
|
|
budget_ratio: float = 1.0, |
|
|
show_tokens: bool = False, |
|
|
min_new_tokens: int = 3, |
|
|
) -> Dict: |
|
|
self._reset_mamba_states() |
|
|
|
|
|
print(f"\n{'='*80}") |
|
|
print("FAST GENERATION (no cache)") |
|
|
print(f"{'='*80}") |
|
|
print(f"Prompt: {prompt}") |
|
|
print("─" * 80) |
|
|
|
|
|
input_ids = self._format_to_training_chat(prompt) |
|
|
|
|
|
generated_tokens: List[int] = [] |
|
|
token_times: List[float] = [] |
|
|
stop_ids = set(t for t in [self.im_end_id, self.eos_id] if t is not None) |
|
|
ban_initial_ids = set(t for t in [self.im_end_id, self.eos_id, self.im_start_id, self.pad_id] if t is not None) |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
with torch.inference_mode(): |
|
|
|
|
|
logits, _ = self.model( |
|
|
input_ids, |
|
|
budget_ratio=budget_ratio, |
|
|
position_offset=0, |
|
|
use_cache=False |
|
|
) |
|
|
next_token_logits = logits[:, -1, :] |
|
|
vocab_size = next_token_logits.size(-1) |
|
|
|
|
|
print("Generating...", end=" ", flush=True) |
|
|
is_cuda = torch.cuda.is_available() |
|
|
buffer = [] |
|
|
|
|
|
for _ in range(max_tokens): |
|
|
if is_cuda: |
|
|
torch.cuda.synchronize() |
|
|
t0 = time.time() |
|
|
|
|
|
|
|
|
logits_for_sampling = next_token_logits.squeeze(0).clone() / max(1e-6, temperature) |
|
|
vocab_size = logits_for_sampling.size(0) |
|
|
|
|
|
|
|
|
if len(generated_tokens) < min_new_tokens and min_new_tokens > 0: |
|
|
for tid in ban_initial_ids: |
|
|
if tid is not None and 0 <= tid < vocab_size: |
|
|
logits_for_sampling[tid] = float("-inf") |
|
|
|
|
|
|
|
|
if top_k and top_k > 0: |
|
|
kth = torch.topk(logits_for_sampling, top_k)[0][-1] |
|
|
logits_for_sampling[logits_for_sampling < kth] = float("-inf") |
|
|
|
|
|
|
|
|
if top_p < 1.0: |
|
|
sorted_logits, sorted_indices = torch.sort(logits_for_sampling, descending=True) |
|
|
cumulative_probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1) |
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone() |
|
|
sorted_indices_to_remove[0] = False |
|
|
remove_idx = sorted_indices[sorted_indices_to_remove] |
|
|
logits_for_sampling[remove_idx] = float("-inf") |
|
|
|
|
|
|
|
|
probs = F.softmax(logits_for_sampling, dim=-1) |
|
|
next_token_id = torch.multinomial(probs, num_samples=1).item() |
|
|
|
|
|
generated_tokens.append(next_token_id) |
|
|
|
|
|
|
|
|
if show_tokens: |
|
|
tok_text = self.tokenizer.decode([next_token_id], skip_special_tokens=False) |
|
|
buffer.append(tok_text) |
|
|
if len(buffer) >= 16: |
|
|
print("".join(buffer), end="", flush=True) |
|
|
buffer.clear() |
|
|
|
|
|
|
|
|
if (next_token_id in stop_ids) and (len(generated_tokens) >= max(1, min_new_tokens)): |
|
|
if buffer: |
|
|
print("".join(buffer), end="", flush=True) |
|
|
buffer.clear() |
|
|
if show_tokens: |
|
|
print(" [EOT]", flush=True) |
|
|
break |
|
|
|
|
|
|
|
|
input_ids = torch.cat( |
|
|
[input_ids, torch.tensor([[next_token_id]], device=self.device)], |
|
|
dim=1 |
|
|
) |
|
|
logits, _ = self.model( |
|
|
input_ids, |
|
|
budget_ratio=budget_ratio, |
|
|
position_offset=0, |
|
|
use_cache=False |
|
|
) |
|
|
next_token_logits = logits[:, -1, :] |
|
|
|
|
|
if is_cuda: |
|
|
torch.cuda.synchronize() |
|
|
token_times.append(time.time() - t0) |
|
|
|
|
|
|
|
|
if buffer: |
|
|
print("".join(buffer), end="", flush=True) |
|
|
buffer.clear() |
|
|
|
|
|
|
|
|
|
|
|
total_time = time.time() - start_time |
|
|
text = self.tokenizer.decode(generated_tokens, skip_special_tokens=False) |
|
|
text = self._postprocess_like_training(text) |
|
|
|
|
|
if show_tokens and (not generated_tokens or (generated_tokens[-1] not in stop_ids)): |
|
|
print() |
|
|
|
|
|
num_gen = len(generated_tokens) |
|
|
if num_gen == 0: |
|
|
print("\nNo tokens generated.") |
|
|
return {'output': '', 'tokens_per_sec': 0, 'decode_tps': 0, 'total_time': total_time, 'num_tokens': 0} |
|
|
|
|
|
decode_time = sum(token_times) |
|
|
toks_per_sec = num_gen / total_time if total_time > 0 else 0 |
|
|
decode_tps = num_gen / decode_time if decode_time > 0 else 0 |
|
|
|
|
|
print("\n" + "─" * 80) |
|
|
print("STATISTICS") |
|
|
print("─" * 80) |
|
|
print(f"Tokens: {num_gen}") |
|
|
print(f"Total time: {total_time:.2f}s") |
|
|
print(f"Overall speed: {toks_per_sec:.1f} tok/s (includes prompt)") |
|
|
print(f"Decode speed: {decode_tps:.1f} tok/s (generation only)") |
|
|
print(f"Time/token: {(decode_time/num_gen)*1000:.1f}ms") |
|
|
print("─" * 80) |
|
|
print(f"Output: {text[:100]}{'...' if len(text) > 100 else ''}") |
|
|
print("=" * 80 + "\n") |
|
|
|
|
|
self._reset_mamba_states() |
|
|
|
|
|
return { |
|
|
'output': text, |
|
|
'tokens_per_sec': toks_per_sec, |
|
|
'decode_tps': decode_tps, |
|
|
'total_time': total_time, |
|
|
'num_tokens': num_gen, |
|
|
} |
|
|
|
|
|
def interactive_mode(self): |
|
|
print("\n" + "=" * 80) |
|
|
print("INTERACTIVE MODE (no cache, stateless)") |
|
|
print("Type 'quit' or your prompt") |
|
|
print("=" * 80 + "\n") |
|
|
while True: |
|
|
try: |
|
|
prompt = input("\nYou: ") |
|
|
except (EOFError, KeyboardInterrupt): |
|
|
print("\nBye.") |
|
|
break |
|
|
if prompt.lower() in ["quit", "exit", "q"]: |
|
|
break |
|
|
if not prompt.strip(): |
|
|
continue |
|
|
print("\nAssistant: ", end="", flush=True) |
|
|
self.generate_once(prompt, max_tokens=2000, temperature=0.8, show_tokens=True) |
|
|
|
|
|
def _cast_layernorm_fp32(module: nn.Module): |
|
|
for m in module.modules(): |
|
|
if isinstance(m, nn.LayerNorm): |
|
|
m.float() |
|
|
|
|
|
def load_model_and_tokenizer(model_dir: str): |
|
|
""" |
|
|
Load AdaptiveRiverLM model and tokenizer from a folder layout like: |
|
|
|
|
|
model_dir/ |
|
|
checkpoint.pt (or any .pt file) |
|
|
tokenizer/ |
|
|
tokenizer.json |
|
|
special_tokens_map.json |
|
|
... |
|
|
|
|
|
Automatically finds the .pt file if not explicitly named. |
|
|
""" |
|
|
print(f"Searching for model checkpoint in: {model_dir}") |
|
|
ckpts = glob.glob(os.path.join(model_dir, "*.pt")) |
|
|
if not ckpts: |
|
|
raise FileNotFoundError(f"No .pt checkpoint found in {model_dir}") |
|
|
if len(ckpts) > 1: |
|
|
print(f"[Warning] Multiple .pt files found, using: {ckpts[0]}") |
|
|
checkpoint_path = ckpts[0] |
|
|
|
|
|
tokenizer_path = os.path.join(model_dir, "tokenizer") |
|
|
if not os.path.isdir(tokenizer_path): |
|
|
raise FileNotFoundError(f"Missing tokenizer directory: {tokenizer_path}") |
|
|
|
|
|
print(f"Loading tokenizer from: {tokenizer_path}") |
|
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True, trust_remote_code=True) |
|
|
if tokenizer.pad_token is None: |
|
|
print("Tokenizer missing pad_token. Assigning eos_token as pad_token.") |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
|
|
|
|
print("Building model (AdaptiveRiverLM)...") |
|
|
cfg = estimate_1b_config() |
|
|
cfg.vocab_size = len(tokenizer) |
|
|
cfg.tie_word_embeddings = False |
|
|
|
|
|
model = AdaptiveRiverLM(cfg) |
|
|
|
|
|
print(f"Loading checkpoint: {checkpoint_path}") |
|
|
state = torch.load(checkpoint_path, map_location="cpu") |
|
|
model_state_dict = model.state_dict() |
|
|
converted_state = {} |
|
|
|
|
|
for k, param in model_state_dict.items(): |
|
|
if k in state and state[k].shape == param.shape: |
|
|
converted_state[k] = state[k] |
|
|
|
|
|
print("Loading weights...") |
|
|
load_result = model.load_state_dict(converted_state, strict=False) |
|
|
|
|
|
if load_result.missing_keys: |
|
|
print("\n--- Missing Keys ---") |
|
|
for k in load_result.missing_keys: |
|
|
print(" ", k) |
|
|
if load_result.unexpected_keys: |
|
|
print("\n--- Unexpected Keys ---") |
|
|
for k in load_result.unexpected_keys: |
|
|
print(" ", k) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
model = model.to(device) |
|
|
|
|
|
if device == "cuda" and torch.cuda.is_bf16_supported(): |
|
|
_cast_layernorm_fp32(model) |
|
|
model = model.to(torch.bfloat16) |
|
|
else: |
|
|
model = model.to(torch.float32) |
|
|
|
|
|
model.eval() |
|
|
print(f"Model and tokenizer loaded successfully from {model_dir} on {device}") |
|
|
return model, tokenizer, device |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Stateless inference for AdaptiveRiverLM (no KV cache), proper EOT handling") |
|
|
parser.add_argument("--model_dir", type=str, required=True, help="Path to model folder (with checkpoint.pt and tokenizer/)") |
|
|
parser.add_argument("--prompt", type=str, default="Hello, my name is") |
|
|
parser.add_argument("--max_tokens", type=int, default=2000) |
|
|
parser.add_argument("--temperature", type=float, default=0.8) |
|
|
parser.add_argument("--top_p", type=float, default=1.0) |
|
|
parser.add_argument("--top_k", type=int, default=0) |
|
|
parser.add_argument("--min_new_tokens", type=int, default=3) |
|
|
parser.add_argument("--interactive", action="store_true", help="Interactive mode (stateless)") |
|
|
args = parser.parse_args() |
|
|
|
|
|
model, tokenizer, device = load_model_and_tokenizer(args.model_dir) |
|
|
|
|
|
|
|
|
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>") |
|
|
im_start_id = tokenizer.convert_tokens_to_ids("<|im_start|>") |
|
|
eos_id = tokenizer.eos_token_id |
|
|
pad_id = tokenizer.pad_token_id |
|
|
|
|
|
stop_ids = set(t for t in [im_end_id, eos_id] if t is not None) |
|
|
ban_initial_ids = set(t for t in [im_end_id, eos_id, im_start_id, pad_id] if t is not None) |
|
|
|
|
|
|
|
|
tester = FastInferenceTester(model, tokenizer, device, im_start_id, im_end_id, eos_id, pad_id) |
|
|
|
|
|
if args.interactive: |
|
|
tester.interactive_mode() |
|
|
else: |
|
|
tester.generate_once( |
|
|
args.prompt, |
|
|
max_tokens=args.max_tokens, |
|
|
temperature=args.temperature, |
|
|
top_p=args.top_p, |
|
|
top_k=args.top_k, |
|
|
show_tokens=True, |
|
|
min_new_tokens=args.min_new_tokens, |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|