File size: 4,997 Bytes
be761d6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | from transformers import PretrainedConfig
class MiniMambaConfig(PretrainedConfig):
"""
Minimal or extended config class for MiniMamba.
Inherits from HF's PretrainedConfig so we can do:
model = MiniMamba.from_pretrained(...)
and it will load this config automatically.
This config includes all fields from the provided config.json.
"""
model_type = "minimamba"
def __init__(
self,
# Standard HF fields:
model_type="minimamba",
_name_or_path="Mamba_500M",
architectures=["MiniMamba"],
# Key Mamba architecture hyperparameters:
dim=1024,
num_layers=54,
num_heads=32,
state_dim=128,
num_groups=1,
conv_size=4,
use_mem_eff_path=True,
dt_bias=True,
D_has_head_dim=True,
learnable_init_states=False,
ssm_chunk_size=256,
vocab_size=200064,
ffn_dim_multiplier=2.0,
multiple_of=256,
norm_eps=1e-5,
init_use_depth=False,
init_base_std=None,
init_std_factor="disabled",
hidden_act="silu",
bias=False,
# Torch / training:
torch_dtype="bfloat16",
seed=1337,
# The init_config block nested in JSON:
init_args=None, # e.g. dict with dt_max, dt_min, dt_init_floor, ...
# Additional Mamba or training fields:
seq_len=8192,
weight_tying=False,
dropout=0.0,
num_epochs=1,
global_bsz=524288,
bsz=1,
warmup_steps=1907,
eval_period=50,
save_period=500,
max_lr=3.0e-4,
min_lr=3.0e-5,
max_norm=1.0,
dilation=1,
fsdp=True,
ddp=False,
mixed_precision=True,
cpu_offload=False,
sharding_strategy="full_shard",
state_dict_type="full",
auto_wrap_policy="partial",
backward_prefetch="backward_pre",
forward_prefetch=False,
sync_module_states=True,
use_orig_params=True,
device_id=None,
precision=None, # e.g. dict with param="bfloat16", reduce="bfloat16", buffer="bfloat16"
fsdp_modules=None,# e.g. ["MambaBlock"]
use_activation_checkpointing=True,
use_attn=False,
softcap=50.0,
torch_compile=False,
# Now accept arbitrary additional kwargs, to remain flexible:
**kwargs
):
super().__init__(
# In HF, these common keys are typically passed to the parent:
model_type=model_type,
_name_or_path=_name_or_path,
architectures=architectures,
**kwargs
)
self.dim = dim
self.num_layers = num_layers
self.num_heads = num_heads
self.state_dim = state_dim
self.num_groups = num_groups
self.conv_size = conv_size
self.use_mem_eff_path = use_mem_eff_path
self.dt_bias = dt_bias
self.D_has_head_dim = D_has_head_dim
self.learnable_init_states = learnable_init_states
self.ssm_chunk_size = ssm_chunk_size
self.vocab_size = vocab_size
self.ffn_dim_multiplier = ffn_dim_multiplier
self.multiple_of = multiple_of
self.norm_eps = norm_eps
self.init_use_depth = init_use_depth
self.init_base_std = init_base_std
self.init_std_factor = init_std_factor
self.hidden_act = hidden_act
self.bias = bias
self.torch_dtype = torch_dtype
self.seed = seed
# Nested init_args (dt_max, dt_min, etc.).
# Could store it as a dict, or parse out the fields individually:
self.init_args = init_args or {}
self.seq_len = seq_len
self.weight_tying = weight_tying
self.dropout = dropout
self.num_epochs = num_epochs
self.global_bsz = global_bsz
self.bsz = bsz
self.warmup_steps = warmup_steps
self.eval_period = eval_period
self.save_period = save_period
self.max_lr = max_lr
self.min_lr = min_lr
self.max_norm = max_norm
self.dilation = dilation
self.fsdp = fsdp
self.ddp = ddp
self.mixed_precision = mixed_precision
self.cpu_offload = cpu_offload
self.sharding_strategy = sharding_strategy
self.state_dict_type = state_dict_type
self.auto_wrap_policy = auto_wrap_policy
self.backward_prefetch = backward_prefetch
self.forward_prefetch = forward_prefetch
self.sync_module_states = sync_module_states
self.use_orig_params = use_orig_params
self.device_id = device_id
self.precision = precision
self.fsdp_modules = fsdp_modules
self.use_activation_checkpointing = use_activation_checkpointing
self.use_attn = use_attn
self.softcap = softcap
self.torch_compile = torch_compile
# If you want to store any leftover kwargs:
self.extra_args = kwargs
|