File size: 4,997 Bytes
be761d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from transformers import PretrainedConfig

class MiniMambaConfig(PretrainedConfig):
    """
    Minimal or extended config class for MiniMamba.
    Inherits from HF's PretrainedConfig so we can do:
      model = MiniMamba.from_pretrained(...)
    and it will load this config automatically.

    This config includes all fields from the provided config.json.
    """
    model_type = "minimamba"

    def __init__(
        self,
        # Standard HF fields:
        model_type="minimamba",
        _name_or_path="Mamba_500M",
        architectures=["MiniMamba"],

        # Key Mamba architecture hyperparameters:
        dim=1024,
        num_layers=54,
        num_heads=32,
        state_dim=128,
        num_groups=1,
        conv_size=4,
        use_mem_eff_path=True,
        dt_bias=True,
        D_has_head_dim=True,
        learnable_init_states=False,
        ssm_chunk_size=256,
        vocab_size=200064,
        ffn_dim_multiplier=2.0,
        multiple_of=256,
        norm_eps=1e-5,
        init_use_depth=False,
        init_base_std=None,
        init_std_factor="disabled",
        hidden_act="silu",
        bias=False,

        # Torch / training:
        torch_dtype="bfloat16",
        seed=1337,

        # The init_config block nested in JSON:
        init_args=None,  # e.g. dict with dt_max, dt_min, dt_init_floor, ...
        
        # Additional Mamba or training fields:
        seq_len=8192,
        weight_tying=False,
        dropout=0.0,
        num_epochs=1,
        global_bsz=524288,
        bsz=1,
        warmup_steps=1907,
        eval_period=50,
        save_period=500,
        max_lr=3.0e-4,
        min_lr=3.0e-5,
        max_norm=1.0,
        dilation=1,
        fsdp=True,
        ddp=False,
        mixed_precision=True,
        cpu_offload=False,
        sharding_strategy="full_shard",
        state_dict_type="full",
        auto_wrap_policy="partial",
        backward_prefetch="backward_pre",
        forward_prefetch=False,
        sync_module_states=True,
        use_orig_params=True,
        device_id=None,
        precision=None,   # e.g. dict with param="bfloat16", reduce="bfloat16", buffer="bfloat16"
        fsdp_modules=None,# e.g. ["MambaBlock"]
        use_activation_checkpointing=True,
        use_attn=False,
        softcap=50.0,
        torch_compile=False,

        # Now accept arbitrary additional kwargs, to remain flexible:
        **kwargs
    ):
        super().__init__(
            # In HF, these common keys are typically passed to the parent:
            model_type=model_type,
            _name_or_path=_name_or_path,
            architectures=architectures,
            **kwargs
        )

        self.dim = dim
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.state_dim = state_dim
        self.num_groups = num_groups
        self.conv_size = conv_size
        self.use_mem_eff_path = use_mem_eff_path
        self.dt_bias = dt_bias
        self.D_has_head_dim = D_has_head_dim
        self.learnable_init_states = learnable_init_states
        self.ssm_chunk_size = ssm_chunk_size
        self.vocab_size = vocab_size
        self.ffn_dim_multiplier = ffn_dim_multiplier
        self.multiple_of = multiple_of
        self.norm_eps = norm_eps
        self.init_use_depth = init_use_depth
        self.init_base_std = init_base_std
        self.init_std_factor = init_std_factor
        self.hidden_act = hidden_act
        self.bias = bias

        self.torch_dtype = torch_dtype
        self.seed = seed

        # Nested init_args (dt_max, dt_min, etc.).
        # Could store it as a dict, or parse out the fields individually:
        self.init_args = init_args or {}

        self.seq_len = seq_len
        self.weight_tying = weight_tying
        self.dropout = dropout
        self.num_epochs = num_epochs
        self.global_bsz = global_bsz
        self.bsz = bsz
        self.warmup_steps = warmup_steps
        self.eval_period = eval_period
        self.save_period = save_period
        self.max_lr = max_lr
        self.min_lr = min_lr
        self.max_norm = max_norm
        self.dilation = dilation
        self.fsdp = fsdp
        self.ddp = ddp
        self.mixed_precision = mixed_precision
        self.cpu_offload = cpu_offload
        self.sharding_strategy = sharding_strategy
        self.state_dict_type = state_dict_type
        self.auto_wrap_policy = auto_wrap_policy
        self.backward_prefetch = backward_prefetch
        self.forward_prefetch = forward_prefetch
        self.sync_module_states = sync_module_states
        self.use_orig_params = use_orig_params
        self.device_id = device_id
        self.precision = precision
        self.fsdp_modules = fsdp_modules
        self.use_activation_checkpointing = use_activation_checkpointing
        self.use_attn = use_attn
        self.softcap = softcap
        self.torch_compile = torch_compile

        # If you want to store any leftover kwargs:
        self.extra_args = kwargs