zhiqiang2025 commited on
Commit
a103aea
·
verified ·
1 Parent(s): fdd0cca

Upload model

Browse files
config.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/datas/model/hf_model/MiniMind-449M-final",
3
+ "architectures": [
4
+ "MiniMindForCausalLM"
5
+ ],
6
+ "attention_dropout": 0.1,
7
+ "attention_layer_norm_with_affine": true,
8
+ "attn_implementation": null,
9
+ "auto_map": {
10
+ "AutoConfig": "configuration_minimind.ModelConfig",
11
+ "AutoModelForCausalLM": "model_minimind.MiniMindForCausalLM"
12
+ },
13
+ "aux_loss_alpha": 0.1,
14
+ "block_group_size": 1,
15
+ "block_type": "sequential",
16
+ "bos_token_id": 1,
17
+ "clip_qkv": null,
18
+ "dropout": 0.0,
19
+ "embedding_layer_norm": false,
20
+ "embedding_size": 50304,
21
+ "eos_token_id": 50256,
22
+ "flash_attn": true,
23
+ "hidden_act": "silu",
24
+ "hidden_size": 768,
25
+ "init_device": "cuda:0",
26
+ "intermediate_size": 3072,
27
+ "layer_norm_type": "default",
28
+ "layer_norm_with_affine": true,
29
+ "max_position_embeddings": 32768,
30
+ "max_seq_len": 512,
31
+ "model_type": "minimind",
32
+ "multiple_of": 64,
33
+ "n_routed_experts": 4,
34
+ "n_shared_experts": true,
35
+ "norm_topk_prob": true,
36
+ "num_attention_heads": 8,
37
+ "num_experts_per_tok": 2,
38
+ "num_hidden_layers": 48,
39
+ "num_key_value_heads": 2,
40
+ "pad_token_id": 50256,
41
+ "precision": null,
42
+ "rms_norm_eps": 1e-05,
43
+ "rope": false,
44
+ "rope_full_precision": true,
45
+ "rope_theta": 1000000.0,
46
+ "scoring_func": "softmax",
47
+ "seq_aux": true,
48
+ "torch_dtype": "float32",
49
+ "transformers_version": "4.48.0",
50
+ "use_moe": false,
51
+ "vocab_size": 50257
52
+ }
configuration_minimind.py ADDED
@@ -0,0 +1,1498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from copy import deepcopy
4
+ from dataclasses import asdict, dataclass, field
5
+ from glob import glob
6
+ from pathlib import Path
7
+ from typing import (
8
+ Any,
9
+ Dict,
10
+ Iterable,
11
+ List,
12
+ Optional,
13
+ Tuple,
14
+ Type,
15
+ TypeVar,
16
+ Union,
17
+ cast,
18
+ )
19
+ from transformers import PretrainedConfig
20
+ import numpy as np
21
+ import torch
22
+ from omegaconf import DictConfig, ListConfig
23
+ from omegaconf import OmegaConf as om
24
+ from omegaconf.errors import OmegaConfBaseException
25
+ from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
26
+ from os import PathLike
27
+ from typing import Union
28
+ from enum import Enum
29
+ # from .aliases import PathOrStr
30
+ from .exceptions import OLMoConfigurationError
31
+ # from .util import StrEnum
32
+
33
+
34
+ PathOrStr = Union[str, PathLike]
35
+ class StrEnum(str, Enum):
36
+ """
37
+ This is equivalent to Python's :class:`enum.StrEnum` since version 3.11.
38
+ We include this here for compatibility with older version of Python.
39
+ """
40
+
41
+ def __str__(self) -> str:
42
+ return self.value
43
+
44
+ def __repr__(self) -> str:
45
+ return f"'{str(self)}'"
46
+
47
+
48
+
49
+
50
+ __all__ = [
51
+ "ActivationType",
52
+ "ActivationCheckpointingStrategy",
53
+ "BlockType",
54
+ "LayerNormType",
55
+ "InitFnType",
56
+ "ModelConfig",
57
+ "OptimizerType",
58
+ "OptimizerConfig",
59
+ "SchedulerType",
60
+ "SchedulerConfig",
61
+ "DataConfig",
62
+ "InstanceFilterConfig",
63
+ "EvaluatorConfig",
64
+ "TokenizerConfig",
65
+ "TrainConfig",
66
+ "PaddingDirection",
67
+ "TruncationDirection",
68
+ "SpeedMonitorConfig",
69
+ "WandbConfig",
70
+ "CompilerConfig",
71
+ "WandbConfig",
72
+ "DDPConfig",
73
+ "DistributedStrategy",
74
+ "DDPGradSyncMode",
75
+ "FSDPPrecision",
76
+ "FSDPWrapStrategy",
77
+ "FSDPConfig",
78
+ "SingleGPUConfig",
79
+ "CheckpointType",
80
+ ]
81
+
82
+ C = TypeVar("C", bound="BaseConfig")
83
+ D = TypeVar("D", bound="DictConfig|ListConfig")
84
+
85
+
86
+ class BaseConfig(PretrainedConfig):
87
+ def __init__(self, **kwargs):
88
+ super().__init__(**kwargs)
89
+ for key, value in kwargs.items():
90
+ setattr(self, key, value)
91
+ # @classmethod
92
+ # def _register_resolvers(cls, validate_paths: bool = True):
93
+ # # Expands path globs into a list.
94
+ # def path_glob(*paths) -> List[str]:
95
+ # out = []
96
+ # for path in paths:
97
+ # matches = sorted(glob(path))
98
+ # if not matches and validate_paths:
99
+ # raise FileNotFoundError(f"{path} does not match any files or dirs")
100
+ # out.extend(matches)
101
+ # return out
102
+
103
+ # def is_url(path: PathOrStr) -> bool:
104
+ # return re.match(r"[a-z0-9]+://.*", str(path)) is not None
105
+ # # Chooses the first path in the arguments that exists.
106
+ # def path_choose(*paths) -> str:
107
+ # from .util import is_url
108
+
109
+ # for path in paths:
110
+ # if is_url(path) or Path(path).exists():
111
+ # return path
112
+ # if validate_paths:
113
+ # raise FileNotFoundError(", ".join(paths))
114
+ # else:
115
+ # return ""
116
+
117
+ # # Finds the latest checkpoint in a folder.
118
+ # def path_last_checkpoint(path) -> str:
119
+ # from .util import find_latest_checkpoint
120
+
121
+ # latest_checkpoint = find_latest_checkpoint(path)
122
+ # if latest_checkpoint is None:
123
+ # if validate_paths:
124
+ # raise FileNotFoundError(f"Could not find a latest checkpoint at {path}")
125
+ # else:
126
+ # return ""
127
+ # else:
128
+ # return str(latest_checkpoint)
129
+
130
+ # om.register_new_resolver("path.glob", path_glob, replace=True)
131
+ # om.register_new_resolver("path.choose", path_choose, replace=True)
132
+ # om.register_new_resolver("path.last_checkpoint", path_last_checkpoint, replace=True)
133
+
134
+ @classmethod
135
+ def update_legacy_settings(cls, config: D) -> D:
136
+ """
137
+ Update the legacy config settings whose schemas have undergone backwards-incompatible changes.
138
+ """
139
+ return config
140
+
141
+ @classmethod
142
+ def new(cls: Type[C], **kwargs) -> C:
143
+ cls._register_resolvers()
144
+ conf = om.structured(cls)
145
+ try:
146
+ if kwargs:
147
+ conf = om.merge(conf, kwargs)
148
+ return cast(C, om.to_object(conf))
149
+ except OmegaConfBaseException as e:
150
+ raise OLMoConfigurationError(str(e))
151
+
152
+ @classmethod
153
+ def load(
154
+ cls: Type[C],
155
+ path: PathOrStr,
156
+ overrides: Optional[List[str]] = None,
157
+ key: Optional[str] = None,
158
+ validate_paths: bool = True,
159
+ ) -> C:
160
+ """Load from a YAML file."""
161
+ #cls._register_resolvers(validate_paths=validate_paths)
162
+ schema = om.structured(cls)
163
+ try:
164
+ raw = om.load(str(path))
165
+ if key is not None:
166
+ raw = raw[key] # type: ignore
167
+ raw = cls.update_legacy_settings(raw)
168
+ conf = om.merge(schema, raw)
169
+ if overrides:
170
+ conf = om.merge(conf, om.from_dotlist(overrides))
171
+ return cast(C, om.to_object(conf))
172
+ except OmegaConfBaseException as e:
173
+ raise OLMoConfigurationError(str(e))
174
+
175
+ def save(self, path: PathOrStr) -> None:
176
+ """Save to a YAML file."""
177
+ om.save(config=self, f=str(path))
178
+
179
+ def asdict(self, exclude: Optional[Iterable[str]] = None) -> Dict[str, Any]:
180
+ out = asdict(self) # type: ignore
181
+ if exclude is not None:
182
+ for name in exclude:
183
+ if name in out:
184
+ del out[name]
185
+ return out
186
+
187
+ def update_with(self, **kwargs):
188
+ result = deepcopy(self)
189
+ for key, value in kwargs.items():
190
+ setattr(result, key, value)
191
+ return result
192
+
193
+
194
+ class LayerNormType(StrEnum):
195
+ default = "default"
196
+ """
197
+ The default LayerNorm implementation, equivalent to PyTorch's built-in version.
198
+ """
199
+
200
+ low_precision = "low_precision"
201
+ """
202
+ A low-precision version of the default LayerNorm.
203
+ """
204
+
205
+ rms = "rms"
206
+ """
207
+ An RMSNorm implementation. When using ``torch.compile`` this is
208
+ probably the fastest implementation.
209
+ """
210
+
211
+
212
+ class ActivationType(StrEnum):
213
+ gelu = "gelu"
214
+ relu = "relu"
215
+ swiglu = "swiglu"
216
+ silu = 'silu'
217
+
218
+
219
+ class BlockType(StrEnum):
220
+ sequential = "sequential"
221
+
222
+ llama = "llama"
223
+ """
224
+ A block similar to the sequential block with slightly different
225
+ implementations of operations like attention to imitate the behavior of Llama.
226
+ """
227
+
228
+
229
+ class InitFnType(StrEnum):
230
+ mitchell = "mitchell"
231
+ """
232
+ The strategy suggested to us by Mitchell Wortsman from UW.
233
+ This uses a truncated normal distribution with an adaptive standard deviation that depends
234
+ on the size of the weights as well as the depth of the layer.
235
+ """
236
+
237
+ normal = "normal"
238
+ """
239
+ All weights are initialized from the same normal distribution.
240
+ """
241
+
242
+ kaiming_normal = "kaiming_normal"
243
+ """
244
+ All weights are initialized with the Kaiming method from a normal distribution.
245
+ Note this currently won't work with FSDP.
246
+ """
247
+
248
+ fan_in = "fan_in"
249
+ """
250
+ "Fan-in variance scaling", i.e. normal with a standard deviation of ``1/sqrt(d_in)`` where ``d_in``
251
+ is the input dimensionality of the kernel.
252
+ """
253
+
254
+ full_megatron = "full_megatron"
255
+ """
256
+ This is what metaseq calls "full megatron init". It is the init used for Llama 2.
257
+ """
258
+
259
+ class ModelConfig(BaseConfig):
260
+ model_type = "minimind" # 重要:用于识别你的模型类型
261
+
262
+ def __init__(
263
+ self,
264
+ bos_token_id: int = 1,
265
+ eos_token_id: int = 50256,
266
+ hidden_size: int = 768,
267
+ num_attention_heads: int = 8,
268
+ num_key_value_heads: int = 2,
269
+ clip_qkv = None,
270
+ num_hidden_layers: int = 16,
271
+ intermediate_size: int = 3072,
272
+ multiple_of: int = 64,
273
+ hidden_act = "silu",
274
+ block_type = "sequential",
275
+ block_group_size: int = 1,
276
+ rope: bool = False,
277
+ rope_full_precision: bool = True,
278
+ rope_theta: float = 1e6,
279
+ flash_attn: bool = True,
280
+ attention_dropout: float = 0.1,
281
+ embedding_layer_norm: bool = False,
282
+ layer_norm_type = "default",
283
+ layer_norm_with_affine: bool = True,
284
+ rms_norm_eps: float = 1e-5,
285
+ attention_layer_norm_with_affine: bool = True,
286
+ max_position_embeddings: int = 32768,
287
+ max_seq_len: int = 1024,
288
+ vocab_size: int = 50257,
289
+ embedding_size: int = 50304,
290
+ dropout: float = 0.0,
291
+ pad_token_id: int = 50256,
292
+ init_device = None,
293
+ precision = None,
294
+ use_moe: bool = False,
295
+ num_experts_per_tok: int = 2,
296
+ n_routed_experts: int = 4,
297
+ n_shared_experts: bool = True,
298
+ scoring_func: str = "softmax",
299
+ aux_loss_alpha: float = 0.1,
300
+ seq_aux: bool = True,
301
+ norm_topk_prob: bool = True,
302
+ **kwargs, # ← 关键:捕获 AutoConfig / HF 额外传入字段
303
+ ):
304
+ super().__init__(
305
+ bos_token_id=bos_token_id,
306
+ eos_token_id=eos_token_id,
307
+ pad_token_id=pad_token_id,
308
+ **kwargs
309
+ )
310
+
311
+ # 注册所有自定义字段为成员变量
312
+ self.hidden_size = hidden_size
313
+ self.num_attention_heads = num_attention_heads
314
+ self.num_key_value_heads = num_key_value_heads
315
+ self.clip_qkv = clip_qkv
316
+ self.num_hidden_layers = num_hidden_layers
317
+ self.intermediate_size = intermediate_size
318
+ self.multiple_of = multiple_of
319
+ self.hidden_act = hidden_act
320
+ self.block_type = block_type
321
+ self.block_group_size = block_group_size
322
+ self.rope = rope
323
+ self.rope_full_precision = rope_full_precision
324
+ self.rope_theta = rope_theta
325
+ self.flash_attn = flash_attn
326
+ self.attention_dropout = attention_dropout
327
+ self.embedding_layer_norm = embedding_layer_norm
328
+ self.layer_norm_type = layer_norm_type
329
+ self.layer_norm_with_affine = layer_norm_with_affine
330
+ self.rms_norm_eps = rms_norm_eps
331
+ self.attention_layer_norm_with_affine = attention_layer_norm_with_affine
332
+ self.max_position_embeddings = max_position_embeddings
333
+ self.max_seq_len = max_seq_len
334
+ self.vocab_size = vocab_size
335
+ self.embedding_size = embedding_size
336
+ self.dropout = dropout
337
+ self.init_device = init_device
338
+ self.precision = precision
339
+ self.use_moe = use_moe
340
+ self.num_experts_per_tok = num_experts_per_tok
341
+ self.n_routed_experts = n_routed_experts
342
+ self.n_shared_experts = n_shared_experts
343
+ self.scoring_func = scoring_func
344
+ self.aux_loss_alpha = aux_loss_alpha
345
+ self.seq_aux = seq_aux
346
+ self.norm_topk_prob = norm_topk_prob
347
+ # @dataclass
348
+ # class ModelConfig(BaseConfig):
349
+ # bos_token_id: int = 1
350
+ # eos_token_id: int = 2
351
+
352
+ # hidden_size: int = 768
353
+ # num_attention_heads: int = 8
354
+ # num_key_value_heads: Optional[int] = 2
355
+ # # n_kv_heads: Optional[int] = None
356
+ # """
357
+ # The number of heads to use for keys and values. Defaults to `n_heads`.
358
+ # Set this to ``None`` or ``n_heads`` for normal multi-head attention.
359
+ # Set this to 1 for multi-query attention.
360
+ # Set it to some in-between value for Llama2-style grouped query attention.
361
+ # """
362
+
363
+ # clip_qkv: Optional[float] = None
364
+ # """
365
+ # Clip QKV to this value when set.
366
+ # """
367
+
368
+ # num_hidden_layers: int = 16
369
+ # """
370
+ # The number of layers/blocks.
371
+ # """
372
+
373
+ # # mlp_ratio: int = 4
374
+ # """
375
+ # The ratio of the inner MLP dimensionality to ``d_model``.
376
+ # This is only used when ``mlp_hidden_size`` is not set.
377
+ # """
378
+
379
+ # #hidden_dim: Optional[int] = None
380
+ # intermediate_size: Optional[int] = 3072
381
+ # """
382
+ # Set the exact hidden size for the MLP. Otherwise the inner MLP hidden size will be set to `mlp_ratio * d_model`.
383
+ # """
384
+ # multiple_of: int = 64
385
+ # """
386
+ # refer to MiniMind architecture
387
+ # """
388
+ # hidden_act: ActivationType = ActivationType.silu
389
+ # #activation_type: ActivationType = ActivationType.swiglu
390
+ # """
391
+ # The activation function to use within the MLP layers.
392
+ # """
393
+
394
+ # block_type: BlockType = BlockType.sequential
395
+ # """
396
+ # The transformer block implementation.
397
+ # """
398
+
399
+ # block_group_size: int = 1
400
+ # """
401
+ # The number of blocks to group together into a single parent block.
402
+ # This has no affect on the number of parameters in the model and is only used to wrap groups
403
+ # of blocks together with a single FSDP wrapper during training.
404
+ # """
405
+
406
+ # #alibi: bool = False
407
+ # """
408
+ # If ``True``, use ALiBi embeddings. Mutually exclusive with ``rope``.
409
+ # """
410
+
411
+ # #alibi_bias_max: float = 8.0
412
+ # """
413
+ # Maximum absolute value of ALiBi bias.
414
+ # """
415
+
416
+ # rope: bool = False
417
+ # """
418
+ # Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``.
419
+ # """
420
+
421
+ # rope_full_precision: bool = True
422
+ # """
423
+ # If ``True``, apply RoPE embeddings at full precision regardless of the input type. Otherwise,
424
+ # apply RoPE at the precision of the input.
425
+ # """
426
+
427
+ # rope_theta: float = 1e6
428
+ # """
429
+ # The theta setting for RoPE.
430
+ # """
431
+
432
+ # flash_attn: bool = True
433
+ # """
434
+ # If ``True``, use ``FlashAttention``.
435
+ # """
436
+
437
+ # attention_dropout: float = 0.1
438
+ # """
439
+ # The dropout probability within the attention modules.
440
+ # """
441
+
442
+ # #multi_query_attention: Optional[bool] = None
443
+ # """
444
+ # Deprecated. Use n_kv_heads instead.
445
+ # """
446
+
447
+ # #attention_layer_norm: bool = False
448
+ # """
449
+ # Apply layer norm to the keys and queries within the attention mechanism.
450
+ # This can help stabilize training.
451
+ # """
452
+
453
+ # #residual_dropout: float = 0.1
454
+ # """
455
+ # The dropout probability for the MLP and attention output within each block.
456
+ # """
457
+
458
+ # #embedding_dropout: float = 0.1
459
+ # """
460
+ # The dropout probability for embeddings.
461
+ # """
462
+
463
+ # embedding_layer_norm: bool = False
464
+ # """
465
+ # Apply layer norm directly to the embeddings.
466
+ # """
467
+
468
+ # layer_norm_type: LayerNormType = LayerNormType.default
469
+ # """
470
+ # The layernorm implementation to use.
471
+ # """
472
+
473
+ # layer_norm_with_affine: bool = True
474
+ # """
475
+ # Whether to include bias and weight parameters for the layer norms.
476
+ # This only affects layer norms that are immediately followed by a linear layer in the forward pass,
477
+ # so everything except QK-norms. To turn off affines for QK norms as well, set :attr:`attention_layer_norm_with_affine`
478
+ # to ``False``.
479
+ # """
480
+ # rms_norm_eps: float = 1e-05
481
+
482
+ # # norm_eps: float = 1e-05
483
+
484
+ # attention_layer_norm_with_affine: bool = True
485
+ # """
486
+ # Toggle affine transform for the QK norms.
487
+ # """
488
+ # max_position_embeddings: int = 32768
489
+
490
+ # max_seq_len: int = 1024
491
+ # """
492
+ # The maximum input sequence length supported by the model.
493
+ # """
494
+
495
+ # #include_bias: bool = True
496
+ # """
497
+ # Whether or not to include bias parameters in linear layers.
498
+ # In PaLM, they got rid of all bias terms because they found that large
499
+ # models tend to have near 0 bias terms anyway.
500
+ # """
501
+
502
+ # #bias_for_layer_norm: Optional[bool] = None
503
+ # """
504
+ # Whether or not to include bias parameters in layer norm.
505
+ # This is separate from the include_bias parameter, because of a ROCm crash when biases are disabled in
506
+ # layer norm.
507
+ # When this is None (the default), it inherits the setting from include_bias.
508
+ # """
509
+
510
+ # #scale_logits: bool = False
511
+ # """
512
+ # If ``True``, scale the output logits by ``1 / sqrt(d_model)``.
513
+ # """
514
+
515
+ # vocab_size: int = 50257
516
+ # """
517
+ # Vocabulary size of the model.
518
+ # """
519
+
520
+ # embedding_size: Optional[int] = 50304
521
+ # """
522
+ # The number of embeddings, i.e. the number of tokens. If set to ``None`` it will default
523
+ # to ``vocab_size``. If ``vocab_size`` is not a multiple of 128, setting this to the
524
+ # next multiple of 128 that's greater than ``vocab_size`` can improve throughput
525
+ # substantially.
526
+ # """
527
+ # dropout: float = 0.0
528
+
529
+ # #weight_tying: bool = True
530
+ # """
531
+ # Whether to tie output linear weights to the input embedding.
532
+ # """
533
+
534
+ # eos_token_id: int = 50256
535
+ # """
536
+ # The ID of the end-of-sentence special token.
537
+ # """
538
+
539
+ # pad_token_id: int = 50256
540
+ # """
541
+ # The ID of the token to use for padding. Defaults to the ID of the EOS token.
542
+ # """
543
+
544
+ # init_device: Optional[str] = None
545
+ # """
546
+ # The torch device to use when initializing the model parameters, e.g. "cpu", "cuda:0", "meta".
547
+ # """
548
+
549
+ # #init_fn: InitFnType = InitFnType.normal
550
+ # """
551
+ # The weight initialization strategy.
552
+ # """
553
+
554
+ # #init_std: float = 0.02
555
+ # """
556
+ # The standard deviation to use when initializing weights with a "fixed distribution" ``init_fn``, such
557
+ # as "normal".
558
+ # """
559
+
560
+ # #init_cutoff_factor: Optional[float] = None
561
+ # """
562
+ # A positive factor used to scale the cutoff values when initializing weights with a "fixed distribution" ``init_fn``, such
563
+ # as "normal". Setting this to None means values are not cutoff.
564
+ # """
565
+
566
+ # precision: Optional[str] = None
567
+ # """
568
+ # Precision used to train/evaluate with. You shouldn't set this directly.
569
+ # See :data:`TrainConfig.precision` instead.
570
+ # """
571
+
572
+ # #scale_emb_init: bool = False
573
+ # """
574
+ # If ``True``, embeddings are scaled up by ``sqrt(d_model)`` during initialization.
575
+ # Currently this is only used with `full_megatron` init when ``emb_init_std`` is unset.
576
+ # """
577
+
578
+ # #emb_init_std: Optional[float] = None
579
+ # """
580
+ # Override the standard deviation to use when initializing the embedding weights.
581
+ # """
582
+
583
+ # #norm_after: bool = False
584
+ # """
585
+ # Apply norm after the attention/feedforward layers rather than before, as introduced in the Swin transformer paper (Liu et al).
586
+ # """
587
+
588
+
589
+ # use_moe: bool = False
590
+ # num_experts_per_tok: int = 2
591
+ # n_routed_experts: int = 4
592
+ # n_shared_experts: bool = True
593
+ # scoring_func: str = "softmax"
594
+
595
+
596
+ # aux_loss_alpha: float = 0.1
597
+ # seq_aux: bool = True
598
+ # norm_topk_prob: bool = True
599
+
600
+ # @property
601
+ # def effective_n_kv_heads(self) -> int:
602
+ # if self.n_kv_heads is None:
603
+ # if self.multi_query_attention is True:
604
+ # return 1
605
+ # else:
606
+ # return self.n_heads
607
+ # else:
608
+ # if self.multi_query_attention is None:
609
+ # return self.n_kv_heads
610
+ # if self.multi_query_attention:
611
+ # n_kv_heads_should_be = 1
612
+ # else:
613
+ # n_kv_heads_should_be = self.n_heads
614
+ # if self.n_kv_heads == n_kv_heads_should_be:
615
+ # return n_kv_heads_should_be
616
+ # else:
617
+ # raise OLMoConfigurationError(
618
+ # "You can't set `multi_query_attention` and `n_kv_heads` at the same time."
619
+ # )
620
+
621
+
622
+ class OptimizerType(StrEnum):
623
+ lionw = "lionw"
624
+ adamw = "adamw"
625
+
626
+
627
+ @dataclass
628
+ class OptimizerConfig(BaseConfig):
629
+ name: OptimizerType = OptimizerType.lionw
630
+ learning_rate: float = 1.0e-4
631
+ weight_decay: float = 0.01
632
+ betas: Tuple[float, float] = (0.9, 0.95)
633
+ eps: float = 1e-5
634
+
635
+ # no_decay_norm_and_bias: Optional[bool] = None
636
+ """
637
+ Deprecated. Use ``decay_norm_and_bias`` and ``decay_embeddings`` instead.
638
+ """
639
+
640
+ # selective_updates: bool = False
641
+ """
642
+ If ``True``, optimizer parameter and state updates are skipped when the corresponding gradient is 0.
643
+ """
644
+
645
+ # decay_norm_and_bias: bool = False
646
+ # decay_embeddings: bool = False
647
+ # metrics_log_interval: Optional[int] = None
648
+ """
649
+ The interval with which to collect and log detailed parameter-specific metrics.
650
+ This only applies when logging to W&B, since these metrics won't be logged to the console.
651
+ If not set, defaults to the wandb `log_interval`.
652
+ """
653
+
654
+ # record_update_metrics: bool = False
655
+ """
656
+ Whether to record detailed metrics about the optimizer's parameter updates, like the norm and max
657
+ of the update with AdamW.
658
+ """
659
+
660
+ def __post_init__(self):
661
+ self.betas = tuple(self.betas) # type: ignore[assignment]
662
+
663
+ @classmethod
664
+ def update_legacy_settings(cls, config: D) -> D:
665
+ new_config = config.copy()
666
+ if om.is_dict(new_config):
667
+ assert isinstance(new_config, DictConfig)
668
+
669
+ if hasattr(new_config, "name") and new_config.name == "decoupled_lionw":
670
+ new_config.name = "lionw"
671
+ if hasattr(new_config, "eps"):
672
+ del new_config.eps
673
+
674
+ return new_config
675
+
676
+
677
+ class SchedulerType(StrEnum):
678
+ cosine_with_warmup = "cosine_with_warmup"
679
+ cosine_annealing = "cosine_annealing"
680
+ step_law_with_warmup = "step_law_with_warmup"
681
+ linear_with_warmup = "linear_with_warmup"
682
+ inverse_sqrt_with_warmup = "inverse_sqrt_with_warmup"
683
+ max_scheduler = "max_scheduler"
684
+ constant = "constant"
685
+ cosine_linear_envelope = "cosine_linear_envelope"
686
+ constant_with_warmup = "constant_with_warmup"
687
+
688
+
689
+ class SchedulerUnits(StrEnum):
690
+ steps = "steps"
691
+ tokens = "tokens"
692
+
693
+
694
+ @dataclass
695
+ class SchedulerConfig(BaseConfig):
696
+ name: SchedulerType = SchedulerType.cosine_with_warmup
697
+ # units: SchedulerUnits = SchedulerUnits.steps
698
+ t_warmup: Union[int, float] = 2000
699
+ # t_max: Optional[Union[int, float]] = None
700
+ # alpha_f: float = 0.1
701
+
702
+ # grad_clip_warmup_steps: Optional[Union[int, float]] = None
703
+ """
704
+ The warmup period for which the max grad norm (or norm ratio) will be set to its
705
+ warmup value of `max_grad_norm * grad_clip_warmup_factor`.
706
+ # """
707
+
708
+ # grad_clip_warmup_factor: Optional[float] = None
709
+ """
710
+ The ratio of the max allowed gradient norm (or norm ratio) for clipping during the warmup period
711
+ vs after the warmup period.
712
+ """
713
+
714
+ # warmup_min_lr: Optional[float] = None
715
+ """
716
+ The starting LR during the warmup period. If not set this defaults to 10% of
717
+ the target LR.
718
+ """
719
+
720
+
721
+ class PaddingDirection(StrEnum):
722
+ right = "right"
723
+ left = "left"
724
+
725
+
726
+ @dataclass
727
+ class InstanceFilterConfig(BaseConfig):
728
+ repetition_max_period: int = 13
729
+ repetition_min_period: int = 1
730
+ repetition_max_count: int = 32
731
+
732
+
733
+ @dataclass
734
+ class DataConfig(BaseConfig):
735
+ data_name: Optional[str] = None
736
+ paths: Optional[str] = None
737
+ memmap_dtype: str = "uint16"
738
+ datasets: Optional[Dict[str, List[str]]] = None
739
+ # label_mask_paths: Optional[List[str]] = None
740
+ # pad_direction: PaddingDirection = PaddingDirection.right
741
+ # generate_attention_mask: bool = False
742
+ # generate_doc_lengths: bool = False
743
+ num_workers: int = 0
744
+ # drop_last: bool = False
745
+ # pin_memory: bool = False
746
+ # prefetch_factor: Optional[int] = None
747
+ # persistent_workers: bool = False
748
+ # timeout: int = 0
749
+ # seed: Optional[int] = None
750
+ # instance_filter: Optional[InstanceFilterConfig] = None
751
+ # custom_dataset: Optional[CustomDatasetConfig] = None
752
+
753
+ @property
754
+ def effective_memmap_dtype(self):
755
+ try:
756
+ # getattr will check this is part of numpy module, while np.dtype will check
757
+ # if this is a valid numpy dtype.
758
+ np.dtype(dtype := getattr(np, self.memmap_dtype))
759
+ except (AttributeError, TypeError) as e:
760
+ raise TypeError(f"Value {self.memmap_dtype} is not a valid numpy type") from e
761
+ return dtype
762
+
763
+
764
+ @dataclass
765
+ class CustomDatasetCollatorConfig(BaseConfig):
766
+ input_id_field: str = "input_ids" #: The field in the dataset items that contains the input token IDs.
767
+ attention_mask_field: Optional[str] = None #: The field in the dataset items that contains the attention mask.
768
+ attention_bias_field: Optional[str] = None #: The field in the dataset items that contains the attention bias.
769
+ label_mask_field: Optional[str] = None #: The field in the dataset items that contains the label mask.
770
+ index_field: Optional[str] = None #: The field in the dataset items that contains the index of the item.
771
+ instance_mask_field: Optional[str] = None #: The field in the dataset items that contains the instance mask.
772
+ doc_lens_field: Optional[str] = None #: The field in the dataset items that contains the document lengths.
773
+ metadata_field: Optional[str] = None #: The field in the dataset items that contains the metadata.
774
+
775
+
776
+ @dataclass
777
+ class CustomDatasetConfig(BaseConfig):
778
+ name: str #: The name of the custom dataset class or function that will be used to load the dataset.
779
+ module: Optional[
780
+ str
781
+ ] = None #: The module where the custom dataset class is defined. If not set, the module will be inferred from the class name.
782
+ args: Optional[Dict[str, Any]] = None #: The arguments to pass to the custom dataset class or function
783
+ collate_fn: Optional[
784
+ str
785
+ ] = None #: The name of the collate function to use for the custom dataset. Assumes the collate function is defined in the same module as the custom dataset class unless specified otherwise using the full object path.
786
+ token_field: Optional[str] = None #: The field in the dataset items that contains the tokenized text.
787
+ collate_config: Optional[CustomDatasetCollatorConfig] = field(
788
+ default_factory=CustomDatasetCollatorConfig
789
+ ) #: The configuration for the collate function to use for the custom dataset.
790
+
791
+
792
+ class EvaluatorType(StrEnum):
793
+ downstream = "downstream"
794
+ lm = "lm"
795
+
796
+
797
+ @dataclass
798
+ class EvaluatorConfig(BaseConfig):
799
+ label: str
800
+ type: EvaluatorType = EvaluatorType.lm
801
+ data: DataConfig = field(default_factory=DataConfig)
802
+ device_eval_batch_size: Optional[int] = None
803
+ subset_num_batches: Optional[int] = None
804
+
805
+
806
+ class TruncationDirection(StrEnum):
807
+ right = "right"
808
+ left = "left"
809
+
810
+
811
+ @dataclass
812
+ class TokenizerConfig(BaseConfig):
813
+ identifier: str = "gpt2"
814
+ # truncate_direction: TruncationDirection = TruncationDirection.right
815
+
816
+
817
+ @dataclass
818
+ class WandbConfig(BaseConfig):
819
+ project: Optional[str] = None
820
+ entity: Optional[str] = "ai2-llm"
821
+ group: Optional[str] = None
822
+ name: Optional[str] = None
823
+ tags: Optional[List[str]] = field(default_factory=lambda: ["watching"])
824
+ log_artifacts: bool = False
825
+ rank_zero_only: bool = True
826
+ log_interval: int = 1
827
+
828
+
829
+ @dataclass
830
+ class SpeedMonitorConfig(BaseConfig):
831
+ window_size: int = 100
832
+ gpu_flops_available: Optional[Union[float, int]] = None
833
+
834
+
835
+ @dataclass
836
+ class CompilerConfig(BaseConfig):
837
+ mode: Optional[str] = None
838
+ """
839
+ The mode to compile the model in. At the moment this can be "default",
840
+ "reduce-overhead" (useful for smaller models/batches), or "max-autotune"
841
+ (the fastest for larger models, but takes a long time to compile).
842
+ """
843
+
844
+ fullgraph: bool = False
845
+ """
846
+ Whether it is OK to break model into several subgraphs when compiling.
847
+ Note that this is not compatible with FSDP.
848
+ """
849
+
850
+ backend: str = "inductor"
851
+ """
852
+ The backend to use.
853
+ """
854
+
855
+ dynamic: Optional[bool] = None
856
+ """
857
+ From the torch docs:
858
+
859
+ Use dynamic shape tracing. When this is True, we will up-front attempt to generate a kernel that is as dynamic
860
+ as possible to avoid recompilations when sizes change. This may not always work as some
861
+ operations/optimizations will force specialization; use TORCH_LOGS=dynamic to debug overspecialization. When
862
+ this is False, we will NEVER generate dynamic kernels, we will always specialize. By default (None), we
863
+ automatically detect if dynamism has occurred and compile a more dynamic kernel upon recompile.
864
+ """
865
+
866
+
867
+ class DistributedStrategy(StrEnum):
868
+ ddp = "ddp"
869
+ """
870
+ Wrap OLMo in torch.nn.parallel.DistributedDataParallel to train across ranks.
871
+ """
872
+
873
+ fsdp = "fsdp"
874
+ """
875
+ Wrap OLMo in torch.distributed.fsdp.FullyShardedDataParallel to train across ranks.
876
+ """
877
+
878
+ single = "single"
879
+ """
880
+ Train on a single device, i.e., do not distribute training. For development and debugging.
881
+ """
882
+
883
+
884
+ class DDPGradSyncMode(StrEnum):
885
+ batch = "batch"
886
+ """
887
+ Synchronize gradients after computation at each bucket only at the last micro-batch.
888
+ This is slightly faster than gradient syncs across each micro-batch but will consume more memory.
889
+ Can use this mode only when `find_unused_params` is set to False.
890
+ """
891
+
892
+ micro_batch = "micro_batch"
893
+ """
894
+ Synchronize gradients after computation at each bucket per micro-batch.
895
+ This will be slightly slower than gradient sync at the last micro-batch, but will consume less memory.
896
+ Can use this mode with both option of `find_unused_params` but specifically recommended to use with `find_unused_params`
897
+ set to True, to prevent errors.
898
+ """
899
+
900
+
901
+ @dataclass
902
+ class DDPConfig(BaseConfig):
903
+ grad_sync_mode: DDPGradSyncMode = DDPGradSyncMode.batch
904
+ """
905
+ Gradient sync mode for DDP
906
+
907
+ Note: When `find_unused_params` is set, set `grad_sync_mode` to `micro_batch` as different micro-batches might activate
908
+ different parts of the model, ex- MOEs.
909
+ """
910
+
911
+ find_unused_params: bool = False
912
+ """
913
+ (from torch documentation)
914
+
915
+ This mode allows running backward on a subgraph of the model, and DDP finds out which parameters
916
+ are involved in the backward pass by traversing the autograd graph from the model output and marking
917
+ all unused parameters as ready for reduction. Note that traversing the autograd graph introduces extra overheads,
918
+ so applications should only set find_unused_parameters to True when necessary.
919
+ """
920
+
921
+
922
+ class FSDPWrapStrategy(StrEnum):
923
+ by_block = "by_block"
924
+ """
925
+ Wrap each OLMo block with its own FSDP instance.
926
+ """
927
+
928
+ by_block_and_size = "by_block_and_size"
929
+ """
930
+ Like 'by_block' but `wte` and `ff_out` will be wrapped separately as well.
931
+ """
932
+
933
+ by_block_group = "by_block_group"
934
+ """
935
+ Wrap each block group together into its own FSDP instance.
936
+ This requires :attr:`~ModelConfig.block_group_size` to be bigger than 1.
937
+ """
938
+
939
+ by_block_group_and_size = "by_block_group_and_size"
940
+ """
941
+ Like 'by_block_group' but `wte` and `ff_out` will be wrapped separately as well.
942
+ """
943
+
944
+ size_based = "size_based"
945
+ """
946
+ Used PyTorch's default size-based auto wrap policy.
947
+ """
948
+
949
+ one_in_two = "one_in_two"
950
+ one_in_three = "one_in_three"
951
+ one_in_four = "one_in_four"
952
+ one_in_five = "one_in_five"
953
+
954
+
955
+ class FSDPPrecision(StrEnum):
956
+ pure = "pure"
957
+ """
958
+ Equivalent to :class:`torch.distributed.fsdp.MixedPrecision` with ``param_dtype``, ``reduce_dtype``,
959
+ and ``buffer_dtype`` all set to the autocast precision data type.
960
+ """
961
+
962
+ mixed = "mixed"
963
+ """
964
+ Equivalent to :class:`torch.distributed.fsdp.MixedPrecision` with ``param_dtype``, and ``buffer_dtype``
965
+ set to the autocast precision data type, while ``reduce_dtype`` is set to fp32.
966
+ """
967
+
968
+
969
+ @dataclass
970
+ class FSDPConfig(BaseConfig):
971
+ use_orig_params: bool = True
972
+ """
973
+ This must be ``True`` if using ``compile`` or you want to track the parameter norm during training.
974
+ """
975
+
976
+ sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD
977
+
978
+ wrapping_strategy: Optional[FSDPWrapStrategy] = None
979
+ """
980
+ The wrapping strategy to use. If ``None``, the default, the model is wrapped with a single top-level
981
+ FSDP instance.
982
+ """
983
+
984
+ precision: Optional[FSDPPrecision] = FSDPPrecision.pure
985
+
986
+ hybrid_sharding_num_model_replicas: Optional[int] = None
987
+ """
988
+ The number of model instances, when using a hybrid sharding strategy.
989
+ If not ``None``, this must divide the total number of nodes. If ``None``, the default,
990
+ a model instance is used per node (as determined by ``get_world_size() // get_local_world_size()``).
991
+ PyTorch's default HSDP behavior matches this default behavior.
992
+ """
993
+
994
+
995
+ @dataclass
996
+ class SingleGPUConfig(BaseConfig):
997
+ device: str = "auto"
998
+ """
999
+ Device to run single-device training.
1000
+ """
1001
+
1002
+ def get_device(self):
1003
+ if self.device == "auto":
1004
+ if torch.backends.mps.is_available():
1005
+ return torch.device("mps")
1006
+ elif torch.cuda.is_available():
1007
+ return torch.device("cuda")
1008
+ else:
1009
+ return torch.device("cpu")
1010
+ elif self.device == "mps" and not torch.backends.mps.is_available():
1011
+ raise OLMoConfigurationError("MPS not available.")
1012
+ elif self.device == "cuda" and not torch.cuda.is_available():
1013
+ raise OLMoConfigurationError("CUDA not available.")
1014
+ else:
1015
+ return torch.device(self.device)
1016
+
1017
+
1018
+ class CheckpointType(StrEnum):
1019
+ sharded = "sharded"
1020
+ unsharded = "unsharded"
1021
+ sharded_ephemeral = "sharded_ephemeral"
1022
+
1023
+
1024
+ class ShardedCheckpointerType(StrEnum):
1025
+ torch_new = "torch_new"
1026
+ torch_legacy = "torch_legacy"
1027
+ local = "local"
1028
+ olmo_core = "olmo_core"
1029
+
1030
+
1031
+ class ActivationCheckpointingStrategy(StrEnum):
1032
+ whole_layer = "whole_layer"
1033
+ """
1034
+ Checkpoint every transformer layer.
1035
+ """
1036
+
1037
+ one_in_two = "one_in_two"
1038
+ """
1039
+ Checkpoint one in two transformer layers.
1040
+ """
1041
+
1042
+ one_in_three = "one_in_three"
1043
+ """
1044
+ Checkpoint one in three transformer layers.
1045
+ """
1046
+
1047
+ one_in_four = "one_in_four"
1048
+ """
1049
+ Checkpoint one in four transformer layers.
1050
+ """
1051
+
1052
+ one_in_eight = "one_in_eight"
1053
+ """
1054
+ Checkpoint one in eight transformer layers.
1055
+ """
1056
+
1057
+ two_in_three = "two_in_three"
1058
+ """
1059
+ Checkpoint two out of every three transformer layers.
1060
+ """
1061
+
1062
+ three_in_four = "three_in_four"
1063
+ """
1064
+ Checkpoint three out of four of every transformer layers.
1065
+ """
1066
+
1067
+ fine_grained = "fine_grained"
1068
+ """
1069
+ Focus checkpointing on where it is cheap to recompute and saves most memory.
1070
+ """
1071
+
1072
+
1073
+ @dataclass
1074
+ class TrainConfig(BaseConfig):
1075
+ """
1076
+ OLMo training configuration.
1077
+ """
1078
+
1079
+ run_name: Optional[str] = None
1080
+ """
1081
+ The name of the run.
1082
+ """
1083
+
1084
+ # seed: int = 6198
1085
+ """
1086
+ Used to seed all initial RNG states.
1087
+ """
1088
+
1089
+ epochs: Optional[int] = None
1090
+ """
1091
+ Increment this when starting a new epoch.
1092
+ """
1093
+
1094
+ dry_run: bool = False
1095
+ """
1096
+ If ``True``, don't actually train.
1097
+ """
1098
+
1099
+ model: ModelConfig = field(default_factory=ModelConfig)
1100
+ """
1101
+ OLMo Model configuration.
1102
+ """
1103
+
1104
+ optimizer: OptimizerConfig = field(default_factory=OptimizerConfig)
1105
+ """
1106
+ Optimizer configuration.
1107
+ """
1108
+
1109
+ scheduler: SchedulerConfig = field(default_factory=SchedulerConfig)
1110
+ """
1111
+ Learning rate scheduler configuration.
1112
+ """
1113
+
1114
+ data: DataConfig = field(default_factory=DataConfig)
1115
+ """
1116
+ Training data configuration.
1117
+ """
1118
+ ###############################################
1119
+ ### Restarting and resuming training
1120
+ ###############################################
1121
+ #restore_dataloader: bool = True
1122
+ """
1123
+ When restarting, restore the data loader to where it left off.
1124
+ If you restarting in order to train on a different dataset, set this to ``False``.
1125
+ """
1126
+
1127
+ #fast_forward_batches: Optional[int] = None
1128
+ """
1129
+ When restarting, use this to fast-forward the dataloader beyond the last checkpoint.
1130
+ This can be useful when restarting due to a loss spike in order to skip the data that
1131
+ corresponded to the spike.
1132
+ """
1133
+ ###############################################
1134
+ #evaluators: List[EvaluatorConfig] = field(default_factory=list)
1135
+ """
1136
+ Evaluation configurations.
1137
+ """
1138
+
1139
+ #eval_interval: int = 1000
1140
+ """
1141
+ How often (in terms of batches) to run evaluations.
1142
+ """
1143
+
1144
+ tokenizer: TokenizerConfig = field(default_factory=TokenizerConfig)
1145
+ """
1146
+ Tokenizer configuration.
1147
+ """
1148
+
1149
+ save_folder: str = "./"
1150
+ """
1151
+ The directory to save checkpoints to.
1152
+ """
1153
+
1154
+ #remote_save_folder: Optional[str] = None
1155
+ """
1156
+ A folder in a cloud bucket to upload saved checkpoints to.
1157
+ """
1158
+
1159
+ canceled_check_interval: int = 50
1160
+ """
1161
+ How often (in batches) to check if the run has been canceled or reached its time limit.
1162
+ """
1163
+
1164
+ save_interval: Optional[int] = 1000
1165
+ """
1166
+ How often (in terms of steps) to save sharded training state checkpoints.
1167
+ """
1168
+
1169
+ #save_interval_unsharded: Optional[int] = None
1170
+ """
1171
+ How often (if at all) to save unsharded training state checkpoint.
1172
+ For large models it can be costly to save these, so it usually makes sense to save
1173
+ these less often than regular (sharded) training checkpoints.
1174
+ """
1175
+
1176
+ #save_interval_ephemeral: Optional[int] = None
1177
+ """
1178
+ How often (if at all) to save ephemeral sharded checkpoints. These checkpoints are the same
1179
+ as those saved every `save_interval` except that at most only the most recent one of these is kept.
1180
+ This is useful when you want to checkpoint often for restarts in case of failures, but don't
1181
+ want to keep the majority of these checkpoints.
1182
+
1183
+ For example, suppose you want to keep your checkpoints at every 1000 steps, but you also want to save
1184
+ a temporary checkpoint every 100 steps in case your job fails. In that case you would
1185
+ set `save_interval=1000` and `save_interval_ephemeral=100`.
1186
+ """
1187
+
1188
+ #save_num_checkpoints_to_keep: int = -1
1189
+ """
1190
+ How many sharded checkpoints to keep.
1191
+ """
1192
+
1193
+ #save_num_unsharded_checkpoints_to_keep: int = -1
1194
+ """
1195
+ How many unsharded checkpoints to keep.
1196
+ """
1197
+
1198
+ #save_overwrite: bool = False
1199
+ """
1200
+ If ``True``, overwrite any conflicting checkpoint files.
1201
+ """
1202
+
1203
+ #force_save_unsharded: bool = False
1204
+ """
1205
+ Save an unsharded checkpoint before training (even during a dry run).
1206
+ Use this option with `--load-path={PATH}` and `--dry_run` to convert a sharded
1207
+ checkpoint into an unsharded checkpoint.
1208
+ """
1209
+
1210
+ #no_pre_train_checkpoint: bool = False
1211
+ """
1212
+ Skip saving pre-train checkpoint.
1213
+ """
1214
+
1215
+ load_checkpoint: Optional[str] = None
1216
+ """
1217
+ The path to a training checkpoint to restore/resume from. If not set, then training begins from scratch.
1218
+
1219
+ Note that you can make use of the "path.last_checkpoint" Omegaconfig YAML resolver here, which takes
1220
+ a local or remote directory and resolves to the latest checkpoint (sharded or unsharded) in that directory.
1221
+ For example,
1222
+
1223
+ ```bash
1224
+ --load_path='${path.last_checkpoint:s3://ai2-llm/checkpoints/7b/v1_5-mix-run-001}'
1225
+ ```
1226
+
1227
+ If `try_load_latest_save` is set and saved checkpoints exist, then `load_path` will be overriden
1228
+ by the latest saved checkpoint.
1229
+ """
1230
+
1231
+ #load_path_sharded_checkpointer: Optional[ShardedCheckpointerType] = None
1232
+ """
1233
+ The sharded checkpointer type to use to load the initial checkpoint from ``load_path``.
1234
+ """
1235
+
1236
+ #try_load_latest_save: bool = False
1237
+ """
1238
+ If set, then training will be resumed from the latest checkpoint in the local save folder, falling
1239
+ back to the latest checkpoint in the remote save folder if none exists. If there are no checkpoints
1240
+ in the local and remote save folders, then checkpoint loading will fall back to `load_path`.
1241
+ """
1242
+
1243
+ #reset_optimizer_state: bool = False
1244
+ """
1245
+ When this is set, we restore the model from a checkpoint (if given), but we leave the optimizer uninitialized.
1246
+ We also set a new learning rate schedule that does a new warmup, such that it intercepts the original learning
1247
+ curve (according to the current learning rate schedule settings), and continues from there.
1248
+ """
1249
+
1250
+ # reset_trainer_state: bool = False
1251
+ """
1252
+ When this is set we don't restore the trainer state from a checkpoint.
1253
+ """
1254
+
1255
+ # sharded_checkpointer: ShardedCheckpointerType = ShardedCheckpointerType.torch_legacy
1256
+ """
1257
+ The name of the sharded checkpointer to use to save (sharded) checkpoints throughout training.
1258
+ """
1259
+
1260
+ #new_style_checkpoints: Optional[bool] = None
1261
+ """
1262
+ Deprecated. Use ``sharded_checkpointer`` instead.
1263
+ """
1264
+
1265
+ #max_duration: Union[int, str] = 10000
1266
+ """
1267
+ How long to train for.
1268
+
1269
+ If specified without a unit (the default), the units are assumed to be steps.
1270
+ You can also specify this in terms of tokens, for example: `max_duration="2e12T"` means train until
1271
+ 2 trillion tokens.
1272
+ """
1273
+
1274
+ batch_size: int = 64
1275
+ #global_train_batch_size: int = 512
1276
+ """
1277
+ The effective global batch size.
1278
+ """
1279
+
1280
+ #device_train_batch_size: Optional[int] = None # calculated automatically
1281
+ """
1282
+ Don't set this manually. This will be set to ``global_train_batch_size // world_size``.
1283
+ """
1284
+
1285
+ #device_train_microbatch_size: int = 16
1286
+ """
1287
+ The number of instances passed to the model in a single forward-backward pass. You should set
1288
+ this as large as you can based on available GPU memory.
1289
+ """
1290
+
1291
+ # device_eval_batch_size: int = 16
1292
+ """
1293
+ The number of evaluation instances passed to the model in a single forward pass on each device.
1294
+ """
1295
+
1296
+ #eval_subset_num_batches: int = -1
1297
+ """
1298
+ The number of batches to use for downstream evaluation from each dataset.
1299
+ """
1300
+
1301
+ #eval_on_load: bool = False
1302
+ """
1303
+ When resuming from a checkpoint, run the evaluation loop right away.
1304
+ """
1305
+
1306
+ #device_train_grad_accum: Optional[int] = None # calculated automatically
1307
+ """
1308
+ Don't set this manually. This will be set to ``device_train_batch_size // device_train_microbatch_size``.
1309
+ """
1310
+ grad_clip: float = 1.0
1311
+ #max_grad_norm: Optional[float] = None
1312
+ """
1313
+ Clip gradient norms to this value if set.
1314
+ """
1315
+
1316
+ #max_grad_norm_ratio: Optional[float] = None
1317
+ """
1318
+ If set, gradient norms will be clipped to `max_grad_norm_ratio * exp_avg(norm(grad))`.
1319
+ This takes priority over `max_grad_norm` when set.
1320
+ """
1321
+
1322
+ precision: Optional[str] = None
1323
+ """
1324
+ Precision to train with (e.g. "amp_bf16", "amp_fp16", or "fp32").
1325
+ """
1326
+
1327
+ wandb: Optional[WandbConfig] = None
1328
+ """
1329
+ Weights & Biases configuration.
1330
+ """
1331
+
1332
+ # speed_monitor: SpeedMonitorConfig = field(default_factory=SpeedMonitorConfig)
1333
+ """
1334
+ Speed monitor configuration.
1335
+ """
1336
+
1337
+ console_log_interval: int = 1
1338
+ """
1339
+ How often to log to the console.
1340
+ """
1341
+
1342
+ #gen1_gc_interval: Optional[int] = 1
1343
+ """
1344
+ How often (in steps) to run generation 1 garbage collection.
1345
+ Set to ``None`` to use automatic garbage collection (i.e. we don't mess with it).
1346
+ """
1347
+
1348
+ #compile: Optional[CompilerConfig] = None
1349
+ """
1350
+ Settings for compiling the model with ``torch.compile()``.
1351
+ """
1352
+
1353
+ #distributed_strategy: Optional[DistributedStrategy] = DistributedStrategy.fsdp
1354
+ """
1355
+ Distributed strategy for OLMo model (eg. single GPU, DDP, FSDP).
1356
+ """
1357
+
1358
+ #fsdp: Optional[FSDPConfig] = field(default_factory=FSDPConfig)
1359
+ """
1360
+ Fully sharded data parallel settings.
1361
+ """
1362
+
1363
+ ddp: Optional[DDPConfig] = None
1364
+ """
1365
+ DDP settings.
1366
+ """
1367
+
1368
+ #single: SingleGPUConfig = field(default_factory=lambda: SingleGPUConfig(device="auto"))
1369
+ """
1370
+ Single device settings for GPU/CPU/MPS. Defaults to auto-detect the best device.
1371
+ """
1372
+
1373
+ #softmax_auxiliary_loss: bool = False
1374
+ """
1375
+ If ``True``, we add the auxiliary loss function from PaLM that encourages the softmax
1376
+ normalizing term to be close to 0.
1377
+ """
1378
+
1379
+ #auxiliary_loss_multiplier: Optional[float] = 1e-4
1380
+ """
1381
+ Used with `softmax_auxiliary_loss`. PaLM uses 1e-4, Chameleon uses 1e-5.
1382
+ """
1383
+
1384
+ #time_limit: Optional[float] = None
1385
+ """
1386
+ The maximum amount of time to train for before saving a checkpoint and ending early.
1387
+ """
1388
+
1389
+ #extra_steps_after_cancel: int = 10
1390
+ """
1391
+ Under certain conditions when a run is canceled we train for a few extra steps after saving
1392
+ the final checkpoint so that when the run is restarted from the latest checkpoint we have some
1393
+ overlap in metrics.
1394
+ """
1395
+
1396
+ #early_stopping_factor: Optional[float] = None
1397
+
1398
+ #save_data_indices: bool = True
1399
+ """
1400
+ Save training data indices from each batch for each worker.
1401
+ """
1402
+
1403
+ #python_profiling: bool = False
1404
+ """
1405
+ Whether to run the Python profiler on batches 6, 7, and 8.
1406
+ """
1407
+
1408
+ #torch_profiling: bool = False
1409
+ """
1410
+ Whether to run the PyTorch profiler on batches 6, 7, and 8.
1411
+ """
1412
+
1413
+ #stop_at: Optional[int] = None
1414
+ """
1415
+ Stop at a specific step.
1416
+ """
1417
+
1418
+ #stop_after: Optional[int] = None
1419
+ """
1420
+ Stop after a specific number of steps.
1421
+ """
1422
+
1423
+ #activation_checkpointing: Optional[ActivationCheckpointingStrategy] = None
1424
+ """
1425
+ The activation checkpointing strategy to use.
1426
+ """
1427
+
1428
+ #fused_loss: Optional[bool] = None
1429
+ """
1430
+ Whether to use the fused CE loss function from `flash-attn`.
1431
+ """
1432
+
1433
+ #hf_datasets_cache_dir: Optional[str] = None
1434
+ """
1435
+ Deprecated, HF datasets are now stored in `olmo_data.hf_datasets`.
1436
+
1437
+ Path to cache directory of HF datasets saved with `datasets.save_to_disk`.
1438
+ """
1439
+
1440
+ # module_outputs_save_steps: Optional[List[int]] = None
1441
+ """
1442
+ Outputs of model submodules are saved during the provided steps. Submodule outputs
1443
+ can be compared using `scripts/compare_module_outputs.py`.
1444
+ """
1445
+ accumulation_steps: int = 8
1446
+ """
1447
+ accumulation steps for gradient accumulation.
1448
+ """
1449
+
1450
+ @property
1451
+ def autocast_precision(self) -> torch.dtype:
1452
+ if self.precision == "amp_bf16":
1453
+ return torch.bfloat16
1454
+ elif self.precision == "amp_fp16":
1455
+ return torch.float16
1456
+ elif self.precision == "fp32":
1457
+ return torch.float32
1458
+ else:
1459
+ raise ValueError(f"Unexpected precision type '{self.precision}'")
1460
+
1461
+ @property
1462
+ def fsdp_precision(self) -> Optional[MixedPrecision]:
1463
+ if self.fsdp is not None:
1464
+ if self.fsdp.precision is None:
1465
+ return None
1466
+ elif self.fsdp.precision == FSDPPrecision.pure:
1467
+ return MixedPrecision(
1468
+ param_dtype=self.autocast_precision,
1469
+ reduce_dtype=self.autocast_precision,
1470
+ buffer_dtype=self.autocast_precision,
1471
+ )
1472
+ elif self.fsdp.precision == FSDPPrecision.mixed:
1473
+ return MixedPrecision(
1474
+ param_dtype=self.autocast_precision,
1475
+ reduce_dtype=torch.float32,
1476
+ buffer_dtype=self.autocast_precision,
1477
+ )
1478
+ else:
1479
+ raise NotImplementedError(f"{self.fsdp.precision}")
1480
+ else:
1481
+ raise ValueError("self.fsdp is None!")
1482
+
1483
+ @classmethod
1484
+ def update_legacy_settings(cls, config: D) -> D:
1485
+ new_config = config.copy()
1486
+ if om.is_dict(new_config):
1487
+ assert isinstance(new_config, DictConfig)
1488
+
1489
+ if hasattr(new_config, "activation_checkpointing"):
1490
+ if new_config.activation_checkpointing is False:
1491
+ new_config.activation_checkpointing = None
1492
+ if new_config.activation_checkpointing is True:
1493
+ new_config.activation_checkpointing = ActivationCheckpointingStrategy.whole_layer
1494
+
1495
+ if hasattr(new_config, "optimizer"):
1496
+ new_config.optimizer = OptimizerConfig.update_legacy_settings(new_config.optimizer)
1497
+
1498
+ return new_config
exceptions.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __all__ = [
2
+ "OLMoError",
3
+ "OLMoConfigurationError",
4
+ "OLMoCliError",
5
+ "OLMoEnvironmentError",
6
+ "OLMoNetworkError",
7
+ "OLMoCheckpointError",
8
+ ]
9
+
10
+
11
+ class OLMoError(Exception):
12
+ """
13
+ Base class for all custom OLMo exceptions.
14
+ """
15
+
16
+
17
+ class OLMoConfigurationError(OLMoError):
18
+ """
19
+ An error with a configuration file.
20
+ """
21
+
22
+
23
+ class OLMoCliError(OLMoError):
24
+ """
25
+ An error from incorrect CLI usage.
26
+ """
27
+
28
+
29
+ class OLMoEnvironmentError(OLMoError):
30
+ """
31
+ An error from incorrect environment variables.
32
+ """
33
+
34
+
35
+ class OLMoNetworkError(OLMoError):
36
+ """
37
+ An error with a network request.
38
+ """
39
+
40
+
41
+ class OLMoCheckpointError(OLMoError):
42
+ """
43
+ An error occurred reading or writing from a checkpoint.
44
+ """
45
+
46
+
47
+ class OLMoThreadError(Exception):
48
+ """
49
+ Raised when a thread fails.
50
+ """
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 50256,
5
+ "pad_token_id": 50256,
6
+ "transformers_version": "4.48.0"
7
+ }
model_minimind.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘
2
+ # MiniMind Config
3
+ # 📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘
4
+
5
+ from transformers import PretrainedConfig
6
+ from .configuration_minimind import ModelConfig
7
+
8
+ class MiniMindConfig(PretrainedConfig):
9
+ model_type = "minimind"
10
+
11
+ def __init__(
12
+ self,
13
+ dropout: float = 0.0,
14
+ bos_token_id: int = 1,
15
+ eos_token_id: int = 2,
16
+ hidden_act: str = 'silu',
17
+ hidden_size: int = 512,
18
+ intermediate_size: int = None,
19
+ max_position_embeddings: int = 32768,
20
+ num_attention_heads: int = 8,
21
+ num_hidden_layers: int = 8,
22
+ num_key_value_heads: int = 2,
23
+ # 50257 + 1 special token: pad token
24
+ vocab_size: int = 50257,
25
+ rms_norm_eps: float = 1e-05,
26
+ rope_theta: int = 1000000.0,
27
+ flash_attn: bool = True,
28
+ ####################################################
29
+ # Here are the specific configurations of MOE
30
+ # When use_moe is false, the following is invalid
31
+ ####################################################
32
+ use_moe: bool = False,
33
+ num_experts_per_tok: int = 2,
34
+ n_routed_experts: int = 4,
35
+ n_shared_experts: int = 1,
36
+ scoring_func: str = 'softmax',
37
+ aux_loss_alpha: float = 0.1,
38
+ seq_aux: bool = True,
39
+ norm_topk_prob: bool = True,
40
+ **kwargs
41
+ ):
42
+ super().__init__(**kwargs)
43
+ self.dropout = dropout
44
+ self.bos_token_id = bos_token_id
45
+ self.eos_token_id = eos_token_id
46
+ self.hidden_act = hidden_act
47
+ self.hidden_size = hidden_size
48
+ self.intermediate_size = intermediate_size
49
+ self.max_position_embeddings = max_position_embeddings
50
+ self.num_attention_heads = num_attention_heads
51
+ self.num_hidden_layers = num_hidden_layers
52
+ self.num_key_value_heads = num_key_value_heads
53
+ self.vocab_size = vocab_size
54
+ self.rms_norm_eps = rms_norm_eps
55
+ self.rope_theta = rope_theta
56
+ self.flash_attn = flash_attn
57
+ ####################################################
58
+ # Here are the specific configurations of MOE
59
+ # When use_moe is false, the following is invalid
60
+ ####################################################
61
+ self.use_moe = use_moe
62
+ self.num_experts_per_tok = num_experts_per_tok # 每个token选择的专家数量
63
+ self.n_routed_experts = n_routed_experts # 总的专家数量
64
+ self.n_shared_experts = n_shared_experts # 共享专家
65
+ self.scoring_func = scoring_func # 评分函数,默认为'softmax'
66
+ self.aux_loss_alpha = aux_loss_alpha # 辅助损失的alpha参数
67
+ self.seq_aux = seq_aux # 是否在序列级别上计算辅助损失
68
+ self.norm_topk_prob = norm_topk_prob # 是否标准化top-k概率
69
+
70
+
71
+ # 📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘
72
+ # MiniMind Model
73
+ # 📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘
74
+
75
+ import math
76
+ import torch
77
+ from torch import nn
78
+ from transformers.activations import ACT2FN
79
+ from typing import Optional, Tuple, List, Union
80
+ import torch.nn.functional as F
81
+ from transformers import PreTrainedModel, GenerationMixin, PretrainedConfig
82
+ from transformers.modeling_outputs import CausalLMOutputWithPast
83
+
84
+
85
+ class RMSNorm(torch.nn.Module):
86
+ def __init__(self, dim: int, eps: float = 1e-5):
87
+ super().__init__()
88
+ self.eps = eps
89
+ self.weight = nn.Parameter(torch.ones(dim))
90
+
91
+ def _norm(self, x):
92
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
93
+
94
+ def forward(self, x):
95
+ return self.weight * self._norm(x.float()).type_as(x)
96
+
97
+
98
+ def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
99
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
100
+ t = torch.arange(end, device=freqs.device)
101
+ freqs = torch.outer(t, freqs).float()
102
+ freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1)
103
+ freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1)
104
+ return freqs_cos, freqs_sin
105
+
106
+
107
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
108
+ def rotate_half(x):
109
+ return torch.cat((-x[..., x.shape[-1] // 2:], x[..., : x.shape[-1] // 2]), dim=-1)
110
+
111
+ q_embed = (q * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(q) * sin.unsqueeze(unsqueeze_dim))
112
+ k_embed = (k * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(k) * sin.unsqueeze(unsqueeze_dim))
113
+ return q_embed, k_embed
114
+
115
+
116
+ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
117
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
118
+ bs, slen, num_key_value_heads, head_dim = x.shape
119
+ if n_rep == 1:
120
+ return x
121
+ return (
122
+ x[:, :, :, None, :]
123
+ .expand(bs, slen, num_key_value_heads, n_rep, head_dim)
124
+ .reshape(bs, slen, num_key_value_heads * n_rep, head_dim)
125
+ )
126
+
127
+
128
+ class Attention(nn.Module):
129
+ def __init__(self, args: ModelConfig):
130
+ super().__init__()
131
+ self.num_key_value_heads = args.num_attention_heads if args.num_key_value_heads is None else args.num_key_value_heads
132
+ assert args.num_attention_heads % self.num_key_value_heads == 0
133
+ self.n_local_heads = args.num_attention_heads
134
+ self.n_local_kv_heads = self.num_key_value_heads
135
+ self.n_rep = self.n_local_heads // self.n_local_kv_heads
136
+ self.head_dim = args.hidden_size // args.num_attention_heads
137
+ self.q_proj = nn.Linear(args.hidden_size, args.num_attention_heads * self.head_dim, bias=False)
138
+ self.k_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
139
+ self.v_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
140
+ self.o_proj = nn.Linear(args.num_attention_heads * self.head_dim, args.hidden_size, bias=False)
141
+ self.attn_dropout = nn.Dropout(args.dropout)
142
+ self.resid_dropout = nn.Dropout(args.dropout)
143
+ self.dropout = args.dropout
144
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
145
+ # print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
146
+
147
+ def forward(self,
148
+ x: torch.Tensor,
149
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor], # 修改为接收cos和sin
150
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
151
+ use_cache=False,
152
+ attention_mask: Optional[torch.Tensor] = None):
153
+ bsz, seq_len, _ = x.shape
154
+ xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x)
155
+ xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
156
+ xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
157
+ xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
158
+
159
+ cos, sin = position_embeddings
160
+ xq, xk = apply_rotary_pos_emb(xq, xk, cos[:seq_len], sin[:seq_len])
161
+
162
+ # kv_cache实现
163
+ if past_key_value is not None:
164
+ xk = torch.cat([past_key_value[0], xk], dim=1)
165
+ xv = torch.cat([past_key_value[1], xv], dim=1)
166
+ past_kv = (xk, xv) if use_cache else None
167
+
168
+ xq, xk, xv = (
169
+ xq.transpose(1, 2),
170
+ repeat_kv(xk, self.n_rep).transpose(1, 2),
171
+ repeat_kv(xv, self.n_rep).transpose(1, 2)
172
+ )
173
+
174
+ if self.flash and seq_len != 1:
175
+ dropout_p = self.dropout if self.training else 0.0
176
+ attn_mask = None
177
+ if attention_mask is not None:
178
+ attn_mask = attention_mask.view(bsz, 1, 1, -1).expand(bsz, self.n_local_heads, seq_len, -1)
179
+ attn_mask = attn_mask.bool() if attention_mask is not None else None
180
+
181
+ output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True)
182
+ else:
183
+ scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
184
+ scores = scores + torch.triu(
185
+ torch.full((seq_len, seq_len), float("-inf"), device=scores.device),
186
+ diagonal=1
187
+ ).unsqueeze(0).unsqueeze(0) # scores+mask
188
+
189
+ if attention_mask is not None:
190
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
191
+ extended_attention_mask = (1.0 - extended_attention_mask) * -1e9
192
+ scores = scores + extended_attention_mask
193
+
194
+ scores = F.softmax(scores.float(), dim=-1).type_as(xq)
195
+ scores = self.attn_dropout(scores)
196
+ output = scores @ xv
197
+
198
+ output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
199
+ output = self.resid_dropout(self.o_proj(output))
200
+ return output, past_kv
201
+
202
+
203
+ class FeedForward(nn.Module):
204
+ def __init__(self, config: ModelConfig):
205
+ super().__init__()
206
+ if config.intermediate_size is None:
207
+ intermediate_size = int(config.hidden_size * 8 / 3)
208
+ config.intermediate_size = 64 * ((intermediate_size + 64 - 1) // 64)
209
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
210
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
211
+ self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
212
+ self.dropout = nn.Dropout(config.dropout)
213
+ self.act_fn = ACT2FN[config.hidden_act]
214
+
215
+ def forward(self, x):
216
+ return self.dropout(self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)))
217
+
218
+
219
+ class MoEGate(nn.Module):
220
+ def __init__(self, config: ModelConfig):
221
+ super().__init__()
222
+ self.config = config
223
+ self.top_k = config.num_experts_per_tok
224
+ self.n_routed_experts = config.n_routed_experts
225
+
226
+ self.scoring_func = config.scoring_func
227
+ self.alpha = config.aux_loss_alpha
228
+ self.seq_aux = config.seq_aux
229
+
230
+ self.norm_topk_prob = config.norm_topk_prob
231
+ self.gating_dim = config.hidden_size
232
+ self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
233
+ self.reset_parameters()
234
+
235
+ def reset_parameters(self) -> None:
236
+ import torch.nn.init as init
237
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5))
238
+
239
+ def forward(self, hidden_states):
240
+ bsz, seq_len, h = hidden_states.shape
241
+ hidden_states = hidden_states.view(-1, h)
242
+ logits = F.linear(hidden_states, self.weight, None)
243
+ if self.scoring_func == 'softmax':
244
+ scores = logits.softmax(dim=-1)
245
+ else:
246
+ raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
247
+
248
+ topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
249
+
250
+ if self.top_k > 1 and self.norm_topk_prob:
251
+ denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
252
+ topk_weight = topk_weight / denominator
253
+
254
+ if self.training and self.alpha > 0.0:
255
+ scores_for_aux = scores
256
+ aux_topk = self.top_k
257
+ topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
258
+ if self.seq_aux:
259
+ scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
260
+ ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
261
+ ce.scatter_add_(1, topk_idx_for_aux_loss,
262
+ torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(
263
+ seq_len * aux_topk / self.n_routed_experts)
264
+ aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
265
+ else:
266
+ mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
267
+ ce = mask_ce.float().mean(0)
268
+ Pi = scores_for_aux.mean(0)
269
+ fi = ce * self.n_routed_experts
270
+ aux_loss = (Pi * fi).sum() * self.alpha
271
+ else:
272
+ aux_loss = 0
273
+ return topk_idx, topk_weight, aux_loss
274
+
275
+
276
+ class MOEFeedForward(nn.Module):
277
+ def __init__(self, config: ModelConfig):
278
+ super().__init__()
279
+ self.config = config
280
+ self.experts = nn.ModuleList([
281
+ FeedForward(config)
282
+ for _ in range(config.n_routed_experts)
283
+ ])
284
+ self.gate = MoEGate(config)
285
+ if config.n_shared_experts > 0:
286
+ self.shared_experts = nn.ModuleList([
287
+ FeedForward(config)
288
+ for _ in range(config.n_shared_experts)
289
+ ])
290
+
291
+ def forward(self, x):
292
+ identity = x
293
+ orig_shape = x.shape
294
+ bsz, seq_len, _ = x.shape
295
+ # 使用门控机制选择专家
296
+ topk_idx, topk_weight, aux_loss = self.gate(x)
297
+ x = x.view(-1, x.shape[-1])
298
+ flat_topk_idx = topk_idx.view(-1)
299
+ if self.training:
300
+ x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
301
+ y = torch.empty_like(x, dtype=torch.float16)
302
+ for i, expert in enumerate(self.experts):
303
+ y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致
304
+ y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
305
+ y = y.view(*orig_shape)
306
+ else:
307
+ y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
308
+ if self.config.n_shared_experts > 0:
309
+ for expert in self.shared_experts:
310
+ y = y + expert(identity)
311
+ self.aux_loss = aux_loss
312
+ return y
313
+
314
+ @torch.no_grad()
315
+ def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
316
+ expert_cache = torch.zeros_like(x)
317
+ idxs = flat_expert_indices.argsort()
318
+ tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
319
+ token_idxs = idxs // self.config.num_experts_per_tok
320
+ # 当tokens_per_expert = [6, 15, 20, 26],tokens_per_expert.shape[0]即为专家数量(此时为4)
321
+ # 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时
322
+ # 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token(每个token有可能被多个专家处理,这取决于num_experts_per_tok)
323
+ # 接下来9个位���token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推
324
+ for i, end_idx in enumerate(tokens_per_expert):
325
+ start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
326
+ if start_idx == end_idx:
327
+ continue
328
+ expert = self.experts[i]
329
+ exp_token_idx = token_idxs[start_idx:end_idx]
330
+ expert_tokens = x[exp_token_idx]
331
+ expert_out = expert(expert_tokens).to(expert_cache.dtype)
332
+ expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
333
+ expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
334
+
335
+ return expert_cache
336
+
337
+
338
+ class MiniMindBlock(nn.Module):
339
+ def __init__(self, layer_id: int, config: ModelConfig):
340
+ super().__init__()
341
+ self.num_attention_heads = config.num_attention_heads
342
+ self.hidden_size = config.hidden_size
343
+ self.head_dim = config.hidden_size // config.num_attention_heads
344
+ self.self_attn = Attention(config)
345
+
346
+ self.layer_id = layer_id
347
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
348
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
349
+ self.mlp = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
350
+
351
+ def forward(self, hidden_states, position_embeddings, past_key_value=None, use_cache=False, attention_mask=None):
352
+ residual = hidden_states
353
+ hidden_states, present_key_value = self.self_attn(
354
+ self.input_layernorm(hidden_states), position_embeddings,
355
+ past_key_value, use_cache, attention_mask
356
+ )
357
+ hidden_states += residual
358
+ hidden_states = hidden_states + self.mlp(self.post_attention_layernorm(hidden_states))
359
+ return hidden_states, present_key_value
360
+
361
+
362
+ class MiniMindModel(nn.Module):
363
+ def __init__(self, config: ModelConfig):
364
+ super().__init__()
365
+ self.config = config
366
+ self.vocab_size, self.num_hidden_layers = config.vocab_size, config.num_hidden_layers
367
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
368
+ self.dropout = nn.Dropout(config.dropout)
369
+ self.layers = nn.ModuleList([MiniMindBlock(l, config) for l in range(self.num_hidden_layers)])
370
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
371
+
372
+ freqs_cos, freqs_sin = precompute_freqs_cis(dim=config.hidden_size // config.num_attention_heads,
373
+ end=config.max_position_embeddings, theta=config.rope_theta)
374
+ self.register_buffer("freqs_cos", freqs_cos, persistent=False)
375
+ self.register_buffer("freqs_sin", freqs_sin, persistent=False)
376
+
377
+ def forward(self,
378
+ input_ids: Optional[torch.Tensor] = None,
379
+ attention_mask: Optional[torch.Tensor] = None,
380
+ past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
381
+ use_cache: bool = False,
382
+ **kwargs):
383
+ batch_size, seq_length = input_ids.shape
384
+ past_key_values = past_key_values or [None] * len(self.layers)
385
+ start_pos = past_key_values[0][0].shape[1] if past_key_values[0] is not None else 0
386
+
387
+ hidden_states = self.dropout(self.embed_tokens(input_ids))
388
+
389
+ position_embeddings = (
390
+ self.freqs_cos[start_pos:start_pos + seq_length],
391
+ self.freqs_sin[start_pos:start_pos + seq_length]
392
+ )
393
+
394
+ presents = []
395
+ for layer_idx, (layer, past_key_value) in enumerate(zip(self.layers, past_key_values)):
396
+ hidden_states, present = layer(
397
+ hidden_states,
398
+ position_embeddings,
399
+ past_key_value=past_key_value,
400
+ use_cache=use_cache,
401
+ attention_mask=attention_mask
402
+ )
403
+ presents.append(present)
404
+
405
+ hidden_states = self.norm(hidden_states)
406
+
407
+ aux_loss = sum(
408
+ layer.mlp.aux_loss
409
+ for layer in self.layers
410
+ if isinstance(layer.mlp, MOEFeedForward)
411
+ )
412
+
413
+ return hidden_states, presents, aux_loss
414
+
415
+
416
+ class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
417
+ config_class = ModelConfig
418
+
419
+ def __init__(self, config: ModelConfig = None):
420
+ self.config = config or ModelConfig()
421
+ super().__init__(self.config)
422
+ self.model = MiniMindModel(self.config)
423
+ self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
424
+ self.model.embed_tokens.weight = self.lm_head.weight
425
+ self.OUT = CausalLMOutputWithPast()
426
+
427
+ def forward(self,
428
+ input_ids: Optional[torch.Tensor] = None,
429
+ attention_mask: Optional[torch.Tensor] = None,
430
+ past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
431
+ use_cache: bool = False,
432
+ logits_to_keep: Union[int, torch.Tensor] = 0,
433
+ **args):
434
+ h, past_kvs, aux_loss = self.model(
435
+ input_ids=input_ids,
436
+ attention_mask=attention_mask,
437
+ past_key_values=past_key_values,
438
+ use_cache=use_cache,
439
+ **args
440
+ )
441
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
442
+ logits = self.lm_head(h[:, slice_indices, :])
443
+ self.OUT.__setitem__('last_hidden_state', h)
444
+ self.OUT.__setitem__('logits', logits)
445
+ self.OUT.__setitem__('aux_loss', aux_loss)
446
+ self.OUT.__setitem__('past_key_values', past_kvs)
447
+ return self.OUT
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:20f4bf36fbb95ac7294fc81bba1dbddc1960eae1dbd9c35523f3ca4bac1d9f43
3
+ size 1796898322
special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|im_start|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|im_end|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<|pad|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": false,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<|endoftext|>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<|im_start|>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "2": {
23
+ "content": "<|im_end|>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ },
30
+ "3": {
31
+ "content": "<|pad|>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": true
37
+ }
38
+ },
39
+ "additional_special_tokens": [],
40
+ "bos_token": "<|im_start|>",
41
+ "chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ '<|im_start|>system\\n' + system_message + '<|im_end|>\\n' }}{% else %}{{ '<|im_start|>system\\nYou are a helpful assistant<|im_end|>\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\\n' + content + '<|im_end|>\\n<|im_start|>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\\n' }}{% endif %}{% endfor %}",
42
+ "clean_up_tokenization_spaces": false,
43
+ "eos_token": "<|im_end|>",
44
+ "extra_special_tokens": {},
45
+ "legacy": true,
46
+ "model_max_length": 100000,
47
+ "pad_token": "<|pad|>",
48
+ "sp_model_kwargs": {},
49
+ "spaces_between_special_tokens": false,
50
+ "tokenizer_class": "PreTrainedTokenizerFast",
51
+ "unk_token": null
52
+ }