Firworks commited on
Commit
9bedf50
·
verified ·
1 Parent(s): d31ba5b

Add NVFP4 quantized checkpoint

Browse files
README.md ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ datasets:
3
+ - HuggingFaceH4/ultrachat_200k
4
+ base_model:
5
+ - allenai/Bolmo-7B
6
+ ---
7
+ # Bolmo-7B-nvfp4
8
+
9
+ **Format:** NVFP4 — weights & activations quantized to FP4 with dual scaling.
10
+ **Base model:** `allenai/Bolmo-7B`
11
+ **How it was made:** One-shot calibration with LLM Compressor (NVFP4 recipe), long-seq calibration with HuggingFaceH4/ultrachat_200k.
12
+
13
+ > Notes: Keep `lm_head` in high precision; calibrate on long, domain-relevant sequences.
14
+
15
+ Check the original model card for information about this model.
16
+
17
+ # Running the model with VLLM in Docker
18
+ ```sh
19
+ sudo docker run --runtime nvidia --gpus all -p 8000:8000 --ipc=host vllm/vllm-openai:nightly --model Firworks/Bolmo-7B-nvfp4 --dtype auto --max-model-len 32768
20
+ ```
21
+ This was tested on an RTX Pro 6000 Blackwell cloud instance.
22
+
23
+ If there are other models you're interested in seeing quantized to NVFP4 for use on the DGX Spark, or other modern Blackwell (or newer) cards let me know. I'm trying to make more NVFP4 models available to allow more people to try them out.
config.json ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_expanded_embeddings": true,
3
+ "architectures": [
4
+ "BolmoForCausalLM"
5
+ ],
6
+ "attention_bias": false,
7
+ "attention_dropout": 0.0,
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_bolmo.BolmoConfig",
10
+ "AutoModelForCausalLM": "modeling_bolmo.BolmoForCausalLM"
11
+ },
12
+ "bos_token_id": 1,
13
+ "boundary_predictor_lookahead": 1,
14
+ "boundary_threshold": "sample:0",
15
+ "dtype": "float32",
16
+ "eos_token_id": 1,
17
+ "hidden_act": "silu",
18
+ "hidden_size": 4096,
19
+ "initializer_range": 0.02,
20
+ "intermediate_size": 11008,
21
+ "layer_types": [
22
+ "sliding_attention",
23
+ "sliding_attention",
24
+ "sliding_attention",
25
+ "full_attention",
26
+ "sliding_attention",
27
+ "sliding_attention",
28
+ "sliding_attention",
29
+ "full_attention",
30
+ "sliding_attention",
31
+ "sliding_attention",
32
+ "sliding_attention",
33
+ "full_attention",
34
+ "sliding_attention",
35
+ "sliding_attention",
36
+ "sliding_attention",
37
+ "full_attention",
38
+ "sliding_attention",
39
+ "sliding_attention",
40
+ "sliding_attention",
41
+ "full_attention",
42
+ "sliding_attention",
43
+ "sliding_attention",
44
+ "sliding_attention",
45
+ "full_attention",
46
+ "sliding_attention",
47
+ "sliding_attention",
48
+ "sliding_attention",
49
+ "full_attention",
50
+ "sliding_attention",
51
+ "sliding_attention",
52
+ "sliding_attention",
53
+ "full_attention"
54
+ ],
55
+ "local_intermediate_size": 5504,
56
+ "local_rms_norm_eps": 1e-05,
57
+ "max_position_embeddings": 65536,
58
+ "model_type": "bolmo",
59
+ "num_attention_heads": 32,
60
+ "num_hidden_layers": 32,
61
+ "num_key_value_heads": 32,
62
+ "num_local_decoder_layers": 4,
63
+ "num_local_encoder_layers": 1,
64
+ "num_local_heads": 16,
65
+ "pad_token_id": 0,
66
+ "quantization_config": {
67
+ "config_groups": {
68
+ "group_0": {
69
+ "format": "nvfp4-pack-quantized",
70
+ "input_activations": {
71
+ "actorder": null,
72
+ "block_structure": null,
73
+ "dynamic": "local",
74
+ "group_size": 16,
75
+ "num_bits": 4,
76
+ "observer": "minmax",
77
+ "observer_kwargs": {},
78
+ "strategy": "tensor_group",
79
+ "symmetric": true,
80
+ "type": "float"
81
+ },
82
+ "output_activations": null,
83
+ "targets": [
84
+ "Linear"
85
+ ],
86
+ "weights": {
87
+ "actorder": null,
88
+ "block_structure": null,
89
+ "dynamic": false,
90
+ "group_size": 16,
91
+ "num_bits": 4,
92
+ "observer": "minmax",
93
+ "observer_kwargs": {},
94
+ "strategy": "tensor_group",
95
+ "symmetric": true,
96
+ "type": "float"
97
+ }
98
+ }
99
+ },
100
+ "format": "nvfp4-pack-quantized",
101
+ "global_compression_ratio": null,
102
+ "ignore": [
103
+ "lm_head"
104
+ ],
105
+ "kv_cache_scheme": null,
106
+ "quant_method": "compressed-tensors",
107
+ "quantization_status": "compressed",
108
+ "sparsity_config": {},
109
+ "transform_config": {},
110
+ "version": "0.12.2"
111
+ },
112
+ "rms_norm_eps": 1e-06,
113
+ "rope_scaling": {
114
+ "attention_factor": 1.2079441541679836,
115
+ "beta_fast": 32,
116
+ "beta_slow": 1,
117
+ "factor": 8.0,
118
+ "original_max_position_embeddings": 8192,
119
+ "rope_type": "yarn"
120
+ },
121
+ "rope_theta": 500000,
122
+ "sliding_window": 4096,
123
+ "subword_vocab_size": 100278,
124
+ "tie_word_embeddings": false,
125
+ "tokenizer_config": {
126
+ "bos_token_id": 1,
127
+ "bpe_token_end_id": 3,
128
+ "eos_token_id": 1,
129
+ "original_identifier": "allenai/dolma2-tokenizer",
130
+ "pad_token_id": 0,
131
+ "special_tokens": [
132
+ "<pad>",
133
+ "<bos>",
134
+ "<eos>",
135
+ "<bpe_token_end>"
136
+ ],
137
+ "special_tokens_first": true,
138
+ "vocab_size": 520
139
+ },
140
+ "transformers_version": "4.57.3",
141
+ "use_cache": true,
142
+ "vocab_size": 520
143
+ }
configuration_bolmo.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import asdict
2
+ from typing import Any
3
+
4
+ from transformers.configuration_utils import PretrainedConfig, layer_type_validation
5
+ from transformers.modeling_rope_utils import rope_config_validation
6
+ from .tokenization_bolmo import BolmoTokenizerConfig
7
+
8
+ class BolmoConfig(PretrainedConfig):
9
+ r"""
10
+ This is the configuration class to store the configuration of a [`Olmo3Model`]. It is used to instantiate an OLMo3
11
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
12
+ defaults will yield a similar configuration to that of the [allenai/OLMo-3-0725-1B](https://huggingface.co/allenai/OLMo-3-0725-1B).
13
+
14
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
15
+ documentation from [`PretrainedConfig`] for more information.
16
+
17
+
18
+ Args:
19
+ vocab_size (`int`, *optional*, defaults to 50304):
20
+ Vocabulary size of the Olmo3 model. Defines the number of different tokens that can be represented by the
21
+ `inputs_ids` passed when calling [`Olmo3Model`]
22
+ hidden_size (`int`, *optional*, defaults to 4096):
23
+ Dimension of the hidden representations.
24
+ intermediate_size (`int`, *optional*, defaults to 11008):
25
+ Dimension of the MLP representations.
26
+ num_hidden_layers (`int`, *optional*, defaults to 32):
27
+ Number of hidden layers in the Transformer decoder.
28
+ num_attention_heads (`int`, *optional*, defaults to 32):
29
+ Number of attention heads for each attention layer in the Transformer decoder.
30
+ num_key_value_heads (`int`, *optional*):
31
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
32
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
33
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
34
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
35
+ by meanpooling all the original heads within that group. For more details, check out [this
36
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
37
+ `num_attention_heads`.
38
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
39
+ The non-linear activation function (function or string) in the decoder.
40
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
41
+ The maximum sequence length that this model might ever be used with.
42
+ initializer_range (`float`, *optional*, defaults to 0.02):
43
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
44
+ use_cache (`bool`, *optional*, defaults to `True`):
45
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
46
+ relevant if `config.is_decoder=True`.
47
+ pad_token_id (`int`, *optional*, defaults to 1):
48
+ Padding token id.
49
+ bos_token_id (`int`, *optional*):
50
+ Beginning of stream token id.
51
+ eos_token_id (`int`, *optional*, defaults to 50279):
52
+ End of stream token id.
53
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
54
+ Whether to tie weight embeddings
55
+ rope_theta (`float`, *optional*, defaults to 10000.0):
56
+ The base period of the RoPE embeddings.
57
+ rope_scaling (`Dict`, *optional*):
58
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
59
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
60
+ accordingly.
61
+ Expected contents:
62
+ `rope_type` (`str`):
63
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
64
+ 'llama3'], with 'default' being the original RoPE implementation.
65
+ `factor` (`float`, *optional*):
66
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
67
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
68
+ original maximum pre-trained length.
69
+ `original_max_position_embeddings` (`int`, *optional*):
70
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
71
+ pretraining.
72
+ `attention_factor` (`float`, *optional*):
73
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
74
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
75
+ `factor` field to infer the suggested value.
76
+ `beta_fast` (`float`, *optional*):
77
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
78
+ ramp function. If unspecified, it defaults to 32.
79
+ `beta_slow` (`float`, *optional*):
80
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
81
+ ramp function. If unspecified, it defaults to 1.
82
+ `short_factor` (`list[float]`, *optional*):
83
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
84
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
85
+ size divided by the number of attention heads divided by 2
86
+ `long_factor` (`list[float]`, *optional*):
87
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
88
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
89
+ size divided by the number of attention heads divided by 2
90
+ `low_freq_factor` (`float`, *optional*):
91
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
92
+ `high_freq_factor` (`float`, *optional*):
93
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
94
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
95
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
96
+ attention_dropout (`float`, *optional*, defaults to 0.0):
97
+ The dropout ratio for the attention probabilities.
98
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
99
+ The epsilon used by the rms normalization layers.
100
+ sliding_window (`int`, *optional*, defaults to 4096):
101
+ Size of the sliding window for sliding window attention.
102
+ layer_types (`list`, *optional*):
103
+ Attention pattern for each layer. Defaults to sliding window attention
104
+ for 3 out of 4 layers, and full attention for every 4th layer.
105
+
106
+ ```python
107
+ >>> from transformers import Olmo3Model, Olmo3Config
108
+
109
+ >>> # Initializing a Olmo3 7B style configuration
110
+ >>> configuration = Olmo3Config()
111
+
112
+ >>> # Initializing a model from the Olmo3 7B style configuration
113
+ >>> model = Olmo3Model(configuration)
114
+
115
+ >>> # Accessing the model configuration
116
+ >>> configuration = model.config
117
+ ```
118
+ """
119
+
120
+ model_type = "bolmo"
121
+ keys_to_ignore_at_inference = ["past_key_values"]
122
+ base_model_tp_plan = {
123
+ "layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
124
+ "layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
125
+ "layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
126
+ "layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
127
+ "layers.*.mlp.gate_proj": "colwise",
128
+ "layers.*.mlp.up_proj": "colwise",
129
+ "layers.*.mlp.down_proj": "rowwise",
130
+ }
131
+ base_model_pp_plan = {
132
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
133
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
134
+ "norm": (["hidden_states"], ["hidden_states"]),
135
+ }
136
+
137
+ def __init__(
138
+ self,
139
+ vocab_size=520,
140
+ hidden_size=4096,
141
+ intermediate_size=11008,
142
+ num_hidden_layers=32,
143
+ num_attention_heads=32,
144
+ num_key_value_heads=None,
145
+ hidden_act="silu",
146
+ max_position_embeddings=2048,
147
+ initializer_range=0.02,
148
+ use_cache=True,
149
+ pad_token_id=1,
150
+ bos_token_id=None,
151
+ eos_token_id=50279,
152
+ tie_word_embeddings=False,
153
+ rope_theta=10000.0,
154
+ rope_scaling=None,
155
+ attention_bias=False,
156
+ attention_dropout=0.0,
157
+ rms_norm_eps=1e-5,
158
+ sliding_window=4096,
159
+ layer_types=None,
160
+ # bolmo config
161
+ add_expanded_embeddings: bool = True,
162
+ boundary_predictor_lookahead: int = 1,
163
+ boundary_threshold: str = "sample:0",
164
+ num_local_encoder_layers: int = 1,
165
+ num_local_decoder_layers: int = 4,
166
+ num_local_heads: int = 16,
167
+ local_intermediate_size: int = 5504,
168
+ local_rms_norm_eps=1e-5,
169
+ subword_vocab_size: int = 100278, # dolma2_tokenizer subword vocab size
170
+ tokenizer_config: BolmoTokenizerConfig | dict[str, Any] | None = None,
171
+ **kwargs,
172
+ ):
173
+ super().__init__(
174
+ pad_token_id=pad_token_id,
175
+ bos_token_id=bos_token_id,
176
+ eos_token_id=eos_token_id,
177
+ tie_word_embeddings=tie_word_embeddings,
178
+ **kwargs,
179
+ )
180
+ self.vocab_size = vocab_size
181
+ self.max_position_embeddings = max_position_embeddings
182
+ self.hidden_size = hidden_size
183
+ self.intermediate_size = intermediate_size
184
+ self.num_hidden_layers = num_hidden_layers
185
+ self.num_attention_heads = num_attention_heads
186
+
187
+ # for backward compatibility
188
+ if num_key_value_heads is None:
189
+ num_key_value_heads = num_attention_heads
190
+
191
+ self.num_key_value_heads = num_key_value_heads
192
+ self.hidden_act = hidden_act
193
+ self.initializer_range = initializer_range
194
+ self.use_cache = use_cache
195
+ self.rope_theta = rope_theta
196
+ self.rope_scaling = rope_scaling
197
+ self._rope_scaling_validation()
198
+ self.attention_bias = attention_bias
199
+ self.attention_dropout = attention_dropout
200
+
201
+ self.rms_norm_eps = rms_norm_eps
202
+
203
+ self.sliding_window = sliding_window
204
+ self.layer_types = layer_types
205
+ if self.layer_types is None:
206
+ self.layer_types = [
207
+ "sliding_attention" if (i + 1) % 4 != 0 else "full_attention" for i in range(self.num_hidden_layers)
208
+ ]
209
+ layer_type_validation(self.layer_types)
210
+
211
+ # bolmo configuration
212
+ self.add_expanded_embeddings = add_expanded_embeddings
213
+ self.boundary_predictor_lookahead = boundary_predictor_lookahead
214
+ self.boundary_threshold = boundary_threshold
215
+ self.num_local_encoder_layers = num_local_encoder_layers
216
+ self.num_local_decoder_layers = num_local_decoder_layers
217
+ self.num_local_heads = num_local_heads
218
+ self.local_intermediate_size = local_intermediate_size
219
+ self.local_rms_norm_eps = local_rms_norm_eps
220
+ self.subword_vocab_size = subword_vocab_size
221
+
222
+ if tokenizer_config is None:
223
+ self.tokenizer_config = asdict(BolmoTokenizerConfig.bolmo())
224
+ elif isinstance(tokenizer_config, BolmoTokenizerConfig):
225
+ self.tokenizer_config = asdict(tokenizer_config)
226
+ else:
227
+ self.tokenizer_config = tokenizer_config
228
+
229
+ def _rope_scaling_validation(self):
230
+ """
231
+ Validate the `rope_scaling` configuration.
232
+ """
233
+ rope_config_validation(self)
234
+
235
+ __all__ = ["BolmoConfig"]
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "do_sample": true,
4
+ "eos_token_id": 50279,
5
+ "pad_token_id": 1,
6
+ "transformers_version": "4.57.3"
7
+ }
model-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:99b1a3015951f3907501cc07ed492c6e60e5ade94b0fe62faf09f15d8ef58c87
3
+ size 4979678176
model-00002-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff6107aa16c14af67f8169fb536c4c94b70e8d0a48b11bb67f6d6e2ce5d829a2
3
+ size 742727128
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_bolmo.py ADDED
@@ -0,0 +1,1351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Callable, Optional, Union, cast
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn import functional as F
8
+
9
+ from transformers.utils.generic import TransformersKwargs
10
+
11
+ from transformers.activations import ACT2FN
12
+ from transformers.cache_utils import Cache, DynamicCache
13
+ from transformers.generation import GenerationMixin, GenerationConfig, LogitsProcessorList, StoppingCriteriaList
14
+ from transformers.generation.utils import GenerateOutput
15
+ from transformers.integrations import use_kernel_forward_from_hub
16
+ from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
17
+ from transformers.modeling_layers import GradientCheckpointingLayer
18
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
19
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
20
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
21
+ from transformers.processing_utils import Unpack
22
+ from transformers.utils import can_return_tuple
23
+ from transformers.utils.deprecation import deprecate_kwarg
24
+ from transformers.utils.generic import check_model_inputs
25
+
26
+ from .configuration_bolmo import BolmoConfig
27
+ from .tokenization_bolmo import BolmoTokenizerConfig
28
+ from .utils_bolmo import compute_boundary_mask, pad_right, pad_left, MaskState
29
+
30
+ try:
31
+ from xlstm.xlstm_large.model import mLSTMLayer, mLSTMLayerConfig, mLSTMLayerStateType, soft_cap, mLSTMBackendConfig
32
+ except ImportError:
33
+ raise ImportError("The `xlstm` package is required to use Bolmo. Please install it via `pip install xlstm`.")
34
+
35
+
36
+ @use_kernel_forward_from_hub("RMSNorm")
37
+ class BolmoRMSNorm(nn.Module):
38
+ def __init__(self, hidden_size, eps=1e-6):
39
+ """
40
+ BolmoRMSNorm is equivalent to T5LayerNorm
41
+ """
42
+ super().__init__()
43
+ self.weight = nn.Parameter(torch.ones(hidden_size))
44
+ self.variance_epsilon = eps
45
+
46
+ def forward(self, hidden_states):
47
+ input_dtype = hidden_states.dtype
48
+ hidden_states = hidden_states.to(torch.float32)
49
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
50
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
51
+ return (self.weight * hidden_states).to(input_dtype)
52
+
53
+ def extra_repr(self):
54
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
55
+
56
+
57
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
58
+ """
59
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
60
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
61
+ """
62
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
63
+ if n_rep == 1:
64
+ return hidden_states
65
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
66
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
67
+
68
+
69
+ def eager_attention_forward(
70
+ module: nn.Module,
71
+ query: torch.Tensor,
72
+ key: torch.Tensor,
73
+ value: torch.Tensor,
74
+ attention_mask: Optional[torch.Tensor],
75
+ scaling: float,
76
+ dropout: float = 0.0,
77
+ **kwargs: Unpack[TransformersKwargs],
78
+ ):
79
+ key_states = repeat_kv(key, module.num_key_value_groups)
80
+ value_states = repeat_kv(value, module.num_key_value_groups)
81
+
82
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
83
+ if attention_mask is not None:
84
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
85
+ attn_weights = attn_weights + causal_mask
86
+
87
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
88
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
89
+ attn_output = torch.matmul(attn_weights, value_states)
90
+ attn_output = attn_output.transpose(1, 2).contiguous()
91
+
92
+ return attn_output, attn_weights
93
+
94
+
95
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
96
+ """Applies Rotary Position Embedding to the query and key tensors.
97
+
98
+ Args:
99
+ q (`torch.Tensor`): The query tensor.
100
+ k (`torch.Tensor`): The key tensor.
101
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
102
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
103
+ position_ids (`torch.Tensor`, *optional*):
104
+ Deprecated and unused.
105
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
106
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
107
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
108
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
109
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
110
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
111
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
112
+ Returns:
113
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
114
+ """
115
+ q_type, k_type = q.dtype, k.dtype
116
+ cos = cos.unsqueeze(unsqueeze_dim)
117
+ sin = sin.unsqueeze(unsqueeze_dim)
118
+ q_embed = (q * cos) + (rotate_half(q) * sin)
119
+ k_embed = (k * cos) + (rotate_half(k) * sin)
120
+ return q_embed.to(q_type), k_embed.to(k_type)
121
+
122
+
123
+ def rotate_half(x):
124
+ """Rotates half the hidden dims of the input."""
125
+ x1 = x[..., : x.shape[-1] // 2]
126
+ x2 = x[..., x.shape[-1] // 2 :]
127
+ return torch.cat((-x2, x1), dim=-1)
128
+
129
+
130
+ class BolmoAttention(nn.Module):
131
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
132
+
133
+ def __init__(self, config: BolmoConfig, layer_idx: int):
134
+ super().__init__()
135
+ self.config = config
136
+ self.layer_idx = layer_idx
137
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
138
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
139
+ self.scaling = self.head_dim**-0.5
140
+ self.attention_dropout = config.attention_dropout
141
+ self.is_causal = True
142
+
143
+ self.q_proj = nn.Linear(
144
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
145
+ )
146
+ self.k_proj = nn.Linear(
147
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
148
+ )
149
+ self.v_proj = nn.Linear(
150
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
151
+ )
152
+ self.o_proj = nn.Linear(
153
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
154
+ )
155
+ self.q_norm = BolmoRMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps)
156
+ self.k_norm = BolmoRMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps)
157
+ assert config.layer_types is not None
158
+ self.attention_type = config.layer_types[layer_idx]
159
+ self.sliding_window = config.sliding_window if self.attention_type == "sliding_attention" else None
160
+
161
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
162
+ def forward(
163
+ self,
164
+ hidden_states: torch.Tensor,
165
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
166
+ attention_mask: Optional[torch.Tensor],
167
+ past_key_values: Optional[Cache] = None,
168
+ cache_position: Optional[torch.Tensor] = None,
169
+ **kwargs: Unpack[TransformersKwargs],
170
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
171
+ input_shape = hidden_states.shape[:-1]
172
+ hidden_shape = (*input_shape, -1, self.head_dim)
173
+
174
+ query_states = self.q_norm(self.q_proj(hidden_states))
175
+ key_states = self.k_norm(self.k_proj(hidden_states))
176
+ value_states = self.v_proj(hidden_states)
177
+
178
+ query_states = query_states.view(hidden_shape).transpose(1, 2)
179
+ key_states = key_states.view(hidden_shape).transpose(1, 2)
180
+ value_states = value_states.view(hidden_shape).transpose(1, 2)
181
+
182
+ cos, sin = position_embeddings
183
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
184
+
185
+ if past_key_values is not None:
186
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
187
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
188
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
189
+
190
+ attention_interface: Callable = eager_attention_forward
191
+ if self.config._attn_implementation != "eager":
192
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
193
+
194
+ attn_output, attn_weights = attention_interface(
195
+ self,
196
+ query_states,
197
+ key_states,
198
+ value_states,
199
+ attention_mask,
200
+ dropout=0.0 if not self.training else self.attention_dropout,
201
+ scaling=self.scaling,
202
+ sliding_window=self.sliding_window,
203
+ **kwargs,
204
+ )
205
+
206
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
207
+ attn_output = self.o_proj(attn_output)
208
+ return attn_output, attn_weights
209
+
210
+
211
+ class BolmoMLP(nn.Module):
212
+ def __init__(self, config):
213
+ super().__init__()
214
+ self.config = config
215
+ self.hidden_size = config.hidden_size
216
+ self.intermediate_size = config.intermediate_size
217
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
218
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
219
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
220
+ self.act_fn = ACT2FN[config.hidden_act]
221
+
222
+ def forward(self, x):
223
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
224
+ return down_proj
225
+
226
+
227
+ class BolmoDecoderLayer(GradientCheckpointingLayer):
228
+ def __init__(self, config: BolmoConfig, layer_idx: int):
229
+ super().__init__()
230
+ self.hidden_size = config.hidden_size
231
+ self.self_attn = BolmoAttention(config=config, layer_idx=layer_idx)
232
+
233
+ self.mlp = BolmoMLP(config)
234
+ self.post_attention_layernorm = BolmoRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
235
+ self.post_feedforward_layernorm = BolmoRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
236
+
237
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
238
+ def forward(
239
+ self,
240
+ hidden_states: torch.Tensor,
241
+ attention_mask: Optional[torch.Tensor] = None,
242
+ position_ids: Optional[torch.Tensor] = None,
243
+ past_key_values: Optional[Cache] = None,
244
+ use_cache: Optional[bool] = False,
245
+ cache_position: Optional[torch.Tensor] = None,
246
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
247
+ **kwargs: Unpack[TransformersKwargs],
248
+ ) -> torch.Tensor:
249
+ residual = hidden_states
250
+ attn_out, _ = self.self_attn(
251
+ hidden_states=hidden_states,
252
+ attention_mask=attention_mask,
253
+ position_ids=position_ids,
254
+ past_key_values=past_key_values,
255
+ use_cache=use_cache,
256
+ cache_position=cache_position,
257
+ position_embeddings=position_embeddings,
258
+ **kwargs,
259
+ )
260
+ hidden_states = self.post_attention_layernorm(attn_out)
261
+ hidden_states = residual + hidden_states
262
+
263
+ # Fully Connected
264
+ residual = hidden_states
265
+ mlp_out = self.mlp(hidden_states)
266
+ hidden_states = self.post_feedforward_layernorm(mlp_out)
267
+ hidden_states = residual + hidden_states
268
+
269
+ return hidden_states
270
+
271
+
272
+ class BolmoBoundaryPredictor(nn.Module):
273
+ def __init__(self, config: BolmoConfig):
274
+ super().__init__()
275
+
276
+ self.d_model = config.hidden_size
277
+ self.boundary_threshold = config.boundary_threshold
278
+ self.boundary_predictor_lookahead = config.boundary_predictor_lookahead
279
+ self.q_proj_layer = nn.Linear(self.d_model, self.d_model, bias=False)
280
+ self.k_proj_layer = nn.Linear(self.d_model, self.d_model, bias=False)
281
+
282
+ def forward(
283
+ self,
284
+ hidden_states: torch.Tensor,
285
+ sequence_start_indices: Optional[torch.Tensor] = None,
286
+ epsilon: float = 1e-3,
287
+ ) -> tuple[torch.Tensor, torch.Tensor]:
288
+ if self.boundary_predictor_lookahead == 0:
289
+ # do not use the same rep for k and v, use current and one before as in H-Net + pad with negative to the left
290
+ cos_sim = torch.cat([
291
+ torch.ones((hidden_states.shape[0], 1), device=hidden_states.device, dtype=hidden_states.dtype) * -1,
292
+ torch.einsum(
293
+ "b l d, b l d -> b l",
294
+ F.normalize(self.q_proj_layer(hidden_states[:, :-1]), dim=-1),
295
+ F.normalize(self.k_proj_layer(hidden_states[:, 1:]), dim=-1),
296
+ )
297
+ ], dim=1)
298
+ else:
299
+ cos_sim = torch.einsum(
300
+ "b l d, b l d -> b l",
301
+ F.normalize(self.q_proj_layer(hidden_states[:, :-self.boundary_predictor_lookahead]), dim=-1),
302
+ F.normalize(self.k_proj_layer(hidden_states[:, self.boundary_predictor_lookahead:]), dim=-1),
303
+ )
304
+ boundary_logprobs = torch.log1p(-cos_sim.float().clip(max=1.0 - epsilon)) - math.log(2)
305
+ POSITIVE_LOGPROB = 0.0
306
+ NEGATIVE_LOGPROB = -100_000
307
+ if sequence_start_indices is None:
308
+ boundary_logprobs[:, 0] = POSITIVE_LOGPROB
309
+ else:
310
+ pad_mask = torch.arange(boundary_logprobs.shape[1], device=boundary_logprobs.device)[None, :] < sequence_start_indices[:, None]
311
+ boundary_logprobs = boundary_logprobs.masked_fill(pad_mask, NEGATIVE_LOGPROB)
312
+ boundary_logprobs[torch.arange(len(boundary_logprobs), device=boundary_logprobs.device), sequence_start_indices] = POSITIVE_LOGPROB
313
+
314
+ boundary_logprobs = F.pad(boundary_logprobs, (0, self.boundary_predictor_lookahead), "constant", NEGATIVE_LOGPROB)
315
+ boundary_mask = compute_boundary_mask(boundary_logprobs, self.boundary_threshold)
316
+
317
+ return boundary_logprobs, boundary_mask
318
+
319
+
320
+ class BolmoXLSTMLayer(mLSTMLayer):
321
+ def __init__(self, config: BolmoConfig):
322
+ super().__init__(mLSTMLayerConfig(
323
+ embedding_dim=config.hidden_size,
324
+ num_heads=config.num_local_heads,
325
+ mlstm_backend=mLSTMBackendConfig(
326
+ chunkwise_kernel="chunkwise--triton_limit_chunk",
327
+ sequence_kernel="native_sequence__triton",
328
+ step_kernel="triton",
329
+ mode="train",
330
+ return_last_states=True,
331
+ autocast_kernel_dtype="float32",
332
+ )
333
+ ))
334
+
335
+ # original forward adapted to support sequence_start_indices
336
+ # i.e. set the forget gate to zero at the start of sequence
337
+ def _original_forward(
338
+ self, x: torch.Tensor,
339
+ state: mLSTMLayerStateType | None = None,
340
+ sequence_start_indices: Optional[torch.Tensor] = None,
341
+ ) -> tuple[torch.Tensor, mLSTMLayerStateType | None]:
342
+ assert x.ndim == 3, f"Input must have shape [B, S, D], got {x.shape}"
343
+ B, S, _ = x.shape
344
+ if self.config.weight_mode == "single":
345
+ q = self.q(x)
346
+ k = self.k(x)
347
+ v = self.v(x)
348
+ o_preact = self.ogate_preact(x)
349
+ i_preact = soft_cap(
350
+ self.igate_preact(x), cap_value=self.config.gate_soft_cap
351
+ )
352
+ f_preact = soft_cap(
353
+ self.fgate_preact(x), cap_value=self.config.gate_soft_cap
354
+ )
355
+ elif self.config.weight_mode == "fused":
356
+ qkv_opreact = self.qkv_opreact(x)
357
+ q, k, v, o_preact = torch.tensor_split(
358
+ qkv_opreact,
359
+ (
360
+ self.qk_dim,
361
+ 2 * self.qk_dim,
362
+ 2 * self.qk_dim + self.v_dim,
363
+ ),
364
+ dim=-1,
365
+ )
366
+
367
+ if_preact = soft_cap(
368
+ self.ifgate_preact(x), cap_value=self.config.gate_soft_cap
369
+ )
370
+ i_preact, f_preact = torch.tensor_split(
371
+ if_preact, (self.config.num_heads,), dim=-1
372
+ )
373
+ else:
374
+ raise ValueError(f"Unknown weight_mode: {self.config.weight_mode}")
375
+
376
+ q = q.reshape(B, S, self.config.num_heads, -1).transpose(1, 2)
377
+ k = k.reshape(B, S, self.config.num_heads, -1).transpose(1, 2)
378
+ v = v.reshape(B, S, self.config.num_heads, -1).transpose(1, 2)
379
+
380
+ if sequence_start_indices is not None:
381
+ f_preact[torch.arange(B, device=f_preact.device), sequence_start_indices] = -100_000
382
+
383
+ i_preact = i_preact.transpose(1, 2)
384
+ f_preact = f_preact.transpose(1, 2)
385
+ if state is None:
386
+ c_initial, n_initial, m_initial = None, None, None
387
+ else:
388
+ c_initial, n_initial, m_initial = state
389
+
390
+ h, state = self.mlstm_backend(
391
+ q=q,
392
+ k=k,
393
+ v=v,
394
+ i=i_preact,
395
+ f=f_preact,
396
+ c_initial=c_initial,
397
+ n_initial=n_initial,
398
+ m_initial=m_initial,
399
+ )
400
+ expected_h_shape = (
401
+ B,
402
+ self.config.num_heads,
403
+ S,
404
+ self.v_dim // self.config.num_heads,
405
+ )
406
+ assert (
407
+ h.shape == expected_h_shape
408
+ ), f"Got {h.shape}, expected {expected_h_shape}"
409
+
410
+ h = h.transpose(1, 2)
411
+ h_norm = self.multihead_norm(h)
412
+ h_norm = h_norm.reshape(B, S, -1)
413
+
414
+ h_out = self.ogate_act_fn(o_preact) * h_norm
415
+
416
+ y = self.out_proj(h_out)
417
+ return y, state
418
+
419
+ def forward( # type: ignore
420
+ self,
421
+ x: torch.Tensor,
422
+ past_key_values: Optional[dict] = None,
423
+ use_cache: bool = False,
424
+ sequence_start_indices: Optional[torch.Tensor] = None,
425
+ cache_mask: Optional[MaskState] = None
426
+ ):
427
+ if self.training:
428
+ self.mlstm_backend.config.mode = "train"
429
+ else:
430
+ self.mlstm_backend.config.mode = "inference"
431
+
432
+ if use_cache:
433
+ assert past_key_values is not None
434
+
435
+ prev_mode = self.mlstm_backend.config.mode
436
+ state = past_key_values.get("state", None)
437
+
438
+ if cache_mask is not None:
439
+ state_for_model = cast(mLSTMLayerStateType, tuple(cache_mask.selective_get(x, inv=True) for x in state) if state is not None else None)
440
+ else:
441
+ state_for_model = state
442
+
443
+ h, new_state = self._original_forward(
444
+ x,
445
+ state=state_for_model,
446
+ sequence_start_indices=sequence_start_indices
447
+ )
448
+ assert new_state is not None
449
+
450
+ if state is None or cache_mask is None:
451
+ state = new_state
452
+ else:
453
+ if cache_mask is not None:
454
+ for i in range(len(state)):
455
+ cache_mask.selective_put(new_state[i], state[i], inv=True)
456
+
457
+ past_key_values["state"] = state
458
+ self.mlstm_backend.config.mode = prev_mode
459
+
460
+ return h
461
+ else:
462
+ h, _ = super().forward(x)
463
+ return h
464
+
465
+ class BolmoLocalLayer(nn.Module):
466
+ def __init__(self, config: BolmoConfig):
467
+ super().__init__()
468
+ self.config = config
469
+ self.hidden_size = config.hidden_size
470
+
471
+ self.act_fn = ACT2FN[config.hidden_act]
472
+
473
+ self.xlstm = BolmoXLSTMLayer(config)
474
+
475
+ local_mlp_config = copy.deepcopy(config)
476
+ local_mlp_config.intermediate_size = config.local_intermediate_size
477
+ self.mlp = BolmoMLP(local_mlp_config)
478
+
479
+ self.pre_xlstm_layernorm = BolmoRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
480
+ self.pre_feedforward_layernorm = BolmoRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
481
+
482
+ def forward(
483
+ self,
484
+ hidden_states: torch.Tensor,
485
+ sequence_start_indices: Optional[torch.Tensor] = None,
486
+ past_key_values: Optional[dict] = None,
487
+ use_cache: Optional[bool] = False,
488
+ cache_mask: Optional[MaskState] = None,
489
+ ) -> torch.Tensor:
490
+ residual = hidden_states
491
+ xlstm_out = self.xlstm(self.pre_xlstm_layernorm(hidden_states), sequence_start_indices=sequence_start_indices, past_key_values=past_key_values["xlstm"] if past_key_values is not None else None, use_cache=use_cache, cache_mask=cache_mask)
492
+ hidden_states = residual + xlstm_out
493
+
494
+ # Fully Connected
495
+ residual = hidden_states
496
+ ffn_out = self.mlp(self.pre_feedforward_layernorm(hidden_states))
497
+ hidden_states = residual + ffn_out
498
+
499
+ return hidden_states
500
+
501
+
502
+ class BolmoLocalEncoder(nn.Module):
503
+ def __init__(self, config: BolmoConfig):
504
+ super().__init__()
505
+ self.config = config
506
+ self.hidden_size = config.hidden_size
507
+ self.add_expanded_embeddings = config.add_expanded_embeddings
508
+
509
+ self.byte_embedding = nn.Embedding(
510
+ config.vocab_size,
511
+ self.hidden_size,
512
+ )
513
+ if self.add_expanded_embeddings:
514
+ self.subword_embedding = nn.Embedding(
515
+ config.subword_vocab_size,
516
+ self.hidden_size,
517
+ )
518
+ else:
519
+ self.subword_embedding = None
520
+
521
+ self.layers = nn.ModuleList(
522
+ [BolmoLocalLayer(config) for _ in range(config.num_local_encoder_layers)]
523
+ )
524
+
525
+ self.post_last_block_norm = BolmoRMSNorm(
526
+ self.hidden_size,
527
+ config.local_rms_norm_eps,
528
+ )
529
+ self.out_projection = nn.Linear(
530
+ self.hidden_size,
531
+ self.hidden_size,
532
+ bias=True,
533
+ )
534
+
535
+ self.boundary_predictor_module = BolmoBoundaryPredictor(config)
536
+
537
+ self.has_cache = False
538
+
539
+ def prepare_inference_cache(self, batch_size: int):
540
+ device = next(self.parameters()).device
541
+ self.has_cache = True
542
+
543
+ self.cache_seqlens = 0
544
+ self.last_h = torch.zeros((batch_size, self.hidden_size), dtype=self.out_projection.weight.dtype, device=device)
545
+ self.layer_states = [{"xlstm": {}} for _ in range(len(self.layers))]
546
+
547
+ def free_inference_cache(self):
548
+ self.has_cache = False
549
+ if hasattr(self, "cache_seqlens"):
550
+ del self.cache_seqlens
551
+ if hasattr(self, "last_h"):
552
+ del self.last_h
553
+ if hasattr(self, "layer_states"):
554
+ del self.layer_states
555
+
556
+ def _embed(self, tokens, expanded_input_ids: Optional[torch.Tensor] = None):
557
+ embeddings = self.byte_embedding(tokens)
558
+ if self.add_expanded_embeddings:
559
+ assert expanded_input_ids is not None and self.subword_embedding is not None
560
+ embeddings = embeddings + self.subword_embedding(expanded_input_ids)
561
+
562
+ return embeddings
563
+
564
+ def _pool(
565
+ self,
566
+ h: torch.Tensor,
567
+ boundary_mask: torch.Tensor | None,
568
+ n_patches: int,
569
+ boundary_state: Optional[MaskState] = None,
570
+ ):
571
+ if self.has_cache and self.cache_seqlens > 0:
572
+ assert boundary_state is not None
573
+ if boundary_state.all():
574
+ assert h.shape[1] == 1
575
+ reduced_h = h
576
+ else:
577
+ reduced_h = h[[], :, :]
578
+ else:
579
+ assert boundary_mask is not None
580
+
581
+ L = h.shape[1]
582
+ token_idx = (
583
+ torch.arange(L, device=h.device)[None, :] + (~boundary_mask).long() * L # type: ignore
584
+ )
585
+ seq_sorted_indices = torch.argsort(token_idx, dim=1)
586
+ index = seq_sorted_indices[:, :n_patches, None].expand(
587
+ -1, -1, h.shape[-1]
588
+ )
589
+
590
+ reduced_h = torch.gather(
591
+ h,
592
+ dim=1,
593
+ index=index,
594
+ )
595
+
596
+ return reduced_h
597
+
598
+ def forward(
599
+ self,
600
+ input_ids,
601
+ true_boundary_mask: Optional[torch.Tensor] = None,
602
+ boundary_state: Optional[MaskState] = None,
603
+ pad_state: Optional[MaskState] = None,
604
+ expanded_input_ids: Optional[torch.Tensor] = None,
605
+ sequence_start_indices: Optional[torch.Tensor] = None,
606
+ ):
607
+ embeddings = self._embed(input_ids, expanded_input_ids)
608
+
609
+ # pass through encoder layers
610
+ if self.has_cache and self.cache_seqlens > 0:
611
+ assert pad_state is not None
612
+
613
+ # step those batch positions which are not currently idle (i.e. at a boundary position)
614
+ # if all batch positions are idle, skip the step entirely
615
+ # all positions being idle only happens if fuse_boundaries=False. In this case, the step where we
616
+ # obtain a new representation from the global model will have all positions for the local encoder being idle.
617
+ if not pad_state.all():
618
+ h = pad_state.selective_get(embeddings, inv=True)
619
+
620
+ for i, block in enumerate(self.layers):
621
+ h = block(h, past_key_values=self.layer_states[i], use_cache=True, cache_mask=pad_state)
622
+
623
+ if self.post_last_block_norm is not None:
624
+ h = self.post_last_block_norm(h)
625
+
626
+ pad_state.selective_put(h[:, -1, :], self.last_h, inv=True)
627
+
628
+ h = self.last_h.unsqueeze(1)
629
+ else:
630
+ h = embeddings
631
+ for i, block in enumerate(self.layers):
632
+ if self.has_cache:
633
+ use_cache = True
634
+ past_key_values = self.layer_states[i]
635
+ else:
636
+ use_cache = False
637
+ past_key_values = None
638
+
639
+ h = block(h, past_key_values=past_key_values, use_cache=use_cache, sequence_start_indices=sequence_start_indices)
640
+
641
+ if self.post_last_block_norm is not None:
642
+ h = self.post_last_block_norm(h)
643
+
644
+ if self.has_cache:
645
+ self.last_h.copy_(h[:, -1, :])
646
+
647
+ if not self.has_cache or self.cache_seqlens == 0: # only used for prefill
648
+ boundary_logprobs, boundary_mask = self.boundary_predictor_module(
649
+ h,
650
+ sequence_start_indices=sequence_start_indices,
651
+ )
652
+ if boundary_state is not None:
653
+ # can't predict through encoder - must be through prev local decoder step
654
+ boundary_mask[:, -1] = boundary_state.mask
655
+ else:
656
+ boundary_logprobs = boundary_mask = None
657
+
658
+ # overwrite with true boundaries
659
+ if true_boundary_mask is not None:
660
+ boundary_mask = true_boundary_mask
661
+
662
+ patch_embeddings = self._pool(
663
+ h=h,
664
+ boundary_mask=boundary_mask,
665
+ n_patches=int(cast(torch.Tensor, boundary_mask).sum(-1).max().item()) if boundary_mask is not None else 1,
666
+ boundary_state=boundary_state,
667
+ )
668
+ patch_embeddings = self.out_projection(patch_embeddings)
669
+
670
+ if self.has_cache:
671
+ self.cache_seqlens += input_ids.shape[1]
672
+
673
+ return h, patch_embeddings, boundary_logprobs, boundary_mask
674
+
675
+
676
+ class BolmoLocalDecoder(nn.Module):
677
+ def __init__(self, config: BolmoConfig):
678
+ super().__init__()
679
+ self.config = config
680
+ self.hidden_size = config.hidden_size
681
+
682
+ self.initial_norm = BolmoRMSNorm(
683
+ self.hidden_size,
684
+ eps=config.local_rms_norm_eps,
685
+ )
686
+
687
+ self.in_projection = nn.Linear(
688
+ self.hidden_size,
689
+ self.hidden_size,
690
+ bias=True,
691
+ )
692
+
693
+ self.layers = nn.ModuleList(
694
+ [BolmoLocalLayer(config) for _ in range(config.num_local_decoder_layers)]
695
+ )
696
+
697
+ self.has_cache = False
698
+
699
+ def prepare_inference_cache(self, batch_size: int):
700
+ device = next(self.parameters()).device
701
+ self.has_cache = True
702
+
703
+ self.cache_seqlens = 0
704
+ self.last_value = torch.zeros((batch_size, self.hidden_size), dtype=self.in_projection.weight.dtype, device=device)
705
+ self.layer_states = [{"xlstm": {}} for _ in range(len(self.layers))]
706
+
707
+ def free_inference_cache(self):
708
+ self.has_cache = False
709
+ if hasattr(self, "cache_seqlens"):
710
+ del self.cache_seqlens
711
+ if hasattr(self, "last_value"):
712
+ del self.last_value
713
+ if hasattr(self, "layer_states"):
714
+ del self.layer_states
715
+
716
+ def _depool(
717
+ self,
718
+ embeds: torch.Tensor,
719
+ patch_embeds: torch.Tensor,
720
+ boundary_mask: Optional[torch.Tensor],
721
+ boundary_state: Optional[MaskState] = None,
722
+ sequence_start_indices: Optional[torch.Tensor] = None,
723
+ ) -> torch.Tensor:
724
+ if self.has_cache and self.cache_seqlens > 0:
725
+ assert boundary_state is not None
726
+
727
+ if patch_embeds.numel() > 0:
728
+ # we got a new value from the global model, so must be at boundary position
729
+ h_patch = patch_embeds[:, -1:, :]
730
+ h = embeds + h_patch
731
+
732
+ self.last_value.copy_(h_patch[:, -1])
733
+ else:
734
+ h = embeds + self.last_value.unsqueeze(1)
735
+
736
+ # skip pad positions until we get a new value from the global model
737
+ if patch_embeds.numel() == 0:
738
+ h = boundary_state.selective_get(h, inv=True)
739
+ else:
740
+ boundary_state = None
741
+
742
+ if h.shape[0] > 0:
743
+ for i, layer in enumerate(self.layers):
744
+ h = layer(h, past_key_values=self.layer_states[i], use_cache=True, cache_mask=boundary_state)
745
+
746
+ self.cache_seqlens += h.shape[1]
747
+
748
+ return h
749
+ else:
750
+ assert boundary_mask is not None
751
+
752
+ h_patch = patch_embeds
753
+ prepool_out = h_patch
754
+
755
+ # TODO(benjaminm): clipping is problematic if it happens too much; track clip %.
756
+ plug_back_idx = (torch.cumsum(boundary_mask, dim=1) - 1).clip(min=0, max=prepool_out.shape[1] - 1)
757
+ depool_out = torch.gather(
758
+ prepool_out,
759
+ dim=1,
760
+ index=plug_back_idx.unsqueeze(-1).expand(-1, -1, self.hidden_size),
761
+ )
762
+
763
+ depool_out_modulated = depool_out
764
+ h = depool_out_modulated + embeds
765
+
766
+ for i, layer in enumerate(self.layers):
767
+ if self.has_cache:
768
+ use_cache = True
769
+ past_key_values = self.layer_states[i]
770
+ else:
771
+ use_cache = False
772
+ past_key_values = None
773
+
774
+ h = layer(h, past_key_values=past_key_values, use_cache=use_cache, sequence_start_indices=sequence_start_indices)
775
+
776
+ if self.has_cache:
777
+ self.last_value.copy_(prepool_out[:, -1])
778
+ self.cache_seqlens += h.shape[1]
779
+
780
+ return h
781
+
782
+ def forward(
783
+ self,
784
+ embeds: torch.Tensor,
785
+ patch_embeds: torch.Tensor,
786
+ boundary_state: Optional[MaskState],
787
+ boundary_mask: torch.Tensor | None,
788
+ sequence_start_indices: Optional[torch.Tensor] = None,
789
+ ) -> torch.Tensor:
790
+ h = self.in_projection(embeds)
791
+ h_patch = self.initial_norm(patch_embeds)
792
+
793
+ return self._depool(
794
+ embeds=h,
795
+ patch_embeds=h_patch,
796
+ boundary_mask=boundary_mask,
797
+ boundary_state=boundary_state,
798
+ sequence_start_indices=sequence_start_indices,
799
+ )
800
+
801
+
802
+ class BolmoRotaryEmbedding(nn.Module):
803
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
804
+
805
+ def __init__(self, config: BolmoConfig, device=None, rope_type: Optional[str] = None):
806
+ super().__init__()
807
+ if rope_type is not None:
808
+ self.rope_type = rope_type
809
+ elif hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
810
+ # BC: "rope_type" was originally "type"
811
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
812
+ else:
813
+ self.rope_type = "default"
814
+ assert self.rope_type is not None
815
+
816
+ self.max_seq_len_cached = config.max_position_embeddings
817
+ self.original_max_seq_len = config.max_position_embeddings
818
+
819
+ self.config = config
820
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
821
+
822
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
823
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
824
+ self.original_inv_freq = self.inv_freq
825
+
826
+ @torch.no_grad()
827
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
828
+ def forward(self, x, position_ids):
829
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
830
+ position_ids_expanded = position_ids[:, None, :].float()
831
+
832
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
833
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
834
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
835
+ emb = torch.cat((freqs, freqs), dim=-1)
836
+ cos = emb.cos() * self.attention_scaling
837
+ sin = emb.sin() * self.attention_scaling
838
+ return cos, sin
839
+
840
+
841
+ class BolmoPreTrainedModel(PreTrainedModel):
842
+ config: BolmoConfig
843
+ base_model_prefix = "model"
844
+ supports_gradient_checkpointing = True
845
+ _no_split_modules = ["BolmoDecoderLayer"]
846
+ _skip_keys_device_placement = ["past_key_values"]
847
+ _supports_flash_attn = True
848
+ _supports_sdpa = True
849
+ _supports_flex_attn = True
850
+
851
+ _can_compile_fullgraph = True
852
+ _supports_attention_backend = True
853
+ _can_record_outputs = {
854
+ "hidden_states": BolmoDecoderLayer,
855
+ "attentions": BolmoAttention,
856
+ }
857
+
858
+
859
+ class BolmoModel(BolmoPreTrainedModel):
860
+ def __init__(self, config: BolmoConfig):
861
+ super().__init__(config)
862
+ self.padding_idx = config.pad_token_id
863
+ self.vocab_size = config.vocab_size
864
+
865
+ self.local_encoder = BolmoLocalEncoder(config)
866
+ self.local_decoder = BolmoLocalDecoder(config)
867
+
868
+ self.layers = nn.ModuleList(
869
+ [BolmoDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
870
+ )
871
+ self.norm = BolmoRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
872
+ self.gradient_checkpointing = False
873
+ self.rotary_embs = nn.ModuleDict(
874
+ {
875
+ "sliding_attention": BolmoRotaryEmbedding(config=config, rope_type="default"),
876
+ "full_attention": BolmoRotaryEmbedding(config=config),
877
+ }
878
+ )
879
+
880
+ self.tokenizer_config = BolmoTokenizerConfig(**config.tokenizer_config)
881
+ self._tokenizer = None
882
+
883
+ # Initialize weights and apply final processing
884
+ self.post_init()
885
+
886
+ def get_input_embeddings(self):
887
+ return self.local_encoder.byte_embedding
888
+
889
+ def set_input_embeddings(self, value: nn.Embedding): # type: ignore
890
+ self.local_encoder.byte_embedding = value
891
+
892
+ @property
893
+ def tokenizer(self):
894
+ if self._tokenizer is None:
895
+ self._tokenizer = self.tokenizer_config.build()
896
+
897
+ return self._tokenizer
898
+
899
+ def prefill_boundary_prediction_forward(
900
+ self,
901
+ input_ids: torch.Tensor,
902
+ expanded_input_ids: Optional[torch.Tensor] = None,
903
+ sequence_start_indices: Optional[torch.Tensor] = None,
904
+ last_token_is_boundary: bool = False,
905
+ **kwargs,
906
+ ) -> torch.Tensor:
907
+ _, _, _, boundary_mask = self.local_encoder.forward( # type: ignore
908
+ input_ids,
909
+ expanded_input_ids=expanded_input_ids,
910
+ boundary_state=MaskState(torch.full((input_ids.shape[0],), fill_value=last_token_is_boundary, device=input_ids.device, dtype=torch.bool)),
911
+ pad_state=MaskState(torch.zeros((input_ids.shape[0],), device=input_ids.device, dtype=torch.bool)),
912
+ sequence_start_indices=sequence_start_indices,
913
+ )
914
+
915
+ return cast(torch.Tensor, boundary_mask)
916
+
917
+ @check_model_inputs()
918
+ def forward(
919
+ self,
920
+ input_ids: torch.Tensor,
921
+ expanded_input_ids: Optional[torch.Tensor] = None,
922
+ attention_mask: Optional[torch.Tensor] = None,
923
+ position_ids: Optional[torch.Tensor] = None,
924
+ past_key_values: Optional[Cache] = None,
925
+ cache_position: Optional[torch.Tensor] = None,
926
+ use_cache: Optional[bool] = None,
927
+ boundary_mask: Optional[torch.Tensor] = None,
928
+ boundary_state: Optional[MaskState] = None,
929
+ pad_state: Optional[MaskState] = None,
930
+ sequence_start_indices: Optional[torch.Tensor] = None,
931
+ **kwargs: Unpack[TransformersKwargs],
932
+ ) -> BaseModelOutputWithPast:
933
+ batch_size = input_ids.shape[0]
934
+ device = input_ids.device
935
+
936
+ if self.local_encoder.add_expanded_embeddings and expanded_input_ids is None and input_ids is not None:
937
+ # not optimized
938
+ expanded_input_ids_list: list[torch.Tensor] = []
939
+ for example_idx in range(batch_size):
940
+ expanded_input_ids_list.append(torch.tensor(self.tokenizer.expand_byte_ids(input_ids[example_idx].tolist()), dtype=torch.long, device=device))
941
+ expanded_input_ids = pad_right(expanded_input_ids_list, value=self.tokenizer.pad_token_id, multiple_of=1) # type: ignore
942
+
943
+ h_byte, h_patch, _, boundary_mask = self.local_encoder(
944
+ input_ids=input_ids,
945
+ expanded_input_ids=expanded_input_ids,
946
+ true_boundary_mask=boundary_mask,
947
+ boundary_state=boundary_state,
948
+ pad_state=pad_state,
949
+ )
950
+
951
+ if use_cache and past_key_values is None:
952
+ past_key_values = DynamicCache(config=self.config)
953
+
954
+ if cache_position is None:
955
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
956
+ cache_position: torch.Tensor = torch.arange(
957
+ past_seen_tokens, past_seen_tokens + h_patch.shape[1], device=device
958
+ )
959
+
960
+ if position_ids is None:
961
+ position_ids = cache_position.unsqueeze(0) # type: ignore
962
+
963
+ # It may already have been prepared by e.g. `generate`
964
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
965
+ # Prepare mask arguments
966
+ mask_kwargs = {
967
+ "config": self.config,
968
+ "input_embeds": h_patch,
969
+ "attention_mask": attention_mask,
970
+ "cache_position": cache_position,
971
+ "past_key_values": past_key_values,
972
+ "position_ids": position_ids,
973
+ }
974
+ # Create the masks
975
+ causal_mask_mapping = {
976
+ "full_attention": create_causal_mask(**mask_kwargs),
977
+ "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
978
+ }
979
+
980
+ position_embeddings_mapping = {
981
+ "sliding_attention": self.rotary_embs["sliding_attention"](h_byte, position_ids),
982
+ "full_attention": self.rotary_embs["full_attention"](h_byte, position_ids),
983
+ }
984
+
985
+ if h_patch.numel() > 0:
986
+ # we need to convert from right-pad to left-pad and back for prefill
987
+ # since flash attention expects left-pad and local/enc dec expect right-pad global tokens
988
+ # should add better left-pad support but this only affects prefill so OK for now
989
+ # although super inefficient!
990
+ if boundary_mask is not None: # prefill
991
+ n_boundaries = boundary_mask.sum(-1)
992
+
993
+ for i, current_n_boundaries in enumerate(n_boundaries):
994
+ h_patch[i, -current_n_boundaries:] = h_patch[i, :current_n_boundaries].clone()
995
+
996
+ h_patch_after_global = h_patch
997
+
998
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
999
+ h_patch_after_global = decoder_layer(
1000
+ h_patch_after_global,
1001
+ attention_mask=causal_mask_mapping[decoder_layer.self_attn.attention_type],
1002
+ position_ids=position_ids,
1003
+ past_key_values=past_key_values,
1004
+ cache_position=cache_position,
1005
+ position_embeddings=position_embeddings_mapping[decoder_layer.self_attn.attention_type],
1006
+ **kwargs,
1007
+ )
1008
+
1009
+ if boundary_mask is not None: # prefill
1010
+ n_boundaries = boundary_mask.sum(-1)
1011
+
1012
+ for i, current_n_boundaries in enumerate(n_boundaries):
1013
+ h_patch_after_global[i, :current_n_boundaries] = h_patch_after_global[i, -current_n_boundaries:].clone()
1014
+ else:
1015
+ h_patch_after_global = h_patch
1016
+
1017
+ h_out = self.local_decoder.forward( # type: ignore
1018
+ embeds=h_byte,
1019
+ patch_embeds=h_patch_after_global,
1020
+ boundary_mask=boundary_mask,
1021
+ boundary_state=boundary_state,
1022
+ sequence_start_indices=sequence_start_indices,
1023
+ )
1024
+ h_out = self.norm(h_out)
1025
+
1026
+ return BaseModelOutputWithPast(
1027
+ last_hidden_state=h_out,
1028
+ past_key_values=past_key_values,
1029
+ )
1030
+
1031
+
1032
+ class BolmoForCausalLM(BolmoPreTrainedModel, GenerationMixin):
1033
+ _tied_weights_keys = ["lm_head.weight"]
1034
+ _tp_plan = {"lm_head": "colwise_rep"}
1035
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
1036
+
1037
+ def __init__(self, config):
1038
+ super().__init__(config)
1039
+ self.model = BolmoModel(config)
1040
+ self.vocab_size = config.vocab_size
1041
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1042
+
1043
+ # Initialize weights and apply final processing
1044
+ self.post_init()
1045
+
1046
+ def get_output_embeddings(self):
1047
+ return self.lm_head
1048
+
1049
+ def set_output_embeddings(self, new_embeddings: nn.Linear):
1050
+ self.lm_head = new_embeddings
1051
+
1052
+ @can_return_tuple
1053
+ def forward(
1054
+ self,
1055
+ input_ids: torch.Tensor,
1056
+ expanded_input_ids: Optional[torch.Tensor] = None,
1057
+ attention_mask: Optional[torch.Tensor] = None,
1058
+ position_ids: Optional[torch.Tensor] = None,
1059
+ past_key_values: Optional[Cache] = None,
1060
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1061
+ cache_position: Optional[torch.Tensor] = None,
1062
+ use_cache: Optional[bool] = None,
1063
+ boundary_mask: Optional[torch.Tensor] = None,
1064
+ boundary_state: Optional[MaskState] = None,
1065
+ pad_state: Optional[MaskState] = None,
1066
+ sequence_start_indices: Optional[torch.Tensor] = None,
1067
+ logits_to_keep: Union[int, torch.Tensor] = 0,
1068
+ **kwargs: Unpack[TransformersKwargs],
1069
+ ) -> CausalLMOutputWithPast:
1070
+ r"""
1071
+ Example:
1072
+
1073
+ ```python
1074
+ >>> from transformers import AutoTokenizer, BolmoForCausalLM
1075
+
1076
+ >>> model = BolmoForCausalLM.from_pretrained("meta-olmo3/Bolmo-2-7b-hf")
1077
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-olmo3/Bolmo-2-7b-hf")
1078
+
1079
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1080
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1081
+
1082
+ >>> # Generate
1083
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1084
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1085
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1086
+ ```"""
1087
+ outputs: BaseModelOutputWithPast = self.model(
1088
+ input_ids=input_ids,
1089
+ expanded_input_ids=expanded_input_ids,
1090
+ attention_mask=attention_mask,
1091
+ position_ids=position_ids,
1092
+ past_key_values=past_key_values,
1093
+ inputs_embeds=inputs_embeds,
1094
+ cache_position=cache_position,
1095
+ use_cache=use_cache,
1096
+ boundary_mask=boundary_mask,
1097
+ boundary_state=boundary_state,
1098
+ pad_state=pad_state,
1099
+ sequence_start_indices=sequence_start_indices,
1100
+ **kwargs,
1101
+ )
1102
+
1103
+ hidden_states = cast(torch.Tensor, outputs.last_hidden_state)
1104
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1105
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1106
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
1107
+
1108
+ return CausalLMOutputWithPast(
1109
+ logits=logits,
1110
+ past_key_values=outputs.past_key_values,
1111
+ hidden_states=outputs.hidden_states,
1112
+ attentions=outputs.attentions,
1113
+ )
1114
+
1115
+ @torch.no_grad()
1116
+ def generate( # type: ignore
1117
+ self,
1118
+ inputs: torch.Tensor,
1119
+ generation_config: Optional[GenerationConfig] = None,
1120
+ logits_processor: Optional[LogitsProcessorList] = None,
1121
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
1122
+ use_model_defaults: Optional[bool] = None,
1123
+ **kwargs,
1124
+ ) -> Union[GenerateOutput, torch.Tensor]:
1125
+ # generic preprocessing
1126
+
1127
+ generation_config, model_kwargs = self._prepare_generation_config(
1128
+ generation_config, use_model_defaults, **kwargs
1129
+ )
1130
+ self._prepare_special_tokens(generation_config, device=self.model.device)
1131
+
1132
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
1133
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
1134
+
1135
+ # start of custom generate
1136
+
1137
+ expand_input_ids = self.model.local_encoder.add_expanded_embeddings
1138
+ batch_size = len(inputs)
1139
+
1140
+ if expand_input_ids:
1141
+ expanded_input_ids = []
1142
+
1143
+ for i in range(len(inputs)):
1144
+ expanded_input_ids.append(torch.tensor(self.model.tokenizer.expand_byte_ids(inputs[i].tolist()), device=self.device, dtype=torch.long))
1145
+
1146
+ expanded_input_ids = pad_left(expanded_input_ids, value=self.model.tokenizer.pad_token_id, multiple_of=1) # type: ignore
1147
+ else:
1148
+ expanded_input_ids = None
1149
+
1150
+ byte_input_ids = inputs
1151
+ sequence_start_indices = (byte_input_ids == self.model.tokenizer.pad_token_id).sum(-1)
1152
+ batch_size, prompt_len = byte_input_ids.shape
1153
+ finished = torch.zeros(batch_size, dtype=torch.bool, device=self.device)
1154
+
1155
+ boundary_offset = self.model.tokenizer.offset + 256
1156
+ eos = self.model.tokenizer.eos_token_id
1157
+
1158
+ self.model.local_encoder.free_inference_cache()
1159
+ self.model.local_decoder.free_inference_cache()
1160
+
1161
+ boundary_mask = self.model.prefill_boundary_prediction_forward( # type: ignore
1162
+ byte_input_ids,
1163
+ expanded_input_ids=expanded_input_ids,
1164
+ sequence_start_indices=sequence_start_indices,
1165
+ )
1166
+
1167
+ self.model.local_encoder.prepare_inference_cache(batch_size)
1168
+ self.model.local_decoder.prepare_inference_cache(batch_size)
1169
+
1170
+ # roll back by one and force decoding to account for lookahead
1171
+ boundary_mask = boundary_mask[:, :-1]
1172
+ # need to roll one byte back and force decoding to detect whether the last byte is a boundary
1173
+ forced_decoding_ids = byte_input_ids[:, -1].cpu().tolist()
1174
+ byte_input_ids = byte_input_ids[:, :-1]
1175
+ expanded_input_ids = expanded_input_ids[:, :-1] if expanded_input_ids is not None else None
1176
+ # stays the same unless last token is pad.
1177
+ sequence_start_indices = (byte_input_ids == self.model.tokenizer.pad_token_id).sum(-1)
1178
+
1179
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
1180
+ has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
1181
+ generation_config = self._prepare_generated_length(
1182
+ generation_config=generation_config,
1183
+ has_default_max_length=has_default_max_length,
1184
+ has_default_min_length=has_default_min_length,
1185
+ model_input_name="input_ids",
1186
+ inputs_tensor=byte_input_ids,
1187
+ input_ids_length=byte_input_ids.shape[1],
1188
+ )
1189
+
1190
+ logits_processor = self._get_logits_processor(
1191
+ generation_config=generation_config, # type: ignore
1192
+ input_ids_seq_length=byte_input_ids.shape[1],
1193
+ encoder_input_ids=byte_input_ids, # type: ignore
1194
+ logits_processor=logits_processor,
1195
+ device=byte_input_ids.device, # type: ignore
1196
+ model_kwargs=model_kwargs,
1197
+ )
1198
+ stopping_criteria = self._get_stopping_criteria(
1199
+ generation_config=generation_config, # type: ignore
1200
+ stopping_criteria=stopping_criteria,
1201
+ tokenizer=self.model.tokenizer,
1202
+ )
1203
+
1204
+ # output container
1205
+ generated = byte_input_ids
1206
+
1207
+ max_n_prefill_patches = boundary_mask.sum(-1).max().item()
1208
+ tokens_generated_plus_prefilled = max_n_prefill_patches
1209
+ bytes_generated = 0
1210
+
1211
+ # generation state
1212
+ boundary_state = MaskState(boundary_mask[:, -1].clone())
1213
+ pad_state = MaskState(torch.zeros(batch_size, dtype=torch.bool, device=self.device))
1214
+ next_tokens = torch.full((batch_size,), self.model.tokenizer.bpe_token_end_id, device=self.device, dtype=torch.long) # type: ignore
1215
+ non_boundary_generated_tokens = [[byte_input_ids[example_idx, -1].item()] for example_idx in range(batch_size)]
1216
+ bytes_since_boundary = (boundary_mask.flip(1).cumsum(-1) == 0).sum(-1)
1217
+ is_first_forward = True
1218
+ global_past_key_values = None
1219
+
1220
+ while not finished.all():
1221
+ input_ids_for_model = (
1222
+ generated
1223
+ if is_first_forward
1224
+ else torch.tensor([x[-1] for x in non_boundary_generated_tokens], device=generated.device, dtype=generated.dtype).unsqueeze(1)
1225
+ )
1226
+ assert not (
1227
+ (input_ids_for_model == self.model.tokenizer.bpe_token_end_id) |
1228
+ (input_ids_for_model >= boundary_offset)
1229
+ ).any().item() # type: ignore
1230
+ if expand_input_ids:
1231
+ expanded_input_ids_for_model = torch.zeros_like(input_ids_for_model)
1232
+ for i in range(input_ids_for_model.shape[0]):
1233
+ expanded_input_ids_for_model[i, :] = torch.tensor(self.model.tokenizer.expand_byte_ids(
1234
+ generated[i, :].tolist(),
1235
+ n_last=input_ids_for_model.shape[1],
1236
+ ), device=expanded_input_ids_for_model.device, dtype=expanded_input_ids_for_model.dtype)
1237
+ else:
1238
+ expanded_input_ids_for_model = None
1239
+
1240
+ out = self.forward( # type: ignore
1241
+ input_ids_for_model,
1242
+ expanded_input_ids=expanded_input_ids_for_model,
1243
+ boundary_mask=boundary_mask if is_first_forward else None,
1244
+ boundary_state=boundary_state,
1245
+ pad_state=pad_state,
1246
+ sequence_start_indices=sequence_start_indices,
1247
+ logits_to_keep=1,
1248
+ use_cache=True,
1249
+ past_key_values=global_past_key_values,
1250
+ )
1251
+ next_token_logits = cast(torch.Tensor, out.logits)
1252
+ global_past_key_values = out.past_key_values
1253
+
1254
+ if boundary_state.all():
1255
+ # new token, must not be boundary
1256
+ bytes_since_boundary[:] = 0
1257
+ else:
1258
+ boundary_state.selective_add(1, bytes_since_boundary, inv=True)
1259
+
1260
+ if any(x is not None for x in forced_decoding_ids):
1261
+ # only supported for the first token atm, so len(next_token_logits) == batch_size
1262
+ assert len(next_token_logits) == batch_size and is_first_forward
1263
+ for example_idx in range(batch_size):
1264
+ forced_decoding_id = forced_decoding_ids[example_idx]
1265
+
1266
+ if forced_decoding_id is not None:
1267
+ no_boundary_logit = next_token_logits[example_idx, 0, forced_decoding_id].item()
1268
+ boundary_logit = next_token_logits[example_idx, 0, forced_decoding_id + boundary_offset].item()
1269
+
1270
+ next_token_logits[example_idx, 0, :] = -100_000
1271
+ next_token_logits[example_idx, 0, forced_decoding_id] = no_boundary_logit
1272
+ next_token_logits[example_idx, 0, forced_decoding_id + boundary_offset] = boundary_logit
1273
+
1274
+ forced_decoding_ids[example_idx] = None # only force once
1275
+
1276
+ # passing input_ids to logit processor not implemented
1277
+ next_token_scores = logits_processor(None, next_token_logits[:, -1]) # type: ignore
1278
+
1279
+ if generation_config is not None and generation_config.do_sample:
1280
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
1281
+ new_next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
1282
+ else:
1283
+ new_next_tokens = torch.argmax(next_token_scores, dim=-1)
1284
+
1285
+ if boundary_state.all() or is_first_forward:
1286
+ tokens_generated_plus_prefilled += 1
1287
+
1288
+ next_tokens = new_next_tokens
1289
+ next_tokens_cpu = next_tokens.cpu()
1290
+ for example_idx in range(batch_size):
1291
+ if finished[example_idx].item():
1292
+ continue
1293
+
1294
+ next_token_cpu = next_tokens_cpu[example_idx].item()
1295
+
1296
+ if next_token_cpu >= boundary_offset:
1297
+ next_token_cpu -= boundary_offset
1298
+
1299
+ non_boundary_generated_tokens[example_idx].append(next_token_cpu)
1300
+ else:
1301
+ next_tokens[:] = self.model.tokenizer.bpe_token_end_id # type: ignore
1302
+ boundary_state.selective_put(new_next_tokens, next_tokens, inv=True)
1303
+ next_tokens_cpu = next_tokens.cpu()
1304
+
1305
+ for example_idx in range(batch_size):
1306
+ if finished[example_idx].item():
1307
+ continue
1308
+
1309
+ next_token_cpu = next_tokens_cpu[example_idx].item()
1310
+
1311
+ if not boundary_state.cpu_mask[example_idx].item():
1312
+ if next_token_cpu >= boundary_offset:
1313
+ next_token_cpu -= boundary_offset
1314
+
1315
+ non_boundary_generated_tokens[example_idx].append(next_token_cpu)
1316
+
1317
+ is_first_forward = False
1318
+
1319
+ boundary_state = MaskState(
1320
+ (next_tokens == self.model.tokenizer.bpe_token_end_id) |
1321
+ (next_tokens >= boundary_offset) |
1322
+ finished
1323
+ ) # type: ignore
1324
+ pad_state = MaskState(
1325
+ (next_tokens == self.model.tokenizer.bpe_token_end_id) |
1326
+ finished
1327
+ )
1328
+
1329
+ # Force EOS for (previously) finished sequences
1330
+ next_tokens = torch.where(finished, torch.full_like(next_tokens, eos), next_tokens)
1331
+
1332
+ # Append next tokens
1333
+ generated = torch.cat([generated, next_tokens.unsqueeze(-1)], dim=1)
1334
+
1335
+ # Handle finished sequences
1336
+ stop_hit = next_tokens.eq(eos) | next_tokens.eq(eos + boundary_offset)
1337
+
1338
+ for i in range(batch_size):
1339
+ # passing `scores` to stopping criteria not implemented
1340
+ if stopping_criteria(torch.tensor(non_boundary_generated_tokens[i], dtype=torch.long).unsqueeze(0), None).squeeze(0).item(): # type: ignore
1341
+ stop_hit[i] = True
1342
+
1343
+ finished |= stop_hit
1344
+ bytes_generated += 1
1345
+
1346
+ return pad_left([
1347
+ torch.cat([byte_input_ids[i, :-1], torch.tensor(x, dtype=torch.long, device=byte_input_ids.device)])
1348
+ for i, x in enumerate(non_boundary_generated_tokens)
1349
+ ], value=self.model.tokenizer.pad_token_id, multiple_of=1) # type: ignore
1350
+
1351
+ __all__ = ["BolmoForCausalLM", "BolmoModel", "BolmoPreTrainedModel"]
recipe.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ default_stage:
2
+ default_modifiers:
3
+ QuantizationModifier:
4
+ targets: [Linear]
5
+ ignore: [lm_head, 're:visual.*', 're:.*vision_tower.*', 're:.*video_tower.*', 're:.*audio_tower.*',
6
+ 're:.*multi_modal_projector.*']
7
+ scheme: NVFP4
special_tokens_map.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<bos>",
3
+ "eos_token": "<bos>",
4
+ "pad_token": "<pad>"
5
+ }
tokenization_bolmo.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from functools import lru_cache
3
+ from typing import Optional, Union
4
+ from transformers import AutoTokenizer
5
+ from transformers.tokenization_utils import PreTrainedTokenizer
6
+
7
+ # Source: https://github.com/openai/gpt-2/blob/master/src/encoder.py#L9
8
+ # Also implemented in https://docs.rs/tokenizers/latest/src/tokenizers/pre_tokenizers/byte_level.rs.html#13-39
9
+ _CHARS_TO_BYTES = {
10
+ "Ā": 0, "ā": 1, "Ă": 2, "ă": 3, "Ą": 4, "ą": 5, "Ć": 6, "ć": 7, "Ĉ": 8,
11
+ "ĉ": 9, "Ċ": 10, "ċ": 11, "Č": 12, "č": 13, "Ď": 14, "ď": 15, "Đ": 16,
12
+ "đ": 17, "Ē": 18, "ē": 19, "Ĕ": 20, "ĕ": 21, "Ė": 22, "ė": 23, "Ę": 24,
13
+ "ę": 25, "Ě": 26, "ě": 27, "Ĝ": 28, "ĝ": 29, "Ğ": 30, "ğ": 31, "Ġ": 32,
14
+ "!": 33, '"': 34, "#": 35, "$": 36, "%": 37, "&": 38, "'": 39, "(": 40,
15
+ ")": 41, "*": 42, "+": 43, ",": 44, "-": 45, ".": 46, "/": 47, "0": 48,
16
+ "1": 49, "2": 50, "3": 51, "4": 52, "5": 53, "6": 54, "7": 55, "8": 56,
17
+ "9": 57, ":": 58, ";": 59, "<": 60, "=": 61, ">": 62, "?": 63, "@": 64,
18
+ "A": 65, "B": 66, "C": 67, "D": 68, "E": 69, "F": 70, "G": 71, "H": 72,
19
+ "I": 73, "J": 74, "K": 75, "L": 76, "M": 77, "N": 78, "O": 79, "P": 80,
20
+ "Q": 81, "R": 82, "S": 83, "T": 84, "U": 85, "V": 86, "W": 87, "X": 88,
21
+ "Y": 89, "Z": 90, "[": 91, "\\": 92, "]": 93, "^": 94, "_": 95, "`": 96,
22
+ "a": 97, "b": 98, "c": 99, "d": 100, "e": 101, "f": 102, "g": 103,
23
+ "h": 104, "i": 105, "j": 106, "k": 107, "l": 108, "m": 109, "n": 110,
24
+ "o": 111, "p": 112, "q": 113, "r": 114, "s": 115, "t": 116, "u": 117,
25
+ "v": 118, "w": 119, "x": 120, "y": 121, "z": 122, "{": 123, "|": 124,
26
+ "}": 125, "~": 126, "ġ": 127, "Ģ": 128, "ģ": 129, "Ĥ": 130, "ĥ": 131,
27
+ "Ħ": 132, "ħ": 133, "Ĩ": 134, "ĩ": 135, "Ī": 136, "ī": 137, "Ĭ": 138,
28
+ "ĭ": 139, "Į": 140, "į": 141, "İ": 142, "ı": 143, "IJ": 144, "ij": 145,
29
+ "Ĵ": 146, "ĵ": 147, "Ķ": 148, "ķ": 149, "ĸ": 150, "Ĺ": 151, "ĺ": 152,
30
+ "Ļ": 153, "ļ": 154, "Ľ": 155, "ľ": 156, "Ŀ": 157, "ŀ": 158, "Ł": 159,
31
+ "ł": 160, "¡": 161, "¢": 162, "£": 163, "¤": 164, "¥": 165, "¦": 166,
32
+ "§": 167, "¨": 168, "©": 169, "ª": 170, "«": 171, "¬": 172, "Ń": 173,
33
+ "®": 174, "¯": 175, "°": 176, "±": 177, "²": 178, "³": 179, "´": 180,
34
+ "µ": 181, "¶": 182, "·": 183, "¸": 184, "¹": 185, "º": 186, "»": 187,
35
+ "¼": 188, "½": 189, "¾": 190, "¿": 191, "À": 192, "Á": 193, "Â": 194,
36
+ "Ã": 195, "Ä": 196, "Å": 197, "Æ": 198, "Ç": 199, "È": 200, "É": 201,
37
+ "Ê": 202, "Ë": 203, "Ì": 204, "Í": 205, "Î": 206, "Ï": 207, "Ð": 208,
38
+ "Ñ": 209, "Ò": 210, "Ó": 211, "Ô": 212, "Õ": 213, "Ö": 214, "×": 215,
39
+ "Ø": 216, "Ù": 217, "Ú": 218, "Û": 219, "Ü": 220, "Ý": 221, "Þ": 222,
40
+ "ß": 223, "à": 224, "á": 225, "â": 226, "ã": 227, "ä": 228, "å": 229,
41
+ "æ": 230, "ç": 231, "è": 232, "é": 233, "ê": 234, "ë": 235, "ì": 236,
42
+ "í": 237, "î": 238, "ï": 239, "ð": 240, "ñ": 241, "ò": 242, "ó": 243,
43
+ "ô": 244, "õ": 245, "ö": 246, "÷": 247, "ø": 248, "ù": 249, "ú": 250,
44
+ "û": 251, "ü": 252, "ý": 253, "þ": 254, "ÿ": 255,
45
+ }
46
+ _BYTES_TO_CHARS = {v: k for k, v in _CHARS_TO_BYTES.items()}
47
+
48
+ def _bytes_to_chars(byte_sequence: bytes) -> str:
49
+ return "".join(_BYTES_TO_CHARS[byte] for byte in byte_sequence)
50
+
51
+ def _chars_to_bytes(char_sequence: str) -> list:
52
+ return list(bytes(_CHARS_TO_BYTES[char] for char in char_sequence))
53
+
54
+ @dataclass
55
+ class BolmoTokenizerConfig:
56
+ vocab_size: int
57
+ bos_token_id: int
58
+ pad_token_id: int
59
+ eos_token_id: int
60
+ bpe_token_end_id: int
61
+ special_tokens: list[str] = field(default_factory=lambda: [])
62
+ special_tokens_first: bool = True
63
+ original_identifier: Optional[str] = None
64
+
65
+
66
+ @classmethod
67
+ def bolmo(cls) -> "BolmoTokenizerConfig":
68
+ special_tokens = [
69
+ "<pad>",
70
+ "<bos>",
71
+ "<eos>",
72
+ "<bpe_token_end>",
73
+ ]
74
+
75
+ return cls(
76
+ # *2 to accomodate fused boundary tokens
77
+ vocab_size=(len(special_tokens) + 256) * 2,
78
+ special_tokens=special_tokens,
79
+ bos_token_id=special_tokens.index("<bos>"),
80
+ pad_token_id=special_tokens.index("<pad>"),
81
+ eos_token_id=special_tokens.index("<bos>"),
82
+ bpe_token_end_id=special_tokens.index("<bpe_token_end>"),
83
+ original_identifier="allenai/dolma2-tokenizer",
84
+ )
85
+
86
+ def build(self):
87
+ return BolmoTokenizer(tokenizer_config=self)
88
+
89
+
90
+ class BolmoTokenizer(PreTrainedTokenizer):
91
+ TOKEN_ID_KEY = -1
92
+
93
+ def __init__(self, **kwargs):
94
+ tokenizer_config = kwargs.pop("tokenizer_config", BolmoTokenizerConfig.bolmo())
95
+
96
+ self.config = tokenizer_config
97
+ self.hf_tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.original_identifier)
98
+ if self.config.special_tokens_first:
99
+ self.offset = len(tokenizer_config.special_tokens)
100
+ self.special_tokens_offset = 0
101
+ else:
102
+ self.offset = 0
103
+ self.special_tokens_offset = self.config.vocab_size - len(tokenizer_config.special_tokens)
104
+
105
+ self.byte_sequences = {}
106
+
107
+ for key, value in self.hf_tokenizer.get_vocab().items():
108
+ if key in self.config.special_tokens:
109
+ byte_sequence = [self.special_tokens_offset + self.config.special_tokens.index(key)]
110
+ elif value == self.hf_tokenizer.eos_token_id and self.eos_token_id is not None:
111
+ byte_sequence = [self.eos_token_id]
112
+ elif value == self.hf_tokenizer.bos_token_id and self.bos_token_id is not None:
113
+ byte_sequence = [self.bos_token_id]
114
+ elif value == self.hf_tokenizer.pad_token_id and self.pad_token_id is not None:
115
+ byte_sequence = [self.pad_token_id]
116
+ else:
117
+ byte_sequence = [self.offset + i for i in _chars_to_bytes(key)]
118
+
119
+ assert self.byte_sequences.get(value) is None
120
+ self.byte_sequences[value] = byte_sequence
121
+
122
+ self.byte_trie = {}
123
+
124
+ for token_id, byte_sequence in self.byte_sequences.items():
125
+ current_dict = self.byte_trie
126
+ for byte in byte_sequence[::-1]: # retrieved from the back so store in reverse order
127
+ if byte not in current_dict:
128
+ current_dict[byte] = {}
129
+ current_dict = current_dict[byte]
130
+ current_dict[BolmoTokenizer.TOKEN_ID_KEY] = token_id
131
+
132
+ self.add_bos_token = True
133
+ self.add_eos_token = False
134
+ self.padding_side = "left" # for generate
135
+
136
+ super().__init__(
137
+ bos_token=self.config.special_tokens[self.config.bos_token_id],
138
+ eos_token=self.config.special_tokens[self.config.eos_token_id],
139
+ pad_token=self.config.special_tokens[self.config.pad_token_id],
140
+ extra_ids=0,
141
+ )
142
+
143
+ @property
144
+ def bos_token_id(self):
145
+ return self.config.bos_token_id
146
+
147
+ @property
148
+ def eos_token_id(self):
149
+ return self.config.eos_token_id
150
+
151
+ @property
152
+ def pad_token_id(self):
153
+ return self.config.pad_token_id
154
+
155
+ @property
156
+ def bpe_token_end_id(self):
157
+ return self.config.bpe_token_end_id
158
+
159
+ @property
160
+ def vocab_size(self):
161
+ return self.config.vocab_size
162
+
163
+ def _convert_id_to_token(self, index):
164
+ if index < self.offset:
165
+ return self.config.special_tokens[index - self.special_tokens_offset]
166
+
167
+ if index >= self.offset + 256 and index < self.offset * 2 + 256:
168
+ # special token with fused boundary
169
+ return self.config.special_tokens[index - self.offset - 256] + "b"
170
+
171
+ return _BYTES_TO_CHARS[index - self.offset - 256 - self.offset] + "b" if index >= self.offset + 256 else _BYTES_TO_CHARS[index - self.offset]
172
+
173
+ def _convert_token_to_id(self, token):
174
+ if token in self.config.special_tokens:
175
+ return self.config.special_tokens.index(token)
176
+
177
+ if token in [x + "b" for x in self.config.special_tokens]:
178
+ # special token with fused boundary
179
+ return 256 + self.config.special_tokens.index(token[:-1])
180
+
181
+ if len(token) > 1 and token[-1] == "b":
182
+ return self.offset + 256 + _CHARS_TO_BYTES[token[0]]
183
+ else:
184
+ return self.offset + _CHARS_TO_BYTES[token]
185
+
186
+ def get_vocab(self):
187
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
188
+ return vocab
189
+
190
+ def expand_byte_ids(self, byte_ids: list[int], n_last: Optional[int] = None) -> list[int]:
191
+ # search in the byte tree for the longest matching token at every byte position
192
+ expanded_ids = []
193
+ for i in range(len(byte_ids)):
194
+ if n_last is not None and i < len(byte_ids) - n_last:
195
+ continue
196
+
197
+ current_dict = self.byte_trie
198
+ current_expansion = None
199
+
200
+ for i in range(i, -1, -1):
201
+ byte = byte_ids[i]
202
+
203
+ if byte == self.bpe_token_end_id:
204
+ # skip bpe token end markers, needed for generation
205
+ continue
206
+
207
+ if byte >= self.offset + 256:
208
+ # ignore fused boundary
209
+ byte -= self.offset + 256
210
+
211
+ try:
212
+ current_dict = current_dict[byte]
213
+ if BolmoTokenizer.TOKEN_ID_KEY in current_dict:
214
+ current_expansion = current_dict[BolmoTokenizer.TOKEN_ID_KEY]
215
+ except KeyError:
216
+ assert current_expansion is not None
217
+ break
218
+
219
+ expanded_ids.append(current_expansion)
220
+
221
+ return expanded_ids
222
+
223
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens
224
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
225
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
226
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
227
+
228
+ output = bos_token_id + token_ids_0 + eos_token_id
229
+
230
+ if token_ids_1 is not None:
231
+ output = output + bos_token_id + token_ids_1 + eos_token_id
232
+
233
+ return output
234
+
235
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_special_tokens_mask
236
+ def get_special_tokens_mask(
237
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
238
+ ) -> list[int]:
239
+ """
240
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
241
+ special tokens using the tokenizer `prepare_for_model` method.
242
+ Args:
243
+ token_ids_0 (`List[int]`):
244
+ List of IDs.
245
+ token_ids_1 (`List[int]`, *optional*):
246
+ Optional second list of IDs for sequence pairs.
247
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
248
+ Whether or not the token list is already formatted with special tokens for the model.
249
+ Returns:
250
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
251
+ """
252
+ if already_has_special_tokens:
253
+ return super().get_special_tokens_mask(
254
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
255
+ )
256
+
257
+ bos_token_id = [1] if self.add_bos_token else []
258
+ eos_token_id = [1] if self.add_eos_token else []
259
+
260
+ if token_ids_1 is None:
261
+ return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
262
+ return (
263
+ bos_token_id
264
+ + ([0] * len(token_ids_0))
265
+ + eos_token_id
266
+ + bos_token_id
267
+ + ([0] * len(token_ids_1))
268
+ + eos_token_id
269
+ )
270
+
271
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.create_token_type_ids_from_sequences
272
+ def create_token_type_ids_from_sequences(
273
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
274
+ ) -> list[int]:
275
+ """
276
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
277
+ sequence pair mask has the following format:
278
+ ```
279
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
280
+ | first sequence | second sequence |
281
+ ```
282
+ if token_ids_1 is None, only returns the first portion of the mask (0s).
283
+ Args:
284
+ token_ids_0 (`List[int]`):
285
+ List of ids.
286
+ token_ids_1 (`List[int]`, *optional*):
287
+ Optional second list of IDs for sequence pairs.
288
+ Returns:
289
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
290
+ """
291
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
292
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
293
+
294
+ output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
295
+
296
+ if token_ids_1 is not None:
297
+ output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
298
+
299
+ return output
300
+
301
+ def _tokenize(self, text: str, **kwargs) -> list[str]:
302
+ """Take as input a string and return a list of strings (tokens) for words/sub-words"""
303
+ tokens = self.convert_ids_to_tokens(self._bolmo_encode(text))
304
+ return tokens
305
+
306
+ def _patch_ids_to_byte_ids(self, input_ids: list[int]):
307
+ return [byte_token_id for token_id in input_ids for byte_token_id in self.byte_sequences[token_id]]
308
+
309
+ def _bolmo_encode(self, string: str, add_special_tokens=False):
310
+ input_ids = self.hf_tokenizer.encode(string, add_special_tokens=add_special_tokens)
311
+ return self._patch_ids_to_byte_ids(input_ids)
312
+
313
+ def _bolmo_decode(self, tokens: list[int], skip_special_tokens: bool = False) -> str:
314
+ return self._decode_to_bytes(tokens, skip_special_tokens=skip_special_tokens).decode("utf-8", errors="replace")
315
+
316
+ def _decode_to_bytes(self, tokens: list[int], skip_special_tokens: bool = False) -> bytes:
317
+ tokens_without_boundary = []
318
+ for token in tokens:
319
+ if token >= (self.offset + 256):
320
+ token -= self.offset + 256
321
+
322
+ tokens_without_boundary.append(token)
323
+
324
+ utf8_bytes = []
325
+
326
+ for token in tokens_without_boundary:
327
+ if token < self.offset:
328
+ if skip_special_tokens:
329
+ continue
330
+ else:
331
+ utf8_bytes.extend(self.config.special_tokens[token].encode("utf-8"))
332
+ else:
333
+ utf8_bytes.append(min(token - self.offset, 255))
334
+
335
+ return bytes(utf8_bytes)
336
+
337
+ def get_tokens_and_patch_lengths(self, original_input_ids: list[int], add_bos=False, strip_pad=False, skip_last=False):
338
+ if add_bos and self.bos_token_id is not None:
339
+ byte_tokens = [self.bos_token_id]
340
+ patch_lengths = [1]
341
+ else:
342
+ byte_tokens = []
343
+ patch_lengths = []
344
+
345
+ for idx, token in enumerate(original_input_ids):
346
+ # optionally skip last token to keep the length the same if add_bos=True
347
+ if skip_last and idx == len(original_input_ids) - 1:
348
+ break
349
+
350
+ token_byte_tokens = self._patch_ids_to_byte_ids([int(token)])
351
+
352
+ if strip_pad and all(t == self.pad_token_id for t in token_byte_tokens):
353
+ # skip padding tokens
354
+ continue
355
+
356
+ patch_lengths.append(len(token_byte_tokens))
357
+ byte_tokens.extend(token_byte_tokens)
358
+
359
+ return byte_tokens, patch_lengths
360
+
361
+ def convert_tokens_to_string(self, tokens: list[str]) -> str:
362
+ return self._bolmo_decode(self.convert_tokens_to_ids(tokens), skip_special_tokens=False) # type: ignore
363
+
364
+ def _decode(
365
+ self,
366
+ token_ids: Union[int, list[int]],
367
+ skip_special_tokens: bool = False,
368
+ clean_up_tokenization_spaces: Optional[bool] = None,
369
+ spaces_between_special_tokens: bool = True,
370
+ **kwargs,
371
+ ) -> str:
372
+ if isinstance(token_ids, int):
373
+ token_ids = [token_ids]
374
+
375
+ return self._bolmo_decode(token_ids, skip_special_tokens=skip_special_tokens)
376
+
377
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
378
+ return () # type: ignore
tokenizer_config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<pad>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<bos>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ }
19
+ },
20
+ "auto_map": {
21
+ "AutoTokenizer": [
22
+ "tokenization_bolmo.BolmoTokenizer",
23
+ null
24
+ ]
25
+ },
26
+ "bos_token": "<bos>",
27
+ "clean_up_tokenization_spaces": false,
28
+ "eos_token": "<bos>",
29
+ "extra_ids": 0,
30
+ "extra_special_tokens": {},
31
+ "model_max_length": 1000000000000000019884624838656,
32
+ "pad_token": "<pad>",
33
+ "tokenizer_class": "BolmoTokenizer"
34
+ }
utils_bolmo.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ def compute_boundary_mask(boundary_logprobs: torch.Tensor, boundary_threshold: str) -> torch.Tensor:
8
+ if boundary_threshold.startswith("sample:"):
9
+ _, temperature = boundary_threshold.split(":")
10
+ temperature = float(temperature)
11
+
12
+ if temperature == 0:
13
+ return (boundary_logprobs > math.log(0.5))
14
+ elif temperature == 1:
15
+ return torch.bernoulli(torch.exp(boundary_logprobs)).to(torch.bool)
16
+ else:
17
+ raise NotImplementedError("Temperatures outside {0,1} are not implemented yet.")
18
+ elif boundary_threshold.startswith("topk:"):
19
+ _, topk = boundary_threshold.split(":")
20
+ topk = int(topk)
21
+ thresholds = torch.quantile(boundary_logprobs, dim=1, q=1 - (topk / boundary_logprobs.shape[1]))
22
+ return (boundary_logprobs >= thresholds.unsqueeze(-1))
23
+ elif boundary_threshold.startswith("topk_percent:"):
24
+ _, topk_percent = boundary_threshold.split(":")
25
+ topk_percent = float(topk_percent)
26
+ assert 0 <= topk_percent <= 1
27
+ thresholds = torch.quantile(boundary_logprobs, dim=1, q=1 - topk_percent)
28
+ return (boundary_logprobs >= thresholds.unsqueeze(-1))
29
+ else:
30
+ raise ValueError(f"Unknown boundary threshold: {boundary_threshold}")
31
+
32
+
33
+ def _pad(tensors: list[torch.Tensor], multiple_of: int, direction: str, value):
34
+ max_len = max(t.size(0) for t in tensors)
35
+ if multiple_of > 1:
36
+ # Round up max_len to the nearest multiple_of
37
+ max_len = ((max_len + multiple_of - 1) // multiple_of) * multiple_of
38
+ padded = []
39
+ for t in tensors:
40
+ if direction == "left":
41
+ pad_shape = (max_len - t.size(0), 0)
42
+ elif direction == "right":
43
+ pad_shape = (0, max_len - t.size(0))
44
+ else:
45
+ raise ValueError(f"Unknown direction: {direction}. Must be 'left' or 'right'.")
46
+ padded.append(F.pad(t, pad_shape, value=value))
47
+ return torch.stack(padded, dim=0)
48
+
49
+ def pad_right(
50
+ tensors: list[torch.Tensor],
51
+ multiple_of: int = 128,
52
+ value=0,
53
+ ):
54
+ return _pad(tensors, multiple_of, direction="right", value=value)
55
+
56
+ def pad_left(
57
+ tensors: list[torch.Tensor],
58
+ multiple_of: int = 128,
59
+ value=0,
60
+ ):
61
+ return _pad(tensors, multiple_of, direction="left", value=value)
62
+
63
+ class MaskState:
64
+ def __init__(self, mask):
65
+ self.cpu_mask = mask.cpu()
66
+
67
+ self.mask = mask
68
+ self.inv_mask = ~mask
69
+ self._all = self.cpu_mask.all().item()
70
+ self._any = self.cpu_mask.any().item()
71
+
72
+ def any(self):
73
+ return self._any
74
+
75
+ def all(self):
76
+ return self._all
77
+
78
+ def selective_get(self, x, inv=False):
79
+ # try to avoid sync through nonzero on index
80
+ if inv:
81
+ if self.all():
82
+ return x[[]]
83
+ elif not self.any():
84
+ return x
85
+ else:
86
+ return x[self.inv_mask]
87
+ else:
88
+ if self.all():
89
+ return x
90
+ elif not self.any():
91
+ return x[[]]
92
+ else:
93
+ return x[self.mask]
94
+
95
+ def selective_put(self, x, out, inv=False):
96
+ # try to avoid sync through nonzero on index
97
+ if inv:
98
+ if self.all():
99
+ return
100
+ elif not self.any():
101
+ out.copy_(x)
102
+ else:
103
+ out[self.inv_mask] = x
104
+ else:
105
+ if self.all():
106
+ out.copy_(x)
107
+ elif not self.any():
108
+ return
109
+ else:
110
+ out[self.mask] = x
111
+
112
+ def selective_add(self, x, out, inv=False):
113
+ # try to avoid sync through nonzero on index
114
+ if inv:
115
+ if self.all():
116
+ return
117
+ elif not self.any():
118
+ out.add_(x)
119
+ else:
120
+ out[self.inv_mask] += x
121
+ else:
122
+ if self.all():
123
+ out.add_(x)
124
+ elif not self.any():
125
+ return
126
+ else:
127
+ out[self.mask] += x