Alienanthony commited on
Commit
9cd89a6
·
verified ·
1 Parent(s): f470e7a

Upload of model inferencing and svg

Browse files
Files changed (3) hide show
  1. ROE_Build.svg +1010 -0
  2. ROE_EDU_BASE_Undercooked.pt +3 -0
  3. inference_tester.py +839 -0
ROE_Build.svg ADDED
ROE_EDU_BASE_Undercooked.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c766bfe668c1523101abb66a76405a7ffa5cbd243c080e2f080709d48cf1131f
3
+ size 4415289553
inference_tester.py ADDED
@@ -0,0 +1,839 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import argparse
5
+ import time
6
+ from dataclasses import dataclass
7
+ from typing import Dict, List, Optional, Tuple
8
+ import glob
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from transformers import AutoTokenizer
13
+ import torch.utils.checkpoint as cp
14
+ import os
15
+
16
+ # ----------------------------------------------------------------------------
17
+ # mamba-ssm dependency
18
+ # ----------------------------------------------------------------------------
19
+ try:
20
+ from mamba_ssm import Mamba
21
+ from mamba_ssm.utils.generation import InferenceParams
22
+ _HAS_MAMBA = True
23
+ except ImportError:
24
+ _HAS_MAMBA = False
25
+ InferenceParams = None
26
+ print("=" * 80)
27
+ print("[WARNING] mamba-ssm not installed. Mamba layers will not function.")
28
+ print("Install with: pip install mamba-ssm")
29
+ print("=" * 80)
30
+
31
+ class Mamba(nn.Module):
32
+ def __init__(self, *args, **kwargs):
33
+ super().__init__()
34
+ print("ERROR: Mamba placeholder. mamba-ssm not installed.")
35
+ def forward(self, x, *args, **kwargs):
36
+ print("ERROR: mamba-ssm not installed. Cannot run MambaBlock.")
37
+ return x
38
+
39
+ # ----------------------------------------------------------------------------
40
+ # Model
41
+ # ----------------------------------------------------------------------------
42
+
43
+ @dataclass
44
+ class AdaptiveRiverConfig:
45
+ vocab_size: int = 50257
46
+ d_model: int = 1024
47
+ n_layers: int = 24
48
+ d_ff: int = 4096
49
+ dropout: float = 0.0
50
+ rope_theta: float = 10000.0
51
+ rotary_pct: float = 1.0
52
+ layer_norm_eps: float = 1e-5
53
+ rope_scaling_type: str | None = None
54
+ rope_scaling_factor: float = 1.0
55
+ experts_per_layer: int = 4
56
+ top_k_ffn: int = 1
57
+ moe_dropout: float = 0.0
58
+ attn_n_experts: int = 6
59
+ attn_top_k: int = 6
60
+ attn_n_orig_heads: int = 16
61
+ mamba_d_state: int = 16
62
+ mamba_d_conv: int = 4
63
+ mamba_expand: int = 2
64
+ entropy_weight: float = 1e-4
65
+ head_entropy_weight: float = 1e-4
66
+ default_budget_ratio: float = 1.0
67
+ init_std: float = 0.02
68
+ tie_word_embeddings: bool = False # untied head (matches training)
69
+ load_balance_weight: float = 0.01
70
+ router_z_weight: float = 0.001
71
+ gate_temperature: float = 0.7
72
+ checkpoint_attn_thresh: float = 0.35
73
+ checkpoint_ffn_thresh: float = 0.35
74
+ soak_dtype: str = "fp32"
75
+
76
+ def _init_weights(module: nn.Module, std: float):
77
+ if isinstance(module, nn.Linear):
78
+ nn.init.normal_(module.weight, mean=0.0, std=std)
79
+ if module.bias is not None:
80
+ nn.init.zeros_(module.bias)
81
+
82
+ def topk_mask_ste(scores: torch.Tensor, k: int) -> torch.Tensor:
83
+ s = scores.float()
84
+ if k >= s.size(-1):
85
+ return torch.ones_like(s)
86
+ topk = torch.topk(s, k=k, dim=-1).indices
87
+ one_hot = torch.zeros_like(s)
88
+ one_hot.scatter_(dim=-1, index=topk, value=1.0)
89
+ probs = F.softmax(s, dim=-1)
90
+ return one_hot + probs - probs.detach()
91
+
92
+ class RotaryEmbedding(nn.Module):
93
+ def __init__(self, dim, base=10000.0, scaling_type: str | None = None, scaling_factor: float = 1.0):
94
+ super().__init__()
95
+ self.dim = dim
96
+ self.base = float(base)
97
+ self.scaling_type = scaling_type
98
+ self.scaling_factor = float(scaling_factor)
99
+ base = self._effective_base()
100
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.float32) / self.dim))
101
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
102
+ self._cos_sin_cache = None
103
+ self._cos_sin_cache_device = None
104
+ self._cos_sin_cache_dtype = None
105
+ self._cos_sin_max_seq_len = -1
106
+ def _effective_base(self) -> float:
107
+ if not self.scaling_type or self.scaling_factor == 1.0:
108
+ return self.base
109
+ if self.scaling_type in ("ntk", "linear", "yarn"):
110
+ return self.base * self.scaling_factor
111
+ return self.base
112
+ def _get_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype):
113
+ if (seq_len > self._cos_sin_max_seq_len or self._cos_sin_cache is None
114
+ or self._cos_sin_cache_device != device or self._cos_sin_cache_dtype != dtype):
115
+ self._cos_sin_max_seq_len = max(seq_len, 2048)
116
+ t = torch.arange(self._cos_sin_max_seq_len, device=device, dtype=self.inv_freq.dtype)
117
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
118
+ emb = torch.cat((freqs, freqs), dim=-1)
119
+ cos = emb.cos().to(dtype)
120
+ sin = emb.sin().to(dtype)
121
+ self._cos_sin_cache = (cos, sin)
122
+ self._cos_sin_cache_device = device
123
+ self._cos_sin_cache_dtype = dtype
124
+ return self._cos_sin_cache
125
+ def forward(self, x, seq_len: int, offset: int | torch.Tensor = 0):
126
+ device, dtype = x.device, x.dtype
127
+ cos, sin = self._get_cos_sin_cache(seq_len + int(offset), device, dtype)
128
+ if isinstance(offset, torch.Tensor):
129
+ if offset.numel() > 1:
130
+ t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype).float()
131
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
132
+ emb = torch.cat((freqs, freqs), dim=-1)
133
+ cos_val = emb.cos()[None, None, :, :].to(dtype)
134
+ sin_val = emb.sin()[None, None, :, :].to(dtype)
135
+ return cos_val, sin_val
136
+ else:
137
+ offset = int(offset.item())
138
+ cos = cos[offset:offset+seq_len].unsqueeze(0).unsqueeze(0)
139
+ sin = sin[offset:offset+seq_len].unsqueeze(0).unsqueeze(0)
140
+ return cos, sin
141
+
142
+ def apply_rotary(x, cos, sin):
143
+ x1, x2 = x[..., ::2], x[..., 1::2]
144
+ x_rot = torch.stack((-x2, x1), dim=-1).flatten(-2)
145
+ return x * cos + x_rot * sin
146
+
147
+ class PTLayerNorm(nn.Module):
148
+ def __init__(self, hidden_size, eps=1e-5):
149
+ super().__init__()
150
+ self.ln = nn.LayerNorm(hidden_size, eps=eps)
151
+ def forward(self, x):
152
+ return self.ln(x)
153
+
154
+ class GlobalSDPAHead(nn.Module):
155
+ def __init__(self, d_model, head_dim, dropout, rope_theta, rotary_pct, cfg):
156
+ super().__init__()
157
+ self.q_proj = nn.Linear(d_model, head_dim, bias=False)
158
+ self.k_proj = nn.Linear(d_model, head_dim, bias=False)
159
+ self.v_proj = nn.Linear(d_model, head_dim, bias=False)
160
+ self.rotary_dim = int(head_dim * rotary_pct)
161
+ self.dropout_p = dropout
162
+ self.rope = None
163
+ if self.rotary_dim > 0:
164
+ self.rope = RotaryEmbedding(
165
+ self.rotary_dim, base=rope_theta,
166
+ scaling_type=cfg.rope_scaling_type,
167
+ scaling_factor=cfg.rope_scaling_factor,
168
+ )
169
+ def forward(self, x, position_offset):
170
+ if isinstance(position_offset, torch.Tensor):
171
+ position_offset = int(position_offset.view(-1)[0].item())
172
+ else:
173
+ position_offset = int(position_offset)
174
+ B, T, C = x.shape
175
+ q, k, v = self.q_proj(x), self.k_proj(x), self.v_proj(x)
176
+ if self.rotary_dim > 0:
177
+ cos, sin = self.rope(q, seq_len=T, offset=position_offset)
178
+ cos = cos.squeeze(1); sin = sin.squeeze(1)
179
+ q_rot = apply_rotary(q[..., :self.rotary_dim], cos, sin)
180
+ k_rot = apply_rotary(k[..., :self.rotary_dim], cos, sin)
181
+ q = torch.cat([q_rot, q[..., self.rotary_dim:]], dim=-1)
182
+ k = torch.cat([k_rot, k[..., self.rotary_dim:]], dim=-1)
183
+ q, k, v = [t.unsqueeze(1) for t in (q, k, v)]
184
+ dropout_p = self.dropout_p if self.training else 0.0
185
+ out = F.scaled_dot_product_attention(q, k, v, is_causal=True, dropout_p=dropout_p)
186
+ return out.squeeze(1)
187
+
188
+ class AttentionMoERouter(nn.Module):
189
+ def __init__(self, d_model, num_experts, top_k):
190
+ super().__init__()
191
+ self.top_k = top_k
192
+ self.num_experts = num_experts
193
+ self.gate_proj = nn.Linear(d_model, num_experts, bias=False)
194
+ nn.init.normal_(self.gate_proj.weight, mean=0.0, std=0.01)
195
+ def forward(self, x, budget_ratio, temperature):
196
+ seq_embed = x.mean(dim=1)
197
+ logits = self.gate_proj(seq_embed) / max(1e-6, float(temperature))
198
+ logits = logits.clamp(min=-10.0, max=10.0)
199
+ k_target = max(1, int(round(self.top_k * (0.25 + 0.75 * budget_ratio))))
200
+ k_target = min(k_target, logits.size(-1))
201
+ vals, idx = torch.topk(logits, k_target, dim=-1)
202
+ weights = F.softmax(vals.to(torch.float32), dim=-1).to(x.dtype)
203
+ mask = torch.zeros_like(logits, dtype=torch.bool)
204
+ mask.scatter_(1, idx, True)
205
+ with torch.no_grad():
206
+ p = F.softmax(logits, dim=-1)
207
+ entropy = -(p * (p.clamp_min(1e-12)).log()).sum(dim=-1).mean()
208
+ return mask, weights, idx, entropy, logits
209
+
210
+ class MoEAttention(nn.Module):
211
+ def __init__(self, cfg: AdaptiveRiverConfig):
212
+ super().__init__()
213
+ self.d_model = cfg.d_model
214
+ self.n_experts = cfg.attn_n_experts
215
+ self.cfg = cfg
216
+ self.head_dim = cfg.d_model // cfg.attn_n_orig_heads
217
+ self.rotary_dim = int(self.head_dim * cfg.rotary_pct)
218
+ self.router = AttentionMoERouter(cfg.d_model, cfg.attn_n_experts, cfg.attn_top_k)
219
+ self.q_proj = nn.Linear(cfg.d_model, self.n_experts * self.head_dim, bias=False)
220
+ self.k_proj = nn.Linear(cfg.d_model, self.n_experts * self.head_dim, bias=False)
221
+ self.v_proj = nn.Linear(cfg.d_model, self.n_experts * self.head_dim, bias=False)
222
+ self.rope = None
223
+ if self.rotary_dim > 0:
224
+ self.rope = RotaryEmbedding(
225
+ self.rotary_dim, base=cfg.rope_theta,
226
+ scaling_type=cfg.rope_scaling_type,
227
+ scaling_factor=cfg.rope_scaling_factor,
228
+ )
229
+ self.o_proj = nn.Linear(cfg.attn_n_experts * self.head_dim, cfg.d_model, bias=False)
230
+ def forward(self, x, position_offset, budget_ratio, temperature):
231
+ B, T, C = x.shape
232
+ E, H = self.n_experts, self.head_dim
233
+ sel_mask, gate_w, gate_idx, entropy, gate_logits = self.router(x, budget_ratio, temperature)
234
+ q = self.q_proj(x).view(B, T, E, H).permute(0, 2, 1, 3)
235
+ k = self.k_proj(x).view(B, T, E, H).permute(0, 2, 1, 3)
236
+ v = self.v_proj(x).view(B, T, E, H).permute(0, 2, 1, 3)
237
+ if self.rope:
238
+ if isinstance(position_offset, torch.Tensor):
239
+ position_offset = int(position_offset.view(-1)[0].item())
240
+ else:
241
+ position_offset = int(position_offset)
242
+ cos, sin = self.rope(q, seq_len=T, offset=position_offset)
243
+ cos = cos.squeeze(1); sin = sin.squeeze(1)
244
+ q_rot = apply_rotary(q[..., :self.rotary_dim], cos, sin)
245
+ k_rot = apply_rotary(k[..., :self.rotary_dim], cos, sin)
246
+ q = torch.cat([q_rot, q[..., self.rotary_dim:]], dim=-1)
247
+ k = torch.cat([k_rot, k[..., self.rotary_dim:]], dim=-1)
248
+ q_b = q.reshape(B * E, T, H)
249
+ k_b = k.reshape(B * E, T, H)
250
+ v_b = v.reshape(B * E, T, H)
251
+ dropout_p = self.cfg.dropout if self.training else 0.0
252
+ out_b = F.scaled_dot_product_attention(q_b, k_b, v_b, is_causal=True, dropout_p=dropout_p)
253
+ out = out_b.view(B, E, T, H).permute(0, 2, 1, 3)
254
+ W = torch.zeros(B, E, device=x.device, dtype=out.dtype)
255
+ W.scatter_(1, gate_idx, gate_w.to(out.dtype))
256
+ weighted_out = torch.einsum('b t e h, b e -> b t e h', out, W)
257
+ y = weighted_out.reshape(B, T, E * H).to(self.o_proj.weight.dtype)
258
+ y = self.o_proj(y)
259
+ with torch.no_grad():
260
+ usage = sel_mask.float().mean(dim=0)
261
+ expected = sel_mask.float().sum(dim=-1).mean()
262
+ den = torch.clamp(expected, min=1e-6)
263
+ usage_norm = usage / den
264
+ uniform = 1.0 / self.n_experts
265
+ attn_lb = ((usage_norm - uniform) ** 2).sum() * self.n_experts / self.n_experts
266
+ attn_rz = (gate_logits ** 2).mean()
267
+ head_keep = sel_mask.float().mean()
268
+ return y, {
269
+ "head_entropy": entropy,
270
+ "head_keep_frac": head_keep,
271
+ "attn_load_balance_loss": attn_lb,
272
+ "attn_router_z_loss": attn_rz,
273
+ }
274
+
275
+ class ExpertFFN(nn.Module):
276
+ def __init__(self, d_model: int, d_ff: int, dropout: float):
277
+ super().__init__()
278
+ self.w1 = nn.Linear(d_model, d_ff, bias=False)
279
+ self.w2 = nn.Linear(d_ff, d_model, bias=False)
280
+ self.dropout_p = dropout
281
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
282
+ x = self.w1(x)
283
+ x = F.gelu(x, approximate="tanh")
284
+ x = F.dropout(x, p=self.dropout_p, training=self.training)
285
+ x = self.w2(x)
286
+ return x
287
+
288
+ class MoEFFN(nn.Module):
289
+ def __init__(self, d_model: int, d_ff: int, n_experts: int, top_k: int, dropout: float, cfg: AdaptiveRiverConfig):
290
+ super().__init__()
291
+ self.n_experts = n_experts
292
+ self.base_top_k = top_k
293
+ self.cfg = cfg
294
+ self.router = nn.Linear(d_model, n_experts, bias=False)
295
+ self.w1_stacked = nn.Parameter(torch.empty(n_experts, d_ff, d_model))
296
+ self.w2_stacked = nn.Parameter(torch.empty(n_experts, d_model, d_ff))
297
+ std = cfg.init_std
298
+ nn.init.normal_(self.router.weight, mean=0.0, std=std)
299
+ nn.init.normal_(self.w1_stacked, mean=0.0, std=std)
300
+ nn.init.normal_(self.w2_stacked, mean=0.0, std=std)
301
+ def forward(self, x: torch.Tensor, budget_ratio: float):
302
+ B, T, C = x.shape
303
+ N = B * T
304
+ X = x.reshape(N, C)
305
+ k_target = max(1, int(round(self.base_top_k * (0.5 + budget_ratio / 2.0))))
306
+ k_target = min(k_target, self.n_experts)
307
+ scores = self.router(X).to(torch.float32).clamp(min=-10.0, max=10.0)
308
+ probs = F.softmax(scores, dim=-1).to(X.dtype)
309
+ mask = topk_mask_ste(scores, k=k_target).to(X.dtype)
310
+ gate = (mask * probs)
311
+ gate = gate / gate.sum(dim=-1, keepdim=True).clamp_min(1e-6)
312
+ x_ff = torch.einsum('n c, e d c -> n e d', X, self.w1_stacked)
313
+ x_act = F.gelu(x_ff, approximate="tanh")
314
+ y_experts = torch.einsum('n e d, e c d -> n e c', x_act, self.w2_stacked)
315
+ y = torch.einsum('n e, n e c -> n c', gate, y_experts).view(B, T, C).to(x.dtype)
316
+ with torch.no_grad():
317
+ entropy = (-probs * probs.clamp_min(1e-12).log()).sum(dim=-1).mean()
318
+ router_z = (scores ** 2).mean().clamp(max=10.0)
319
+ frac = mask.mean(dim=0)
320
+ uniform = 1.0 / self.n_experts
321
+ lb = ((frac - uniform) ** 2).sum() * self.n_experts / self.n_experts
322
+ return y, {
323
+ "router_entropy": entropy,
324
+ "ffn_expert_usage": frac.detach(),
325
+ "ffn_load_balance_loss": lb,
326
+ "ffn_router_z_loss": router_z,
327
+ }
328
+
329
+ class MambaBlock(nn.Module):
330
+ def __init__(self, cfg: AdaptiveRiverConfig, enhanced: bool = False, layer_idx: int | None = None):
331
+ super().__init__()
332
+ if not _HAS_MAMBA:
333
+ print(f"MambaBlock Layer {layer_idx} disabled: mamba-ssm not installed.")
334
+ self.mamba = None
335
+ return
336
+ self.cfg = cfg
337
+ self.ln1 = PTLayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
338
+ self.mamba = Mamba(
339
+ d_model=cfg.d_model,
340
+ d_state=cfg.mamba_d_state,
341
+ d_conv=cfg.mamba_d_conv,
342
+ expand=cfg.mamba_expand * (2 if enhanced else 1),
343
+ layer_idx=layer_idx,
344
+ )
345
+ self.ln2 = PTLayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
346
+ self.ffn = nn.Sequential(
347
+ nn.Linear(cfg.d_model, cfg.d_ff * (2 if enhanced else 1), bias=False),
348
+ nn.GELU(approximate="tanh"),
349
+ nn.Linear(cfg.d_ff * (2 if enhanced else 1), cfg.d_model, bias=False),
350
+ )
351
+ def forward(
352
+ self,
353
+ x,
354
+ attn_mask=None,
355
+ position_offset: int | torch.Tensor = 0,
356
+ past_kv=None,
357
+ budget_ratio: float = 1.0,
358
+ use_cache: bool = False,
359
+ mamba_state: Optional[InferenceParams] = None,
360
+ ):
361
+ if not _HAS_MAMBA or self.mamba is None:
362
+ stats = {"head_entropy": torch.tensor(0.0, device=x.device),
363
+ "head_keep_frac": torch.tensor(1.0, device=x.device),
364
+ "mamba_out_l2": torch.tensor(0.0, device=x.device)}
365
+ return x, stats, (None, None)
366
+ h = self.ln1(x)
367
+ x_m = self.mamba(h) # stateless path
368
+ m_out_l2 = x_m.float().pow(2).mean()
369
+ x = x + x_m
370
+ h2 = self.ln2(x)
371
+ x = x + self.ffn(h2)
372
+ stats = {
373
+ "head_entropy": torch.tensor(0.0, device=x.device),
374
+ "head_keep_frac": torch.tensor(1.0, device=x.device),
375
+ "mamba_out_l2": m_out_l2.detach(),
376
+ }
377
+ return x, stats, (None, None)
378
+
379
+ class RoutedBlock(nn.Module):
380
+ def __init__(self, cfg: AdaptiveRiverConfig):
381
+ super().__init__()
382
+ self.cfg = cfg
383
+ self.ln1 = PTLayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
384
+ self.ln2 = PTLayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
385
+ self.attn = MoEAttention(cfg)
386
+ self.ffn = MoEFFN(cfg.d_model, cfg.d_ff, cfg.experts_per_layer, cfg.top_k_ffn, cfg.moe_dropout, cfg)
387
+ def _attn_forward(self, h: torch.Tensor, position_offset: int, budget_ratio: float):
388
+ if isinstance(position_offset, torch.Tensor):
389
+ position_offset = int(position_offset.view(-1)[0].item())
390
+ else:
391
+ position_offset = int(position_offset)
392
+ return self.attn(h, position_offset, budget_ratio, self.cfg.gate_temperature)
393
+ def forward(
394
+ self,
395
+ x,
396
+ attn_mask=None,
397
+ position_offset: int | torch.Tensor = 0,
398
+ past_kv=None,
399
+ budget_ratio: float = 1.0,
400
+ use_cache: bool = False,
401
+ mamba_state: Optional[InferenceParams] = None,
402
+ ):
403
+ h = self.ln1(x)
404
+ attn_out, attn_stats = self._attn_forward(h, position_offset, budget_ratio)
405
+ x = x + attn_out
406
+ h2 = self.ln2(x)
407
+ ffn_out, moe_stats = self.ffn(h2, budget_ratio=budget_ratio)
408
+ x = x + ffn_out
409
+ stats = {**attn_stats, **moe_stats}
410
+ return x, stats, (None, None)
411
+
412
+ class AdaptiveRiverLM(nn.Module):
413
+ def __init__(self, cfg: AdaptiveRiverConfig):
414
+ super().__init__()
415
+ self.cfg = cfg
416
+ self.embed = nn.Embedding(cfg.vocab_size, cfg.d_model)
417
+ self.blocks = nn.ModuleList()
418
+ mamba_layer_counter = 0
419
+ for i in range(cfg.n_layers):
420
+ if i < 2:
421
+ print(f"[model] Layer {i}: Mamba")
422
+ self.blocks.append(MambaBlock(cfg, enhanced=False, layer_idx=mamba_layer_counter)); mamba_layer_counter += 1
423
+ elif i >= (cfg.n_layers - 2):
424
+ print(f"[model] Layer {i}: Mamba (enhanced)")
425
+ self.blocks.append(MambaBlock(cfg, enhanced=True, layer_idx=mamba_layer_counter)); mamba_layer_counter += 1
426
+ else:
427
+ if i == 2:
428
+ print(f"[model] Layers {i}-{cfg.n_layers-3}: MoE Attention + MoE FFN")
429
+ self.blocks.append(RoutedBlock(cfg))
430
+ self.ln_f = PTLayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
431
+ self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
432
+ if cfg.tie_word_embeddings:
433
+ self.lm_head.weight = self.embed.weight
434
+ self.apply(lambda m: _init_weights(m, cfg.init_std) if isinstance(m, nn.Linear) else None)
435
+ def forward(
436
+ self,
437
+ input_ids: torch.Tensor,
438
+ budget_ratio: Optional[float] = None,
439
+ mamba_states: Optional[List] = None,
440
+ past_kvs: Optional[List] = None,
441
+ position_offset: int | torch.Tensor = 0,
442
+ return_expert_stats: bool = False,
443
+ use_cache: bool = False,
444
+ ):
445
+ x = self.embed(input_ids)
446
+ b = float(self.cfg.default_budget_ratio if budget_ratio is None else budget_ratio)
447
+ all_stats: Dict[str, List[torch.Tensor]] = {}
448
+ for block in self.blocks:
449
+ x, stats, _ = block(
450
+ x,
451
+ position_offset=position_offset,
452
+ past_kv=None,
453
+ budget_ratio=b,
454
+ use_cache=False,
455
+ mamba_state=None,
456
+ )
457
+ for k, v in stats.items():
458
+ all_stats.setdefault(k, []).append(torch.as_tensor(v.detach() if isinstance(v, torch.Tensor) else v))
459
+ _ = {k: torch.stack(v).mean() for k, v in all_stats.items() if len(v) > 0}
460
+ x = self.ln_f(x)
461
+ logits = self.lm_head(x)
462
+ return logits, _
463
+
464
+ def estimate_1b_config() -> AdaptiveRiverConfig:
465
+ return AdaptiveRiverConfig(
466
+ vocab_size=50257,
467
+ d_model=1024,
468
+ n_layers=24,
469
+ d_ff=4096,
470
+ experts_per_layer=4,
471
+ top_k_ffn=1,
472
+ default_budget_ratio=1.0,
473
+ attn_n_experts=6,
474
+ attn_top_k=6,
475
+ attn_n_orig_heads=16,
476
+ mamba_d_state=16,
477
+ mamba_d_conv=4,
478
+ mamba_expand=2,
479
+ gate_temperature=0.7,
480
+ head_entropy_weight=1e-4,
481
+ checkpoint_attn_thresh=0.35,
482
+ checkpoint_ffn_thresh=0.35,
483
+ load_balance_weight=0.01,
484
+ router_z_weight=0.001,
485
+ tie_word_embeddings=False,
486
+ )
487
+
488
+ # ----------------------------------------------------------------------------
489
+ # Inference (stateless) with proper end-of-turn handling
490
+ # ----------------------------------------------------------------------------
491
+
492
+ class FastInferenceTester:
493
+ def __init__(self, model, tokenizer, device, im_start_id, im_end_id, eos_id, pad_id):
494
+ self.model = model
495
+ self.tokenizer = tokenizer
496
+ self.device = device
497
+ self.im_start_id = im_start_id
498
+ self.im_end_id = im_end_id
499
+ self.eos_id = eos_id
500
+ self.pad_id = pad_id
501
+
502
+ self.model.eval()
503
+ torch.set_grad_enabled(False)
504
+ print("Using model's native precision")
505
+
506
+ if hasattr(torch, 'compile') and _HAS_MAMBA:
507
+ print("Skipping torch.compile due to mamba-ssm kernels.")
508
+ else:
509
+ try:
510
+ print("Compiling model with torch.compile...")
511
+ self.model = torch.compile(self.model, mode="reduce-overhead")
512
+ print("Model compiled successfully")
513
+ except Exception as e:
514
+ print(f"Could not compile model: {e}")
515
+ print("Running without compilation")
516
+
517
+ def _format_to_training_chat(self, prompt: str) -> torch.Tensor:
518
+ messages = [{"role": "user", "content": prompt}]
519
+ formatted = self.tokenizer.apply_chat_template(
520
+ messages, tokenize=False, add_generation_prompt=True
521
+ )
522
+ input_ids = self.tokenizer.encode(
523
+ formatted, add_special_tokens=False, return_tensors="pt"
524
+ ).to(self.device)
525
+ return input_ids
526
+
527
+ def _postprocess_like_training(self, text: str) -> str:
528
+ if "<|im_start|>assistant" in text:
529
+ return text.split("<|im_start|>assistant")[-1].split("<|im_end|>")[0].strip()
530
+ if "assistant\n" in text:
531
+ return text.split("assistant\n")[-1].split("<|im_end|>")[0].strip()
532
+ return text.split("<|im_end|>")[0].strip()
533
+
534
+ def _reset_mamba_states(self):
535
+ if not _HAS_MAMBA:
536
+ return
537
+ for block in self.model.blocks:
538
+ if isinstance(block, MambaBlock) and hasattr(block, "mamba"):
539
+ for attr in ("inference_params", "conv_state", "ssm_state"):
540
+ if hasattr(block.mamba, attr):
541
+ setattr(block.mamba, attr, None)
542
+
543
+ def generate_once(
544
+ self,
545
+ prompt: str,
546
+ max_tokens: int = 2000,
547
+ temperature: float = 0.8,
548
+ top_p: float = 1.0,
549
+ top_k: int = 0,
550
+ budget_ratio: float = 1.0,
551
+ show_tokens: bool = False,
552
+ min_new_tokens: int = 3,
553
+ ) -> Dict:
554
+ self._reset_mamba_states()
555
+
556
+ print(f"\n{'='*80}")
557
+ print("FAST GENERATION (no cache)")
558
+ print(f"{'='*80}")
559
+ print(f"Prompt: {prompt}")
560
+ print("─" * 80)
561
+
562
+ input_ids = self._format_to_training_chat(prompt)
563
+
564
+ generated_tokens: List[int] = []
565
+ token_times: List[float] = []
566
+ stop_ids = set(t for t in [self.im_end_id, self.eos_id] if t is not None)
567
+ ban_initial_ids = set(t for t in [self.im_end_id, self.eos_id, self.im_start_id, self.pad_id] if t is not None)
568
+
569
+ start_time = time.time()
570
+
571
+ with torch.inference_mode():
572
+ # Prefill over full prompt
573
+ logits, _ = self.model(
574
+ input_ids,
575
+ budget_ratio=budget_ratio,
576
+ position_offset=0,
577
+ use_cache=False
578
+ )
579
+ next_token_logits = logits[:, -1, :] # [1, vocab]
580
+ vocab_size = next_token_logits.size(-1)
581
+
582
+ print("Generating...", end=" ", flush=True)
583
+ is_cuda = torch.cuda.is_available()
584
+ buffer = [] # small output buffer for streaming
585
+
586
+ for _ in range(max_tokens):
587
+ if is_cuda:
588
+ torch.cuda.synchronize()
589
+ t0 = time.time()
590
+
591
+ # 1D view for sampling/masking
592
+ logits_for_sampling = next_token_logits.squeeze(0).clone() / max(1e-6, temperature)
593
+ vocab_size = logits_for_sampling.size(0)
594
+
595
+ # Ban structural tokens at the very start
596
+ if len(generated_tokens) < min_new_tokens and min_new_tokens > 0:
597
+ for tid in ban_initial_ids:
598
+ if tid is not None and 0 <= tid < vocab_size:
599
+ logits_for_sampling[tid] = float("-inf")
600
+
601
+ # Top-k
602
+ if top_k and top_k > 0:
603
+ kth = torch.topk(logits_for_sampling, top_k)[0][-1]
604
+ logits_for_sampling[logits_for_sampling < kth] = float("-inf")
605
+
606
+ # Top-p
607
+ if top_p < 1.0:
608
+ sorted_logits, sorted_indices = torch.sort(logits_for_sampling, descending=True)
609
+ cumulative_probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
610
+ sorted_indices_to_remove = cumulative_probs > top_p
611
+ sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
612
+ sorted_indices_to_remove[0] = False
613
+ remove_idx = sorted_indices[sorted_indices_to_remove]
614
+ logits_for_sampling[remove_idx] = float("-inf")
615
+
616
+ # Sample
617
+ probs = F.softmax(logits_for_sampling, dim=-1)
618
+ next_token_id = torch.multinomial(probs, num_samples=1).item()
619
+
620
+ generated_tokens.append(next_token_id)
621
+
622
+ # Decode + buffered print
623
+ if show_tokens:
624
+ tok_text = self.tokenizer.decode([next_token_id], skip_special_tokens=False)
625
+ buffer.append(tok_text)
626
+ if len(buffer) >= 16:
627
+ print("".join(buffer), end="", flush=True)
628
+ buffer.clear()
629
+
630
+ # Stop on EOT/EOS after min_new_tokens
631
+ if (next_token_id in stop_ids) and (len(generated_tokens) >= max(1, min_new_tokens)):
632
+ if buffer:
633
+ print("".join(buffer), end="", flush=True)
634
+ buffer.clear()
635
+ if show_tokens:
636
+ print(" [EOT]", flush=True)
637
+ break
638
+
639
+ # Stateless decode: append token and re-run forward
640
+ input_ids = torch.cat(
641
+ [input_ids, torch.tensor([[next_token_id]], device=self.device)],
642
+ dim=1
643
+ )
644
+ logits, _ = self.model(
645
+ input_ids,
646
+ budget_ratio=budget_ratio,
647
+ position_offset=0,
648
+ use_cache=False
649
+ )
650
+ next_token_logits = logits[:, -1, :]
651
+
652
+ if is_cuda:
653
+ torch.cuda.synchronize()
654
+ token_times.append(time.time() - t0)
655
+
656
+ # Flush any remaining buffered tokens
657
+ if buffer:
658
+ print("".join(buffer), end="", flush=True)
659
+ buffer.clear()
660
+
661
+
662
+
663
+ total_time = time.time() - start_time
664
+ text = self.tokenizer.decode(generated_tokens, skip_special_tokens=False)
665
+ text = self._postprocess_like_training(text)
666
+
667
+ if show_tokens and (not generated_tokens or (generated_tokens[-1] not in stop_ids)):
668
+ print()
669
+
670
+ num_gen = len(generated_tokens)
671
+ if num_gen == 0:
672
+ print("\nNo tokens generated.")
673
+ return {'output': '', 'tokens_per_sec': 0, 'decode_tps': 0, 'total_time': total_time, 'num_tokens': 0}
674
+
675
+ decode_time = sum(token_times)
676
+ toks_per_sec = num_gen / total_time if total_time > 0 else 0
677
+ decode_tps = num_gen / decode_time if decode_time > 0 else 0
678
+
679
+ print("\n" + "─" * 80)
680
+ print("STATISTICS")
681
+ print("─" * 80)
682
+ print(f"Tokens: {num_gen}")
683
+ print(f"Total time: {total_time:.2f}s")
684
+ print(f"Overall speed: {toks_per_sec:.1f} tok/s (includes prompt)")
685
+ print(f"Decode speed: {decode_tps:.1f} tok/s (generation only)")
686
+ print(f"Time/token: {(decode_time/num_gen)*1000:.1f}ms")
687
+ print("─" * 80)
688
+ print(f"Output: {text[:100]}{'...' if len(text) > 100 else ''}")
689
+ print("=" * 80 + "\n")
690
+
691
+ self._reset_mamba_states()
692
+
693
+ return {
694
+ 'output': text,
695
+ 'tokens_per_sec': toks_per_sec,
696
+ 'decode_tps': decode_tps,
697
+ 'total_time': total_time,
698
+ 'num_tokens': num_gen,
699
+ }
700
+
701
+ def interactive_mode(self):
702
+ print("\n" + "=" * 80)
703
+ print("INTERACTIVE MODE (no cache, stateless)")
704
+ print("Type 'quit' or your prompt")
705
+ print("=" * 80 + "\n")
706
+ while True:
707
+ try:
708
+ prompt = input("\nYou: ")
709
+ except (EOFError, KeyboardInterrupt):
710
+ print("\nBye.")
711
+ break
712
+ if prompt.lower() in ["quit", "exit", "q"]:
713
+ break
714
+ if not prompt.strip():
715
+ continue
716
+ print("\nAssistant: ", end="", flush=True)
717
+ self.generate_once(prompt, max_tokens=2000, temperature=0.8, show_tokens=True)
718
+
719
+ def _cast_layernorm_fp32(module: nn.Module):
720
+ for m in module.modules():
721
+ if isinstance(m, nn.LayerNorm):
722
+ m.float()
723
+
724
+ def load_model_and_tokenizer(model_dir: str):
725
+ """
726
+ Load AdaptiveRiverLM model and tokenizer from a folder layout like:
727
+
728
+ model_dir/
729
+ checkpoint.pt (or any .pt file)
730
+ tokenizer/
731
+ tokenizer.json
732
+ special_tokens_map.json
733
+ ...
734
+
735
+ Automatically finds the .pt file if not explicitly named.
736
+ """
737
+ print(f"Searching for model checkpoint in: {model_dir}")
738
+ ckpts = glob.glob(os.path.join(model_dir, "*.pt"))
739
+ if not ckpts:
740
+ raise FileNotFoundError(f"No .pt checkpoint found in {model_dir}")
741
+ if len(ckpts) > 1:
742
+ print(f"[Warning] Multiple .pt files found, using: {ckpts[0]}")
743
+ checkpoint_path = ckpts[0]
744
+
745
+ tokenizer_path = os.path.join(model_dir, "tokenizer")
746
+ if not os.path.isdir(tokenizer_path):
747
+ raise FileNotFoundError(f"Missing tokenizer directory: {tokenizer_path}")
748
+
749
+ print(f"Loading tokenizer from: {tokenizer_path}")
750
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True, trust_remote_code=True)
751
+ if tokenizer.pad_token is None:
752
+ print("Tokenizer missing pad_token. Assigning eos_token as pad_token.")
753
+ tokenizer.pad_token = tokenizer.eos_token
754
+ tokenizer.pad_token_id = tokenizer.eos_token_id
755
+
756
+ print("Building model (AdaptiveRiverLM)...")
757
+ cfg = estimate_1b_config()
758
+ cfg.vocab_size = len(tokenizer)
759
+ cfg.tie_word_embeddings = False
760
+
761
+ model = AdaptiveRiverLM(cfg)
762
+
763
+ print(f"Loading checkpoint: {checkpoint_path}")
764
+ state = torch.load(checkpoint_path, map_location="cpu")
765
+ model_state_dict = model.state_dict()
766
+ converted_state = {}
767
+
768
+ for k, param in model_state_dict.items():
769
+ if k in state and state[k].shape == param.shape:
770
+ converted_state[k] = state[k]
771
+
772
+ print("Loading weights...")
773
+ load_result = model.load_state_dict(converted_state, strict=False)
774
+
775
+ if load_result.missing_keys:
776
+ print("\n--- Missing Keys ---")
777
+ for k in load_result.missing_keys:
778
+ print(" ", k)
779
+ if load_result.unexpected_keys:
780
+ print("\n--- Unexpected Keys ---")
781
+ for k in load_result.unexpected_keys:
782
+ print(" ", k)
783
+
784
+ device = "cuda" if torch.cuda.is_available() else "cpu"
785
+ model = model.to(device)
786
+
787
+ if device == "cuda" and torch.cuda.is_bf16_supported():
788
+ _cast_layernorm_fp32(model)
789
+ model = model.to(torch.bfloat16)
790
+ else:
791
+ model = model.to(torch.float32)
792
+
793
+ model.eval()
794
+ print(f"Model and tokenizer loaded successfully from {model_dir} on {device}")
795
+ return model, tokenizer, device
796
+
797
+
798
+ def main():
799
+ parser = argparse.ArgumentParser(description="Stateless inference for AdaptiveRiverLM (no KV cache), proper EOT handling")
800
+ parser.add_argument("--model_dir", type=str, required=True, help="Path to model folder (with checkpoint.pt and tokenizer/)")
801
+ parser.add_argument("--prompt", type=str, default="Hello, my name is")
802
+ parser.add_argument("--max_tokens", type=int, default=2000)
803
+ parser.add_argument("--temperature", type=float, default=0.8)
804
+ parser.add_argument("--top_p", type=float, default=1.0)
805
+ parser.add_argument("--top_k", type=int, default=0)
806
+ parser.add_argument("--min_new_tokens", type=int, default=3)
807
+ parser.add_argument("--interactive", action="store_true", help="Interactive mode (stateless)")
808
+ args = parser.parse_args()
809
+
810
+ model, tokenizer, device = load_model_and_tokenizer(args.model_dir)
811
+
812
+ # Resolve special token IDs for end-of-turn handling
813
+ im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
814
+ im_start_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
815
+ eos_id = tokenizer.eos_token_id
816
+ pad_id = tokenizer.pad_token_id
817
+
818
+ stop_ids = set(t for t in [im_end_id, eos_id] if t is not None)
819
+ ban_initial_ids = set(t for t in [im_end_id, eos_id, im_start_id, pad_id] if t is not None)
820
+
821
+
822
+ tester = FastInferenceTester(model, tokenizer, device, im_start_id, im_end_id, eos_id, pad_id)
823
+
824
+ if args.interactive:
825
+ tester.interactive_mode()
826
+ else:
827
+ tester.generate_once(
828
+ args.prompt,
829
+ max_tokens=args.max_tokens,
830
+ temperature=args.temperature,
831
+ top_p=args.top_p,
832
+ top_k=args.top_k,
833
+ show_tokens=True,
834
+ min_new_tokens=args.min_new_tokens,
835
+ )
836
+
837
+ if __name__ == "__main__":
838
+ main()
839
+