zaydzuhri commited on
Commit
d9de648
·
verified ·
1 Parent(s): 6a052b3

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fla/layers/__init__.py +44 -0
  2. fla/models/abc/__pycache__/configuration_abc.cpython-312.pyc +0 -0
  3. fla/models/abc/configuration_abc.py +91 -0
  4. fla/models/delta_net/__pycache__/__init__.cpython-312.pyc +0 -0
  5. fla/models/gated_deltaproduct/__pycache__/configuration_gated_deltaproduct.cpython-312.pyc +0 -0
  6. fla/models/gla/__pycache__/modeling_gla.cpython-312.pyc +0 -0
  7. fla/models/gsa/__pycache__/__init__.cpython-312.pyc +0 -0
  8. fla/models/gsa/__pycache__/configuration_gsa.cpython-312.pyc +0 -0
  9. fla/models/gsa/__pycache__/modeling_gsa.cpython-312.pyc +0 -0
  10. fla/models/hgrn/__pycache__/modeling_hgrn.cpython-312.pyc +0 -0
  11. fla/models/hgrn2/__pycache__/modeling_hgrn2.cpython-312.pyc +0 -0
  12. fla/models/lightnet/__pycache__/configuration_lightnet.cpython-312.pyc +0 -0
  13. fla/models/linear_attn/__init__.py +12 -0
  14. fla/models/mamba/__pycache__/__init__.cpython-312.pyc +0 -0
  15. fla/models/mamba/__pycache__/modeling_mamba.cpython-312.pyc +0 -0
  16. fla/models/mamba/modeling_mamba.py +843 -0
  17. fla/models/nsa/__init__.py +15 -0
  18. fla/models/nsa/__pycache__/__init__.cpython-312.pyc +0 -0
  19. fla/models/retnet/__pycache__/__init__.cpython-312.pyc +0 -0
  20. fla/models/rwkv6/__init__.py +13 -0
  21. fla/models/rwkv6/configuration_rwkv6.py +82 -0
  22. fla/models/rwkv7/__pycache__/configuration_rwkv7.cpython-312.pyc +0 -0
  23. fla/models/rwkv7/modeling_rwkv7.py +505 -0
  24. fla/models/samba/__init__.py +13 -0
  25. fla/models/samba/configuration_samba.py +92 -0
  26. fla/models/samba/modeling_samba.py +413 -0
  27. fla/models/transformer_dsmtp/__pycache__/__init__.cpython-312.pyc +0 -0
  28. fla/models/transformer_dsmtp/__pycache__/configuration_transformer.cpython-312.pyc +0 -0
  29. fla/models/transformer_mtp/__pycache__/configuration_transformer.cpython-312.pyc +0 -0
  30. fla/models/transformer_top/__pycache__/configuration_transformer.cpython-312.pyc +0 -0
  31. fla/ops/common/__pycache__/chunk_delta_h.cpython-312.pyc +0 -0
  32. flame/__pycache__/__init__.cpython-312.pyc +0 -0
  33. flame/models/__init__.py +0 -0
  34. flame/models/parallelize_fla.py +550 -0
  35. flame/tools/__init__.py +0 -0
  36. flame/tools/utils.py +41 -0
  37. flame/utils/checkpoint.py +50 -0
  38. flame/utils/convert_hf_to_dcp.py +34 -0
  39. flame/utils/hf_utils.py +77 -0
  40. torchtitan/components/__pycache__/dataloader.cpython-312.pyc +0 -0
  41. torchtitan/components/__pycache__/ft.cpython-312.pyc +0 -0
  42. torchtitan/components/__pycache__/loss.cpython-312.pyc +0 -0
  43. torchtitan/components/__pycache__/lr_scheduler.cpython-312.pyc +0 -0
  44. torchtitan/components/__pycache__/tokenizer.cpython-312.pyc +0 -0
  45. torchtitan/datasets/__pycache__/hf_datasets.cpython-312.pyc +0 -0
  46. torchtitan/datasets/tokenizer/__pycache__/tiktoken.cpython-312.pyc +0 -0
  47. torchtitan/datasets/tokenizer/tiktoken.py +190 -0
  48. torchtitan/distributed/__pycache__/__init__.cpython-312.pyc +0 -0
  49. torchtitan/distributed/__pycache__/parallel_dims.cpython-312.pyc +0 -0
  50. torchtitan/distributed/__pycache__/pipeline.cpython-312.pyc +0 -0
fla/layers/__init__.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from .abc import ABCAttention
5
+ from .attn import Attention
6
+ from .based import BasedLinearAttention
7
+ from .bitattn import BitAttention
8
+ from .delta_net import DeltaNet
9
+ from .forgetting_attn import ForgettingAttention
10
+ from .gated_deltanet import GatedDeltaNet
11
+ from .gated_deltaproduct import GatedDeltaProduct
12
+ from .gla import GatedLinearAttention
13
+ from .gsa import GatedSlotAttention
14
+ from .hgrn import HGRNAttention
15
+ from .hgrn2 import HGRN2Attention
16
+ from .lightnet import LightNetAttention
17
+ from .linear_attn import LinearAttention
18
+ from .multiscale_retention import MultiScaleRetention
19
+ from .nsa import NativeSparseAttention
20
+ from .rebased import ReBasedLinearAttention
21
+ from .rwkv6 import RWKV6Attention
22
+ from .rwkv7 import RWKV7Attention
23
+
24
+ __all__ = [
25
+ 'ABCAttention',
26
+ 'Attention',
27
+ 'BasedLinearAttention',
28
+ 'BitAttention',
29
+ 'DeltaNet',
30
+ 'ForgettingAttention',
31
+ 'GatedDeltaNet',
32
+ 'GatedDeltaProduct',
33
+ 'GatedLinearAttention',
34
+ 'GatedSlotAttention',
35
+ 'HGRNAttention',
36
+ 'HGRN2Attention',
37
+ 'LightNetAttention',
38
+ 'LinearAttention',
39
+ 'MultiScaleRetention',
40
+ 'NativeSparseAttention',
41
+ 'ReBasedLinearAttention',
42
+ 'RWKV6Attention',
43
+ 'RWKV7Attention',
44
+ ]
fla/models/abc/__pycache__/configuration_abc.cpython-312.pyc ADDED
Binary file (3.6 kB). View file
 
fla/models/abc/configuration_abc.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class ABCConfig(PretrainedConfig):
9
+
10
+ model_type = 'abc'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2048,
16
+ gate_low_rank_dim: int = 16,
17
+ clamp_min: float = -32,
18
+ clamp_max: float = 32,
19
+ hidden_ratio: Optional[int] = 4,
20
+ intermediate_size: Optional[int] = None,
21
+ num_hidden_layers: int = 24,
22
+ num_heads: int = 4,
23
+ num_slots: Optional[int] = 64,
24
+ use_short_conv: bool = False,
25
+ conv_size: int = 4,
26
+ exapnd_k: float = 0.5,
27
+ exapnd_v: float = 1,
28
+ hidden_act: str = "swish",
29
+ max_position_embeddings: int = 2048,
30
+ elementwise_affine: Optional[bool] = True,
31
+ norm_eps: float = 1e-6,
32
+ use_rope: bool = True,
33
+ attn: Optional[Dict] = None,
34
+ use_cache: bool = True,
35
+ pad_token_id: int = None,
36
+ bos_token_id: int = 1,
37
+ eos_token_id: int = 2,
38
+ tie_word_embeddings: bool = False,
39
+ initializer_range: float = 0.006,
40
+ fuse_norm: bool = True,
41
+ fuse_swiglu: bool = True,
42
+ fuse_cross_entropy: bool = True,
43
+ vocab_size: int = 32000,
44
+ **kwargs
45
+ ):
46
+ self.hidden_size = hidden_size
47
+ self.gate_low_rank_dim = gate_low_rank_dim
48
+ self.clamp_min = clamp_min
49
+ self.clamp_max = clamp_max
50
+ self.hidden_ratio = hidden_ratio
51
+ self.intermediate_size = intermediate_size
52
+ self.num_hidden_layers = num_hidden_layers
53
+ self.num_heads = num_heads
54
+ self.num_slots = num_slots
55
+ self.use_short_conv = use_short_conv
56
+ self.conv_size = conv_size
57
+ self.expand_k = exapnd_k
58
+ self.expand_v = exapnd_v
59
+ self.hidden_act = hidden_act
60
+ self.max_position_embeddings = max_position_embeddings
61
+ self.elementwise_affine = elementwise_affine
62
+ self.norm_eps = norm_eps
63
+ self.use_rope = use_rope
64
+ self.attn = attn
65
+ self.use_cache = use_cache
66
+ self.initializer_range = initializer_range
67
+
68
+ self.fuse_norm = fuse_norm
69
+ self.fuse_swiglu = fuse_swiglu
70
+ self.fuse_cross_entropy = fuse_cross_entropy
71
+ self.vocab_size = vocab_size
72
+
73
+ if attn is not None:
74
+ if not isinstance(attn, Dict):
75
+ raise ValueError("attn must be a dictionary")
76
+ if 'layers' not in attn:
77
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
78
+ if 'num_heads' not in attn:
79
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
80
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
81
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
82
+ attn['window_size'] = attn.get('window_size', None)
83
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
84
+
85
+ super().__init__(
86
+ pad_token_id=pad_token_id,
87
+ bos_token_id=bos_token_id,
88
+ eos_token_id=eos_token_id,
89
+ tie_word_embeddings=tie_word_embeddings,
90
+ **kwargs,
91
+ )
fla/models/delta_net/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (697 Bytes). View file
 
fla/models/gated_deltaproduct/__pycache__/configuration_gated_deltaproduct.cpython-312.pyc ADDED
Binary file (3.37 kB). View file
 
fla/models/gla/__pycache__/modeling_gla.cpython-312.pyc ADDED
Binary file (18.6 kB). View file
 
fla/models/gsa/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (653 Bytes). View file
 
fla/models/gsa/__pycache__/configuration_gsa.cpython-312.pyc ADDED
Binary file (3.84 kB). View file
 
fla/models/gsa/__pycache__/modeling_gsa.cpython-312.pyc ADDED
Binary file (18.7 kB). View file
 
fla/models/hgrn/__pycache__/modeling_hgrn.cpython-312.pyc ADDED
Binary file (18.8 kB). View file
 
fla/models/hgrn2/__pycache__/modeling_hgrn2.cpython-312.pyc ADDED
Binary file (18.9 kB). View file
 
fla/models/lightnet/__pycache__/configuration_lightnet.cpython-312.pyc ADDED
Binary file (3.36 kB). View file
 
fla/models/linear_attn/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.linear_attn.configuration_linear_attn import LinearAttentionConfig
6
+ from fla.models.linear_attn.modeling_linear_attn import LinearAttentionForCausalLM, LinearAttentionModel
7
+
8
+ AutoConfig.register(LinearAttentionConfig.model_type, LinearAttentionConfig)
9
+ AutoModel.register(LinearAttentionConfig, LinearAttentionModel)
10
+ AutoModelForCausalLM.register(LinearAttentionConfig, LinearAttentionForCausalLM)
11
+
12
+ __all__ = ['LinearAttentionConfig', 'LinearAttentionForCausalLM', 'LinearAttentionModel']
fla/models/mamba/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (713 Bytes). View file
 
fla/models/mamba/__pycache__/modeling_mamba.cpython-312.pyc ADDED
Binary file (41.5 kB). View file
 
fla/models/mamba/modeling_mamba.py ADDED
@@ -0,0 +1,843 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 state-spaces/mamba org and HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch MAMBA model."""
16
+
17
+ import math
18
+ import warnings
19
+ from dataclasses import dataclass
20
+ from typing import Any, Dict, Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+ from transformers.activations import ACT2FN
26
+ from transformers.configuration_utils import PretrainedConfig
27
+ from transformers.generation import GenerationMixin
28
+ from transformers.modeling_utils import PreTrainedModel
29
+ from transformers.utils import ModelOutput, logging
30
+ from transformers.utils.deprecation import deprecate_kwarg
31
+
32
+ from fla.models.mamba.configuration_mamba import MambaConfig
33
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm
34
+
35
+ logger = logging.get_logger(__name__)
36
+
37
+
38
+ with warnings.catch_warnings():
39
+ warnings.simplefilter('ignore')
40
+ try:
41
+ from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
42
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
43
+ except ImportError:
44
+ selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
45
+
46
+ try:
47
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
48
+ except ImportError:
49
+ causal_conv1d_update, causal_conv1d_fn = None, None
50
+ is_fast_path_available = all((
51
+ selective_state_update,
52
+ selective_scan_fn,
53
+ causal_conv1d_fn,
54
+ causal_conv1d_update,
55
+ mamba_inner_fn
56
+ ))
57
+
58
+
59
+ class MambaCache:
60
+ """
61
+ Cache for mamba model which does not have attention mechanism and key value states.
62
+
63
+ Arguments:
64
+ config (`PretrainedConfig):
65
+ The configuration file defining the shape-related attributes required to initialize the static cache.
66
+ batch_size (`int`):
67
+ The batch size with which the model will be used. Note that a new instance must be instantiated if a
68
+ smaller batch size is used.
69
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
70
+ The default `dtype` to use when initializing the layer.
71
+ device (`torch.device` or `str`, *optional*):
72
+ The device on which the cache should be initialized. Should be the same as the layer.
73
+
74
+ Attributes:
75
+ dtype: (`torch.dtype`):
76
+ The default `dtype` used to initializing the cache.
77
+ intermediate_size: (`int`):
78
+ Model's intermediate_size taken from config.
79
+ ssm_state_size: (`int`):
80
+ Model's state_size taken from config.
81
+ conv_kernel_size: (`int`):
82
+ Model's convolution kernel size taken from config
83
+ conv_states: (`torch.Tensor`):
84
+ A tensor of shape `[layer_idx, batch_size, intermediate_size, conv_kernel_size]` that holds convolutional states.
85
+ ssm_states: (`torch.Tensor`):
86
+ A tensor of shape `[layer_idx, batch_size, intermediate_size, ssm_state_size]` that holds ssm states
87
+
88
+ Example:
89
+
90
+ ```python
91
+ >>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache
92
+
93
+ >>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
94
+ >>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
95
+
96
+ >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt")
97
+
98
+ >>> # Prepare a cache class and pass it to model's forward
99
+ >>> past_key_values = MambaCache(config=model.config, batch_size=1, device=model.device, dtype=model.dtype)
100
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
101
+ >>> outputs.past_key_values
102
+ MambaCache()
103
+ ```
104
+ """
105
+
106
+ # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
107
+ def __init__(
108
+ self,
109
+ config: PretrainedConfig,
110
+ batch_size: int = None,
111
+ dtype: torch.dtype = torch.float16,
112
+ device: Optional[Union[torch.device, str]] = None,
113
+ max_batch_size: Optional[int] = None,
114
+ ):
115
+ if max_batch_size is not None:
116
+ logger.warning_once(
117
+ f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
118
+ "v4.46. Use the more precisely named 'batch_size' argument instead."
119
+ )
120
+ self.dtype = dtype
121
+ self.batch_size = batch_size or max_batch_size
122
+ self.intermediate_size = config.intermediate_size
123
+ self.ssm_state_size = config.state_size
124
+ self.conv_kernel_size = config.conv_kernel
125
+
126
+ self.conv_states: torch.Tensor = torch.zeros(
127
+ config.num_hidden_layers,
128
+ self.batch_size,
129
+ self.intermediate_size,
130
+ self.conv_kernel_size,
131
+ device=device,
132
+ dtype=dtype,
133
+ )
134
+ self.ssm_states: torch.Tensor = torch.zeros(
135
+ config.num_hidden_layers,
136
+ self.batch_size,
137
+ self.intermediate_size,
138
+ self.ssm_state_size,
139
+ device=device,
140
+ dtype=dtype,
141
+ )
142
+
143
+ torch._dynamo.mark_static_address(self.conv_states)
144
+ torch._dynamo.mark_static_address(self.ssm_states)
145
+
146
+ def update_conv_state(
147
+ self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
148
+ ) -> torch.Tensor:
149
+ conv_state = self.conv_states[layer_idx]
150
+ cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
151
+
152
+ conv_state = conv_state.roll(shifts=-1, dims=-1)
153
+ conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device)
154
+ self.conv_states[layer_idx].zero_()
155
+ self.conv_states[layer_idx] += conv_state
156
+ return self.conv_states[layer_idx]
157
+
158
+ def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
159
+ self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device)
160
+ return self.ssm_states[layer_idx]
161
+
162
+ def reset(self):
163
+ self.conv_states.zero_()
164
+ self.ssm_states.zero_()
165
+
166
+
167
+ class MambaMixer(nn.Module):
168
+ """
169
+ Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
170
+ A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
171
+ ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
172
+ and is why Mamba is called **selective** state spaces)
173
+ """
174
+
175
+ def __init__(self, config: MambaConfig, layer_idx: int):
176
+ super().__init__()
177
+ self.config = config
178
+ self.hidden_size = config.hidden_size
179
+ self.ssm_state_size = config.state_size
180
+ self.conv_kernel_size = config.conv_kernel
181
+ self.intermediate_size = config.intermediate_size
182
+ self.time_step_rank = int(config.time_step_rank)
183
+ self.layer_idx = layer_idx
184
+ self.use_conv_bias = config.use_conv_bias
185
+ self.conv1d = nn.Conv1d(
186
+ in_channels=self.intermediate_size,
187
+ out_channels=self.intermediate_size,
188
+ bias=config.use_conv_bias,
189
+ kernel_size=config.conv_kernel,
190
+ groups=self.intermediate_size,
191
+ padding=config.conv_kernel - 1,
192
+ )
193
+
194
+ self.activation = config.hidden_act
195
+ self.act = ACT2FN[config.hidden_act]
196
+
197
+ # projection of the input hidden states
198
+ self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias)
199
+ # selective projection used to make dt, B and C input dependant
200
+ self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
201
+ # time step projection (discretization)
202
+ self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
203
+
204
+ # S4D real initialization. These are not discretized!
205
+ # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
206
+ A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :]
207
+ A = A.expand(self.intermediate_size, -1).contiguous()
208
+
209
+ self.A_log = nn.Parameter(torch.log(A))
210
+ self.D = nn.Parameter(torch.ones(self.intermediate_size))
211
+ self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
212
+ self.use_bias = config.use_bias
213
+
214
+ if not is_fast_path_available:
215
+ logger.warning_once(
216
+ "The fast path is not available because on of "
217
+ "`(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
218
+ " is None. Falling back to the naive implementation. "
219
+ "To install follow https://github.com/state-spaces/mamba/#installation and"
220
+ " https://github.com/Dao-AILab/causal-conv1d"
221
+ )
222
+
223
+ def cuda_kernels_forward(
224
+ self,
225
+ hidden_states: torch.Tensor,
226
+ cache_params: Optional[MambaCache] = None,
227
+ cache_position: Optional[torch.LongTensor] = None,
228
+ attention_mask: Optional[torch.LongTensor] = None,
229
+ ):
230
+ # 1. Gated MLP's linear projection
231
+ projected_states = self.in_proj(hidden_states).transpose(1, 2)
232
+
233
+ if self.training and cache_params is None: # Doesn't support outputting the states -> used for training
234
+ contextualized_states = mamba_inner_fn(
235
+ projected_states,
236
+ self.conv1d.weight,
237
+ self.conv1d.bias if self.use_conv_bias else None,
238
+ self.x_proj.weight,
239
+ self.dt_proj.weight,
240
+ self.out_proj.weight,
241
+ self.out_proj.bias.float() if self.use_bias else None,
242
+ -torch.exp(self.A_log.float()),
243
+ None, # input-dependent B
244
+ None, # input-dependent C
245
+ self.D.float(),
246
+ delta_bias=self.dt_proj.bias.float(),
247
+ delta_softplus=True,
248
+ )
249
+
250
+ else:
251
+ hidden_states, gate = projected_states.chunk(2, dim=1)
252
+
253
+ if attention_mask is not None:
254
+ hidden_states = hidden_states * attention_mask.unsqueeze(1)
255
+
256
+ # 2. Convolution sequence transformation
257
+ conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
258
+ if cache_params is not None and cache_position[0] > 0:
259
+ hidden_states = causal_conv1d_update(
260
+ hidden_states.squeeze(-1),
261
+ cache_params.conv_states[self.layer_idx],
262
+ conv_weights,
263
+ self.conv1d.bias,
264
+ self.activation,
265
+ )
266
+ hidden_states = hidden_states.unsqueeze(-1)
267
+ else:
268
+ if cache_params is not None:
269
+ conv_states = nn.functional.pad(
270
+ hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)
271
+ )
272
+ cache_params.update_conv_state(self.layer_idx, conv_states, cache_position)
273
+ hidden_states = causal_conv1d_fn(
274
+ hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
275
+ )
276
+
277
+ if attention_mask is not None:
278
+ hidden_states = hidden_states * attention_mask.unsqueeze(1)
279
+
280
+ # 3. State Space Model sequence transformation
281
+ # 3.a. input varying initialization of time_step, B and C
282
+ ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
283
+ time_step, B, C = torch.split(
284
+ ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
285
+ )
286
+ discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)
287
+
288
+ A = -torch.exp(self.A_log.float())
289
+ # 3.c perform the recurrence y ← SSM(A, B, C)(x)
290
+ time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None
291
+ if cache_params is not None and cache_position[0] > 0:
292
+ scan_outputs = selective_state_update(
293
+ cache_params.ssm_states[self.layer_idx],
294
+ hidden_states[..., 0],
295
+ discrete_time_step[..., 0],
296
+ A,
297
+ B[:, 0],
298
+ C[:, 0],
299
+ self.D,
300
+ gate[..., 0],
301
+ time_proj_bias,
302
+ dt_softplus=True,
303
+ ).unsqueeze(-1)
304
+ else:
305
+ scan_outputs, ssm_state = selective_scan_fn(
306
+ hidden_states,
307
+ discrete_time_step,
308
+ A,
309
+ B.transpose(1, 2),
310
+ C.transpose(1, 2),
311
+ self.D.float(),
312
+ gate,
313
+ time_proj_bias,
314
+ delta_softplus=True,
315
+ return_last_state=True,
316
+ )
317
+ if ssm_state is not None and cache_params is not None:
318
+ cache_params.update_ssm_state(self.layer_idx, ssm_state)
319
+
320
+ # 4. Final linear projection
321
+ contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
322
+ return contextualized_states
323
+
324
+ def slow_forward(
325
+ self,
326
+ input_states,
327
+ cache_params: Optional[MambaCache] = None,
328
+ cache_position: Optional[torch.LongTensor] = None,
329
+ attention_mask: Optional[torch.LongTensor] = None
330
+ ):
331
+ batch_size, seq_len, _ = input_states.shape
332
+ dtype = input_states.dtype
333
+ # 1. Gated MLP's linear projection
334
+ # [batch, 2 * intermediate_size, seq_len]
335
+ projected_states = self.in_proj(input_states).transpose(1, 2)
336
+ hidden_states, gate = projected_states.chunk(2, dim=1)
337
+
338
+ if attention_mask is not None:
339
+ hidden_states = hidden_states * attention_mask.unsqueeze(1)
340
+
341
+ # 2. Convolution sequence transformation
342
+ if cache_params is not None:
343
+ ssm_state = cache_params.ssm_states[self.layer_idx].clone()
344
+ ssm_state = ssm_state.to(hidden_states.device)
345
+ # use `cache_position.shape[0]` to check whether we are in prefill
346
+ # stage, it's equivalent to check `cache_position[0] == 0`, which
347
+ # breaks dynamo fullgraph constraints
348
+ if cache_position.shape[0] == self.conv_kernel_size:
349
+ conv_state = nn.functional.pad(
350
+ hidden_states,
351
+ (self.conv_kernel_size - hidden_states.shape[-1], 0)
352
+ )
353
+
354
+ cache_params.update_conv_state(self.layer_idx, conv_state, cache_position)
355
+ # [batch, intermediate_size, seq_len]
356
+ hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len])
357
+ else:
358
+ conv_state = cache_params.update_conv_state(self.layer_idx, hidden_states, cache_position)
359
+ hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
360
+ if self.use_conv_bias:
361
+ hidden_states += self.conv1d.bias
362
+ # [batch, intermediate_size, 1] : decoding
363
+ hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1)
364
+ else:
365
+ ssm_state = torch.zeros(
366
+ (batch_size, self.intermediate_size, self.ssm_state_size),
367
+ device=hidden_states.device, dtype=dtype
368
+ )
369
+ # [batch, intermediate_size, seq_len]
370
+ hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len])
371
+
372
+ if attention_mask is not None:
373
+ hidden_states = hidden_states * attention_mask.unsqueeze(1)
374
+
375
+ # 3. State Space Model sequence transformation
376
+ # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
377
+ ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
378
+ time_step, B, C = torch.split(
379
+ ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
380
+ )
381
+ # [batch, seq_len, intermediate_size]
382
+ discrete_time_step = self.dt_proj(time_step)
383
+ # [batch, intermediate_size, seq_len]
384
+ discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2)
385
+
386
+ # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
387
+ # [intermediate_size, ssm_state_size]
388
+ A = -torch.exp(self.A_log.float())
389
+ # [batch, intermediate_size, seq_len, ssm_state_size]
390
+ discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None])
391
+ # [batch, intermediate_size, seq_len, ssm_state_size]
392
+ discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float()
393
+ deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
394
+
395
+ # 3.c perform the recurrence y ← SSM(A, B, C)(x)
396
+ scan_outputs = []
397
+ for i in range(seq_len):
398
+ # [batch, intermediade_size, ssm_state]
399
+ ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :]
400
+ # [batch, intermediade_size, 1]
401
+ scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1))
402
+ scan_outputs.append(scan_output[:, :, 0])
403
+ # [batch, seq_len, intermediade_size]
404
+ scan_output = torch.stack(scan_outputs, dim=-1)
405
+ scan_output = scan_output + (hidden_states * self.D[None, :, None])
406
+ scan_output = (scan_output * self.act(gate))
407
+
408
+ if cache_params is not None:
409
+ cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
410
+
411
+ # 4. Final linear projection
412
+ # [batch, seq_len, hidden_size]
413
+ contextualized_states = self.out_proj(scan_output.transpose(1, 2))
414
+ return contextualized_states
415
+ # fmt: on
416
+
417
+ def forward(
418
+ self,
419
+ hidden_states,
420
+ cache_params: Optional[MambaCache] = None,
421
+ cache_position: Optional[torch.LongTensor] = None,
422
+ attention_mask: Optional[torch.LongTensor] = None,
423
+ ):
424
+ if is_fast_path_available and "cuda" in self.x_proj.weight.device.type:
425
+ return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
426
+ return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask)
427
+
428
+
429
+ class MambaBlock(nn.Module):
430
+ def __init__(self, config, layer_idx):
431
+ super().__init__()
432
+ self.config = config
433
+ self.layer_idx = layer_idx
434
+ self.residual_in_fp32 = config.residual_in_fp32
435
+ self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
436
+ self.mixer = MambaMixer(config, layer_idx=layer_idx)
437
+
438
+ def forward(
439
+ self,
440
+ hidden_states,
441
+ cache_params: Optional[MambaCache] = None,
442
+ cache_position: Optional[torch.LongTensor] = None,
443
+ attention_mask: Optional[torch.LongTensor] = None,
444
+ ):
445
+ residual = hidden_states
446
+ hidden_states = self.norm(hidden_states)
447
+ if self.residual_in_fp32:
448
+ residual = residual.to(torch.float32)
449
+
450
+ hidden_states = self.mixer(
451
+ hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask
452
+ )
453
+ hidden_states = residual + hidden_states
454
+ if self.residual_in_fp32:
455
+ hidden_states = hidden_states.to(dtype=self.norm.weight.dtype)
456
+ return hidden_states
457
+
458
+
459
+ class MambaPreTrainedModel(PreTrainedModel):
460
+ """
461
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
462
+ models.
463
+ """
464
+
465
+ config_class = MambaConfig
466
+ base_model_prefix = "backbone"
467
+ _no_split_modules = ["MambaBlock", "MambaMixer"]
468
+ supports_gradient_checkpointing = True
469
+ _is_stateful = True
470
+
471
+ def _init_weights(self, module):
472
+ """Initialize the weights."""
473
+ if isinstance(module, nn.Linear):
474
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
475
+ if module.bias is not None:
476
+ if not getattr(module.bias, "_no_reinit", False):
477
+ nn.init.zeros_(module.bias)
478
+ elif isinstance(module, MambaMixer):
479
+ module.A_log._no_weight_decay = True
480
+ module.D._no_weight_decay = True
481
+
482
+ dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale
483
+ if self.config.time_step_init_scheme == "constant":
484
+ nn.init.constant_(module.dt_proj.weight, dt_init_std)
485
+ elif self.config.time_step_init_scheme == "random":
486
+ nn.init.uniform_(module.dt_proj.weight, -dt_init_std, dt_init_std)
487
+
488
+ dt = torch.exp(
489
+ torch.rand(self.config.intermediate_size)
490
+ * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
491
+ + math.log(self.config.time_step_min)
492
+ ).clamp(min=self.config.time_step_floor)
493
+ # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
494
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
495
+ with torch.no_grad():
496
+ module.dt_proj.bias.data = nn.Parameter(inv_dt.to(module.dt_proj.bias.device))
497
+ module.dt_proj.bias._no_reinit = True
498
+ elif isinstance(module, nn.Embedding):
499
+ nn.init.normal_(module.weight, std=self.config.initializer_range)
500
+ elif hasattr(module, 'reset_parameters'):
501
+ module.reset_parameters()
502
+
503
+ if self.config.rescale_prenorm_residual:
504
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
505
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
506
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
507
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
508
+ #
509
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
510
+ for name, p in module.named_parameters():
511
+ if name in ["out_proj.weight"]:
512
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
513
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
514
+ # We need to reinit p since this code could be called multiple times
515
+ # Having just p *= scale would repeatedly scale it down
516
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
517
+ with torch.no_grad():
518
+ p /= math.sqrt(self.config.num_hidden_layers)
519
+
520
+
521
+ @dataclass
522
+ class MambaOutput(ModelOutput):
523
+ """
524
+ Class for the MAMBA model outputs.
525
+
526
+ Args:
527
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
528
+ Sequence of hidden-states at the output of the last layer of the model.
529
+ cache_params (`MambaCache`):
530
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
531
+ avoid providing the old `input_ids`.
532
+
533
+ Includes both the State space model state matrices after the selective scan, and the Convolutional states
534
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*,
535
+ returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
536
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
537
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
538
+
539
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
540
+ """
541
+
542
+ last_hidden_state: Optional[torch.FloatTensor] = None
543
+ cache_params: Optional[MambaCache] = None
544
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
545
+
546
+
547
+ @dataclass
548
+ class MambaCausalLMOutput(ModelOutput):
549
+ """
550
+ Base class for causal language model (or autoregressive) outputs.
551
+
552
+ Args:
553
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
554
+ Language modeling loss (for next-token prediction).
555
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
556
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
557
+ cache_params (`MambaCache`):
558
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
559
+ avoid providing the old `input_ids`.
560
+
561
+ Includes both the State space model state matrices after the selective scan, and the Convolutional states
562
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*,
563
+ returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
564
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
565
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
566
+
567
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
568
+ """
569
+
570
+ loss: Optional[torch.FloatTensor] = None
571
+ logits: Optional[torch.FloatTensor] = None
572
+ cache_params: Optional[MambaCache] = None
573
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
574
+
575
+
576
+ class MambaModel(MambaPreTrainedModel):
577
+ def __init__(self, config):
578
+ super().__init__(config)
579
+
580
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
581
+ self.layers = nn.ModuleList([MambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])
582
+
583
+ self.gradient_checkpointing = False
584
+ self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
585
+ # Initialize weights and apply final processing
586
+ self._register_load_state_dict_pre_hook(self.load_hook)
587
+ self.post_init()
588
+
589
+ def load_hook(self, state_dict, prefix, *args):
590
+ for k in state_dict:
591
+ if "embedding." in k:
592
+ state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k)
593
+ break
594
+
595
+ def get_input_embeddings(self):
596
+ return self.embeddings
597
+
598
+ def set_input_embeddings(self, new_embeddings):
599
+ self.embeddings = new_embeddings
600
+
601
+ def forward(
602
+ self,
603
+ input_ids: Optional[torch.LongTensor] = None,
604
+ inputs_embeds: Optional[torch.LongTensor] = None,
605
+ cache_params: Optional[MambaCache] = None,
606
+ use_cache: Optional[bool] = None,
607
+ output_hidden_states: Optional[bool] = None,
608
+ return_dict: Optional[bool] = None,
609
+ cache_position: Optional[torch.LongTensor] = None,
610
+ attention_mask: Optional[torch.LongTensor] = None,
611
+ ) -> Union[Tuple, MambaOutput]:
612
+ output_hidden_states = (
613
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
614
+ )
615
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
616
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
617
+
618
+ if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor
619
+ raise ValueError(
620
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
621
+ )
622
+
623
+ if inputs_embeds is None:
624
+ inputs_embeds = self.embeddings(input_ids)
625
+
626
+ if self.gradient_checkpointing and self.training and use_cache:
627
+ use_cache = False
628
+
629
+ if use_cache:
630
+ if cache_params is None:
631
+ cache_params = MambaCache(
632
+ self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
633
+ )
634
+ cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device)
635
+ elif cache_position is None:
636
+ # cases when we do manual forward instead of using `model.generate` which will initiate
637
+ # `cache_position` and makes sure it is not None, throw error here instead of doing some
638
+ # hack to conjecture the current cache position
639
+ raise ValueError(
640
+ "You have to specify the `cache_position` manually when `use_cache=True` and `cache_params` is passed, "
641
+ "you don't have to pass a `cache_params` if you are in prefilling stage because in that case it will "
642
+ "be initialized for you automatically"
643
+ )
644
+ else:
645
+ cache_params = None
646
+
647
+ hidden_states = inputs_embeds
648
+ all_hidden_states = () if output_hidden_states else None
649
+ for mixer_block in self.layers:
650
+ if self.gradient_checkpointing and self.training:
651
+ hidden_states = self._gradient_checkpointing_func(
652
+ mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask
653
+ )
654
+ else:
655
+ hidden_states = mixer_block(
656
+ hidden_states,
657
+ cache_params=cache_params,
658
+ cache_position=cache_position,
659
+ attention_mask=attention_mask,
660
+ )
661
+
662
+ if output_hidden_states:
663
+ all_hidden_states = all_hidden_states + (hidden_states,)
664
+
665
+ hidden_states = self.norm_f(hidden_states)
666
+
667
+ if output_hidden_states:
668
+ all_hidden_states = all_hidden_states + (hidden_states,)
669
+
670
+ if not return_dict:
671
+ return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None)
672
+
673
+ return MambaOutput(
674
+ last_hidden_state=hidden_states,
675
+ cache_params=cache_params if use_cache else None,
676
+ hidden_states=all_hidden_states,
677
+ )
678
+
679
+
680
+ class MambaForCausalLM(MambaPreTrainedModel, GenerationMixin):
681
+
682
+ _tied_weights_keys = ["lm_head.weight"]
683
+
684
+ def __init__(self, config):
685
+ super().__init__(config)
686
+ self.backbone = MambaModel(config)
687
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
688
+ self.criterion = None
689
+
690
+ # Initialize weights and apply final processing
691
+ self.post_init()
692
+
693
+ def get_output_embeddings(self):
694
+ return self.lm_head
695
+
696
+ def set_output_embeddings(self, new_embeddings):
697
+ self.lm_head = new_embeddings
698
+
699
+ def get_input_embeddings(self):
700
+ return self.backbone.get_input_embeddings()
701
+
702
+ def set_input_embeddings(self, new_embeddings):
703
+ return self.backbone.set_input_embeddings(new_embeddings)
704
+
705
+ def _update_model_kwargs_for_generation(
706
+ self, outputs: ModelOutput,
707
+ model_kwargs: Dict[str, Any],
708
+ num_new_tokens: int = 1,
709
+ **kwargs
710
+ ) -> Dict[str, Any]:
711
+ model_kwargs["cache_params"] = outputs.get("cache_params", None)
712
+ if (
713
+ model_kwargs.get("use_cache", True)
714
+ and "cache_position" in model_kwargs
715
+ and model_kwargs["cache_position"] is not None
716
+ ):
717
+ model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
718
+
719
+ if "attention_mask" in model_kwargs:
720
+ attention_mask = model_kwargs["attention_mask"]
721
+ model_kwargs["attention_mask"] = torch.cat(
722
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
723
+ )
724
+
725
+ return model_kwargs
726
+
727
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
728
+ def prepare_inputs_for_generation(
729
+ self,
730
+ input_ids,
731
+ inputs_embeds=None,
732
+ use_cache=None,
733
+ cache_params: Optional[MambaCache] = None,
734
+ cache_position: Optional[torch.LongTensor] = None,
735
+ attention_mask: Optional[torch.LongTensor] = None,
736
+ logits_to_keep: Optional[int] = None,
737
+ **kwargs,
738
+ ):
739
+ if use_cache:
740
+ # `cache_position` should have been initialized in `generate`
741
+ if cache_position is None:
742
+ raise ValueError(
743
+ "`cache_position` should not be None as it should have been initialized in "
744
+ "`model.generate`, you are responsible for passing in a valid `cache_position` if "
745
+ "you are calling `prepare_inputs_for_generation` directly with `use_cache=True`"
746
+ )
747
+ if cache_position[0] > 0:
748
+ input_ids = input_ids[:, -1].unsqueeze(-1)
749
+
750
+ if attention_mask is not None:
751
+ attention_mask = None
752
+
753
+ else:
754
+ # we initialize the `cache_position` to full size of `conv_states` at prefill stage
755
+ # considering padding will be applied when input length is shorter, and truncation
756
+ # will be applied when it is longer, so it will be equivalent to always have it match
757
+ # the length of `cache_params.conv_states`, which is `config.conv_kernel`
758
+ cache_position = torch.arange(0, self.config.conv_kernel, device=input_ids.device)
759
+
760
+ if inputs_embeds is not None and cache_params is None:
761
+ model_inputs = {"inputs_embeds": inputs_embeds}
762
+ else:
763
+ model_inputs = {"input_ids": input_ids.contiguous()}
764
+
765
+ if logits_to_keep is not None:
766
+ model_inputs['logits_to_keep'] = logits_to_keep
767
+
768
+ model_inputs.update({
769
+ 'cache_params': cache_params,
770
+ 'use_cache': use_cache,
771
+ 'cache_position': cache_position,
772
+ 'attention_mask': attention_mask,
773
+ 'logits_to_keep': logits_to_keep,
774
+ })
775
+ return model_inputs
776
+
777
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
778
+ def forward(
779
+ self,
780
+ input_ids: Optional[torch.LongTensor] = None,
781
+ attention_mask: Optional[torch.LongTensor] = None,
782
+ inputs_embeds: Optional[torch.FloatTensor] = None,
783
+ cache_params: Optional[MambaCache] = None,
784
+ labels: Optional[torch.LongTensor] = None,
785
+ output_hidden_states: Optional[bool] = None,
786
+ return_dict: Optional[bool] = None,
787
+ use_cache: Optional[bool] = None,
788
+ cache_position: Optional[torch.Tensor] = None,
789
+ logits_to_keep: Optional[int] = 0,
790
+ **kwargs, # for now we need this for generation
791
+ ) -> Union[Tuple, MambaCausalLMOutput]:
792
+ r"""
793
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
794
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
795
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
796
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
797
+ """
798
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
799
+
800
+ mamba_outputs = self.backbone(
801
+ input_ids,
802
+ cache_params=cache_params,
803
+ inputs_embeds=inputs_embeds,
804
+ output_hidden_states=output_hidden_states,
805
+ return_dict=return_dict,
806
+ use_cache=use_cache,
807
+ cache_position=cache_position,
808
+ attention_mask=attention_mask,
809
+ )
810
+ hidden_states = mamba_outputs[0]
811
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
812
+
813
+ loss, logits = None, None
814
+ if not fuse_linear_and_cross_entropy or labels is None:
815
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
816
+ if labels is not None:
817
+ if getattr(self, 'criterion', None) is None:
818
+ if fuse_linear_and_cross_entropy:
819
+ criterion = FusedLinearCrossEntropyLoss()
820
+ elif self.config.fuse_cross_entropy:
821
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
822
+ else:
823
+ criterion = nn.CrossEntropyLoss()
824
+ else:
825
+ criterion = self.criterion
826
+ # Enable model parallelism
827
+ labels = labels.to(hidden_states.device)
828
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
829
+ if fuse_linear_and_cross_entropy:
830
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
831
+ else:
832
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
833
+
834
+ if not return_dict:
835
+ output = (logits,) + mamba_outputs[1:]
836
+ return (loss,) + output if loss is not None else output
837
+
838
+ return MambaCausalLMOutput(
839
+ loss=loss,
840
+ logits=logits,
841
+ cache_params=mamba_outputs.cache_params,
842
+ hidden_states=mamba_outputs.hidden_states,
843
+ )
fla/models/nsa/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.nsa.configuration_nsa import NSAConfig
6
+ from fla.models.nsa.modeling_nsa import NSAForCausalLM, NSAModel
7
+
8
+ AutoConfig.register(NSAConfig.model_type, NSAConfig)
9
+ AutoModel.register(NSAConfig, NSAModel)
10
+ AutoModelForCausalLM.register(NSAConfig, NSAForCausalLM)
11
+
12
+
13
+ __all__ = [
14
+ 'NSAConfig', 'NSAModel', 'NSAForCausalLM',
15
+ ]
fla/models/nsa/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (653 Bytes). View file
 
fla/models/retnet/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (678 Bytes). View file
 
fla/models/rwkv6/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.rwkv6.configuration_rwkv6 import RWKV6Config
6
+ from fla.models.rwkv6.modeling_rwkv6 import RWKV6ForCausalLM, RWKV6Model
7
+
8
+ AutoConfig.register(RWKV6Config.model_type, RWKV6Config, True)
9
+ AutoModel.register(RWKV6Config, RWKV6Model, True)
10
+ AutoModelForCausalLM.register(RWKV6Config, RWKV6ForCausalLM, True)
11
+
12
+
13
+ __all__ = ['RWKV6Config', 'RWKV6ForCausalLM', 'RWKV6Model']
fla/models/rwkv6/configuration_rwkv6.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class RWKV6Config(PretrainedConfig):
9
+
10
+ model_type = 'rwkv6'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ attn_mode: str = "chunk",
16
+ hidden_size: int = 2048,
17
+ expand_k: int = 0.5,
18
+ expand_v: int = 1,
19
+ hidden_ratio: Optional[int] = 3.5,
20
+ intermediate_size: Optional[int] = None,
21
+ num_hidden_layers: int = 24,
22
+ num_heads: int = 4,
23
+ proj_low_rank_dim: int = 32,
24
+ gate_low_rank_dim: int = 64,
25
+ hidden_act: str = "sqrelu",
26
+ max_position_embeddings: int = 2048,
27
+ norm_first: bool = True,
28
+ norm_bias: bool = True,
29
+ norm_eps: float = 1e-5,
30
+ attn: Optional[Dict] = None,
31
+ use_cache: bool = True,
32
+ pad_token_id: int = None,
33
+ bos_token_id: int = 1,
34
+ eos_token_id: int = 2,
35
+ tie_word_embeddings: bool = False,
36
+ initializer_range: float = 0.006,
37
+ fuse_norm: bool = True,
38
+ fuse_cross_entropy: bool = True,
39
+ vocab_size: int = 32000,
40
+ **kwargs
41
+ ):
42
+ self.attn_mode = attn_mode
43
+ self.hidden_size = hidden_size
44
+ self.expand_k = expand_k
45
+ self.expand_v = expand_v
46
+ self.hidden_ratio = hidden_ratio
47
+ self.intermediate_size = intermediate_size
48
+ self.norm_first = norm_first
49
+ self.num_hidden_layers = num_hidden_layers
50
+ self.num_heads = num_heads
51
+ self.proj_low_rank_dim = proj_low_rank_dim
52
+ self.gate_low_rank_dim = gate_low_rank_dim
53
+ self.hidden_act = hidden_act
54
+ self.max_position_embeddings = max_position_embeddings
55
+ self.norm_bias = norm_bias
56
+ self.norm_eps = norm_eps
57
+ self.attn = attn
58
+ self.use_cache = use_cache
59
+ self.initializer_range = initializer_range
60
+ self.fuse_norm = fuse_norm
61
+ self.fuse_cross_entropy = fuse_cross_entropy
62
+ self.vocab_size = vocab_size
63
+
64
+ if attn is not None:
65
+ if not isinstance(attn, Dict):
66
+ raise ValueError("attn must be a dictionary")
67
+ if 'layers' not in attn:
68
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
69
+ if 'num_heads' not in attn:
70
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
71
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
72
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
73
+ attn['window_size'] = attn.get('window_size', None)
74
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
75
+
76
+ super().__init__(
77
+ pad_token_id=pad_token_id,
78
+ bos_token_id=bos_token_id,
79
+ eos_token_id=eos_token_id,
80
+ tie_word_embeddings=tie_word_embeddings,
81
+ **kwargs,
82
+ )
fla/models/rwkv7/__pycache__/configuration_rwkv7.cpython-312.pyc ADDED
Binary file (4.24 kB). View file
 
fla/models/rwkv7/modeling_rwkv7.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.rwkv7 import RWKV7Attention
20
+ from fla.models.rwkv7.configuration_rwkv7 import RWKV7Config
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, LayerNorm
23
+ from fla.modules.activations import ACT2FN
24
+
25
+ if TYPE_CHECKING:
26
+ from transformers.processing_utils import Unpack
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ class RWKV7FeedForward(nn.Module):
32
+
33
+ def __init__(
34
+ self,
35
+ hidden_size: int,
36
+ hidden_ratio: Optional[int] = None,
37
+ intermediate_size: Optional[int] = None,
38
+ hidden_act: str = 'sqrelu',
39
+ layer_idx: int = None
40
+ ) -> RWKV7FeedForward:
41
+ super().__init__()
42
+
43
+ self.hidden_size = hidden_size
44
+ if hidden_ratio is None:
45
+ hidden_ratio = 4
46
+ if intermediate_size is None:
47
+ intermediate_size = int(hidden_size * hidden_ratio)
48
+ intermediate_size = 32 * ((intermediate_size + 32 - 1) // 32)
49
+ self.hidden_ratio = hidden_ratio
50
+ self.intermediate_size = intermediate_size
51
+
52
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
53
+
54
+ self.x_k = nn.Parameter(torch.zeros(hidden_size))
55
+
56
+ self.key = nn.Linear(hidden_size, intermediate_size, bias=False)
57
+ self.value = nn.Linear(intermediate_size, hidden_size, bias=False)
58
+ self.act_fn = ACT2FN[hidden_act]
59
+
60
+ self.layer_idx = layer_idx
61
+
62
+ def forward(
63
+ self,
64
+ x: torch.Tensor,
65
+ attention_mask: Optional[torch.Tensor] = None,
66
+ state: Optional[Cache] = None
67
+ ) -> torch.Tensor:
68
+ if attention_mask is not None:
69
+ x = x.mul(attention_mask[:, -x.shape[-2]:, None])
70
+ if x.shape[1] == 1 and state is not None and state[self.layer_idx]['ffn_state'] is not None:
71
+ shifted = state[self.layer_idx]['ffn_state'].unsqueeze(1)
72
+ else:
73
+ shifted = self.time_shift(x)
74
+ if state is not None and state[self.layer_idx]['ffn_state'] is not None:
75
+ shifted[:, 0] = state[self.layer_idx]['ffn_state'][-1]
76
+ if state is not None:
77
+ # no need to update the offset twice
78
+ state.update(ffn_state=x[:, -1], layer_idx=self.layer_idx, offset=0)
79
+ return self.value(self.act_fn(self.key(x.addcmul(shifted - x, self.x_k)))), state
80
+
81
+
82
+ class RWKV7Block(nn.Module):
83
+
84
+ def __init__(
85
+ self,
86
+ config: RWKV7Config,
87
+ layer_idx: int
88
+ ) -> RWKV7Block:
89
+ super().__init__()
90
+
91
+ self.config = config
92
+ self.layer_idx = layer_idx
93
+
94
+ if config.norm_first and layer_idx == 0:
95
+ self.pre_norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)(
96
+ config.hidden_size,
97
+ bias=config.norm_bias,
98
+ eps=config.norm_eps
99
+ )
100
+ self.attn_norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)(
101
+ config.hidden_size,
102
+ bias=config.norm_bias,
103
+ eps=config.norm_eps
104
+ )
105
+ if config.attn is not None and layer_idx in config.attn['layers']:
106
+ self.attn = Attention(
107
+ hidden_size=config.hidden_size,
108
+ num_heads=config.attn['num_heads'],
109
+ num_kv_heads=config.attn['num_kv_heads'],
110
+ qkv_bias=config.attn['qkv_bias'],
111
+ window_size=config.attn['window_size'],
112
+ rope_theta=config.attn['rope_theta'],
113
+ max_position_embeddings=config.max_position_embeddings,
114
+ layer_idx=layer_idx
115
+ )
116
+ else:
117
+ self.attn = RWKV7Attention(
118
+ mode=config.attn_mode,
119
+ hidden_size=config.hidden_size,
120
+ head_dim=config.head_dim,
121
+ num_heads=config.num_heads,
122
+ decay_low_rank_dim=config.decay_low_rank_dim,
123
+ gate_low_rank_dim=config.gate_low_rank_dim,
124
+ a_low_rank_dim=config.a_low_rank_dim,
125
+ v_low_rank_dim=config.v_low_rank_dim,
126
+ norm_eps=config.norm_eps,
127
+ fuse_norm=config.fuse_norm,
128
+ layer_idx=layer_idx,
129
+ value_dim=config.value_dim[layer_idx]
130
+ )
131
+ self.ffn_norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)(
132
+ config.hidden_size,
133
+ bias=config.norm_bias,
134
+ eps=config.norm_eps
135
+ )
136
+ self.ffn = RWKV7FeedForward(
137
+ hidden_size=config.hidden_size,
138
+ hidden_ratio=config.hidden_ratio,
139
+ intermediate_size=config.intermediate_size,
140
+ hidden_act=config.hidden_act,
141
+ layer_idx=layer_idx
142
+ )
143
+
144
+ def forward(
145
+ self,
146
+ hidden_states: torch.Tensor,
147
+ attention_mask: Optional[torch.Tensor] = None,
148
+ past_key_values: Optional[Cache] = None,
149
+ use_cache: Optional[bool] = False,
150
+ output_attentions: Optional[bool] = False,
151
+ v_first: torch.Tensor = None,
152
+ **kwargs,
153
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
154
+ residual = self.pre_norm(hidden_states) if hasattr(self, 'pre_norm') else hidden_states
155
+ hidden_states = self.attn_norm(residual)
156
+ hidden_states, attentions, past_key_values, v_first = self.attn(
157
+ hidden_states=hidden_states,
158
+ attention_mask=attention_mask,
159
+ past_key_values=past_key_values,
160
+ use_cache=use_cache,
161
+ output_attentions=output_attentions,
162
+ v_first=v_first,
163
+ **kwargs
164
+ )
165
+ if self.config.fuse_norm:
166
+ hidden_states, residual = self.ffn_norm(hidden_states, residual, True)
167
+ else:
168
+ hidden_states = residual + hidden_states
169
+ residual = hidden_states
170
+ hidden_states = self.ffn_norm(hidden_states)
171
+ hidden_states, past_key_values = self.ffn(hidden_states, attention_mask, past_key_values)
172
+ hidden_states = residual + hidden_states
173
+
174
+ outputs = (hidden_states, attentions, past_key_values, v_first)
175
+
176
+ return outputs
177
+
178
+
179
+ class RWKV7PreTrainedModel(PreTrainedModel):
180
+
181
+ config_class = RWKV7Config
182
+ base_model_prefix = 'model'
183
+ supports_gradient_checkpointing = True
184
+ _no_split_modules = ['RWKV7Block']
185
+ _supports_cache_class = True
186
+ _skip_keys_device_placement = ["past_key_values"]
187
+
188
+ def __init__(self, *inputs, **kwargs):
189
+ super().__init__(*inputs, **kwargs)
190
+
191
+ def _init_weights(
192
+ self,
193
+ module: nn.Module,
194
+ rescale_prenorm_residual: bool = True,
195
+ num_residuals_per_layer: int = 2,
196
+ ):
197
+ warnings.warn(
198
+ "RWKV-7 employs a carefully designed initialization strategy tailored to its architecture. "
199
+ "The detailed initialization scheme is currently not implemented here but can be found in the "
200
+ "official code repository. We emphasize that using the recommended initialization is essential "
201
+ "for replicating the results in RWKV-7 paper. Deviations from the prescribed initialization "
202
+ "may lead to performance degradation.\n"
203
+ "Alternatively, please generate initial weights from the official RWKV code repository, and "
204
+ "convert the PyTorch checkpoint into FLA supported format."
205
+ )
206
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
207
+ # Slightly different from the TF version which uses truncated_normal for initialization
208
+ # cf https://github.com/pytorch/pytorch/pull/5617
209
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
210
+ if module.bias is not None:
211
+ nn.init.zeros_(module.bias)
212
+ elif isinstance(module, nn.Parameter):
213
+ nn.init.normal_(module, mean=0.0, std=self.config.initializer_range)
214
+ elif isinstance(module, nn.Embedding):
215
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
216
+ elif hasattr(module, 'reset_parameters'):
217
+ module.reset_parameters()
218
+
219
+ if rescale_prenorm_residual:
220
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
221
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
222
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
223
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
224
+ #
225
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
226
+ p = None
227
+ if hasattr(module, 'o_proj'):
228
+ p = module.o_proj.weight
229
+ elif hasattr(module, 'down_proj'):
230
+ p = module.down_proj.weight
231
+ if p is not None:
232
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
233
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
234
+ # We need to reinit p since this code could be called multiple times
235
+ # Having just p *= scale would repeatedly scale it down
236
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
237
+ with torch.no_grad():
238
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
239
+
240
+
241
+ class RWKV7Model(RWKV7PreTrainedModel):
242
+
243
+ def __init__(self, config: RWKV7Config):
244
+ super().__init__(config)
245
+ self.padding_idx = config.pad_token_id
246
+ self.vocab_size = config.vocab_size
247
+
248
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
249
+ self.layers = nn.ModuleList([RWKV7Block(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
250
+ self.norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)(
251
+ config.hidden_size,
252
+ bias=config.norm_bias,
253
+ eps=config.norm_eps
254
+ )
255
+
256
+ self.gradient_checkpointing = False
257
+
258
+ self.post_init()
259
+
260
+ def get_input_embeddings(self):
261
+ return self.embeddings
262
+
263
+ def set_input_embeddings(self, value):
264
+ self.embeddings = value
265
+
266
+ def forward(
267
+ self,
268
+ input_ids: Optional[torch.LongTensor] = None,
269
+ attention_mask: Optional[torch.Tensor] = None, # noqa
270
+ inputs_embeds: Optional[torch.FloatTensor] = None,
271
+ past_key_values: Optional[Cache] = None,
272
+ use_cache: Optional[bool] = None,
273
+ output_attentions: Optional[bool] = None,
274
+ output_hidden_states: Optional[bool] = None,
275
+ return_dict: Optional[bool] = None,
276
+ **kwargs: Unpack[Dict]
277
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
278
+ if output_attentions:
279
+ warnings.warn("`RWKV7Model` does not `output_attentions` now, setting it to `False`.")
280
+ output_attentions = False
281
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
282
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
283
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
284
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
285
+
286
+ # retrieve input_ids and inputs_embeds
287
+ if input_ids is not None and inputs_embeds is not None:
288
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
289
+ if input_ids is None and inputs_embeds is None:
290
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
291
+
292
+ if inputs_embeds is None:
293
+ inputs_embeds = self.embeddings(input_ids)
294
+ hidden_states = inputs_embeds
295
+
296
+ if use_cache and not isinstance(past_key_values, Cache):
297
+ past_key_values = Cache.from_legacy_cache(past_key_values)
298
+
299
+ if self.gradient_checkpointing and self.training and use_cache:
300
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
301
+ use_cache = False
302
+
303
+ all_hidden_states = () if output_hidden_states else None
304
+ all_attns = () if output_attentions else None
305
+
306
+ v_first = torch.zeros_like(hidden_states)
307
+ for layer in self.layers:
308
+ if output_hidden_states:
309
+ all_hidden_states += (hidden_states,)
310
+
311
+ if self.gradient_checkpointing and self.training:
312
+ hidden_states, attentions, past_key_values, v_first = self._gradient_checkpointing_func(
313
+ layer.__call__,
314
+ hidden_states,
315
+ attention_mask,
316
+ past_key_values,
317
+ use_cache,
318
+ output_attentions,
319
+ v_first,
320
+ **kwargs
321
+ )
322
+ else:
323
+ hidden_states, attentions, past_key_values, v_first = layer(
324
+ hidden_states,
325
+ attention_mask=attention_mask,
326
+ past_key_values=past_key_values,
327
+ use_cache=use_cache,
328
+ output_attentions=output_attentions,
329
+ v_first=v_first,
330
+ **kwargs
331
+ )
332
+
333
+ if output_attentions:
334
+ all_attns += (attentions,)
335
+
336
+ hidden_states = self.norm(hidden_states)
337
+
338
+ # add hidden states from the last decoder layer
339
+ if output_hidden_states:
340
+ all_hidden_states += (hidden_states,)
341
+
342
+ if not return_dict:
343
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
344
+ return BaseModelOutputWithPast(
345
+ last_hidden_state=hidden_states,
346
+ past_key_values=past_key_values,
347
+ hidden_states=all_hidden_states,
348
+ attentions=all_attns
349
+ )
350
+
351
+
352
+ class RWKV7ForCausalLM(RWKV7PreTrainedModel, GenerationMixin):
353
+
354
+ _tied_weights_keys = ["lm_head.weight"]
355
+
356
+ def __init__(self, config):
357
+ super().__init__(config)
358
+ self.model = RWKV7Model(config)
359
+ self.vocab_size = config.vocab_size
360
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
361
+ self.criterion = None
362
+
363
+ # Initialize weights and apply final processing
364
+ self.post_init()
365
+
366
+ def get_input_embeddings(self):
367
+ return self.model.embeddings
368
+
369
+ def set_input_embeddings(self, value):
370
+ self.model.embeddings = value
371
+
372
+ def get_output_embeddings(self):
373
+ return self.lm_head
374
+
375
+ def set_output_embeddings(self, new_embeddings):
376
+ self.lm_head = new_embeddings
377
+
378
+ def set_decoder(self, decoder):
379
+ self.model = decoder
380
+
381
+ def get_decoder(self):
382
+ return self.model
383
+
384
+ def generate(self, *args, **kwargs):
385
+ try:
386
+ return super().generate(*args, **kwargs)
387
+ except AttributeError as exception:
388
+ if 'past_key_values' in str(exception):
389
+ raise AttributeError(
390
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
391
+ f"which is not supported for {self.__class__.__name__}. "
392
+ f"Try another generation strategy instead. "
393
+ f"For the available generation strategies, check this doc: "
394
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
395
+ )
396
+ else:
397
+ raise exception
398
+
399
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
400
+ def prepare_inputs_for_generation(
401
+ self,
402
+ input_ids: torch.LongTensor = None,
403
+ past_key_values: Optional[Cache] = None,
404
+ attention_mask: Optional[torch.Tensor] = None,
405
+ inputs_embeds: Optional[torch.Tensor] = None,
406
+ use_cache: bool = True,
407
+ logits_to_keep: Optional[int] = None,
408
+ **kwargs
409
+ ):
410
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
411
+ if past_key_values is not None and len(past_key_values) > 0:
412
+ input_ids = input_ids[:, -1:]
413
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
414
+ if inputs_embeds is not None and len(past_key_values) == 0:
415
+ model_inputs = {'inputs_embeds': inputs_embeds}
416
+ else:
417
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
418
+ # recompiles graphs as the stride of the inputs is a guard.
419
+ # Ref: https://github.com/huggingface/transformers/pull/29114
420
+ # TODO: use `next_tokens` directly instead.
421
+ model_inputs = {'input_ids': input_ids.contiguous()}
422
+
423
+ if logits_to_keep is not None:
424
+ model_inputs['logits_to_keep'] = logits_to_keep
425
+
426
+ model_inputs.update({
427
+ 'past_key_values': past_key_values,
428
+ 'use_cache': use_cache,
429
+ 'attention_mask': attention_mask,
430
+ })
431
+ return model_inputs
432
+
433
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
434
+ def forward(
435
+ self,
436
+ input_ids: torch.LongTensor = None,
437
+ attention_mask: Optional[torch.Tensor] = None,
438
+ inputs_embeds: Optional[torch.Tensor] = None,
439
+ past_key_values: Optional[Cache] = None,
440
+ labels: Optional[torch.LongTensor] = None,
441
+ shift_labels: Optional[torch.LongTensor] = None,
442
+ use_cache: Optional[bool] = None,
443
+ output_attentions: Optional[bool] = None,
444
+ output_hidden_states: Optional[bool] = None,
445
+ return_dict: Optional[bool] = None,
446
+ logits_to_keep: Optional[int] = 0,
447
+ **kwargs: Unpack[Dict]
448
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
449
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
450
+ output_hidden_states = (
451
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
452
+ )
453
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
454
+
455
+ outputs = self.model(
456
+ input_ids=input_ids,
457
+ attention_mask=attention_mask,
458
+ inputs_embeds=inputs_embeds,
459
+ past_key_values=past_key_values,
460
+ use_cache=use_cache,
461
+ output_attentions=output_attentions,
462
+ output_hidden_states=output_hidden_states,
463
+ return_dict=return_dict,
464
+ **kwargs
465
+ )
466
+
467
+ hidden_states = outputs[0]
468
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
469
+
470
+ loss, logits = None, None
471
+ has_labels = (labels is not None) or (shift_labels is not None)
472
+ if not (fuse_linear_and_cross_entropy and has_labels):
473
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
474
+ if has_labels:
475
+ if getattr(self, 'criterion', None) is None:
476
+ if fuse_linear_and_cross_entropy:
477
+ criterion = FusedLinearCrossEntropyLoss()
478
+ elif self.config.fuse_cross_entropy:
479
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
480
+ else:
481
+ criterion = nn.CrossEntropyLoss()
482
+ else:
483
+ criterion = self.criterion
484
+
485
+ # shift_labels: See https://github.com/huggingface/transformers/pull/36607/files.
486
+ if shift_labels is None:
487
+ shift_labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
488
+ shift_labels = shift_labels.to(hidden_states.device)
489
+
490
+ if fuse_linear_and_cross_entropy:
491
+ loss = criterion(hidden_states, shift_labels, self.lm_head.weight, self.lm_head.bias)
492
+ else:
493
+ loss = criterion(logits.view(shift_labels.numel(), -1), shift_labels.view(-1))
494
+
495
+ if not return_dict:
496
+ output = (logits,) + outputs[1:]
497
+ return (loss,) + output if loss is not None else output
498
+
499
+ return CausalLMOutputWithPast(
500
+ loss=loss,
501
+ logits=logits,
502
+ past_key_values=outputs.past_key_values,
503
+ hidden_states=outputs.hidden_states,
504
+ attentions=outputs.attentions,
505
+ )
fla/models/samba/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.samba.configuration_samba import SambaConfig
6
+ from fla.models.samba.modeling_samba import SambaBlock, SambaForCausalLM, SambaModel
7
+
8
+ AutoConfig.register(SambaConfig.model_type, SambaConfig, True)
9
+ AutoModel.register(SambaConfig, SambaModel, True)
10
+ AutoModelForCausalLM.register(SambaConfig, SambaForCausalLM, True)
11
+
12
+
13
+ __all__ = ['SambaConfig', 'SambaForCausalLM', 'SambaModel', 'SambaBlock']
fla/models/samba/configuration_samba.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import math
4
+ from typing import Dict, Optional
5
+
6
+ from transformers.configuration_utils import PretrainedConfig
7
+
8
+
9
+ class SambaConfig(PretrainedConfig):
10
+
11
+ model_type = "samba"
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2304,
16
+ state_size: int = 16,
17
+ num_hidden_layers: int = 18,
18
+ norm_eps=1e-5,
19
+ pad_token_id: int = 0,
20
+ bos_token_id: int = 1,
21
+ eos_token_id: int = 2,
22
+ expand: int = 2,
23
+ conv_kernel: int = 4,
24
+ use_bias: bool = False,
25
+ use_conv_bias: bool = True,
26
+ hidden_act: str = "swish",
27
+ initializer_range: str = 0.02,
28
+ residual_in_fp32: bool = False,
29
+ time_step_rank: str = "auto",
30
+ time_step_scale: float = 1.0,
31
+ time_step_min: float = 0.001,
32
+ time_step_max: float = 0.1,
33
+ time_step_init_scheme: str = "random",
34
+ time_step_floor: float = 1e-4,
35
+ max_position_embeddings: int = 2048,
36
+ attn: Optional[Dict] = {
37
+ 'layers': (1, 3, 5, 7, 9, 11, 13, 15, 17),
38
+ 'num_heads': 18,
39
+ 'num_kv_heads': 18,
40
+ 'qkv_bias': False,
41
+ 'window_size': 2048,
42
+ 'rope_theta': 10000.
43
+ },
44
+ hidden_ratio: Optional[int] = 4,
45
+ rescale_prenorm_residual: bool = False,
46
+ use_cache: bool = True,
47
+ fuse_norm: bool = True,
48
+ fuse_swiglu: bool = True,
49
+ fuse_cross_entropy: bool = True,
50
+ vocab_size: int = 32000,
51
+ tie_word_embeddings: bool = False,
52
+ **kwargs,
53
+ ):
54
+ self.hidden_size = hidden_size
55
+ self.state_size = state_size
56
+ self.num_hidden_layers = num_hidden_layers
57
+ self.norm_eps = norm_eps
58
+ self.conv_kernel = conv_kernel
59
+ self.expand = expand
60
+ self.intermediate_size = int(expand * self.hidden_size)
61
+ self.bos_token_id = bos_token_id
62
+ self.eos_token_id = eos_token_id
63
+ self.pad_token_id = pad_token_id
64
+ self.use_bias = use_bias
65
+ self.use_conv_bias = use_conv_bias
66
+ self.hidden_act = hidden_act
67
+ self.initializer_range = initializer_range
68
+ self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank
69
+ self.time_step_scale = time_step_scale
70
+ self.time_step_min = time_step_min
71
+ self.time_step_max = time_step_max
72
+ self.time_step_init_scheme = time_step_init_scheme
73
+ self.time_step_floor = time_step_floor
74
+ self.max_position_embeddings = max_position_embeddings
75
+ self.attn = attn
76
+ self.hidden_ratio = hidden_ratio
77
+ self.rescale_prenorm_residual = rescale_prenorm_residual
78
+ self.residual_in_fp32 = residual_in_fp32
79
+ self.use_cache = use_cache
80
+
81
+ self.fuse_norm = fuse_norm
82
+ self.fuse_swiglu = fuse_swiglu
83
+ self.fuse_cross_entropy = fuse_cross_entropy
84
+ self.vocab_size = vocab_size
85
+
86
+ super().__init__(
87
+ bos_token_id=bos_token_id,
88
+ eos_token_id=eos_token_id,
89
+ pad_token_id=pad_token_id,
90
+ tie_word_embeddings=tie_word_embeddings,
91
+ **kwargs
92
+ )
fla/models/samba/modeling_samba.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ from dataclasses import dataclass
7
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.utils.checkpoint
11
+ from torch import nn
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_utils import PreTrainedModel
14
+ from transformers.utils import ModelOutput, logging
15
+ from transformers.utils.deprecation import deprecate_kwarg
16
+
17
+ from fla.layers.attn import Attention
18
+ from fla.models.mamba.modeling_mamba import MambaCache, MambaMixer
19
+ from fla.models.samba.configuration_samba import SambaConfig
20
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
21
+ from fla.modules import GatedMLP as SambaMLP
22
+ from fla.modules import RMSNorm
23
+
24
+ if TYPE_CHECKING:
25
+ from transformers.processing_utils import Unpack
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ class SambaBlock(nn.Module):
31
+ def __init__(self, config, layer_idx):
32
+ super().__init__()
33
+
34
+ self.config = config
35
+ self.layer_idx = layer_idx
36
+
37
+ self.mixer_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
38
+ if config.attn is not None and layer_idx in config.attn['layers']:
39
+ self.mixer = Attention(
40
+ hidden_size=config.hidden_size,
41
+ num_heads=config.attn['num_heads'],
42
+ num_kv_heads=config.attn['num_kv_heads'],
43
+ qkv_bias=config.attn['qkv_bias'],
44
+ window_size=config.attn['window_size'],
45
+ rope_theta=config.attn['rope_theta'],
46
+ max_position_embeddings=config.max_position_embeddings,
47
+ layer_idx=layer_idx
48
+ )
49
+ else:
50
+ self.mixer = MambaMixer(config, layer_idx=layer_idx)
51
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
52
+ self.mlp = SambaMLP(
53
+ hidden_size=config.hidden_size,
54
+ hidden_ratio=config.hidden_ratio,
55
+ hidden_act=config.hidden_act,
56
+ fuse_swiglu=config.fuse_swiglu
57
+ )
58
+
59
+ def forward(
60
+ self,
61
+ hidden_states: torch.Tensor,
62
+ cache_params: Optional[Tuple[torch.Tensor]] = None,
63
+ **kwargs: Unpack[Dict]
64
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
65
+
66
+ residual = hidden_states
67
+ hidden_states = self.mixer_norm(hidden_states)
68
+ if isinstance(self.mixer, MambaMixer):
69
+ hidden_states = self.mixer(hidden_states, cache_params=cache_params, **kwargs)
70
+ else:
71
+ hidden_states, _, cache_params = self.mixer(hidden_states=hidden_states, past_key_values=cache_params, **kwargs)
72
+ if self.config.fuse_norm:
73
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
74
+ else:
75
+ hidden_states = residual + hidden_states
76
+ residual = hidden_states
77
+ hidden_states = self.mlp_norm(hidden_states)
78
+ hidden_states = self.mlp(hidden_states, **kwargs)
79
+ hidden_states = residual + hidden_states
80
+ return hidden_states
81
+
82
+
83
+ class SambaPreTrainedModel(PreTrainedModel):
84
+ """
85
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
86
+ models.
87
+ """
88
+
89
+ config_class = SambaConfig
90
+ base_model_prefix = "backbone"
91
+ _no_split_modules = ["SambaBlock"]
92
+ supports_gradient_checkpointing = True
93
+
94
+ def _init_weights(self, module):
95
+ """Initialize the weights."""
96
+ if isinstance(module, nn.Linear):
97
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
98
+ if module.bias is not None:
99
+ if not getattr(module.bias, "_no_reinit", False):
100
+ nn.init.zeros_(module.bias)
101
+ elif isinstance(module, MambaMixer):
102
+ module.A_log._no_weight_decay = True
103
+ module.D._no_weight_decay = True
104
+
105
+ dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale
106
+ if self.config.time_step_init_scheme == "constant":
107
+ nn.init.constant_(module.dt_proj.weight, dt_init_std)
108
+ elif self.config.time_step_init_scheme == "random":
109
+ nn.init.uniform_(module.dt_proj.weight, -dt_init_std, dt_init_std)
110
+
111
+ dt = torch.exp(
112
+ torch.rand(self.config.intermediate_size)
113
+ * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
114
+ + math.log(self.config.time_step_min)
115
+ ).clamp(min=self.config.time_step_floor)
116
+ # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
117
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
118
+ with torch.no_grad():
119
+ module.dt_proj.bias.data = nn.Parameter(inv_dt.to(module.dt_proj.bias.device))
120
+ module.dt_proj.bias._no_reinit = True
121
+ elif isinstance(module, nn.Embedding):
122
+ nn.init.normal_(module.weight, std=self.config.initializer_range)
123
+ elif hasattr(module, 'reset_parameters'):
124
+ module.reset_parameters()
125
+
126
+ if self.config.rescale_prenorm_residual:
127
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
128
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
129
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
130
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
131
+ #
132
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
133
+ for name, p in module.named_parameters():
134
+ if name in ["out_proj.weight"]:
135
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
136
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
137
+ # We need to reinit p since this code could be called multiple times
138
+ # Having just p *= scale would repeatedly scale it down
139
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
140
+ with torch.no_grad():
141
+ p /= math.sqrt(self.config.num_layers)
142
+
143
+
144
+ @dataclass
145
+ class SambaOutput(ModelOutput):
146
+ """
147
+ Class for the Samba model outputs.
148
+
149
+ Args:
150
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
151
+ Sequence of hidden-states at the output of the last layer of the model.
152
+ cache_params (`MambaCache`):
153
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
154
+ avoid providing the old `input_ids`.
155
+
156
+ Includes both the State space model state matrices after the selective scan, and the Convolutional states
157
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*,
158
+ returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
159
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
160
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
161
+
162
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
163
+ """
164
+
165
+ last_hidden_state: Optional[torch.FloatTensor] = None
166
+ cache_params: Optional[MambaCache] = None
167
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
168
+
169
+
170
+ @dataclass
171
+ class SambaCausalLMOutput(ModelOutput):
172
+ """
173
+ Base class for causal language model (or autoregressive) outputs.
174
+
175
+ Args:
176
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
177
+ Language modeling loss (for next-token prediction).
178
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
179
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
180
+ cache_params (`MambaCache`):
181
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
182
+ avoid providing the old `input_ids`.
183
+
184
+ Includes both the State space model state matrices after the selective scan, and the Convolutional states
185
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*,
186
+ returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
187
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
188
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
189
+
190
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
191
+ """
192
+
193
+ loss: Optional[torch.FloatTensor] = None
194
+ logits: Optional[torch.FloatTensor] = None
195
+ cache_params: Optional[MambaCache] = None
196
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
197
+
198
+
199
+ class SambaModel(SambaPreTrainedModel):
200
+ def __init__(self, config):
201
+ super().__init__(config)
202
+
203
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
204
+ self.layers = nn.ModuleList([SambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])
205
+
206
+ self.gradient_checkpointing = False
207
+ self.norm_f = RMSNorm(config.hidden_size, eps=config.norm_eps)
208
+ # Initialize weights and apply final processing
209
+ self.post_init()
210
+
211
+ def get_input_embeddings(self):
212
+ return self.embeddings
213
+
214
+ def set_input_embeddings(self, new_embeddings):
215
+ self.embeddings = new_embeddings
216
+
217
+ def forward(
218
+ self,
219
+ input_ids: Optional[torch.LongTensor] = None,
220
+ inputs_embeds: Optional[torch.LongTensor] = None,
221
+ cache_params: Optional[MambaCache] = None,
222
+ use_cache: Optional[bool] = None,
223
+ output_hidden_states: Optional[bool] = None,
224
+ return_dict: Optional[bool] = None,
225
+ **kwargs: Unpack[Dict]
226
+ ) -> Union[Tuple, SambaOutput]:
227
+ output_hidden_states = (
228
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
229
+ )
230
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
231
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
232
+
233
+ if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor
234
+ raise ValueError(
235
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
236
+ )
237
+
238
+ if inputs_embeds is None:
239
+ inputs_embeds = self.embeddings(input_ids)
240
+
241
+ if self.gradient_checkpointing and self.training and use_cache:
242
+ use_cache = False
243
+
244
+ if cache_params is None and use_cache:
245
+ cache_params = MambaCache(
246
+ self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
247
+ )
248
+
249
+ hidden_states = inputs_embeds
250
+ all_hidden_states = () if output_hidden_states else None
251
+ for mixer_block in self.layers:
252
+ if self.gradient_checkpointing and self.training:
253
+ hidden_states = self._gradient_checkpointing_func(
254
+ mixer_block.__call__,
255
+ hidden_states,
256
+ cache_params,
257
+ **kwargs
258
+ )
259
+ else:
260
+ hidden_states = mixer_block(
261
+ hidden_states,
262
+ cache_params=cache_params,
263
+ **kwargs
264
+ )
265
+
266
+ if output_hidden_states:
267
+ all_hidden_states = all_hidden_states + (hidden_states,)
268
+
269
+ if use_cache:
270
+ cache_params.seqlen_offset += inputs_embeds.shape[1]
271
+
272
+ hidden_states = self.norm_f(hidden_states)
273
+
274
+ if output_hidden_states:
275
+ all_hidden_states = all_hidden_states + (hidden_states,)
276
+
277
+ if not return_dict:
278
+ return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None)
279
+
280
+ return SambaOutput(
281
+ last_hidden_state=hidden_states,
282
+ cache_params=cache_params if use_cache else None,
283
+ hidden_states=all_hidden_states,
284
+ )
285
+
286
+
287
+ class SambaForCausalLM(SambaPreTrainedModel, GenerationMixin):
288
+
289
+ _tied_weights_keys = ["lm_head.weight"]
290
+
291
+ def __init__(self, config):
292
+ super().__init__(config)
293
+ self.backbone = SambaModel(config)
294
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
295
+ self.criterion = None
296
+
297
+ # Initialize weights and apply final processing
298
+ self.post_init()
299
+
300
+ def get_output_embeddings(self):
301
+ return self.lm_head
302
+
303
+ def set_output_embeddings(self, new_embeddings):
304
+ self.lm_head = new_embeddings
305
+
306
+ def get_input_embeddings(self):
307
+ return self.backbone.get_input_embeddings()
308
+
309
+ def set_input_embeddings(self, new_embeddings):
310
+ return self.backbone.set_input_embeddings(new_embeddings)
311
+
312
+ def _update_model_kwargs_for_generation(
313
+ self, outputs: ModelOutput, model_kwargs: Dict[str, Any], **kwargs
314
+ ) -> Dict[str, Any]:
315
+ model_kwargs["cache_params"] = outputs.get("cache_params", None)
316
+ return model_kwargs
317
+
318
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
319
+ def prepare_inputs_for_generation(
320
+ self,
321
+ input_ids,
322
+ cache_params:
323
+ Optional[MambaCache] = None,
324
+ inputs_embeds=None,
325
+ attention_mask=None,
326
+ use_cache: Optional[bool] = True,
327
+ logits_to_keep: Optional[int] = None,
328
+ **kwargs: Unpack[Dict]
329
+ ):
330
+ # only last token for inputs_ids if the state is passed along.
331
+ if cache_params is not None:
332
+ input_ids = input_ids[:, -1].unsqueeze(-1)
333
+
334
+ if inputs_embeds is not None and cache_params is None:
335
+ model_inputs = {"inputs_embeds": inputs_embeds}
336
+ else:
337
+ model_inputs = {"input_ids": input_ids}
338
+
339
+ if logits_to_keep is not None:
340
+ model_inputs['logits_to_keep'] = logits_to_keep
341
+
342
+ model_inputs.update({
343
+ 'cache_params': cache_params,
344
+ 'use_cache': use_cache,
345
+ 'attention_mask': attention_mask,
346
+ 'logits_to_keep': logits_to_keep,
347
+ })
348
+ return model_inputs
349
+
350
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
351
+ def forward(
352
+ self,
353
+ input_ids: Optional[torch.LongTensor] = None,
354
+ attention_mask: Optional[torch.Tensor] = None, # noqa
355
+ inputs_embeds: Optional[torch.FloatTensor] = None,
356
+ cache_params: Optional[MambaCache] = None,
357
+ labels: Optional[torch.LongTensor] = None,
358
+ output_hidden_states: Optional[bool] = None,
359
+ return_dict: Optional[bool] = None,
360
+ use_cache: Optional[bool] = None,
361
+ logits_to_keep: Optional[int] = 0,
362
+ **kwargs: Unpack[Dict]
363
+ ) -> Union[Tuple, SambaCausalLMOutput]:
364
+ r"""
365
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
366
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
367
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
368
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
369
+ """
370
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
371
+
372
+ outputs = self.backbone(
373
+ input_ids,
374
+ cache_params=cache_params,
375
+ inputs_embeds=inputs_embeds,
376
+ output_hidden_states=output_hidden_states,
377
+ return_dict=return_dict,
378
+ use_cache=use_cache,
379
+ **kwargs
380
+ )
381
+ hidden_states = outputs[0]
382
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
383
+
384
+ loss, logits = None, None
385
+ if not fuse_linear_and_cross_entropy or labels is None:
386
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
387
+ if labels is not None:
388
+ if getattr(self, 'criterion', None) is None:
389
+ if fuse_linear_and_cross_entropy:
390
+ criterion = FusedLinearCrossEntropyLoss()
391
+ elif self.config.fuse_cross_entropy:
392
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
393
+ else:
394
+ criterion = nn.CrossEntropyLoss()
395
+ else:
396
+ criterion = self.criterion
397
+ labels = labels.to(hidden_states.device)
398
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
399
+ if fuse_linear_and_cross_entropy:
400
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
401
+ else:
402
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
403
+
404
+ if not return_dict:
405
+ output = (logits,) + outputs[1:]
406
+ return (loss,) + output if loss is not None else output
407
+
408
+ return SambaCausalLMOutput(
409
+ loss=loss,
410
+ logits=logits,
411
+ cache_params=outputs.cache_params,
412
+ hidden_states=outputs.hidden_states,
413
+ )
fla/models/transformer_dsmtp/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (757 Bytes). View file
 
fla/models/transformer_dsmtp/__pycache__/configuration_transformer.cpython-312.pyc ADDED
Binary file (2.61 kB). View file
 
fla/models/transformer_mtp/__pycache__/configuration_transformer.cpython-312.pyc ADDED
Binary file (2.69 kB). View file
 
fla/models/transformer_top/__pycache__/configuration_transformer.cpython-312.pyc ADDED
Binary file (2.8 kB). View file
 
fla/ops/common/__pycache__/chunk_delta_h.cpython-312.pyc ADDED
Binary file (23.9 kB). View file
 
flame/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (152 Bytes). View file
 
flame/models/__init__.py ADDED
File without changes
flame/models/parallelize_fla.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # This file applies the PT-D parallelisms (except pipeline parallelism) and various
8
+ # training techniques (e.g. activation checkpointing and compile) to the Llama model.
9
+
10
+ from collections import defaultdict
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.distributed import DeviceMesh
15
+ from torch.distributed._composable.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard
16
+ from torch.distributed._composable.replicate import replicate
17
+ from torch.distributed._tensor import Replicate, Shard
18
+ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper as ptd_checkpoint_wrapper
19
+ from torch.distributed.tensor.parallel import (
20
+ ColwiseParallel,
21
+ PrepareModuleInput,
22
+ PrepareModuleOutput,
23
+ RowwiseParallel,
24
+ SequenceParallel,
25
+ parallelize_module
26
+ )
27
+
28
+ from fla.modules.fused_linear_cross_entropy import LinearLossParallel
29
+ from fla.modules.mlp import SwiGLULinearParallel
30
+ from fla.modules.parallel import PrepareModuleWeight
31
+ from torchtitan.config_manager import TORCH_DTYPE_MAP, JobConfig
32
+ from torchtitan.distributed.parallel_dims import ParallelDims
33
+ from torchtitan.tools.logging import logger
34
+
35
+
36
+ def parallelize_fla(
37
+ model: nn.Module,
38
+ world_mesh: DeviceMesh,
39
+ parallel_dims: ParallelDims,
40
+ job_config: JobConfig,
41
+ ):
42
+ """
43
+ Apply tensor parallelism, activation checkpointing, torch.compile, and data
44
+ parallelism to the model.
45
+
46
+ NOTE: The passed-in model preferably should be on meta device. Otherwise,
47
+ the model must fit on GPU or CPU memory.
48
+ """
49
+
50
+ if parallel_dims.tp_enabled:
51
+ if (
52
+ job_config.experimental.enable_async_tensor_parallel
53
+ and not job_config.training.compile
54
+ ):
55
+ raise RuntimeError("Async TP requires --training.compile")
56
+ enable_float8_linear = "float8" in job_config.model.converters
57
+ apply_tp(
58
+ model,
59
+ world_mesh["tp"],
60
+ loss_parallel=parallel_dims.loss_parallel_enabled,
61
+ enable_float8=enable_float8_linear,
62
+ enable_async_tp=job_config.experimental.enable_async_tensor_parallel,
63
+ )
64
+
65
+ if job_config.activation_checkpoint.mode != "none":
66
+ apply_ac(model, job_config.activation_checkpoint)
67
+
68
+ # turn on per-block compile after AC wrapping and before FSDP
69
+ if job_config.training.compile:
70
+ apply_compile(model)
71
+
72
+ if (
73
+ parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled
74
+ ): # apply FSDP or HSDP, potentially with Context Parallel
75
+ if parallel_dims.dp_replicate_enabled:
76
+ dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
77
+ else:
78
+ dp_mesh_dim_names = ("dp_shard_cp",)
79
+
80
+ apply_fsdp(
81
+ model,
82
+ world_mesh[tuple(dp_mesh_dim_names)],
83
+ param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
84
+ reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
85
+ pp_enabled=parallel_dims.pp_enabled,
86
+ cpu_offload=job_config.training.enable_cpu_offload,
87
+ reshard_after_forward_policy=job_config.training.fsdp_reshard_after_forward,
88
+ )
89
+
90
+ if parallel_dims.dp_replicate_enabled:
91
+ logger.info("Applied HSDP to the model")
92
+ else:
93
+ logger.info("Applied FSDP to the model")
94
+
95
+ if parallel_dims.cp_enabled:
96
+ logger.info("Applied Context Parallel to the model")
97
+
98
+ if job_config.training.enable_cpu_offload:
99
+ logger.info("Applied CPU Offloading to the model")
100
+ elif parallel_dims.dp_replicate_enabled:
101
+ if world_mesh.ndim > 1:
102
+ raise RuntimeError("DDP has not supported > 1D parallelism")
103
+ apply_ddp(
104
+ model,
105
+ world_mesh,
106
+ enable_compile=job_config.training.compile,
107
+ enable_compiled_autograd=job_config.experimental.enable_compiled_autograd,
108
+ )
109
+
110
+
111
+ class TPPlan:
112
+ def __init__(
113
+ self,
114
+ model=None,
115
+ loss_parallel=False,
116
+ enable_float8=False,
117
+ ):
118
+ self.model = model
119
+ self.loss_parallel = loss_parallel
120
+ self.enable_float8 = enable_float8
121
+ self.base_model_prefix = getattr(model, "base_model_prefix", "model")
122
+
123
+ # TODO(vkuzo): once float8 configuration supports delayed scaling,
124
+ # add a check here to enforce supported float8 all-gather configurations
125
+ # TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there
126
+ try:
127
+ from torchao.float8.float8_tensor_parallel import (
128
+ Float8ColwiseParallel,
129
+ Float8RowwiseParallel,
130
+ PrepareFloat8ModuleInput
131
+ )
132
+ except ImportError:
133
+ Float8ColwiseParallel = None
134
+ Float8RowwiseParallel = None
135
+ PrepareFloat8ModuleInput = None
136
+ if self.enable_float8 and Float8ColwiseParallel is not None:
137
+ self.rowwise_parallel = Float8RowwiseParallel
138
+ self.colwise_parallel = Float8ColwiseParallel
139
+ self.prepare_module_input = PrepareFloat8ModuleInput
140
+ self.prepare_module_output = PrepareModuleOutput
141
+ else:
142
+ self.rowwise_parallel = RowwiseParallel
143
+ self.colwise_parallel = ColwiseParallel
144
+ self.prepare_module_input = PrepareModuleInput
145
+ self.prepare_module_output = PrepareModuleOutput
146
+
147
+ @property
148
+ def model_plan(self):
149
+ plans = {
150
+ f"{self.base_model_prefix}.embeddings": RowwiseParallel(
151
+ input_layouts=Replicate(),
152
+ output_layouts=Shard(1),
153
+ ),
154
+ f"{self.base_model_prefix}.norm": SequenceParallel(),
155
+ }
156
+ if self.loss_parallel:
157
+ plans.update(
158
+ {
159
+ "lm_head": ColwiseParallel(
160
+ input_layouts=Shard(1),
161
+ output_layouts=Shard(-1) if self.loss_parallel else Replicate(),
162
+ use_local_output=not self.loss_parallel,
163
+ ),
164
+ }
165
+ )
166
+ else:
167
+ plans.update(
168
+ {
169
+ "lm_head": PrepareModuleWeight(layouts=Replicate()),
170
+ "criterion": LinearLossParallel(),
171
+ }
172
+ )
173
+ return plans
174
+
175
+ @property
176
+ def layer_plan(self):
177
+ return {
178
+ "attn_norm": SequenceParallel(),
179
+ **self.attn_plan,
180
+ "mlp_norm": SequenceParallel(),
181
+ **self.mlp_plan,
182
+ }
183
+
184
+ @property
185
+ def attn_plan(self):
186
+ raise NotImplementedError(
187
+ f"TP plans for token mixing layers of {self.model.config.model_type} not implemented"
188
+ )
189
+
190
+ @property
191
+ def mlp_plan(self):
192
+ return {
193
+ "mlp": self.prepare_module_input(
194
+ input_layouts=(Shard(1),),
195
+ desired_input_layouts=(Replicate(),),
196
+ ),
197
+ "mlp.gate_proj": self.colwise_parallel(),
198
+ "mlp.up_proj": self.colwise_parallel(),
199
+ "mlp.down_proj": self.rowwise_parallel(output_layouts=Shard(1)),
200
+ "mlp.swiglu_linear": SwiGLULinearParallel(output_layouts=Shard(1)),
201
+ }
202
+
203
+
204
+ class TransformerTPPlan(TPPlan):
205
+
206
+ @property
207
+ def attn_plan(self):
208
+ return {
209
+ "attn": self.prepare_module_input(
210
+ input_kwarg_layouts={"hidden_states": Shard(1)},
211
+ desired_input_kwarg_layouts={"hidden_states": Replicate()},
212
+ ),
213
+ "attn.q_proj": self.colwise_parallel(),
214
+ "attn.k_proj": self.colwise_parallel(),
215
+ "attn.v_proj": self.colwise_parallel(),
216
+ "attn.o_proj": self.rowwise_parallel(output_layouts=Shard(1)),
217
+ }
218
+
219
+
220
+ class GLATPPlan(TPPlan):
221
+
222
+ @property
223
+ def attn_plan(self):
224
+ return {
225
+ "attn": self.prepare_module_input(
226
+ input_kwarg_layouts={"hidden_states": Shard(1)},
227
+ desired_input_kwarg_layouts={"hidden_states": Replicate()},
228
+ ),
229
+ "attn.q_proj": self.colwise_parallel(),
230
+ "attn.k_proj": self.colwise_parallel(),
231
+ "attn.v_proj": self.colwise_parallel(),
232
+ "attn.g_proj": self.colwise_parallel(),
233
+ "attn.gk_proj.0": PrepareModuleWeight(layouts=Replicate()),
234
+ "attn.gk_proj.1": self.colwise_parallel(),
235
+ "attn.g_norm": SequenceParallel(sequence_dim=-1),
236
+ "attn.o_proj": self.rowwise_parallel(output_layouts=Shard(1)),
237
+ }
238
+
239
+
240
+ TP_PLAN_MAP = {"transformer": TransformerTPPlan, "gla": GLATPPlan}
241
+
242
+
243
+ def apply_tp(
244
+ model: nn.Module,
245
+ tp_mesh: DeviceMesh,
246
+ loss_parallel: bool,
247
+ enable_float8: bool,
248
+ enable_async_tp: bool,
249
+ ):
250
+ """Apply tensor parallelism."""
251
+ # 1. Parallelize the embedding and shard its outputs (which are the first
252
+ # transformer block's inputs)
253
+ # 2. Parallelize the root norm layer over the sequence dim
254
+ # 3. Parallelize the final linear output layer
255
+ tp_plan = TP_PLAN_MAP[model.config.model_type](
256
+ model, loss_parallel=loss_parallel, enable_float8=enable_float8
257
+ )
258
+ parallelize_module(model, tp_mesh, tp_plan.model_plan)
259
+
260
+ blocks = get_blocks(model)
261
+ if blocks is None:
262
+ logger.warning("No block found for tensor parallelism")
263
+ else:
264
+ for _, block in enumerate(blocks):
265
+ parallelize_module(
266
+ module=block,
267
+ device_mesh=tp_mesh,
268
+ parallelize_plan=tp_plan.layer_plan,
269
+ )
270
+
271
+ if enable_async_tp:
272
+ from torch.distributed._symmetric_memory import enable_symm_mem_for_group
273
+
274
+ torch._inductor.config._micro_pipeline_tp = True
275
+ enable_symm_mem_for_group(tp_mesh.get_group().group_name)
276
+
277
+ logger.info(
278
+ f"Applied {'Float8 ' if enable_float8 else ''}{'Async ' if enable_async_tp else ''}"
279
+ "Tensor Parallelism to the model"
280
+ )
281
+
282
+
283
+ # for selective op activation checkpointing
284
+ _save_list = {
285
+ torch.ops.aten.mm.default,
286
+ torch.ops.aten._scaled_dot_product_efficient_attention.default,
287
+ torch.ops.aten._scaled_dot_product_flash_attention.default,
288
+ torch.ops._c10d_functional.reduce_scatter_tensor.default,
289
+ # for low precision training, it's useful to always save
290
+ # the result of max, since the absolute maximum is
291
+ # used to compute the scaling factor for quantization.
292
+ torch.ops.aten.max.default,
293
+ }
294
+
295
+
296
+ def _apply_ac_to_block(module: nn.Module, ac_config):
297
+ valid_ac_modes = ("full", "selective")
298
+ if ac_config.mode not in valid_ac_modes:
299
+ raise ValueError(
300
+ f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}"
301
+ )
302
+
303
+ if ac_config.mode == "full":
304
+ return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
305
+
306
+ assert ac_config.mode == "selective", f"{ac_config.mode}"
307
+ use_op_sac = ac_config.selective_ac_option == "op"
308
+ use_layer_sac = ac_config.selective_ac_option.isdigit()
309
+ if not use_op_sac and not use_layer_sac:
310
+ raise ValueError(
311
+ f"Invalid selective AC option: {ac_config.selective_ac_option}. "
312
+ f"Valid options: 'op' or a positive int representing layer frequency"
313
+ )
314
+ if use_op_sac:
315
+ from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts
316
+
317
+ def _get_custom_policy(meta):
318
+ def _custom_policy(ctx, func, *args, **kwargs):
319
+ mode = "recompute" if ctx.is_recompute else "forward"
320
+ mm_count_key = f"{mode}_mm_count"
321
+ if func == torch.ops.aten.mm.default:
322
+ meta[mm_count_key] += 1
323
+ # Saves output of all compute ops, except every second mm
324
+ to_save = func in _save_list and not (
325
+ func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
326
+ )
327
+ return (
328
+ CheckpointPolicy.MUST_SAVE
329
+ if to_save
330
+ else CheckpointPolicy.PREFER_RECOMPUTE
331
+ )
332
+
333
+ return _custom_policy
334
+
335
+ def selective_checkpointing_context_fn():
336
+ meta = defaultdict(int)
337
+ return create_selective_checkpoint_contexts(_get_custom_policy(meta))
338
+
339
+ return ptd_checkpoint_wrapper(
340
+ module,
341
+ context_fn=selective_checkpointing_context_fn,
342
+ preserve_rng_state=False,
343
+ )
344
+ elif use_layer_sac:
345
+ # Checkpoint every `ac_freq` of the modules passed to this function
346
+ ac_freq = int(ac_config.selective_ac_option)
347
+ ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0)
348
+ ptd_checkpoint_wrapper._count += 1
349
+ if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0:
350
+ return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
351
+ else:
352
+ return module
353
+
354
+
355
+ def apply_ac(model: nn.Module, ac_config):
356
+ """Apply activation checkpointing to the model."""
357
+ blocks = get_blocks(model)
358
+ if blocks is None:
359
+ logger.warning("No block found for activation checkpointing")
360
+ return
361
+
362
+ for layer_id, block in blocks.named_children():
363
+ block = _apply_ac_to_block(block, ac_config)
364
+ blocks.register_module(layer_id, block)
365
+
366
+ logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
367
+
368
+
369
+ def apply_compile(model: nn.Module):
370
+ """
371
+ Apply torch.compile to each block, which makes compilation efficient due to
372
+ repeated structure. Alternatively one can compile the whole model (after applying DP).
373
+ """
374
+
375
+ blocks = get_blocks(model)
376
+ if blocks is None:
377
+ logger.warning("No block found for torch.compile")
378
+ else:
379
+ for layer_id, block in blocks.named_children():
380
+ block = torch.compile(block)
381
+ blocks.register_module(layer_id, block)
382
+ logger.info("Compiling each block with torch.compile")
383
+
384
+ real_model = get_model(model)
385
+
386
+ logger.info("Compiling the embedding, norm, and lm_head layers with torch.compile")
387
+ embeddings_key = get_components_name(real_model, "tok_embeddings")
388
+ if embeddings_key is not None:
389
+ embeddings = torch.compile(getattr(real_model, embeddings_key), fullgraph=True)
390
+ real_model.register_module(embeddings_key, embeddings)
391
+
392
+ norm_key = get_components_name(real_model, "norm")
393
+ if norm_key is not None:
394
+ norm = torch.compile(getattr(real_model, norm_key), fullgraph=True)
395
+ real_model.register_module(norm_key, norm)
396
+
397
+ lm_head_key = get_components_name(model, "lm_head")
398
+ if lm_head_key is not None:
399
+ lm_head = torch.compile(getattr(model, lm_head_key), fullgraph=True)
400
+ model.register_module(lm_head_key, lm_head)
401
+
402
+ logger.info("Compiling the entire model with torch.compile")
403
+ model = torch.compile(model)
404
+
405
+
406
+ def apply_fsdp(
407
+ model: nn.Module,
408
+ dp_mesh: DeviceMesh,
409
+ param_dtype: torch.dtype,
410
+ reduce_dtype: torch.dtype,
411
+ pp_enabled: bool,
412
+ cpu_offload: bool = False,
413
+ reshard_after_forward_policy: str = "default",
414
+ ):
415
+ """
416
+ Apply data parallelism (via FSDP2) to the model.
417
+
418
+ Args:
419
+ model (nn.Module): The model to apply data parallelism to.
420
+ dp_mesh (DeviceMesh): The device mesh to use for data parallelism.
421
+ param_dtype (torch.dtype): The data type to use for model parameters.
422
+ reduce_dtype (torch.dtype): The data type to use for reduction operations.
423
+ pp_enabled (bool): Whether pipeline parallelism is enabled.
424
+ cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False.
425
+ reshard_after_forward_policy (str, optional):
426
+ The policy to use for resharding after forward pass. Defaults to "default".
427
+ Other options: "never", "always".
428
+ - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios.
429
+ - "always" will enable `reshard_after_forward` for all forward passes.
430
+ - "never" will disable `reshard_after_forward` for all forward passes.
431
+
432
+ """
433
+ mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
434
+ fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
435
+ if cpu_offload:
436
+ fsdp_config["offload_policy"] = CPUOffloadPolicy()
437
+
438
+ blocks = get_blocks(model)
439
+ if blocks is None:
440
+ logger.warning("No block found for FSDP")
441
+ else:
442
+ total_blocks = len(blocks)
443
+ for layer_id, block in enumerate(blocks):
444
+ if reshard_after_forward_policy == "always":
445
+ reshard_after_forward = True
446
+ elif reshard_after_forward_policy == "never":
447
+ reshard_after_forward = False
448
+ elif reshard_after_forward_policy == "default":
449
+ if pp_enabled:
450
+ # For PP, do not reshard after forward to avoid per-microbatch
451
+ # all-gathers, which can be expensive and non-overlapped
452
+ reshard_after_forward = False
453
+ else:
454
+ # As an optimization, do not reshard after forward for the last
455
+ # transformer block since FSDP would prefetch it immediately
456
+ reshard_after_forward = int(layer_id) < total_blocks - 1
457
+ else:
458
+ raise ValueError(
459
+ f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}."
460
+ )
461
+ fully_shard(
462
+ block,
463
+ **fsdp_config,
464
+ reshard_after_forward=reshard_after_forward,
465
+ )
466
+
467
+ fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
468
+
469
+
470
+ def apply_ddp(
471
+ model: nn.Module,
472
+ dp_mesh: DeviceMesh,
473
+ enable_compile: bool,
474
+ enable_compiled_autograd: bool,
475
+ ):
476
+ if enable_compile:
477
+ if enable_compiled_autograd:
478
+ torch._dynamo.config.optimize_ddp = (
479
+ "python_reducer_without_compiled_forward"
480
+ )
481
+ else:
482
+ torch._dynamo.config.optimize_ddp = "ddp_optimizer"
483
+
484
+ replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
485
+
486
+ logger.info("Applied DDP to the model")
487
+
488
+
489
+ def get_model(model):
490
+ base_model_prefix = getattr(model, "base_model_prefix", "model")
491
+ if not hasattr(model, base_model_prefix):
492
+ return None
493
+ model = getattr(model, base_model_prefix)
494
+ return model
495
+
496
+
497
+ def get_blocks(model):
498
+ # TODO[flame]: adapt for network not using 'layers' attribute
499
+ model = get_model(model)
500
+ if not hasattr(model, "layers"):
501
+ logger.warning('no "layers" in model can be found')
502
+ return None
503
+ return model.layers
504
+
505
+
506
+ def get_components_name(model, component_name):
507
+ """
508
+ We try to catch tok_embeddings, norm layers and lm_head layers
509
+ We do not catch the layer names in the blocks, for blocks see `get_blocks`
510
+ We assume the model has the following structure:
511
+ LlamaForCausalLM:
512
+ Model:
513
+ embed_tokens,
514
+ layers,
515
+ norm,
516
+ lm_head
517
+ ***
518
+ so, to search 'tok_embeddings' and 'norm' we need to pass `get_model(model)`
519
+ and for 'lm_head' we need to pass `model`
520
+ ***
521
+ """
522
+
523
+ if component_name == "tok_embeddings":
524
+ if hasattr(model, "tok_embeddings"):
525
+ return "tok_embeddings"
526
+ elif hasattr(model, "embed_tokens"):
527
+ return "embed_tokens"
528
+ elif hasattr(model, "embeddings"):
529
+ return "embeddings"
530
+ else:
531
+ logger.warning("No tok_embeddings found in model")
532
+ return None
533
+
534
+ elif component_name == "norm":
535
+ if hasattr(model, "norm"):
536
+ return "norm"
537
+ elif hasattr(model, "norms"):
538
+ return "norms"
539
+ elif hasattr(model, "layernorm"):
540
+ return "layernorm"
541
+ else:
542
+ logger.warning("No norm found in model")
543
+ return None
544
+
545
+ elif component_name == "lm_head":
546
+ if hasattr(model, "lm_head"):
547
+ return "lm_head"
548
+ else:
549
+ logger.warning("No lm_head found in model")
550
+ return None
flame/tools/__init__.py ADDED
File without changes
flame/tools/utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from torch import nn
8
+ from torchtitan.tools.logging import logger
9
+
10
+
11
+ def get_nparams_and_flops(model: nn.Module, model_config, seq_len: int) -> tuple[int, int]:
12
+ nparams = sum(p.numel() for p in model.parameters())
13
+ nparams_embedding = sum(
14
+ sum(p.numel() for p in m.parameters())
15
+ for m in model.children()
16
+ if isinstance(m, nn.Embedding)
17
+ )
18
+
19
+ if hasattr(model_config, "num_heads"):
20
+ num_heads = model_config.num_heads
21
+ elif hasattr(model_config, "num_attention_heads"):
22
+ num_heads = model_config.num_attention_heads
23
+ else:
24
+ num_heads = 1
25
+ logger.warning("num_heads not found in model_config, defaulting to 1. ")
26
+
27
+ l, h, q, t = (
28
+ model_config.num_hidden_layers,
29
+ num_heads,
30
+ model_config.hidden_size // num_heads,
31
+ seq_len,
32
+ )
33
+ # Reasoning behind the factor of 12 for the self-attention part of the formula:
34
+ # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)
35
+ # 2. the flash attention does 1 more matmul recomputation in the backward
36
+ # but recomputation should not be counted in calculating MFU (+0)
37
+ # 3. each matmul performs 1 multiplication and 1 addition (*2)
38
+ # 4. we follow the convention and do not account for sparsity in causal attention
39
+ num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
40
+
41
+ return nparams, num_flops_per_token
flame/utils/checkpoint.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import re
4
+ import shutil
5
+ from torchtitan.tools.logging import logger
6
+
7
+
8
+ def cleanup_local_checkpoints(checkpoint_dir: str, keep_latest_k: int):
9
+ """Removes older checkpoint directories locally, keeping only the latest k for both DCP and HF formats."""
10
+ if keep_latest_k <= 0:
11
+ return # Keep all checkpoints
12
+
13
+ logger.info(f"Cleaning up local checkpoints in {checkpoint_dir}, keeping latest {keep_latest_k}")
14
+
15
+ # Cleanup DCP checkpoints (step-*)
16
+ dcp_checkpoints = sorted(
17
+ glob.glob(os.path.join(checkpoint_dir, "step-*")),
18
+ key=lambda x: int(re.search(r"step-(\d+)", os.path.basename(x)).group(1)) if re.search(r"step-(\d+)", os.path.basename(x)) and not x.endswith("-hf") else -1,
19
+ reverse=True
20
+ )
21
+ # Filter out HF format directories
22
+ dcp_checkpoints = [d for d in dcp_checkpoints if not d.endswith("-hf")]
23
+
24
+ if len(dcp_checkpoints) > keep_latest_k:
25
+ checkpoints_to_delete = dcp_checkpoints[keep_latest_k:]
26
+ logger.info(f"Deleting {len(checkpoints_to_delete)} old DCP checkpoints: {[os.path.basename(c) for c in checkpoints_to_delete]}")
27
+ for ckpt_path in checkpoints_to_delete:
28
+ if os.path.isdir(ckpt_path): # Ensure it's a directory
29
+ try:
30
+ shutil.rmtree(ckpt_path)
31
+ except OSError as e:
32
+ logger.error(f"Error removing directory {ckpt_path}: {e}")
33
+
34
+
35
+ # Cleanup HF checkpoints (step-*-hf)
36
+ hf_checkpoints = sorted(
37
+ glob.glob(os.path.join(checkpoint_dir, "step-*-hf")),
38
+ key=lambda x: int(re.search(r"step-(\d+)-hf", os.path.basename(x)).group(1)) if re.search(r"step-(\d+)-hf", os.path.basename(x)) else -1,
39
+ reverse=True
40
+ )
41
+
42
+ if len(hf_checkpoints) > keep_latest_k:
43
+ checkpoints_to_delete = hf_checkpoints[keep_latest_k:]
44
+ logger.info(f"Deleting {len(checkpoints_to_delete)} old HF checkpoints: {[os.path.basename(c) for c in checkpoints_to_delete]}")
45
+ for ckpt_path in checkpoints_to_delete:
46
+ if os.path.isdir(ckpt_path): # Ensure it's a directory
47
+ try:
48
+ shutil.rmtree(ckpt_path)
49
+ except OSError as e:
50
+ logger.error(f"Error removing directory {ckpt_path}: {e}")
flame/utils/convert_hf_to_dcp.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import argparse
5
+ from pathlib import Path
6
+
7
+ import torch
8
+ import torch.distributed.checkpoint as DCP
9
+ from transformers import AutoModelForCausalLM
10
+
11
+ import fla # noqa
12
+ from torchtitan.tools.logging import init_logger, logger
13
+
14
+
15
+ @torch.inference_mode()
16
+ def convert_hf_weights(model: str, checkpoint: str):
17
+ logger.info(f"Loading model from {model}")
18
+ model = AutoModelForCausalLM.from_pretrained(model)
19
+ state_dict = model.state_dict()
20
+
21
+ logger.info(f"Writing to DCP at '{checkpoint}'")
22
+ checkpoint.mkdir(parents=True, exist_ok=True)
23
+ storage_writer = DCP.filesystem.FileSystemWriter(checkpoint, thread_count=8)
24
+ DCP.save({"model": state_dict}, storage_writer=storage_writer)
25
+
26
+
27
+ if __name__ == "__main__":
28
+ init_logger()
29
+ parser = argparse.ArgumentParser(description="Convert huggingface-style model weights to DCP format.")
30
+ parser.add_argument("--model", type=str, required=True)
31
+ parser.add_argument("--checkpoint", type=Path, required=True)
32
+ args = parser.parse_args()
33
+
34
+ convert_hf_weights(args.model, args.checkpoint)
flame/utils/hf_utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from huggingface_hub import HfApi, HfFolder, logging as hf_logging, create_repo
4
+ from torchtitan.tools.logging import logger
5
+
6
+ def upload_checkpoint_to_hf(
7
+ local_path: str,
8
+ step: int,
9
+ hf_repo_id_for_run: str,
10
+ hf_keep_latest_k: int,
11
+ upload_format: str
12
+ ):
13
+ """Uploads a checkpoint directory to HF Hub and manages retention."""
14
+ if not os.path.isdir(local_path):
15
+ logger.error(f"Local path for upload does not exist or is not a directory: {local_path}")
16
+ return
17
+
18
+ api = HfApi()
19
+ token = HfFolder.get_token()
20
+ if not token:
21
+ logger.warning("Hugging Face Hub token not found. Skipping upload. Login via `huggingface-cli login` or set HF_TOKEN.")
22
+ return
23
+
24
+ # --- Ensure the specific repository for this run exists ---
25
+ try:
26
+ logger.info(f"Ensuring repository {hf_repo_id_for_run} exists...")
27
+ # Use create_repo which handles creation only if it doesn't exist
28
+ create_repo(repo_id=hf_repo_id_for_run, token=token, repo_type="model", exist_ok=True)
29
+ logger.info(f"Repository {hf_repo_id_for_run} ensured.")
30
+ except Exception as e:
31
+ logger.error(f"Failed to create or ensure repository {hf_repo_id_for_run}: {e}", exc_info=True)
32
+ return # Stop if repo interaction fails
33
+
34
+ commit_message = f"Upload {upload_format.upper()} checkpoint step {step}"
35
+ path_in_repo = f"step-{step}"
36
+
37
+ logger.info(f"Uploading {local_path} to {hf_repo_id_for_run}/{path_in_repo} on Hugging Face Hub...")
38
+ try:
39
+ api.upload_folder(
40
+ folder_path=local_path,
41
+ path_in_repo=path_in_repo,
42
+ repo_id=hf_repo_id_for_run,
43
+ repo_type="model",
44
+ commit_message=commit_message,
45
+ token=token,
46
+ )
47
+ logger.info(f"Successfully uploaded step {step} to {hf_repo_id_for_run}.")
48
+ except Exception as e:
49
+ logger.error(f"Failed to upload checkpoint step {step} to {hf_repo_id_for_run}: {e}", exc_info=True)
50
+ if hf_keep_latest_k > 0:
51
+ logger.info(f"Cleaning up old checkpoints on {hf_repo_id_for_run}, keeping latest {hf_keep_latest_k}")
52
+ try:
53
+ repo_files = api.list_repo_tree(hf_repo_id_for_run, repo_type="model", token=token, recursive=False)
54
+ step_folders = [
55
+ item.path for item in repo_files
56
+ if item.path.startswith("step-") and item.path[5:].isdigit()
57
+ ]
58
+
59
+ step_folders.sort(key=lambda x: int(x.split('-')[1]), reverse=True)
60
+
61
+ if len(step_folders) > hf_keep_latest_k:
62
+ folders_to_delete = step_folders[hf_keep_latest_k:]
63
+ logger.info(f"Found {len(step_folders)} checkpoints on Hub. Deleting {len(folders_to_delete)} older ones: {folders_to_delete}")
64
+ for folder in folders_to_delete:
65
+ # Deleting requires repo_id, path_in_repo, and token
66
+ api.delete_folder(
67
+ repo_id=hf_repo_id_for_run,
68
+ path_in_repo=folder,
69
+ repo_type="model",
70
+ commit_message=f"Delete old checkpoint {folder}",
71
+ token=token
72
+ )
73
+ logger.info("Hub cleanup complete.")
74
+ else:
75
+ logger.info("No old checkpoints found on Hub to delete.")
76
+ except Exception as e:
77
+ logger.error(f"Error during Hub checkpoint cleanup for {hf_repo_id_for_run}: {e}", exc_info=True)
torchtitan/components/__pycache__/dataloader.cpython-312.pyc ADDED
Binary file (3.78 kB). View file
 
torchtitan/components/__pycache__/ft.cpython-312.pyc ADDED
Binary file (6.75 kB). View file
 
torchtitan/components/__pycache__/loss.cpython-312.pyc ADDED
Binary file (1.5 kB). View file
 
torchtitan/components/__pycache__/lr_scheduler.cpython-312.pyc ADDED
Binary file (7.71 kB). View file
 
torchtitan/components/__pycache__/tokenizer.cpython-312.pyc ADDED
Binary file (1.09 kB). View file
 
torchtitan/datasets/__pycache__/hf_datasets.cpython-312.pyc ADDED
Binary file (7.03 kB). View file
 
torchtitan/datasets/tokenizer/__pycache__/tiktoken.cpython-312.pyc ADDED
Binary file (7.73 kB). View file
 
torchtitan/datasets/tokenizer/tiktoken.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
8
+ # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
9
+
10
+ import os
11
+ from collections.abc import Collection, Iterator, Sequence, Set as AbstractSet
12
+ from pathlib import Path
13
+ from typing import cast, Literal
14
+
15
+ import tiktoken
16
+ from tiktoken.load import load_tiktoken_bpe
17
+
18
+ from torchtitan.components.tokenizer import Tokenizer
19
+ from torchtitan.config_manager import JobConfig
20
+ from torchtitan.tools.logging import logger
21
+
22
+
23
+ class TikTokenizer(Tokenizer):
24
+ """
25
+ Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
26
+
27
+ Args:
28
+ model_path (str): The path to the Tiktoken model file.
29
+ """
30
+
31
+ special_tokens: dict[str, int]
32
+
33
+ num_reserved_special_tokens = 256
34
+
35
+ pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501, B950
36
+
37
+ def __init__(self, model_path: str):
38
+ super().__init__()
39
+ assert os.path.exists(
40
+ model_path
41
+ ), f"The tokenizer path does not exist: {model_path}"
42
+ assert os.path.isfile(model_path), model_path
43
+
44
+ mergeable_ranks = load_tiktoken_bpe(model_path)
45
+ num_base_tokens = len(mergeable_ranks)
46
+ special_tokens = [
47
+ "<|begin_of_text|>",
48
+ "<|end_of_text|>",
49
+ "<|reserved_special_token_0|>",
50
+ "<|reserved_special_token_1|>",
51
+ "<|reserved_special_token_2|>",
52
+ "<|reserved_special_token_3|>",
53
+ "<|start_header_id|>",
54
+ "<|end_header_id|>",
55
+ "<|reserved_special_token_4|>",
56
+ "<|eot_id|>", # end of turn
57
+ ] + [
58
+ f"<|reserved_special_token_{i}|>"
59
+ for i in range(5, self.num_reserved_special_tokens - 5)
60
+ ]
61
+ self.special_tokens = {
62
+ token: num_base_tokens + i for i, token in enumerate(special_tokens)
63
+ }
64
+ self.model = tiktoken.Encoding(
65
+ name=Path(model_path).name,
66
+ pat_str=self.pat_str,
67
+ mergeable_ranks=mergeable_ranks,
68
+ special_tokens=self.special_tokens,
69
+ )
70
+
71
+ self._n_words: int = self.model.n_vocab
72
+ # BOS / EOS token IDs
73
+ self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
74
+ self.eos_id: int = self.special_tokens["<|end_of_text|>"]
75
+ self.pad_id: int = -1
76
+ self.stop_tokens = {
77
+ self.special_tokens["<|end_of_text|>"],
78
+ self.special_tokens["<|eot_id|>"],
79
+ }
80
+ logger.info(
81
+ f"TikTokenizer built: #words {self.n_words}, BOS ID {self.bos_id}, EOS ID {self.eos_id}"
82
+ )
83
+
84
+ def encode(
85
+ self,
86
+ s: str,
87
+ *,
88
+ bos: bool,
89
+ eos: bool,
90
+ allowed_special: Literal["all"] | AbstractSet[str] | None = None,
91
+ disallowed_special: Literal["all"] | Collection[str] | None = None,
92
+ ) -> list[int]:
93
+ """
94
+ Encodes a string into a list of token IDs.
95
+
96
+ Args:
97
+ s (str): The input string to be encoded.
98
+ bos (bool): Whether to prepend the beginning-of-sequence token.
99
+ eos (bool): Whether to append the end-of-sequence token.
100
+ allowed_tokens ("all"|set[str]): allowed special tokens in string
101
+ disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string
102
+
103
+ Returns:
104
+ list[int]: A list of token IDs.
105
+
106
+ By default, setting disallowed_special=() encodes a string by ignoring
107
+ special tokens. Specifically:
108
+ - Setting `disallowed_special` to () will cause all text corresponding
109
+ to special tokens to be encoded as natural text (insteading of raising
110
+ an error).
111
+ - Setting `allowed_special` to "all" will treat all text corresponding
112
+ to special tokens to be encoded as special tokens.
113
+ """
114
+ assert type(s) is str
115
+ allowed_special = allowed_special or set()
116
+ disallowed_special = disallowed_special or ()
117
+
118
+ # The tiktoken tokenizer can handle <=400k chars without
119
+ # pyo3_runtime.PanicException.
120
+ TIKTOKEN_MAX_ENCODE_CHARS = 400_000
121
+
122
+ # https://github.com/openai/tiktoken/issues/195
123
+ # Here we iterate over subsequences and split if we exceed the limit
124
+ # of max consecutive non-whitespace or whitespace characters.
125
+ MAX_NO_WHITESPACES_CHARS = 25_000
126
+
127
+ substrs = (
128
+ substr
129
+ for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
130
+ for substr in self._split_whitespaces_or_nonwhitespaces(
131
+ s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
132
+ )
133
+ )
134
+ t: list[int] = []
135
+ for substr in substrs:
136
+ t.extend(
137
+ self.model.encode(
138
+ substr,
139
+ allowed_special=allowed_special,
140
+ disallowed_special=disallowed_special,
141
+ )
142
+ )
143
+ if bos:
144
+ t.insert(0, self.bos_id)
145
+ if eos:
146
+ t.append(self.eos_id)
147
+ return t
148
+
149
+ def decode(self, t: Sequence[int]) -> str:
150
+ """
151
+ Decodes a list of token IDs into a string.
152
+
153
+ Args:
154
+ t (List[int]): The list of token IDs to be decoded.
155
+
156
+ Returns:
157
+ str: The decoded string.
158
+ """
159
+ # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
160
+ return self.model.decode(cast(list[int], t))
161
+
162
+ @staticmethod
163
+ def _split_whitespaces_or_nonwhitespaces(
164
+ s: str, max_consecutive_slice_len: int
165
+ ) -> Iterator[str]:
166
+ """
167
+ Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
168
+ consecutive whitespaces or consecutive non-whitespaces.
169
+ """
170
+ current_slice_len = 0
171
+ current_slice_is_space = s[0].isspace() if len(s) > 0 else False
172
+ slice_start = 0
173
+
174
+ for i in range(len(s)):
175
+ is_now_space = s[i].isspace()
176
+
177
+ if current_slice_is_space ^ is_now_space:
178
+ current_slice_len = 1
179
+ current_slice_is_space = is_now_space
180
+ else:
181
+ current_slice_len += 1
182
+ if current_slice_len > max_consecutive_slice_len:
183
+ yield s[slice_start:i]
184
+ slice_start = i
185
+ current_slice_len = 1
186
+ yield s[slice_start:]
187
+
188
+
189
+ def build_tiktoken_tokenizer(job_config: JobConfig) -> TikTokenizer:
190
+ return TikTokenizer(job_config.model.tokenizer_path)
torchtitan/distributed/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (251 Bytes). View file
 
torchtitan/distributed/__pycache__/parallel_dims.cpython-312.pyc ADDED
Binary file (6.11 kB). View file
 
torchtitan/distributed/__pycache__/pipeline.cpython-312.pyc ADDED
Binary file (7.82 kB). View file