alexandretl commited on
Commit
3b164a1
·
1 Parent(s): 58b82e2

head tying | gated mlp | gate of Mamba3 inside module

Browse files
Files changed (3) hide show
  1. configuration_dragon.py +11 -1
  2. modeling_dragon.py +61 -36
  3. training_dragon.py +37 -2
configuration_dragon.py CHANGED
@@ -92,6 +92,11 @@ class DragonConfig(PretrainedConfig):
92
 
93
  def __init__(
94
  self,
 
 
 
 
 
95
  mamba3_rope: bool = True,
96
  mamba3_remove_BC_bias: bool = False,
97
  mamba3_is_id_rms: bool = True,
@@ -192,6 +197,11 @@ class DragonConfig(PretrainedConfig):
192
  mlp_linking=False,
193
  **kwargs,
194
  ):
 
 
 
 
 
195
  self.mamba3_rope = mamba3_rope
196
  self.mamba3_remove_BC_bias = mamba3_remove_BC_bias
197
  self.mamba3_is_id_rms = mamba3_is_id_rms
@@ -309,7 +319,7 @@ class DragonConfig(PretrainedConfig):
309
  pad_token_id=pad_token_id,
310
  bos_token_id=bos_token_id,
311
  eos_token_id=eos_token_id,
312
- tie_word_embeddings=tie_word_embeddings,
313
  **kwargs,
314
  )
315
  # TODO: better way to handle those?
 
92
 
93
  def __init__(
94
  self,
95
+ tie_lm_head: bool = False,
96
+ mlp_type: str = "simple",
97
+ layer_norm_scaling: bool = False,
98
+ mamba_d_state: int = 128,
99
+ mamba_headdim: int = 64,
100
  mamba3_rope: bool = True,
101
  mamba3_remove_BC_bias: bool = False,
102
  mamba3_is_id_rms: bool = True,
 
197
  mlp_linking=False,
198
  **kwargs,
199
  ):
200
+ self.tie_lm_head = tie_lm_head
201
+ self.mlp_type = mlp_type
202
+ self.layer_norm_scaling = layer_norm_scaling
203
+ self.mamba_d_state = mamba_d_state
204
+ self.mamba_headdim = mamba_headdim
205
  self.mamba3_rope = mamba3_rope
206
  self.mamba3_remove_BC_bias = mamba3_remove_BC_bias
207
  self.mamba3_is_id_rms = mamba3_is_id_rms
 
319
  pad_token_id=pad_token_id,
320
  bos_token_id=bos_token_id,
321
  eos_token_id=eos_token_id,
322
+ tie_word_embeddings=tie_lm_head,
323
  **kwargs,
324
  )
325
  # TODO: better way to handle those?
modeling_dragon.py CHANGED
@@ -19,6 +19,8 @@ from transformers.utils import ModelOutput, logging
19
 
20
  from fla.ops.nsa.parallel import parallel_nsa
21
 
 
 
22
  try:
23
  from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
24
  except ImportError:
@@ -559,7 +561,7 @@ class DragonAttention(nn.Module):
559
  self.num_attention_heads = config.num_attention_heads
560
  self.num_key_value_heads = config.num_key_value_heads
561
  self.hidden_size = config.hidden_size
562
- self.head_dim = config.head_dim if config.head_dim else config.hidden_size * config.expand_factor // self.num_attention_heads
563
  self.qk_norm = config.qk_norm
564
  self.window_size = config.sliding_window_size
565
  self.reuse_kv = reuse_kv
@@ -706,7 +708,7 @@ class DragonAttention(nn.Module):
706
  if not self.reuse_kv:
707
  key_states = apply_rotary_emb(key_states, cos, sin)
708
  elif self.config.rope_type_local == "p-rope":
709
- query_states = apply_p_rotary_emb(query_states, cos, sin)
710
  if not self.reuse_kv:
711
  key_states = apply_p_rotary_emb(key_states, cos, sin)
712
  else:
@@ -3519,10 +3521,10 @@ class DragonMamba3(nn.Module):
3519
  )
3520
 
3521
  self.d_model = config.hidden_size
3522
- self.d_state = 128
3523
  self.conv_init = None
3524
  self.expand = 2
3525
- self.headdim = 64
3526
  self.ngroups = config.mamba_ngroups
3527
  self.activation = "swish"
3528
  self.bias = False
@@ -3547,8 +3549,8 @@ class DragonMamba3(nn.Module):
3547
  if config.mamba3_rope:
3548
  self.rope_proj = DragonLinear(config, self.d_model, self.num_rope_angles, bias=False)
3549
 
3550
- # Order: [x, B, C, dt]
3551
- d_in_proj = self.d_inner + 2 * self.d_state * self.ngroups + self.nheads
3552
 
3553
  if self.config.mamba3_is_A_dd:
3554
  self.A_proj = DragonLinear(config, self.d_model, self.nheads, bias=False, dtype=torch.float32)
@@ -3609,10 +3611,11 @@ class DragonMamba3(nn.Module):
3609
  **kwargs
3610
  ):
3611
  # Apply in_proj
3612
- xBCdt = self.in_proj(hidden_states)
3613
- xBC, dd_dt = torch.split(
3614
- xBCdt,
3615
  [
 
3616
  self.d_inner + 2 * self.d_state * self.ngroups,
3617
  self.nheads,
3618
  ],
@@ -3721,16 +3724,21 @@ class DragonMamba3(nn.Module):
3721
  else:
3722
  y = out
3723
 
 
 
 
 
3724
  return y, None, None
3725
 
3726
  class DragonMamba2(nn.Module):
3727
  def __init__(self, config: DragonConfig, layer_idx: Optional[int]):
3728
  super().__init__()
 
3729
  self.d_model = config.hidden_size
3730
- self.d_state = 128
3731
  self.expand = 2
3732
  self.d_inner = self.expand * self.d_model
3733
- self.headdim = 64
3734
  self.ngroups = config.mamba_ngroups
3735
  assert self.d_inner % self.headdim == 0
3736
  self.nheads = self.d_inner // self.headdim
@@ -3740,16 +3748,17 @@ class DragonMamba2(nn.Module):
3740
  d_in_proj = self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
3741
  self.in_proj = DragonLinear(config, self.d_model, d_in_proj, bias=False)
3742
 
3743
- conv_dim = self.d_inner + 2 * self.ngroups * self.d_state
3744
- self.conv1d = nn.Conv1d(
3745
- in_channels=conv_dim,
3746
- out_channels=conv_dim,
3747
- bias=False,
3748
- kernel_size=4,
3749
- groups=conv_dim,
3750
- padding=4-1,
3751
- )
3752
- self.act = nn.SiLU()
 
3753
 
3754
  # Initialize log dt bias
3755
  dt_min=0.001
@@ -3791,18 +3800,19 @@ class DragonMamba2(nn.Module):
3791
  dt = F.softplus(dt + self.dt_bias) # (B, L, nheads)
3792
 
3793
  # 1D Convolution
3794
- if causal_conv1d_fn is None:
3795
- xBC = self.act(
3796
- self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)
3797
- ) # (B, L, self.d_inner + 2 * ngroups * d_state)
3798
- xBC = xBC[:, :seqlen, :]
3799
- else:
3800
- xBC = causal_conv1d_fn(
3801
- x=xBC.transpose(1, 2),
3802
- weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
3803
- bias=self.conv1d.bias,
3804
- activation="swish",
3805
- ).transpose(1, 2)
 
3806
 
3807
  # Split into 3 main branches: X, B, C
3808
  # These correspond to V, K, Q respectively in the SSM/attention duality
@@ -4193,7 +4203,7 @@ class DragonMonoBlock(GradientCheckpointingLayer):
4193
  self.mixer = DragonMamba3(config, layer_idx=layer_idx)
4194
  head_dim = self.mixer.headdim
4195
  num_attention_heads = self.mixer.nheads
4196
- use_gate = config.gate_gdn
4197
  elif layer_type == '2':
4198
  self.mixer = DragonMamba2(config, layer_idx=layer_idx)
4199
  head_dim = self.mixer.headdim
@@ -4249,13 +4259,19 @@ class DragonMonoBlock(GradientCheckpointingLayer):
4249
  self.input_norm = DragonNorm(config, config.hidden_size)
4250
  self.postmixer_norm = DragonNorm(config, config.hidden_size)
4251
  if not config.moe:
4252
- self.mlp = DragonMLP(config)
 
 
 
4253
  else:
4254
  self.mlp = DragonMoE(config)
4255
  global PREVIOUS_MLP
4256
  PREVIOUS_MLP = self.mlp
4257
 
4258
- self.register_buffer("lns", torch.tensor(1.0 if config.use_uscaling else 1. / math.sqrt(layer_idx + (2 if config.old_lns else 1))), persistent=False)
 
 
 
4259
  self.register_buffer("sqrt_tau", torch.sqrt(torch.tensor(self.config.uscaling_tau)) if config.use_uscaling else torch.tensor(1.0), persistent=False)
4260
  self.register_buffer("sqrt_one_minus_tau", torch.sqrt(torch.tensor(1.0 - self.config.uscaling_tau)) if config.use_uscaling else torch.tensor(1.0), persistent=False)
4261
 
@@ -4575,6 +4591,8 @@ class DragonForCausalLM(DragonPreTrainedModel, GenerationMixin):
4575
  self.vocab_size = config.vocab_size
4576
  self.lm_head = DragonLinear(config, config.hidden_size, config.vocab_size, bias=False, alpha_fwd=1/config.hidden_size, alpha_bwd=1/math.sqrt(config.hidden_size))
4577
  self.post_init()
 
 
4578
 
4579
  def forward(
4580
  self,
@@ -4654,6 +4672,13 @@ class DragonForCausalLM(DragonPreTrainedModel, GenerationMixin):
4654
  past_key_values=outputs.past_key_values if not just_loss else None,
4655
  hidden_states=outputs.hidden_states if not just_loss else None,
4656
  )
 
 
 
 
 
 
 
4657
  DragonForCausalLM.register_for_auto_class("AutoModelForCausalLM")
4658
 
4659
  __all__ = ["DragonModel", "DragonForCausalLM", "DragonPreTrainedModel"]
 
19
 
20
  from fla.ops.nsa.parallel import parallel_nsa
21
 
22
+ from flash_attn.modules.mlp import GatedMlp
23
+
24
  try:
25
  from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
26
  except ImportError:
 
561
  self.num_attention_heads = config.num_attention_heads
562
  self.num_key_value_heads = config.num_key_value_heads
563
  self.hidden_size = config.hidden_size
564
+ self.head_dim = config.head_dim # if config.head_dim else config.hidden_size * config.expand_factor // self.num_attention_heads
565
  self.qk_norm = config.qk_norm
566
  self.window_size = config.sliding_window_size
567
  self.reuse_kv = reuse_kv
 
708
  if not self.reuse_kv:
709
  key_states = apply_rotary_emb(key_states, cos, sin)
710
  elif self.config.rope_type_local == "p-rope":
711
+ query_states = apply_p_rotary_emb(query_states, cos, sin, p=0.5)
712
  if not self.reuse_kv:
713
  key_states = apply_p_rotary_emb(key_states, cos, sin)
714
  else:
 
3521
  )
3522
 
3523
  self.d_model = config.hidden_size
3524
+ self.d_state = config.mamba_d_state
3525
  self.conv_init = None
3526
  self.expand = 2
3527
+ self.headdim = config.mamba_headdim
3528
  self.ngroups = config.mamba_ngroups
3529
  self.activation = "swish"
3530
  self.bias = False
 
3549
  if config.mamba3_rope:
3550
  self.rope_proj = DragonLinear(config, self.d_model, self.num_rope_angles, bias=False)
3551
 
3552
+ # Order: [z, x, B, C, dt]
3553
+ d_in_proj = 2 * self.d_inner + 2 * self.d_state * self.ngroups + self.nheads
3554
 
3555
  if self.config.mamba3_is_A_dd:
3556
  self.A_proj = DragonLinear(config, self.d_model, self.nheads, bias=False, dtype=torch.float32)
 
3611
  **kwargs
3612
  ):
3613
  # Apply in_proj
3614
+ zxBCdt = self.in_proj(hidden_states)
3615
+ z, xBC, dd_dt = torch.split(
3616
+ zxBCdt,
3617
  [
3618
+ self.d_inner,
3619
  self.d_inner + 2 * self.d_state * self.ngroups,
3620
  self.nheads,
3621
  ],
 
3724
  else:
3725
  y = out
3726
 
3727
+ y = rearrange(y, "b l h p -> b l (h p)")
3728
+ y = y*self.act(z)
3729
+ y = rearrange(y, "b l (h p) -> b l h p", h=self.nheads).to(x.dtype)
3730
+
3731
  return y, None, None
3732
 
3733
  class DragonMamba2(nn.Module):
3734
  def __init__(self, config: DragonConfig, layer_idx: Optional[int]):
3735
  super().__init__()
3736
+ self.config = config
3737
  self.d_model = config.hidden_size
3738
+ self.d_state = config.mamba_d_state
3739
  self.expand = 2
3740
  self.d_inner = self.expand * self.d_model
3741
+ self.headdim = config.mamba_headdim
3742
  self.ngroups = config.mamba_ngroups
3743
  assert self.d_inner % self.headdim == 0
3744
  self.nheads = self.d_inner // self.headdim
 
3748
  d_in_proj = self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
3749
  self.in_proj = DragonLinear(config, self.d_model, d_in_proj, bias=False)
3750
 
3751
+ if not self.config.mamba3_remove_conv:
3752
+ conv_dim = self.d_inner + 2 * self.ngroups * self.d_state
3753
+ self.conv1d = nn.Conv1d(
3754
+ in_channels=conv_dim,
3755
+ out_channels=conv_dim,
3756
+ bias=False,
3757
+ kernel_size=4,
3758
+ groups=conv_dim,
3759
+ padding=4-1,
3760
+ )
3761
+ self.act = nn.SiLU()
3762
 
3763
  # Initialize log dt bias
3764
  dt_min=0.001
 
3800
  dt = F.softplus(dt + self.dt_bias) # (B, L, nheads)
3801
 
3802
  # 1D Convolution
3803
+ if not self.config.mamba3_remove_conv:
3804
+ if causal_conv1d_fn is None:
3805
+ xBC = self.act(
3806
+ self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)
3807
+ ) # (B, L, self.d_inner + 2 * ngroups * d_state)
3808
+ xBC = xBC[:, :seqlen, :]
3809
+ else:
3810
+ xBC = causal_conv1d_fn(
3811
+ x=xBC.transpose(1, 2),
3812
+ weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
3813
+ bias=self.conv1d.bias,
3814
+ activation="swish",
3815
+ ).transpose(1, 2)
3816
 
3817
  # Split into 3 main branches: X, B, C
3818
  # These correspond to V, K, Q respectively in the SSM/attention duality
 
4203
  self.mixer = DragonMamba3(config, layer_idx=layer_idx)
4204
  head_dim = self.mixer.headdim
4205
  num_attention_heads = self.mixer.nheads
4206
+ use_gate = False
4207
  elif layer_type == '2':
4208
  self.mixer = DragonMamba2(config, layer_idx=layer_idx)
4209
  head_dim = self.mixer.headdim
 
4259
  self.input_norm = DragonNorm(config, config.hidden_size)
4260
  self.postmixer_norm = DragonNorm(config, config.hidden_size)
4261
  if not config.moe:
4262
+ if config.mlp_type == "simple":
4263
+ self.mlp = DragonMLP(config)
4264
+ elif config.mlp_type == "gated":
4265
+ self.mlp = GatedMlp(in_features=config.hidden_size, hidden_features=config.intermediate_size, out_features=config.hidden_size, activation=F.silu, bias1=False, bias2=False)
4266
  else:
4267
  self.mlp = DragonMoE(config)
4268
  global PREVIOUS_MLP
4269
  PREVIOUS_MLP = self.mlp
4270
 
4271
+ if config.use_uscaling or not config.layer_norm_scaling:
4272
+ self.register_buffer("lns", torch.tensor(1.0), persistent=False)
4273
+ else:
4274
+ self.register_buffer("lns", torch.tensor(1. / math.sqrt(layer_idx + (2 if config.old_lns else 1))), persistent=False)
4275
  self.register_buffer("sqrt_tau", torch.sqrt(torch.tensor(self.config.uscaling_tau)) if config.use_uscaling else torch.tensor(1.0), persistent=False)
4276
  self.register_buffer("sqrt_one_minus_tau", torch.sqrt(torch.tensor(1.0 - self.config.uscaling_tau)) if config.use_uscaling else torch.tensor(1.0), persistent=False)
4277
 
 
4591
  self.vocab_size = config.vocab_size
4592
  self.lm_head = DragonLinear(config, config.hidden_size, config.vocab_size, bias=False, alpha_fwd=1/config.hidden_size, alpha_bwd=1/math.sqrt(config.hidden_size))
4593
  self.post_init()
4594
+ if config.tie_lm_head:
4595
+ self.lm_head.weight = self.model.embedding.weight
4596
 
4597
  def forward(
4598
  self,
 
4672
  past_key_values=outputs.past_key_values if not just_loss else None,
4673
  hidden_states=outputs.hidden_states if not just_loss else None,
4674
  )
4675
+
4676
+ def get_output_embeddings(self):
4677
+ return self.lm_head
4678
+
4679
+ def set_output_embeddings(self, new_embeddings):
4680
+ self.lm_head = new_embeddings
4681
+
4682
  DragonForCausalLM.register_for_auto_class("AutoModelForCausalLM")
4683
 
4684
  __all__ = ["DragonModel", "DragonForCausalLM", "DragonPreTrainedModel"]
training_dragon.py CHANGED
@@ -18,6 +18,7 @@ import torch.distributed as dist
18
  from torch.nn.parallel import DistributedDataParallel as DDP
19
 
20
  import transformers
 
21
 
22
  from .configuration_dragon import DragonConfig
23
  from .modeling_dragon import DragonForCausalLM
@@ -59,6 +60,9 @@ class NanoArgs:
59
  mixer_gn: bool = True
60
  mlp_linking : bool = False
61
  final_norm: bool = True
 
 
 
62
 
63
  # MoE
64
  moe: bool = False
@@ -105,6 +109,8 @@ class NanoArgs:
105
  kda_num_v_heads: Optional[int] = None
106
  mamba_mimo_dim: Optional[int] = 2
107
  mamba_ngroups: Optional[int] = 1
 
 
108
  mamba3_rope: bool = True
109
  mamba3_remove_BC_bias: bool = False
110
  mamba3_is_id_rms: bool = True
@@ -125,6 +131,7 @@ class NanoArgs:
125
  adam_eps: float = 1e-8
126
  warmup_iters: int = 200
127
  warmdown_iters: int = 3000
 
128
  grad_norm_clip: float = 1.0
129
  uscaling_mult_embed: float = 0
130
  uscaling_mult_scalar: float = 0
@@ -325,6 +332,15 @@ if args.intra_doc_masking:
325
  args.device_batch_size = 1
326
  print("!!! Forcing device_batch_size to 1 for intra-document masking !!!")
327
 
 
 
 
 
 
 
 
 
 
328
  # set up DDP (distributed data parallel).
329
  assert torch.cuda.is_available()
330
  dist.init_process_group(
@@ -425,6 +441,11 @@ print0(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total}
425
 
426
  # load model.
427
  config_hf = DragonConfig(
 
 
 
 
 
428
  mamba3_rope=args.mamba3_rope,
429
  mamba3_remove_BC_bias=args.mamba3_remove_BC_bias,
430
  mamba3_is_id_rms=args.mamba3_is_id_rms,
@@ -600,8 +621,22 @@ def get_lr_wsd(num_iterations, warmup_iters, warmdown_iters, it):
600
  else:
601
  decay_ratio = (num_iterations - it) / warmdown_iters
602
  return decay_ratio
603
- sched_func = partial(get_lr_wsd, args.total_iterations, args.warmup_iters, args.warmdown_iters)
604
- schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, sched_func) for opt in optimizers]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
605
 
606
  # resume if necessary.
607
  start_iter = 0
 
18
  from torch.nn.parallel import DistributedDataParallel as DDP
19
 
20
  import transformers
21
+ from transformers import get_wsd_schedule
22
 
23
  from .configuration_dragon import DragonConfig
24
  from .modeling_dragon import DragonForCausalLM
 
60
  mixer_gn: bool = True
61
  mlp_linking : bool = False
62
  final_norm: bool = True
63
+ layer_norm_scaling: bool = False # not read when using muP
64
+ mlp_type: str = "simple" # simple, gated
65
+ tie_lm_head: bool = False
66
 
67
  # MoE
68
  moe: bool = False
 
109
  kda_num_v_heads: Optional[int] = None
110
  mamba_mimo_dim: Optional[int] = 2
111
  mamba_ngroups: Optional[int] = 1
112
+ mamba_d_state: int = 128
113
+ mamba_headdim: int = 64
114
  mamba3_rope: bool = True
115
  mamba3_remove_BC_bias: bool = False
116
  mamba3_is_id_rms: bool = True
 
131
  adam_eps: float = 1e-8
132
  warmup_iters: int = 200
133
  warmdown_iters: int = 3000
134
+ warmdown_type: str = "linear" # linear, cosine
135
  grad_norm_clip: float = 1.0
136
  uscaling_mult_embed: float = 0
137
  uscaling_mult_scalar: float = 0
 
332
  args.device_batch_size = 1
333
  print("!!! Forcing device_batch_size to 1 for intra-document masking !!!")
334
 
335
+ if args.mlp_type == "gated":
336
+ if args.use_uscaling:
337
+ print("problem: gated MLP with muP is not supported, because we use FA backend")
338
+ exit(0)
339
+
340
+ if args.moe:
341
+ print("problem: gated MLP with MoE is not supported, because we use FA backend")
342
+ exit(0)
343
+
344
  # set up DDP (distributed data parallel).
345
  assert torch.cuda.is_available()
346
  dist.init_process_group(
 
441
 
442
  # load model.
443
  config_hf = DragonConfig(
444
+ tie_lm_head=args.tie_lm_head,
445
+ mlp_type=args.mlp_type,
446
+ layer_norm_scaling=args.layer_norm_scaling,
447
+ mamba_d_state=args.mamba_d_state,
448
+ mamba_headdim=args.mamba_headdim,
449
  mamba3_rope=args.mamba3_rope,
450
  mamba3_remove_BC_bias=args.mamba3_remove_BC_bias,
451
  mamba3_is_id_rms=args.mamba3_is_id_rms,
 
621
  else:
622
  decay_ratio = (num_iterations - it) / warmdown_iters
623
  return decay_ratio
624
+ if args.warmdown_type == "linear":
625
+ sched_func = partial(get_lr_wsd, args.total_iterations, args.warmup_iters, args.warmdown_iters)
626
+ schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, sched_func) for opt in optimizers]
627
+ elif args.warmdown_type == "cosine":
628
+ sched = get_wsd_schedule(
629
+ optimizers[0],
630
+ num_warmup_steps=args.warmup_iters,
631
+ num_decay_steps=args.warmdown_iters,
632
+ num_training_steps=args.total_iterations,
633
+ min_lr_ratio=0.,
634
+ warmup_type='linear',
635
+ decay_type='cosine',
636
+ )
637
+ schedulers = [sched]
638
+ else:
639
+ raise ValueError(f"Unknown warmdown type: {args.warmdown_type}")
640
 
641
  # resume if necessary.
642
  start_iter = 0