yagizdevre commited on
Commit
83c4388
·
1 Parent(s): be761d6
__init__.py CHANGED
@@ -1,2 +1,2 @@
1
  from .configuration_minimamba import MiniMambaConfig
2
- from .modeling_minimamba import Mamba2
 
1
  from .configuration_minimamba import MiniMambaConfig
2
+ from .modeling_minimamba import MiniMamba
__pycache__/__init__.cpython-312.pyc CHANGED
Binary files a/__pycache__/__init__.cpython-312.pyc and b/__pycache__/__init__.cpython-312.pyc differ
 
__pycache__/causal_conv1d_compilable.cpython-312.pyc ADDED
Binary file (10.7 kB). View file
 
__pycache__/configuration_minimamba.cpython-312.pyc ADDED
Binary file (3.85 kB). View file
 
__pycache__/model.cpython-312.pyc ADDED
Binary file (39 kB). View file
 
__pycache__/modeling_minimamba.cpython-312.pyc ADDED
Binary file (7.94 kB). View file
 
__pycache__/norms.cpython-312.pyc ADDED
Binary file (14.6 kB). View file
 
__pycache__/ssm_compilable.cpython-312.pyc ADDED
Binary file (12.3 kB). View file
 
modeling_minimamba.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
@@ -6,118 +7,84 @@ from transformers import PreTrainedModel
6
  from transformers.modeling_outputs import CausalLMOutput
7
 
8
  from .configuration_minimamba import MiniMambaConfig
9
- from enum import Enum
10
- from dataclasses import dataclass, field
11
 
12
 
13
- from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states
14
- from mamba_ssm.ops.triton.selective_state_update import selective_state_update
15
- from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
16
-
17
- from .causal_conv1d_compilable import causal_conv1d_fn, causal_conv1d_update
18
- from .ssm_compilable import mamba_chunk_scan_combined
19
- from .norms import build_norm
20
-
21
-
22
- class InitStdFactor(Enum):
23
- DISABLED = "disabled" # Init std is divided by 1.0
24
- GLOBAL_DEPTH = "global_depth" # Init std is divided by sqrt(2*num_layers)
25
- CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth)
26
- DIM_RATIO = "dim_ratio" # Init std is divided by model_dim/4096
27
-
28
- @dataclass
29
- class InitConfig:
30
- dt_max: float = 0.1
31
- dt_min: float = 0.001
32
-
33
- dt_init_floor: float = 1e-4
34
-
35
- A_init_min: float = 1
36
- A_init_max: float = 16
37
-
38
-
39
- DEFAULT_INIT_CONFIG = InitConfig()
40
 
 
 
 
 
 
 
 
 
41
 
42
- class MiniSTU(PreTrainedModel):
43
- config_class = MiniSTUConfig
 
 
 
44
 
45
- def __init__(self, config) -> None:
46
- super(MiniSTU, self).__init__(config)
47
- self.n_layers = config.n_layers
48
- self.n = nearest_power_of_two(config.seq_len * 2 - 1, round_up=True)
49
-
50
- if isinstance(config.torch_dtype, torch.dtype):
51
- torch_dtype = config.torch_dtype
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  else:
53
- torch_dtype = getattr(torch, config.torch_dtype)
54
-
55
- device = torch.device(config.device)
56
-
57
- self.phi = get_spectral_filters(
58
- config.seq_len,
59
- config.num_eigh,
60
- config.use_hankel_L,
61
- device=device,
62
- dtype=torch_dtype,
63
- )
64
 
65
- self.use_approx = config.use_approx
66
- self.use_hankel_L = config.use_hankel_L
67
-
68
- self.tok_emb = nn.Embedding(
69
- config.vocab_size, config.n_embd, dtype=torch_dtype, device=device
70
- )
71
- self.dropout = nn.Dropout(config.dropout)
72
-
73
- self.layers = nn.ModuleList()
74
- for layer_idx in range(self.n_layers):
75
- if layer_idx % 2 == 0:
76
- self.layers.append(STULayer(config, self.phi, self.n))
77
- else:
78
- self.layers.append(
79
- AttentionLayer(config)
80
- if config.use_attn
81
- else STULayer(config, self.phi, self.n)
82
- )
83
-
84
- self.norm = TritonNorm(config.n_embd) if triton_norm else RMSNorm(config.n_embd)
85
-
86
- self.lm_head = nn.Linear(
87
- config.n_embd, config.vocab_size, bias=config.bias, dtype=torch_dtype, device=device
88
- )
89
- self.tok_emb.weight = self.lm_head.weight
90
-
91
- self.std = (config.n_embd) ** -0.5
92
  self.apply(self._init_weights)
93
- print("Model Parameter Count: %.2fM\n" % (self._get_num_params() / 1e6,))
 
94
 
95
  def forward(
96
  self,
97
- input_ids: torch.Tensor,
98
- labels: torch.Tensor = None,
99
  **kwargs
100
  ) -> CausalLMOutput:
101
- # Compute embeddings
102
- tok_emb = self.tok_emb(input_ids)
103
- x = self.dropout(tok_emb)
104
-
105
- # Pass through layers
106
- for layer in self.layers:
107
- x = layer(x)
108
-
109
- # Normalize and project to vocabulary
110
- x = self.norm(x)
111
- logits = self.lm_head(x)
112
 
113
  loss = None
114
  if labels is not None:
115
- # Shift so that tokens predict the next token
116
  shift_logits = logits[..., :-1, :].contiguous()
117
  shift_labels = labels[..., 1:].contiguous()
118
  loss_fct = nn.CrossEntropyLoss()
119
  loss = loss_fct(
120
- shift_logits.view(-1, shift_logits.size(-1)),
121
  shift_labels.view(-1)
122
  )
123
 
@@ -126,73 +93,7 @@ class MiniSTU(PreTrainedModel):
126
  logits=logits,
127
  )
128
 
129
- def _get_num_params(self):
130
- n_params = sum(p.numel() for p in self.parameters())
131
- if hasattr(self, "pos_emb") and self.pos_emb is not None:
132
- n_params -= self.pos_emb.weight.numel()
133
- if self.tok_emb.weight is not self.lm_head.weight:
134
- n_params -= self.tok_emb.weight.numel()
135
- return n_params
136
-
137
- def _init_weights(self, module):
138
- if isinstance(module, nn.Linear):
139
- if hasattr(module, "SCALE_INIT"):
140
- self.std *= (2 * self.n_layers) ** -0.5
141
- torch.nn.init.normal_(module.weight, mean=0.0, std=self.std)
142
- if module.bias is not None:
143
- torch.nn.init.zeros_(module.bias)
144
- elif isinstance(module, nn.Embedding):
145
- torch.nn.init.normal_(module.weight, mean=0.0, std=self.std)
146
- elif isinstance(module, STU):
147
- if self.use_approx:
148
- torch.nn.init.xavier_normal_(module.M_inputs)
149
- torch.nn.init.xavier_normal_(module.M_filters)
150
- else:
151
- torch.nn.init.xavier_normal_(module.M_phi_plus)
152
- if not self.use_hankel_L:
153
- torch.nn.init.xavier_normal_(module.M_phi_minus)
154
- elif isinstance(module, Attention):
155
- torch.nn.init.xavier_normal_(module.c_attn.weight)
156
- torch.nn.init.xavier_normal_(module.c_proj.weight)
157
- if module.c_attn.bias is not None:
158
- torch.nn.init.zeros_(module.c_attn.bias)
159
- if module.c_proj.bias is not None:
160
- torch.nn.init.zeros_(module.c_proj.bias)
161
- @staticmethod
162
- def top_k_top_p_filtering(
163
- logits: torch.Tensor,
164
- top_k: int = 50,
165
- top_p: float = 0.95,
166
- filter_value: float = float("-inf"),
167
- ):
168
- """
169
- Filters a distribution of logits using top-k and/or nucleus (top-p) filtering.
170
- """
171
- # top_k
172
- if top_k > 0:
173
- top_k = min(top_k, logits.size(-1))
174
- # Remove all logits that are not in the top k
175
- indices_to_remove = logits < torch.topk(logits, top_k, dim=-1).values[:, -1, None]
176
- logits[indices_to_remove] = filter_value
177
-
178
- # top_p (nucleus)
179
- if 0 < top_p < 1.0:
180
- sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
181
- cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
182
-
183
- # Remove tokens with cumulative probability above the threshold
184
- sorted_indices_to_remove = cumulative_probs > top_p
185
- # Shift the indices to the right to keep also the first token above the threshold
186
- sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
187
- sorted_indices_to_remove[:, 0] = False
188
-
189
- indices_to_remove = sorted_indices_to_remove.scatter(
190
- dim=1, index=sorted_indices, src=sorted_indices_to_remove
191
- )
192
- logits[indices_to_remove] = filter_value
193
-
194
- return logits
195
-
196
  def generate(
197
  self,
198
  input_ids: torch.LongTensor,
@@ -205,27 +106,9 @@ class MiniSTU(PreTrainedModel):
205
  **kwargs
206
  ):
207
  """
208
- Naive token-by-token generation loop that uses top-k/top-p filtering and optional temperature.
209
-
210
- Args:
211
- input_ids (torch.LongTensor): shape (batch_size, sequence_length).
212
- max_new_tokens (int): max number of tokens to generate (beyond input_ids length).
213
- temperature (float): sampling temperature (>=0).
214
- top_k (int): Top-K sampling cutoff.
215
- top_p (float): Nucleus sampling cutoff.
216
- eos_token_id (int): If set, stop generation when this token is produced.
217
- pad_token_id (int): If set, can be used to pad sequences. (Not fully used here.)
218
- kwargs: Unused arguments (like num_beams) for compatibility.
219
-
220
- Returns:
221
- torch.LongTensor: shape (batch_size, sequence_length + generated_tokens).
222
  """
223
- device = input_ids.device
224
- print("1=====================")
225
- print(tokenizer.decode(input_ids[0], skip_special_tokens=True))
226
- print("1=====================")
227
-
228
- # We'll accumulate new tokens into generated_ids
229
  generated_ids = input_ids.clone()
230
 
231
  for _ in range(max_new_tokens):
@@ -233,857 +116,80 @@ class MiniSTU(PreTrainedModel):
233
  outputs = self.forward(generated_ids)
234
  logits = outputs.logits[:, -1, :] # shape: (batch_size, vocab_size)
235
 
236
- # Scale logits by temperature
237
  if temperature != 1.0:
238
  logits = logits / temperature
239
 
240
- # Filter logits using top-k and/or top-p
241
  logits = self.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
242
 
243
- # Convert to probabilities
244
- probabilities = F.softmax(logits, dim=-1)
245
-
246
- # Sample from the distribution
247
- next_token = torch.multinomial(probabilities, num_samples=1) # (batch_size, 1)
248
 
249
- # Append next token
250
  generated_ids = torch.cat([generated_ids, next_token], dim=1)
251
 
252
- # If eos_token_id is set and any sample produced it, we optionally could break early
253
- if eos_token_id is not None:
254
- # Check if all sequences in the batch ended
255
- # or if you want to do a more fine-grained approach
256
- if (next_token == eos_token_id).all():
257
- break
258
- print("2=====================")
259
- print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
260
- print("2=====================")
261
- return generated_ids
262
-
263
-
264
- @dataclass
265
- class BaseMambaConfig:
266
- """
267
- Configuration for the Mamba family of models.
268
- """
269
- dim: int = 512
270
- num_layers: int = 8
271
- num_heads: int = 8
272
-
273
- state_dim: int = 128
274
- num_groups: int = 1
275
- conv_size: int | None = 4
276
-
277
- bias: bool = False # Linear bias
278
- conv_bias: bool = True # Convolutional bias
279
- dt_bias: bool = False
280
- D_has_head_dim: bool = False
281
- learnable_init_states: bool = False
282
-
283
- ffn_dim_multiplier: float = 2.0
284
- multiple_of: int = 256 # Enforce that MLP hidden layer size is multiple of a large power of 2
285
-
286
- norm_eps: float = 1e-6
287
- norm_type: str = "rmsnorm"
288
-
289
- # CUDA-related items
290
- ssm_chunk_size: int = 256
291
- use_mem_eff_path: bool = False
292
-
293
- # Initialization-related items
294
- init_use_depth: bool = False
295
- init_base_std: float | None = None
296
- init_std_factor: str = "disabled" # e.g. "global_depth"
297
- init_config: InitConfig = field(default_factory=InitConfig)
298
-
299
-
300
- class SSM(nn.Module):
301
- """
302
- State Space Model (SSM) implementation with selective state updates and convolution.
303
-
304
- Implements the core SSM computation with support for both training and inference modes.
305
- During inference, uses cached states for efficient token-by-token generation.
306
- """
307
- def __init__(self, config: BaseMambaConfig) -> None:
308
- """Initialize SSM parameters and layers.
309
- Args:
310
- config: Configuration containing model hyperparameters
311
- """
312
- super().__init__()
313
- self.config = config
314
- vars(self).update(vars(config))
315
-
316
- assert self.dim > 0, "Model dimension (config.dim) must be positive"
317
- assert self.num_heads > 0, "Number of heads (config.num_heads) must be positive"
318
- assert self.state_dim > 0, "State dimension (config.state_dim) must be positive"
319
-
320
- if self.ffn_dim_multiplier is None:
321
- raise ValueError(
322
- "ffn_dim_multiplier must be set to a valid float (e.g. 2.0) "
323
- "to determine hidden_dim in SSM."
324
- )
325
- assert self.ffn_dim_multiplier > 0, "ffn_dim_multiplier must be > 0"
326
-
327
- self.hidden_dim = int(self.ffn_dim_multiplier * self.dim)
328
- self.hidden_dim = config.multiple_of * ( # Round up to multiple_of
329
- (self.hidden_dim + self.multiple_of - 1) // self.multiple_of
330
- )
331
-
332
- assert self.hidden_dim % self.num_heads == 0, (
333
- f"Hidden dim {self.hidden_dim} not divisible by num_heads={self.num_heads}."
334
- )
335
-
336
- self.head_dim = self.hidden_dim // self.num_heads
337
-
338
- self.dt_limit_kwargs = {}
339
- dt_limit = (self.init_config.dt_min, self.init_config.dt_max)
340
- if dt_limit != (0.0, float("inf")):
341
- self.dt_limit_kwargs = dict(dt_limit=dt_limit)
342
-
343
- # Order: [z, x, B, C, dt]
344
- d_input = (
345
- 2 * self.hidden_dim
346
- + 2 * self.num_groups * self.state_dim
347
- + self.num_heads
348
- )
349
-
350
- self.input = nn.Linear(self.dim, d_input, bias=self.bias)
351
-
352
- # Only create Conv1d if self.conv_size is specified
353
- if self.conv_size is not None:
354
- conv_dim = self.hidden_dim + 2 * self.num_groups * self.state_dim
355
-
356
- # Depthwise-ish conv (groups = out_channels)
357
- # TODO: Check that this is used if causal_conv1d_fn and causal_conv1d_update cannot be imported
358
- self.conv1d = nn.Conv1d(
359
- in_channels=conv_dim,
360
- out_channels=conv_dim,
361
- kernel_size=self.conv_size,
362
- groups=conv_dim,
363
- bias=self.conv_bias, # <- This is a boolean in your config, so pass that or True/False
364
- padding=self.conv_size - 1 # for "causal" style
365
- )
366
-
367
- if config.dt_bias:
368
- self.dt_bias = nn.Parameter(torch.empty(self.num_heads))
369
- else:
370
- self.dt_bias = nn.Parameter(torch.zeros(self.num_heads), requires_grad=False)
371
-
372
- self.A_log = nn.Parameter(torch.empty(self.num_heads))
373
-
374
- if config.D_has_head_dim:
375
- self.D = nn.Parameter(torch.ones(self.num_heads, self.head_dim))
376
- else:
377
- self.D = nn.Parameter(torch.ones(self.num_heads))
378
-
379
- if self.learnable_init_states:
380
- self.init_states = nn.Parameter(torch.zeros(self.num_heads, self.head_dim, self.state_dim))
381
-
382
- # Can also just use nn.RMSNorm
383
- self.norm = build_norm(config.norm_type, dim=self.hidden_dim, eps=self.norm_eps)
384
-
385
- self.output = nn.Linear(self.hidden_dim, self.dim, bias=self.bias)
386
-
387
- def _causal_conv(
388
- self,
389
- zxbcdt: torch.Tensor,
390
- tok_idx: torch.Tensor | None = None,
391
- cu_seqlens: torch.Tensor | None = None,
392
- ssm_impl: str = "ssm"
393
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
394
- # TODO: Make slightly less verbose
395
- """Processes input through causal convolution path, handling both full sequence and incremental cases.
396
-
397
- This function implements two processing modes:
398
- 1. Full sequence ("ssm"): Used during training and initial prompt processing.
399
- 2. Incremental ("ssm_update"): Used during token-by-token generation.
400
 
401
- Args:
402
- zxbcdt: Input tensor containing concatenated [z, x, B, C, dt] components
403
- tok_idx: Token indices for sequence processing. Required for "ssm" mode.
404
- Defaults to None.
405
- cu_seqlens: Cumulative sequence lengths for variable length processing.
406
- Used only in "ssm" mode with caching. Defaults to None.
407
- ssm_impl: Implementation mode, either "ssm" for full sequence processing
408
- or "ssm_update" for incremental generation. Defaults to "ssm".
409
-
410
- Returns:
411
- tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
412
- Tuple containing separated components (z, x, B, C, dt), where:
413
- - z: Gating branch
414
- - x: Main branch
415
- - B, C: SSM state matrices (analogous to K, Q in attention)
416
- - dt: Time delta values
417
-
418
- Notes:
419
- - When using "ssm" mode during inference, a cache should be pre-initialized
420
- externally. This design allows for flexible caching strategies without
421
- modifying model code.
422
- - The "ssm_update" mode requires a cache to exist and will use it for
423
- incremental state updates during generation.
424
- - B, C components correspond to Key, Query in the SSM/attention duality.
425
- """
426
- # Split input into components
427
- z, xBC, dt = torch.split(
428
- zxbcdt,
429
- [
430
- self.hidden_dim,
431
- self.hidden_dim + 2 * self.num_groups * self.state_dim,
432
- self.num_heads,
433
- ],
434
- dim=-1,
435
- )
436
-
437
- if ssm_impl == "ssm":
438
- if hasattr(self, "cache"):
439
- conv_varlen_states = causal_conv1d_varlen_states(
440
- xBC.squeeze(0),
441
- cu_seqlens,
442
- state_len=self.cache.conv_cache.shape[-1],
443
- )
444
- self.cache.conv_cache.copy_(conv_varlen_states)
445
-
446
- xBC = causal_conv1d_fn(
447
- x=xBC.transpose(1, 2),
448
- weight=self.conv1d.weight.squeeze(1),
449
- bias=self.conv1d.bias,
450
- activation="silu",
451
- seq_idx=tok_idx,
452
- ).transpose(1, 2)
453
- elif ssm_impl == "ssm_update":
454
- xBC = causal_conv1d_update(
455
- x=xBC.squeeze(0),
456
- conv_state=self.cache.conv_cache,
457
- weight=self.conv1d.weight.squeeze(1),
458
- bias=self.conv1d.bias,
459
- activation="silu",
460
- ).unsqueeze(0)
461
- else:
462
- raise NotImplementedError(f"SSM implementation {ssm_impl} not supported")
463
-
464
- # Split processed tensor into components
465
- x, B, C = torch.split(
466
- xBC,
467
- [
468
- self.hidden_dim,
469
- self.num_groups * self.state_dim,
470
- self.num_groups * self.state_dim,
471
- ],
472
- dim=-1,
473
- )
474
-
475
- return z, x, B, C, dt
476
-
477
- def _non_causal_conv(self, zxbcdt: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
478
- z, x, B, C, dt = torch.split(
479
- zxbcdt,
480
- [
481
- self.hidden_dim,
482
- self.hidden_dim,
483
- self.num_groups * self.state_dim,
484
- self.num_groups * self.state_dim,
485
- self.num_heads,
486
- ],
487
- dim=-1,
488
- )
489
- return z, x, B, C, dt
490
-
491
- def _fwd(self, x, dt, A, B, C, tok_idx, cu_seqlens, initial_states):
492
- """
493
- For training
494
-
495
- Returns:
496
- (bsz, seq_len, num_heads, head_dim)
497
- """
498
- y = mamba_chunk_scan_combined(
499
- x,
500
- dt,
501
- A,
502
- B,
503
- C,
504
- dt_bias=self.dt_bias,
505
- dt_softplus=True,
506
- chunk_size=self.ssm_chunk_size,
507
- D=self.D,
508
- z=None,
509
- seq_idx=tok_idx,
510
- cu_seqlens=cu_seqlens,
511
- initial_states=initial_states,
512
- **self.dt_limit_kwargs,
513
- )
514
-
515
- if hasattr(self, "cache"):
516
- y, varlen_states = y
517
- self.cache.state_cache.copy_(varlen_states)
518
 
519
- return y
520
-
521
- def _step(self, x, seq_len, dt, A, B, C):
 
 
 
 
522
  """
523
- For inference / generation.
524
  """
525
- x = x.squeeze(0)
526
- A = A[..., None, None].expand(self.num_heads, self.head_dim, self.state_dim)
527
- dt = dt.permute(1, 2, 0).expand(seq_len, self.num_heads, self.head_dim)
528
- D = self.D
529
- if D is not None and D.dim() == 1:
530
- D = D.unsqueeze(1).expand(self.num_heads, self.head_dim)
531
- B, C = B.squeeze(0), C.squeeze(0)
532
- y = selective_state_update(
533
- self.cache.state_cache,
534
- x,
535
- dt,
536
- A,
537
- B,
538
- C,
539
- D,
540
- z=None,
541
- dt_bias=(
542
- torch.zeros(self.num_heads, self.head_dim).to(x)
543
- if self.dt_bias is None
544
- else self.dt_bias.unsqueeze(1).expand(self.num_heads, self.head_dim)
545
- ),
546
- dt_softplus=True,
547
- ).unsqueeze(0)
548
-
549
- return y
550
-
551
- def forward(
552
- self,
553
- x: torch.Tensor,
554
- tok_idx: torch.Tensor | None = None,
555
- cu_seqlens: torch.Tensor | None = None,
556
- ssm_impl: str = "ssm",
557
- ) -> torch.Tensor:
558
- bsz, seq_len, _ = x.shape
559
-
560
- zxbcdt = self.input(x)
561
-
562
- A = -torch.exp(self.A_log.float())
563
- initial_states = (
564
- self.init_states.expand(bsz, -1, -1, -1)
565
- if self.learnable_init_states else None
566
- )
567
-
568
- # Causal conv path
569
- if self.conv_size is not None:
570
-
571
- # Memory-efficient Triton kernel path
572
- if self.use_mem_eff_path:
573
- out = mamba_split_conv1d_scan_combined(
574
- zxbcdt,
575
- self.conv1d.weight.squeeze(1),
576
- self.conv1d.bias,
577
- self.dt_bias,
578
- A,
579
- D=self.D,
580
- chunk_size=self.ssm_chunk_size,
581
- seq_idx=tok_idx,
582
- activation="silu",
583
- rmsnorm_weight=self.norm.weight,
584
- rmsnorm_eps=self.norm.eps,
585
- outproj_weight=self.output.weight,
586
- outproj_bias=self.output.bias,
587
- headdim=self.head_dim,
588
- ngroups=self.num_groups,
589
- norm_before_gate=False, # Post-norm, y = self.norm(y * F.silu(z))
590
- initial_states=initial_states,
591
- **self.dt_limit_kwargs,
592
- )
593
- return out
594
- else:
595
- # CUDA kernel path
596
- z, x, B, C, dt = self._causal_conv(zxbcdt)
597
- else:
598
- # Non-causal conv path
599
- z, x, B, C, dt = self._non_causal_conv(zxbcdt)
600
-
601
- x = x.view(bsz, seq_len, self.num_heads, self.head_dim)
602
- B = B.view(bsz, seq_len, self.num_groups, self.state_dim)
603
- C = C.view(bsz, seq_len, self.num_groups, self.state_dim)
604
-
605
- # Chunked SSM scan
606
- if ssm_impl == "ssm":
607
- # (bsz, seq_len, num_heads, head_dim)
608
- y = self._fwd(x, dt, A, B, C, tok_idx, cu_seqlens, initial_states)
609
- elif ssm_impl == "ssm_update":
610
- y = self._step(x, seq_len, dt, A, B, C)
611
- else:
612
- raise NotImplementedError(f"SSM implementation {ssm_impl} not supported")
613
-
614
- y = y.view(bsz, seq_len, self.hidden_dim)
615
-
616
- # Could be different activation function, including None.
617
- # Mamba people post_norm here also (sometimes norm(z)*y or norm(z*y))
618
- # y = self.norm(y) * F.silu(z)
619
- y = self.norm(y * F.silu(z))
620
- out = self.output(y)
621
-
622
- return out
623
-
624
- @torch.inference_mode()
625
- def reset_parameters(self, init_std, factor) -> None:
626
- config = self.config
627
- init_config = config.init_config
628
- if init_config is None:
629
- init_config = DEFAULT_INIT_CONFIG
630
-
631
- # Linear layers
632
- in_init_std = init_std or (self.dim ** (-0.5))
633
- out_init_std = init_std or (self.hidden_dim ** (-0.5))
634
- out_init_std = out_init_std / factor
635
-
636
- nn.init.trunc_normal_(
637
- self.input.weight,
638
- mean=0.0,
639
- std=in_init_std,
640
- a=-3 * in_init_std,
641
- b=3 * in_init_std,
642
- )
643
-
644
- nn.init.trunc_normal_(
645
- self.output.weight,
646
- mean=0.0,
647
- std=out_init_std,
648
- a=-3 * out_init_std,
649
- b=3 * out_init_std,
650
- )
651
-
652
- # SSM
653
- if self.dt_bias is not None and self.dt_bias.requires_grad:
654
- log_dt_min = math.log(init_config.dt_min)
655
- log_dt_max = math.log(init_config.dt_max)
656
-
657
- # Sample log_dt ~ Uniform[log_dt_min, log_dt_max]
658
- log_dt = torch.rand(self.num_heads, device=self.dt_bias.device) * (log_dt_max - log_dt_min) + log_dt_min
659
- dt = torch.exp(log_dt)
660
- dt = torch.clamp(dt, min=init_config.dt_init_floor)
661
-
662
- # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
663
- inv_dt = dt + torch.log(-torch.expm1(-dt))
664
- self.dt_bias.copy_(inv_dt)
665
-
666
- elif self.dt_bias is not None:
667
- # If dt_bias is not trainable, we can just keep it zero or set to any constant
668
- self.dt_bias.fill_(0.0)
669
-
670
- # Convolution
671
- if self.conv_size is not None:
672
- conv_std = init_std or (self.conv_size ** (-0.5))
673
- nn.init.trunc_normal_(
674
- self.conv1d.weight,
675
- mean=0.0,
676
- std=conv_std,
677
- a=-3 * conv_std,
678
- b=3 * conv_std,
679
- )
680
- if self.conv1d.bias is not None:
681
- nn.init.zeros_(self.conv1d.bias)
682
-
683
- # Learnable init states
684
- if self.learnable_init_states:
685
- self.init_states.zero_()
686
-
687
- # Initialize A_log ~ log( Uniform(A_init_min, A_init_max) )
688
- self.A_log.uniform_(init_config.A_init_min, init_config.A_init_max)
689
- self.A_log.log_()
690
-
691
- if self.D is not None:
692
- self.D.data.fill_(1.0)
693
-
694
- # Reset norm parameters
695
- self.norm.reset_parameters()
696
-
697
-
698
- class MambaBlock(nn.Module):
699
- def __init__(self, config: BaseMambaConfig):
700
- super().__init__()
701
- self.norm = build_norm(config.norm_type, dim=config.dim, eps=config.norm_eps)
702
- self.ssm = SSM(config)
703
-
704
- def forward(
705
- self,
706
- x: torch.Tensor,
707
- tok_idx: torch.Tensor | None,
708
- cu_seqlens: torch.Tensor | None,
709
- ssm_impl: str = "ssm",
710
- ) -> torch.Tensor:
711
- x = x + self.ssm(self.norm(x), tok_idx=tok_idx, cu_seqlens=cu_seqlens, ssm_impl=ssm_impl)
712
- return x
713
-
714
- @torch.inference_mode()
715
- def init_weights(self, init_std=None, factor=1.0):
716
- self.norm.reset_parameters()
717
- self.ssm.reset_parameters(init_std, factor)
718
-
719
-
720
- class BaseMamba(nn.Module):
721
- def __init__(self, config: BaseMambaConfig):
722
- super().__init__()
723
- self.model_dim = config.dim
724
- self.init_base_std = config.init_base_std
725
-
726
- self.init_config = config.init_config
727
- self.init_std_factor = InitStdFactor(config.init_std_factor)
728
-
729
- self.layers = nn.ModuleList()
730
- for _ in range(config.num_layers):
731
- self.layers.append(MambaBlock(config))
732
-
733
- def forward(
734
- self,
735
- h: torch.Tensor,
736
- tok_idx: torch.Tensor | None,
737
- cu_seqlens: torch.Tensor | None,
738
- ssm_impl: str = "ssm",
739
- ) -> torch.Tensor:
740
- for layer in self.layers:
741
- h = layer(h, tok_idx=tok_idx, cu_seqlens=cu_seqlens, ssm_impl=ssm_impl)
742
- return h
743
-
744
- @torch.inference_mode()
745
- def reset_parameters(self):
746
- pass
747
-
748
- @torch.inference_mode()
749
- def init_weights(self):
750
- self.reset_parameters()
751
- for depth, layer in enumerate(self.layers):
752
- factor = {
753
- InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
754
- InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
755
- InitStdFactor.DIM_RATIO: self.model_dim / 4096,
756
- InitStdFactor.DISABLED: 1.0,
757
- }[self.init_std_factor]
758
-
759
- layer.init_weights(self.init_base_std, factor)
760
-
761
-
762
- @dataclass
763
- class Mamba2Config(BaseMambaConfig):
764
- seed: int = 1337
765
-
766
- vocab_size: int = -1 # Will error if unchanged, makes you double check!
767
- weight_tying: bool = False
768
- torch_dtype: torch.dtype = torch.bfloat16
769
-
770
- loss_reduction: str = "mean"
771
-
772
- use_attn: bool = False
773
- softcap: float = 50.0
774
-
775
-
776
- class Mamba2(BaseMamba):
777
- def __init__(self, config: Mamba2Config) -> None:
778
- super().__init__(config)
779
- if isinstance(config.torch_dtype, torch.dtype):
780
- torch_dtype = config.torch_dtype
781
- else:
782
- torch_dtype = getattr(torch, config.torch_dtype)
783
- self.weight_tying = config.weight_tying
784
- self.loss_reduction = config.loss_reduction
785
-
786
- assert config.vocab_size > 0, "vocab_size must be set and > 0"
787
-
788
- self.tok_emb = torch.nn.Embedding(config.vocab_size, config.dim)
789
-
790
- self.norm = nn.RMSNorm(config.dim, eps=config.norm_eps)
791
-
792
- self.output = nn.Linear(
793
- config.dim,
794
- config.vocab_size,
795
- bias=False,
796
- )
797
-
798
- if config.weight_tying:
799
- self.output.weight = self.tok_emb.weight
800
 
801
- print("Model Parameter Count: %.2fM\n" % (self._get_num_params() / 1e6,))
 
 
 
802
 
803
- def _get_num_params(self):
804
- n_params = sum(p.numel() for p in self.parameters())
805
- if hasattr(self, "pos_emb") and self.pos_emb is not None:
806
- n_params -= self.pos_emb.weight.numel()
807
- if self.tok_emb.weight is not self.output.weight:
808
- n_params -= self.tok_emb.weight.numel()
809
- return n_params
810
 
811
- def forward(
812
- self,
813
- input_ids: torch.Tensor,
814
- target: torch.Tensor | None = None,
815
- tok_idx: torch.Tensor | None = None,
816
- cu_seqlens: torch.Tensor | None = None,
817
- ssm_impl: str = "ssm",
818
- labels: torch.Tensor = None,
819
- **kwargs
820
- ) -> CausalLMOutput:
821
- h = self.tok_emb(input_ids)
822
- h = super().forward(h, tok_idx=tok_idx, cu_seqlens=cu_seqlens, ssm_impl=ssm_impl)
823
- logits = self.output(self.norm(h))
824
- loss = None
825
- if labels is not None:
826
- # By default, huggingface GPT-like models shift the logits by one
827
- shift_logits = logits[..., :-1, :].contiguous()
828
- shift_labels = labels[..., 1:].contiguous()
829
- loss_fct = nn.CrossEntropyLoss()
830
- loss = loss_fct(
831
- shift_logits.view(-1, shift_logits.size(-1)),
832
- shift_labels.view(-1)
833
- )
834
- return CausalLMOutput(
835
- loss=loss,
836
- logits=logits,
837
- )
838
 
839
- @torch.inference_mode()
840
- def reset_parameters(self, init_std=None):
841
- # Either use fixed base std or sqrt model dim
842
- super().reset_parameters()
843
- init_std = init_std or (self.model_dim ** (-0.5))
844
- self.norm.reset_parameters()
845
- nn.init.trunc_normal_(
846
- self.tok_emb.weight,
847
- mean=0.0,
848
- std=init_std,
849
- a=-3 * init_std,
850
- b=3 * init_std,
851
- )
852
- if not self.weight_tying:
853
- nn.init.trunc_normal_(
854
- self.output.weight,
855
- mean=0.0,
856
- std=init_std,
857
- a=-3 * init_std,
858
- b=3 * init_std,
859
  )
 
860
 
861
- @torch.inference_mode()
862
- def init_weights(self, buffer_device: torch.device = None):
863
- """
864
- Initialize model parameters and optionally compute buffers on a specific device.
865
-
866
- Args:
867
- buffer_device (torch.device, optional): If provided, any large or precomputed
868
- buffers (like RoPE frequency tensors) will be allocated or re-created on
869
- this device during initialization. This can avoid overhead from transferring
870
- buffers between CPU and GPU after creation. If None, buffers default to the
871
- device of the first parameter or CPU.
872
-
873
- Usage:
874
- - Pass a GPU device (e.g., ``torch.device('cuda')``) when you want to ensure
875
- buffers are created directly on GPU, preventing extra transfers.
876
- - Pass a CPU device (e.g., ``torch.device('cpu')``) if you want to keep
877
- large buffers in CPU memory (common in CPU-offload or pipeline-parallel setups).
878
- - Leave it as ``None`` to rely on the model’s existing parameter device or
879
- the default PyTorch device context.
880
 
881
- When / Why:
882
- - Useful in distributed or pipeline-parallel training where parameters may
883
- initially live on CPU, but you still need certain buffers on GPU to avoid
884
- overhead during forward passes.
885
- - Prevents large re-allocations or re-copies when big buffers (like RoPE
886
- frequency tables) are needed per rank.
887
  """
888
- super().init_weights()
889
-
890
- @classmethod
891
- def from_model_args(cls, config: Mamba2Config) -> "Mamba2":
892
  """
893
- Initialize a Mamba model from a MambaConfig object.
894
-
895
- Args:
896
- config (MambaConfig): Mamba configuration arguments.
897
-
898
- Returns:
899
- Mamba: Mamba-2 model.
900
- """
901
- return cls(config)
902
-
903
-
904
- def get_mamba2_flops(
905
- seq_len: int,
906
- dim: int,
907
- num_layers: int,
908
- vocab_size: int,
909
- ffn_multiplier: float = 2.0,
910
- state_dim: int = 128,
911
- conv_size: int = 4,
912
- num_heads: int = 8,
913
- num_groups: int = 1,
914
- multiple_of: int = 256,
915
- include_input_embedding: bool = True,
916
- include_output_logits: bool = True,
917
- forward_backward_multiplier: float = 1.0,
918
- ) -> int:
919
- """
920
- Estimate the FLOPs for a Mamba-2 style model using a "Chinchilla-like" shape-based approach.
921
-
922
- By default, this returns the forward-pass cost. If you want a rough
923
- forward+backward estimate, set `forward_backward_multiplier=3.0` (common
924
- rule-of-thumb for these models).
925
-
926
- What gets counted:
927
- • Hidden dimension is rounded up to 'multiple_of' = 256 (as in Mamba).
928
- • Per-layer:
929
- 1) Input Linear: [dim → 2*hidden_dim + 2*(groups*state_dim) + num_heads]
930
- 2) Depthwise Conv1D: 2*(conv_dim * conv_size), where conv_dim=hidden_dim + 2*groups*state_dim
931
- 3) SSM selective scan: ~9*(dim*state_dim) (from Mamba dev discussion)
932
- 4) Output Linear: [hidden_dim → dim]
933
- • Each layer’s cost is multiplied by (seq_len * num_layers).
934
- • Optionally adds:
935
- - The cost of the input embedding (treating it as a matmul: seq_len×vocab_size × vocab_size×dim).
936
- - The cost of the final projection [dim → vocab_size].
937
- • Finally scaled by `forward_backward_multiplier` if desired.
938
-
939
- Args:
940
- seq_len (int): Sequence length (number of tokens).
941
- dim (int): Model (embedding) dimension.
942
- num_layers (int): Number of Mamba layers.
943
- vocab_size (int): Vocabulary size for final logits projection.
944
- ffn_multiplier (float): FFN expansion ratio, e.g. 2.0 => hidden_dim=2×dim (rounded up).
945
- state_dim (int): SSM state dimension (commonly 128).
946
- conv_size (int): Kernel size for the depthwise conv1d (default=4).
947
- num_heads (int): Number of heads (slightly affects input-lin out_dim).
948
- num_groups (int): For "grouped" states in some Mamba variants (usually 1).
949
- multiple_of (int): Round hidden_dim up to this multiple (commonly 256).
950
- include_input_embedding (bool): If True, count the cost of an “embedding matmul”
951
- for the input tokens => shape-based approach.
952
- include_output_logits (bool): If True, count the cost of final [dim → vocab_size].
953
- forward_backward_multiplier (float): E.g. 1.0 for forward only, 2.0 or 3.0 for forward+backward.
954
-
955
- Returns:
956
- int: Approximate total FLOPs (multiply-adds) for the selected pass(es),
957
- as an integer.
958
- """
959
- # 0) Input embedding (optional)
960
- flops_embedding = 0
961
- if include_input_embedding:
962
- flops_embedding = 2 * (seq_len * vocab_size * dim)
963
-
964
- # 1) Round up hidden_dim
965
- raw_hidden_dim = int(ffn_multiplier * dim)
966
- hidden_dim = multiple_of * ((raw_hidden_dim + multiple_of - 1) // multiple_of)
967
-
968
- # 2) Per-layer forward cost
969
- out_dim_input = 2*hidden_dim + 2*(num_groups*state_dim) + num_heads
970
- flops_input_linear = 2 * (dim * out_dim_input)
971
- conv_dim = hidden_dim + 2*(num_groups*state_dim)
972
- flops_conv = 2 * (conv_dim * conv_size)
973
- flops_ssm = 9 * state_dim * dim
974
- flops_output_linear = 2 * (hidden_dim * dim)
975
- flops_layer = (flops_input_linear + flops_conv + flops_ssm + flops_output_linear)
976
-
977
- # Multiply by #layers and sequence length
978
- flops_layers = flops_layer * num_layers * seq_len
979
-
980
- # 3) Final projection [dim → vocab_size] (optional)
981
- flops_vocab = 0
982
- if include_output_logits:
983
- flops_vocab = 2 * (seq_len * dim * vocab_size)
984
-
985
- # 4) Total forward FLOPs
986
- flops_forward = flops_embedding + flops_layers + flops_vocab
987
-
988
- # 5) Scale for forward+backward if desired
989
- return int(flops_forward * forward_backward_multiplier)
990
-
991
- def get_mamba2_flops_per_token(
992
- **kwargs
993
- ) -> float:
994
- """
995
- Estimate FLOPs per token for a Mamba-2 style model.
996
-
997
- This function extracts necessary parameters from kwargs and calculates the FLOPs per token.
998
-
999
- Args:
1000
- **kwargs: Dictionary containing model configuration parameters.
1001
-
1002
- Returns:
1003
- float: Approximate FLOPs per token.
1004
- """
1005
- defaults = {
1006
- 'ffn_dim_multiplier': 2.0,
1007
- 'state_dim': 128,
1008
- 'conv_size': 4,
1009
- 'num_heads': 8,
1010
- 'num_groups': 1,
1011
- 'multiple_of': 256,
1012
- 'include_input_embedding': True,
1013
- 'include_output_logits': True,
1014
- 'forward_backward_multiplier': 1.0,
1015
- }
1016
- # Merge defaults
1017
- for k, v in defaults.items():
1018
- kwargs.setdefault(k, v)
1019
- # Mandatory keys
1020
- for required in ['seq_len', 'dim', 'num_layers', 'vocab_size']:
1021
- if required not in kwargs:
1022
- raise ValueError(f"Missing required parameter: {required}")
1023
-
1024
- total_flops = get_mamba2_flops(
1025
- seq_len=kwargs['seq_len'],
1026
- dim=kwargs['dim'],
1027
- num_layers=kwargs['num_layers'],
1028
- vocab_size=kwargs['vocab_size'],
1029
- ffn_multiplier=kwargs['ffn_dim_multiplier'],
1030
- state_dim=kwargs['state_dim'],
1031
- conv_size=kwargs['conv_size'],
1032
- num_heads=kwargs['num_heads'],
1033
- num_groups=kwargs['num_groups'],
1034
- multiple_of=kwargs['multiple_of'],
1035
- include_input_embedding=kwargs['include_input_embedding'],
1036
- include_output_logits=kwargs['include_output_logits'],
1037
- forward_backward_multiplier=kwargs['forward_backward_multiplier'],
1038
- )
1039
- flops_per_token = total_flops / kwargs['seq_len']
1040
-
1041
- return flops_per_token
1042
-
1043
-
1044
- # Optional policy for activation checkpointing. With None, we stick to the default (defined distributed.py: default_no_recompute_ops)
1045
- def get_no_recompute_ops():
1046
- return {
1047
- torch.ops.aten.mm.default,
1048
- torch.ops.aten._scaled_mm.default,
1049
- torch.ops.c10d_functional.reduce_scatter_tensor.default,
1050
- torch.ops.mamba_ssm.ssm_chunk_scan_combined_fwd.default,
1051
-
1052
- # For low-precision training, it's useful to always save the result of max(abs(tensor))
1053
- torch.ops.aten.abs.default,
1054
- torch.ops.aten.max.default,
1055
- }
1056
-
1057
-
1058
- def main():
1059
- from mamba_ssm import Mamba2 as MambaRef
1060
-
1061
- x = torch.randn(2, 64, 192).cuda()
1062
-
1063
- # Create and run the first model
1064
- model = MambaRef(
1065
- d_model=192,
1066
- expand=2,
1067
- d_conv=4,
1068
- d_state=64,
1069
- headdim=48,
1070
- ).cuda()
1071
- y = model(x)
1072
- print("Mamba reference output: ", y)
1073
- print("Mean of MambaRef output: ", y.mean().item())
1074
- print("Stddev of MambaRef output: ", y.std().item())
1075
-
1076
- # Create and run the second model
1077
- config = Mamba2Config(vocab_size=200064, use_mem_eff_path=True)
1078
- model2 = Mamba2(
1079
- config=config,
1080
- ).cuda()
1081
-
1082
- # Fix: Convert x to torch.LongTensor
1083
- x_indices = torch.randint(0, config.vocab_size, (2, 64), dtype=torch.long).cuda()
1084
-
1085
- y2 = model2(x_indices)
1086
- print("Mamba output: ", y2)
1087
- print("Mean of Mamba output: ", y2.mean().item())
1088
- print("Stddev of Mamba output: ", y2.std().item())
1089
 
 
 
 
 
1
+ import math
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
 
7
  from transformers.modeling_outputs import CausalLMOutput
8
 
9
  from .configuration_minimamba import MiniMambaConfig
10
+ from .model import Mamba2, Mamba2Config
 
11
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ class MiniMamba(PreTrainedModel):
15
+ """
16
+ A Hugging Face–style wrapper around a Mamba2 model, providing:
17
+ • forward(...) returning a CausalLMOutput
18
+ • support for HF training loops
19
+ • a naive generate(...) method with top-k/top-p sampling
20
+ """
21
+ config_class = MiniMambaConfig # Tells HF which config class to use
22
 
23
+ def __init__(self, config: MiniMambaConfig) -> None:
24
+ """
25
+ Initialize the MiniMamba model, bridging Mamba2 with HF's PreTrainedModel.
26
+ """
27
+ super().__init__(config)
28
 
29
+ # If your config includes Mamba2-like parameters, you can build a Mamba2Config from it:
30
+ mamba2_args = Mamba2Config(
31
+ vocab_size=config.vocab_size,
32
+ num_layers=config.n_layers,
33
+ dim=config.n_embd,
34
+ use_mem_eff_path=True,
35
+ weight_tying=config.weight_tying if hasattr(config, "weight_tying") else False,
36
+ torch_dtype=getattr(torch, config.torch_dtype) if isinstance(config.torch_dtype, str) else config.torch_dtype,
37
+ )
38
+
39
+ # Internally hold a Mamba2 model
40
+ self.mamba = Mamba2(config=mamba2_args)
41
+
42
+ # Because HF wants the final linear to be part of this top-level model,
43
+ # you *can* rely on Mamba2’s built-in embedding + output if you prefer.
44
+ # Mamba2 already has self.tok_emb and self.output.
45
+ # So we typically do NOT need a separate embedding or lm_head here.
46
+ #
47
+ # We only do so if we want the “HF standard” tie-weights approach:
48
+ # self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
49
+ # self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
50
+ # self.lm_head.weight = self.tok_emb.weight
51
+ #
52
+ # But Mamba2 does that internally if config.weight_tying == True.
53
+
54
+ # This is optional: store any device or dtype you might want
55
+ self.device_ = torch.device(config.device)
56
+ if isinstance(config.torch_dtype, str):
57
+ self.dtype_ = getattr(torch, config.torch_dtype)
58
  else:
59
+ self.dtype_ = config.torch_dtype
 
 
 
 
 
 
 
 
 
 
60
 
61
+ # Parameter initialization (HF calls them with self._init_weights in some flows).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  self.apply(self._init_weights)
63
+
64
+ print("MiniMamba Model Parameter Count: %.2fM\n" % (self._get_num_params() / 1e6,))
65
 
66
  def forward(
67
  self,
68
+ input_ids: torch.LongTensor,
69
+ labels: torch.LongTensor = None,
70
  **kwargs
71
  ) -> CausalLMOutput:
72
+ """
73
+ Forward pass for causal language modeling.
74
+ Returns a CausalLMOutput that includes loss (if labels is provided) and logits.
75
+ """
76
+ # Mamba2's forward expects (x: torch.Tensor, target: torch.Tensor|None, ...)
77
+ # but we only need the logits from the simple call:
78
+ logits = self.mamba(input_ids) # shape: [batch, seq_len, vocab_size]
 
 
 
 
79
 
80
  loss = None
81
  if labels is not None:
82
+ # By default, huggingface GPT-like models shift the logits by one
83
  shift_logits = logits[..., :-1, :].contiguous()
84
  shift_labels = labels[..., 1:].contiguous()
85
  loss_fct = nn.CrossEntropyLoss()
86
  loss = loss_fct(
87
+ shift_logits.view(-1, shift_logits.size(-1)),
88
  shift_labels.view(-1)
89
  )
90
 
 
93
  logits=logits,
94
  )
95
 
96
+ @torch.no_grad()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  def generate(
98
  self,
99
  input_ids: torch.LongTensor,
 
106
  **kwargs
107
  ):
108
  """
109
+ A naive token-by-token generation loop (greedy + top-k/top-p + temperature).
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  """
111
+ # We'll accumulate new tokens in generated_ids
 
 
 
 
 
112
  generated_ids = input_ids.clone()
113
 
114
  for _ in range(max_new_tokens):
 
116
  outputs = self.forward(generated_ids)
117
  logits = outputs.logits[:, -1, :] # shape: (batch_size, vocab_size)
118
 
119
+ # Scale by temperature
120
  if temperature != 1.0:
121
  logits = logits / temperature
122
 
123
+ # Filter
124
  logits = self.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
125
 
126
+ # Sample next token
127
+ probs = F.softmax(logits, dim=-1)
128
+ next_token = torch.multinomial(probs, num_samples=1) # shape: (batch, 1)
 
 
129
 
130
+ # Append
131
  generated_ids = torch.cat([generated_ids, next_token], dim=1)
132
 
133
+ # If we have an EOS token, we can break early if all sequences have ended
134
+ if eos_token_id is not None and (next_token == eos_token_id).all():
135
+ break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
+ return generated_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
+ @staticmethod
140
+ def top_k_top_p_filtering(
141
+ logits: torch.Tensor,
142
+ top_k: int = 50,
143
+ top_p: float = 0.95,
144
+ filter_value: float = float("-inf"),
145
+ ):
146
  """
147
+ Filters logits using top-k and/or nucleus (top-p) filtering.
148
  """
149
+ # top_k
150
+ if top_k > 0:
151
+ top_k = min(top_k, logits.size(-1))
152
+ indices_to_remove = logits < torch.topk(logits, top_k, dim=-1).values[:, -1, None]
153
+ logits[indices_to_remove] = filter_value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
+ # top_p (nucleus)
156
+ if 0 < top_p < 1.0:
157
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
158
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
159
 
160
+ # Remove tokens with cumulative probability above the threshold
161
+ sorted_indices_to_remove = cumulative_probs > top_p
 
 
 
 
 
162
 
163
+ # Shift right to keep also the first token above threshold
164
+ sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
165
+ sorted_indices_to_remove[:, 0] = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
+ # Scatter to get back to original indexing
168
+ indices_to_remove = sorted_indices_to_remove.scatter(
169
+ dim=1, index=sorted_indices, src=sorted_indices_to_remove
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  )
171
+ logits[indices_to_remove] = filter_value
172
 
173
+ return logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
+ def _init_weights(self, module):
 
 
 
 
 
176
  """
177
+ HF calls _init_weights to initialize parameters.
178
+ If you prefer Mamba’s own init approach, you can call model.mamba.init_weights().
 
 
179
  """
180
+ # As an example, we just call Mamba2's init routine for the entire submodel,
181
+ # or do some standard PyTorch inits for linear layers, embeddings, etc.
182
+ if isinstance(module, Mamba2):
183
+ module.init_weights() # Mamba2’s internal init
184
+ elif isinstance(module, nn.Linear):
185
+ # e.g. standard xavier or normal init
186
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
187
+ if module.bias is not None:
188
+ nn.init.zeros_(module.bias)
189
+ elif isinstance(module, nn.Embedding):
190
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
191
+ # If needed, do your specialized inits for other modules
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
+ def _get_num_params(self):
194
+ # Count trainable params, subtract duplicates if tying weights, etc.
195
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)