Commit ·
3b164a1
1
Parent(s): 58b82e2
head tying | gated mlp | gate of Mamba3 inside module
Browse files- configuration_dragon.py +11 -1
- modeling_dragon.py +61 -36
- training_dragon.py +37 -2
configuration_dragon.py
CHANGED
|
@@ -92,6 +92,11 @@ class DragonConfig(PretrainedConfig):
|
|
| 92 |
|
| 93 |
def __init__(
|
| 94 |
self,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
mamba3_rope: bool = True,
|
| 96 |
mamba3_remove_BC_bias: bool = False,
|
| 97 |
mamba3_is_id_rms: bool = True,
|
|
@@ -192,6 +197,11 @@ class DragonConfig(PretrainedConfig):
|
|
| 192 |
mlp_linking=False,
|
| 193 |
**kwargs,
|
| 194 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
self.mamba3_rope = mamba3_rope
|
| 196 |
self.mamba3_remove_BC_bias = mamba3_remove_BC_bias
|
| 197 |
self.mamba3_is_id_rms = mamba3_is_id_rms
|
|
@@ -309,7 +319,7 @@ class DragonConfig(PretrainedConfig):
|
|
| 309 |
pad_token_id=pad_token_id,
|
| 310 |
bos_token_id=bos_token_id,
|
| 311 |
eos_token_id=eos_token_id,
|
| 312 |
-
tie_word_embeddings=
|
| 313 |
**kwargs,
|
| 314 |
)
|
| 315 |
# TODO: better way to handle those?
|
|
|
|
| 92 |
|
| 93 |
def __init__(
|
| 94 |
self,
|
| 95 |
+
tie_lm_head: bool = False,
|
| 96 |
+
mlp_type: str = "simple",
|
| 97 |
+
layer_norm_scaling: bool = False,
|
| 98 |
+
mamba_d_state: int = 128,
|
| 99 |
+
mamba_headdim: int = 64,
|
| 100 |
mamba3_rope: bool = True,
|
| 101 |
mamba3_remove_BC_bias: bool = False,
|
| 102 |
mamba3_is_id_rms: bool = True,
|
|
|
|
| 197 |
mlp_linking=False,
|
| 198 |
**kwargs,
|
| 199 |
):
|
| 200 |
+
self.tie_lm_head = tie_lm_head
|
| 201 |
+
self.mlp_type = mlp_type
|
| 202 |
+
self.layer_norm_scaling = layer_norm_scaling
|
| 203 |
+
self.mamba_d_state = mamba_d_state
|
| 204 |
+
self.mamba_headdim = mamba_headdim
|
| 205 |
self.mamba3_rope = mamba3_rope
|
| 206 |
self.mamba3_remove_BC_bias = mamba3_remove_BC_bias
|
| 207 |
self.mamba3_is_id_rms = mamba3_is_id_rms
|
|
|
|
| 319 |
pad_token_id=pad_token_id,
|
| 320 |
bos_token_id=bos_token_id,
|
| 321 |
eos_token_id=eos_token_id,
|
| 322 |
+
tie_word_embeddings=tie_lm_head,
|
| 323 |
**kwargs,
|
| 324 |
)
|
| 325 |
# TODO: better way to handle those?
|
modeling_dragon.py
CHANGED
|
@@ -19,6 +19,8 @@ from transformers.utils import ModelOutput, logging
|
|
| 19 |
|
| 20 |
from fla.ops.nsa.parallel import parallel_nsa
|
| 21 |
|
|
|
|
|
|
|
| 22 |
try:
|
| 23 |
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
|
| 24 |
except ImportError:
|
|
@@ -559,7 +561,7 @@ class DragonAttention(nn.Module):
|
|
| 559 |
self.num_attention_heads = config.num_attention_heads
|
| 560 |
self.num_key_value_heads = config.num_key_value_heads
|
| 561 |
self.hidden_size = config.hidden_size
|
| 562 |
-
self.head_dim = config.head_dim if config.head_dim else config.hidden_size * config.expand_factor // self.num_attention_heads
|
| 563 |
self.qk_norm = config.qk_norm
|
| 564 |
self.window_size = config.sliding_window_size
|
| 565 |
self.reuse_kv = reuse_kv
|
|
@@ -706,7 +708,7 @@ class DragonAttention(nn.Module):
|
|
| 706 |
if not self.reuse_kv:
|
| 707 |
key_states = apply_rotary_emb(key_states, cos, sin)
|
| 708 |
elif self.config.rope_type_local == "p-rope":
|
| 709 |
-
query_states = apply_p_rotary_emb(query_states, cos, sin)
|
| 710 |
if not self.reuse_kv:
|
| 711 |
key_states = apply_p_rotary_emb(key_states, cos, sin)
|
| 712 |
else:
|
|
@@ -3519,10 +3521,10 @@ class DragonMamba3(nn.Module):
|
|
| 3519 |
)
|
| 3520 |
|
| 3521 |
self.d_model = config.hidden_size
|
| 3522 |
-
self.d_state =
|
| 3523 |
self.conv_init = None
|
| 3524 |
self.expand = 2
|
| 3525 |
-
self.headdim =
|
| 3526 |
self.ngroups = config.mamba_ngroups
|
| 3527 |
self.activation = "swish"
|
| 3528 |
self.bias = False
|
|
@@ -3547,8 +3549,8 @@ class DragonMamba3(nn.Module):
|
|
| 3547 |
if config.mamba3_rope:
|
| 3548 |
self.rope_proj = DragonLinear(config, self.d_model, self.num_rope_angles, bias=False)
|
| 3549 |
|
| 3550 |
-
# Order: [x, B, C, dt]
|
| 3551 |
-
d_in_proj = self.d_inner + 2 * self.d_state * self.ngroups + self.nheads
|
| 3552 |
|
| 3553 |
if self.config.mamba3_is_A_dd:
|
| 3554 |
self.A_proj = DragonLinear(config, self.d_model, self.nheads, bias=False, dtype=torch.float32)
|
|
@@ -3609,10 +3611,11 @@ class DragonMamba3(nn.Module):
|
|
| 3609 |
**kwargs
|
| 3610 |
):
|
| 3611 |
# Apply in_proj
|
| 3612 |
-
|
| 3613 |
-
xBC, dd_dt = torch.split(
|
| 3614 |
-
|
| 3615 |
[
|
|
|
|
| 3616 |
self.d_inner + 2 * self.d_state * self.ngroups,
|
| 3617 |
self.nheads,
|
| 3618 |
],
|
|
@@ -3721,16 +3724,21 @@ class DragonMamba3(nn.Module):
|
|
| 3721 |
else:
|
| 3722 |
y = out
|
| 3723 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3724 |
return y, None, None
|
| 3725 |
|
| 3726 |
class DragonMamba2(nn.Module):
|
| 3727 |
def __init__(self, config: DragonConfig, layer_idx: Optional[int]):
|
| 3728 |
super().__init__()
|
|
|
|
| 3729 |
self.d_model = config.hidden_size
|
| 3730 |
-
self.d_state =
|
| 3731 |
self.expand = 2
|
| 3732 |
self.d_inner = self.expand * self.d_model
|
| 3733 |
-
self.headdim =
|
| 3734 |
self.ngroups = config.mamba_ngroups
|
| 3735 |
assert self.d_inner % self.headdim == 0
|
| 3736 |
self.nheads = self.d_inner // self.headdim
|
|
@@ -3740,16 +3748,17 @@ class DragonMamba2(nn.Module):
|
|
| 3740 |
d_in_proj = self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
|
| 3741 |
self.in_proj = DragonLinear(config, self.d_model, d_in_proj, bias=False)
|
| 3742 |
|
| 3743 |
-
|
| 3744 |
-
|
| 3745 |
-
|
| 3746 |
-
|
| 3747 |
-
|
| 3748 |
-
|
| 3749 |
-
|
| 3750 |
-
|
| 3751 |
-
|
| 3752 |
-
|
|
|
|
| 3753 |
|
| 3754 |
# Initialize log dt bias
|
| 3755 |
dt_min=0.001
|
|
@@ -3791,18 +3800,19 @@ class DragonMamba2(nn.Module):
|
|
| 3791 |
dt = F.softplus(dt + self.dt_bias) # (B, L, nheads)
|
| 3792 |
|
| 3793 |
# 1D Convolution
|
| 3794 |
-
if
|
| 3795 |
-
|
| 3796 |
-
|
| 3797 |
-
|
| 3798 |
-
|
| 3799 |
-
|
| 3800 |
-
|
| 3801 |
-
|
| 3802 |
-
|
| 3803 |
-
|
| 3804 |
-
|
| 3805 |
-
|
|
|
|
| 3806 |
|
| 3807 |
# Split into 3 main branches: X, B, C
|
| 3808 |
# These correspond to V, K, Q respectively in the SSM/attention duality
|
|
@@ -4193,7 +4203,7 @@ class DragonMonoBlock(GradientCheckpointingLayer):
|
|
| 4193 |
self.mixer = DragonMamba3(config, layer_idx=layer_idx)
|
| 4194 |
head_dim = self.mixer.headdim
|
| 4195 |
num_attention_heads = self.mixer.nheads
|
| 4196 |
-
use_gate =
|
| 4197 |
elif layer_type == '2':
|
| 4198 |
self.mixer = DragonMamba2(config, layer_idx=layer_idx)
|
| 4199 |
head_dim = self.mixer.headdim
|
|
@@ -4249,13 +4259,19 @@ class DragonMonoBlock(GradientCheckpointingLayer):
|
|
| 4249 |
self.input_norm = DragonNorm(config, config.hidden_size)
|
| 4250 |
self.postmixer_norm = DragonNorm(config, config.hidden_size)
|
| 4251 |
if not config.moe:
|
| 4252 |
-
|
|
|
|
|
|
|
|
|
|
| 4253 |
else:
|
| 4254 |
self.mlp = DragonMoE(config)
|
| 4255 |
global PREVIOUS_MLP
|
| 4256 |
PREVIOUS_MLP = self.mlp
|
| 4257 |
|
| 4258 |
-
|
|
|
|
|
|
|
|
|
|
| 4259 |
self.register_buffer("sqrt_tau", torch.sqrt(torch.tensor(self.config.uscaling_tau)) if config.use_uscaling else torch.tensor(1.0), persistent=False)
|
| 4260 |
self.register_buffer("sqrt_one_minus_tau", torch.sqrt(torch.tensor(1.0 - self.config.uscaling_tau)) if config.use_uscaling else torch.tensor(1.0), persistent=False)
|
| 4261 |
|
|
@@ -4575,6 +4591,8 @@ class DragonForCausalLM(DragonPreTrainedModel, GenerationMixin):
|
|
| 4575 |
self.vocab_size = config.vocab_size
|
| 4576 |
self.lm_head = DragonLinear(config, config.hidden_size, config.vocab_size, bias=False, alpha_fwd=1/config.hidden_size, alpha_bwd=1/math.sqrt(config.hidden_size))
|
| 4577 |
self.post_init()
|
|
|
|
|
|
|
| 4578 |
|
| 4579 |
def forward(
|
| 4580 |
self,
|
|
@@ -4654,6 +4672,13 @@ class DragonForCausalLM(DragonPreTrainedModel, GenerationMixin):
|
|
| 4654 |
past_key_values=outputs.past_key_values if not just_loss else None,
|
| 4655 |
hidden_states=outputs.hidden_states if not just_loss else None,
|
| 4656 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4657 |
DragonForCausalLM.register_for_auto_class("AutoModelForCausalLM")
|
| 4658 |
|
| 4659 |
__all__ = ["DragonModel", "DragonForCausalLM", "DragonPreTrainedModel"]
|
|
|
|
| 19 |
|
| 20 |
from fla.ops.nsa.parallel import parallel_nsa
|
| 21 |
|
| 22 |
+
from flash_attn.modules.mlp import GatedMlp
|
| 23 |
+
|
| 24 |
try:
|
| 25 |
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
|
| 26 |
except ImportError:
|
|
|
|
| 561 |
self.num_attention_heads = config.num_attention_heads
|
| 562 |
self.num_key_value_heads = config.num_key_value_heads
|
| 563 |
self.hidden_size = config.hidden_size
|
| 564 |
+
self.head_dim = config.head_dim # if config.head_dim else config.hidden_size * config.expand_factor // self.num_attention_heads
|
| 565 |
self.qk_norm = config.qk_norm
|
| 566 |
self.window_size = config.sliding_window_size
|
| 567 |
self.reuse_kv = reuse_kv
|
|
|
|
| 708 |
if not self.reuse_kv:
|
| 709 |
key_states = apply_rotary_emb(key_states, cos, sin)
|
| 710 |
elif self.config.rope_type_local == "p-rope":
|
| 711 |
+
query_states = apply_p_rotary_emb(query_states, cos, sin, p=0.5)
|
| 712 |
if not self.reuse_kv:
|
| 713 |
key_states = apply_p_rotary_emb(key_states, cos, sin)
|
| 714 |
else:
|
|
|
|
| 3521 |
)
|
| 3522 |
|
| 3523 |
self.d_model = config.hidden_size
|
| 3524 |
+
self.d_state = config.mamba_d_state
|
| 3525 |
self.conv_init = None
|
| 3526 |
self.expand = 2
|
| 3527 |
+
self.headdim = config.mamba_headdim
|
| 3528 |
self.ngroups = config.mamba_ngroups
|
| 3529 |
self.activation = "swish"
|
| 3530 |
self.bias = False
|
|
|
|
| 3549 |
if config.mamba3_rope:
|
| 3550 |
self.rope_proj = DragonLinear(config, self.d_model, self.num_rope_angles, bias=False)
|
| 3551 |
|
| 3552 |
+
# Order: [z, x, B, C, dt]
|
| 3553 |
+
d_in_proj = 2 * self.d_inner + 2 * self.d_state * self.ngroups + self.nheads
|
| 3554 |
|
| 3555 |
if self.config.mamba3_is_A_dd:
|
| 3556 |
self.A_proj = DragonLinear(config, self.d_model, self.nheads, bias=False, dtype=torch.float32)
|
|
|
|
| 3611 |
**kwargs
|
| 3612 |
):
|
| 3613 |
# Apply in_proj
|
| 3614 |
+
zxBCdt = self.in_proj(hidden_states)
|
| 3615 |
+
z, xBC, dd_dt = torch.split(
|
| 3616 |
+
zxBCdt,
|
| 3617 |
[
|
| 3618 |
+
self.d_inner,
|
| 3619 |
self.d_inner + 2 * self.d_state * self.ngroups,
|
| 3620 |
self.nheads,
|
| 3621 |
],
|
|
|
|
| 3724 |
else:
|
| 3725 |
y = out
|
| 3726 |
|
| 3727 |
+
y = rearrange(y, "b l h p -> b l (h p)")
|
| 3728 |
+
y = y*self.act(z)
|
| 3729 |
+
y = rearrange(y, "b l (h p) -> b l h p", h=self.nheads).to(x.dtype)
|
| 3730 |
+
|
| 3731 |
return y, None, None
|
| 3732 |
|
| 3733 |
class DragonMamba2(nn.Module):
|
| 3734 |
def __init__(self, config: DragonConfig, layer_idx: Optional[int]):
|
| 3735 |
super().__init__()
|
| 3736 |
+
self.config = config
|
| 3737 |
self.d_model = config.hidden_size
|
| 3738 |
+
self.d_state = config.mamba_d_state
|
| 3739 |
self.expand = 2
|
| 3740 |
self.d_inner = self.expand * self.d_model
|
| 3741 |
+
self.headdim = config.mamba_headdim
|
| 3742 |
self.ngroups = config.mamba_ngroups
|
| 3743 |
assert self.d_inner % self.headdim == 0
|
| 3744 |
self.nheads = self.d_inner // self.headdim
|
|
|
|
| 3748 |
d_in_proj = self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
|
| 3749 |
self.in_proj = DragonLinear(config, self.d_model, d_in_proj, bias=False)
|
| 3750 |
|
| 3751 |
+
if not self.config.mamba3_remove_conv:
|
| 3752 |
+
conv_dim = self.d_inner + 2 * self.ngroups * self.d_state
|
| 3753 |
+
self.conv1d = nn.Conv1d(
|
| 3754 |
+
in_channels=conv_dim,
|
| 3755 |
+
out_channels=conv_dim,
|
| 3756 |
+
bias=False,
|
| 3757 |
+
kernel_size=4,
|
| 3758 |
+
groups=conv_dim,
|
| 3759 |
+
padding=4-1,
|
| 3760 |
+
)
|
| 3761 |
+
self.act = nn.SiLU()
|
| 3762 |
|
| 3763 |
# Initialize log dt bias
|
| 3764 |
dt_min=0.001
|
|
|
|
| 3800 |
dt = F.softplus(dt + self.dt_bias) # (B, L, nheads)
|
| 3801 |
|
| 3802 |
# 1D Convolution
|
| 3803 |
+
if not self.config.mamba3_remove_conv:
|
| 3804 |
+
if causal_conv1d_fn is None:
|
| 3805 |
+
xBC = self.act(
|
| 3806 |
+
self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)
|
| 3807 |
+
) # (B, L, self.d_inner + 2 * ngroups * d_state)
|
| 3808 |
+
xBC = xBC[:, :seqlen, :]
|
| 3809 |
+
else:
|
| 3810 |
+
xBC = causal_conv1d_fn(
|
| 3811 |
+
x=xBC.transpose(1, 2),
|
| 3812 |
+
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
| 3813 |
+
bias=self.conv1d.bias,
|
| 3814 |
+
activation="swish",
|
| 3815 |
+
).transpose(1, 2)
|
| 3816 |
|
| 3817 |
# Split into 3 main branches: X, B, C
|
| 3818 |
# These correspond to V, K, Q respectively in the SSM/attention duality
|
|
|
|
| 4203 |
self.mixer = DragonMamba3(config, layer_idx=layer_idx)
|
| 4204 |
head_dim = self.mixer.headdim
|
| 4205 |
num_attention_heads = self.mixer.nheads
|
| 4206 |
+
use_gate = False
|
| 4207 |
elif layer_type == '2':
|
| 4208 |
self.mixer = DragonMamba2(config, layer_idx=layer_idx)
|
| 4209 |
head_dim = self.mixer.headdim
|
|
|
|
| 4259 |
self.input_norm = DragonNorm(config, config.hidden_size)
|
| 4260 |
self.postmixer_norm = DragonNorm(config, config.hidden_size)
|
| 4261 |
if not config.moe:
|
| 4262 |
+
if config.mlp_type == "simple":
|
| 4263 |
+
self.mlp = DragonMLP(config)
|
| 4264 |
+
elif config.mlp_type == "gated":
|
| 4265 |
+
self.mlp = GatedMlp(in_features=config.hidden_size, hidden_features=config.intermediate_size, out_features=config.hidden_size, activation=F.silu, bias1=False, bias2=False)
|
| 4266 |
else:
|
| 4267 |
self.mlp = DragonMoE(config)
|
| 4268 |
global PREVIOUS_MLP
|
| 4269 |
PREVIOUS_MLP = self.mlp
|
| 4270 |
|
| 4271 |
+
if config.use_uscaling or not config.layer_norm_scaling:
|
| 4272 |
+
self.register_buffer("lns", torch.tensor(1.0), persistent=False)
|
| 4273 |
+
else:
|
| 4274 |
+
self.register_buffer("lns", torch.tensor(1. / math.sqrt(layer_idx + (2 if config.old_lns else 1))), persistent=False)
|
| 4275 |
self.register_buffer("sqrt_tau", torch.sqrt(torch.tensor(self.config.uscaling_tau)) if config.use_uscaling else torch.tensor(1.0), persistent=False)
|
| 4276 |
self.register_buffer("sqrt_one_minus_tau", torch.sqrt(torch.tensor(1.0 - self.config.uscaling_tau)) if config.use_uscaling else torch.tensor(1.0), persistent=False)
|
| 4277 |
|
|
|
|
| 4591 |
self.vocab_size = config.vocab_size
|
| 4592 |
self.lm_head = DragonLinear(config, config.hidden_size, config.vocab_size, bias=False, alpha_fwd=1/config.hidden_size, alpha_bwd=1/math.sqrt(config.hidden_size))
|
| 4593 |
self.post_init()
|
| 4594 |
+
if config.tie_lm_head:
|
| 4595 |
+
self.lm_head.weight = self.model.embedding.weight
|
| 4596 |
|
| 4597 |
def forward(
|
| 4598 |
self,
|
|
|
|
| 4672 |
past_key_values=outputs.past_key_values if not just_loss else None,
|
| 4673 |
hidden_states=outputs.hidden_states if not just_loss else None,
|
| 4674 |
)
|
| 4675 |
+
|
| 4676 |
+
def get_output_embeddings(self):
|
| 4677 |
+
return self.lm_head
|
| 4678 |
+
|
| 4679 |
+
def set_output_embeddings(self, new_embeddings):
|
| 4680 |
+
self.lm_head = new_embeddings
|
| 4681 |
+
|
| 4682 |
DragonForCausalLM.register_for_auto_class("AutoModelForCausalLM")
|
| 4683 |
|
| 4684 |
__all__ = ["DragonModel", "DragonForCausalLM", "DragonPreTrainedModel"]
|
training_dragon.py
CHANGED
|
@@ -18,6 +18,7 @@ import torch.distributed as dist
|
|
| 18 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 19 |
|
| 20 |
import transformers
|
|
|
|
| 21 |
|
| 22 |
from .configuration_dragon import DragonConfig
|
| 23 |
from .modeling_dragon import DragonForCausalLM
|
|
@@ -59,6 +60,9 @@ class NanoArgs:
|
|
| 59 |
mixer_gn: bool = True
|
| 60 |
mlp_linking : bool = False
|
| 61 |
final_norm: bool = True
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
# MoE
|
| 64 |
moe: bool = False
|
|
@@ -105,6 +109,8 @@ class NanoArgs:
|
|
| 105 |
kda_num_v_heads: Optional[int] = None
|
| 106 |
mamba_mimo_dim: Optional[int] = 2
|
| 107 |
mamba_ngroups: Optional[int] = 1
|
|
|
|
|
|
|
| 108 |
mamba3_rope: bool = True
|
| 109 |
mamba3_remove_BC_bias: bool = False
|
| 110 |
mamba3_is_id_rms: bool = True
|
|
@@ -125,6 +131,7 @@ class NanoArgs:
|
|
| 125 |
adam_eps: float = 1e-8
|
| 126 |
warmup_iters: int = 200
|
| 127 |
warmdown_iters: int = 3000
|
|
|
|
| 128 |
grad_norm_clip: float = 1.0
|
| 129 |
uscaling_mult_embed: float = 0
|
| 130 |
uscaling_mult_scalar: float = 0
|
|
@@ -325,6 +332,15 @@ if args.intra_doc_masking:
|
|
| 325 |
args.device_batch_size = 1
|
| 326 |
print("!!! Forcing device_batch_size to 1 for intra-document masking !!!")
|
| 327 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
# set up DDP (distributed data parallel).
|
| 329 |
assert torch.cuda.is_available()
|
| 330 |
dist.init_process_group(
|
|
@@ -425,6 +441,11 @@ print0(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total}
|
|
| 425 |
|
| 426 |
# load model.
|
| 427 |
config_hf = DragonConfig(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 428 |
mamba3_rope=args.mamba3_rope,
|
| 429 |
mamba3_remove_BC_bias=args.mamba3_remove_BC_bias,
|
| 430 |
mamba3_is_id_rms=args.mamba3_is_id_rms,
|
|
@@ -600,8 +621,22 @@ def get_lr_wsd(num_iterations, warmup_iters, warmdown_iters, it):
|
|
| 600 |
else:
|
| 601 |
decay_ratio = (num_iterations - it) / warmdown_iters
|
| 602 |
return decay_ratio
|
| 603 |
-
|
| 604 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 605 |
|
| 606 |
# resume if necessary.
|
| 607 |
start_iter = 0
|
|
|
|
| 18 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 19 |
|
| 20 |
import transformers
|
| 21 |
+
from transformers import get_wsd_schedule
|
| 22 |
|
| 23 |
from .configuration_dragon import DragonConfig
|
| 24 |
from .modeling_dragon import DragonForCausalLM
|
|
|
|
| 60 |
mixer_gn: bool = True
|
| 61 |
mlp_linking : bool = False
|
| 62 |
final_norm: bool = True
|
| 63 |
+
layer_norm_scaling: bool = False # not read when using muP
|
| 64 |
+
mlp_type: str = "simple" # simple, gated
|
| 65 |
+
tie_lm_head: bool = False
|
| 66 |
|
| 67 |
# MoE
|
| 68 |
moe: bool = False
|
|
|
|
| 109 |
kda_num_v_heads: Optional[int] = None
|
| 110 |
mamba_mimo_dim: Optional[int] = 2
|
| 111 |
mamba_ngroups: Optional[int] = 1
|
| 112 |
+
mamba_d_state: int = 128
|
| 113 |
+
mamba_headdim: int = 64
|
| 114 |
mamba3_rope: bool = True
|
| 115 |
mamba3_remove_BC_bias: bool = False
|
| 116 |
mamba3_is_id_rms: bool = True
|
|
|
|
| 131 |
adam_eps: float = 1e-8
|
| 132 |
warmup_iters: int = 200
|
| 133 |
warmdown_iters: int = 3000
|
| 134 |
+
warmdown_type: str = "linear" # linear, cosine
|
| 135 |
grad_norm_clip: float = 1.0
|
| 136 |
uscaling_mult_embed: float = 0
|
| 137 |
uscaling_mult_scalar: float = 0
|
|
|
|
| 332 |
args.device_batch_size = 1
|
| 333 |
print("!!! Forcing device_batch_size to 1 for intra-document masking !!!")
|
| 334 |
|
| 335 |
+
if args.mlp_type == "gated":
|
| 336 |
+
if args.use_uscaling:
|
| 337 |
+
print("problem: gated MLP with muP is not supported, because we use FA backend")
|
| 338 |
+
exit(0)
|
| 339 |
+
|
| 340 |
+
if args.moe:
|
| 341 |
+
print("problem: gated MLP with MoE is not supported, because we use FA backend")
|
| 342 |
+
exit(0)
|
| 343 |
+
|
| 344 |
# set up DDP (distributed data parallel).
|
| 345 |
assert torch.cuda.is_available()
|
| 346 |
dist.init_process_group(
|
|
|
|
| 441 |
|
| 442 |
# load model.
|
| 443 |
config_hf = DragonConfig(
|
| 444 |
+
tie_lm_head=args.tie_lm_head,
|
| 445 |
+
mlp_type=args.mlp_type,
|
| 446 |
+
layer_norm_scaling=args.layer_norm_scaling,
|
| 447 |
+
mamba_d_state=args.mamba_d_state,
|
| 448 |
+
mamba_headdim=args.mamba_headdim,
|
| 449 |
mamba3_rope=args.mamba3_rope,
|
| 450 |
mamba3_remove_BC_bias=args.mamba3_remove_BC_bias,
|
| 451 |
mamba3_is_id_rms=args.mamba3_is_id_rms,
|
|
|
|
| 621 |
else:
|
| 622 |
decay_ratio = (num_iterations - it) / warmdown_iters
|
| 623 |
return decay_ratio
|
| 624 |
+
if args.warmdown_type == "linear":
|
| 625 |
+
sched_func = partial(get_lr_wsd, args.total_iterations, args.warmup_iters, args.warmdown_iters)
|
| 626 |
+
schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, sched_func) for opt in optimizers]
|
| 627 |
+
elif args.warmdown_type == "cosine":
|
| 628 |
+
sched = get_wsd_schedule(
|
| 629 |
+
optimizers[0],
|
| 630 |
+
num_warmup_steps=args.warmup_iters,
|
| 631 |
+
num_decay_steps=args.warmdown_iters,
|
| 632 |
+
num_training_steps=args.total_iterations,
|
| 633 |
+
min_lr_ratio=0.,
|
| 634 |
+
warmup_type='linear',
|
| 635 |
+
decay_type='cosine',
|
| 636 |
+
)
|
| 637 |
+
schedulers = [sched]
|
| 638 |
+
else:
|
| 639 |
+
raise ValueError(f"Unknown warmdown type: {args.warmdown_type}")
|
| 640 |
|
| 641 |
# resume if necessary.
|
| 642 |
start_iter = 0
|