Add NVFP4 quantized checkpoint
Browse files- README.md +23 -0
- config.json +143 -0
- configuration_bolmo.py +235 -0
- generation_config.json +7 -0
- model-00001-of-00002.safetensors +3 -0
- model-00002-of-00002.safetensors +3 -0
- model.safetensors.index.json +0 -0
- modeling_bolmo.py +1351 -0
- recipe.yaml +7 -0
- special_tokens_map.json +5 -0
- tokenization_bolmo.py +378 -0
- tokenizer_config.json +34 -0
- utils_bolmo.py +127 -0
README.md
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
datasets:
|
| 3 |
+
- HuggingFaceH4/ultrachat_200k
|
| 4 |
+
base_model:
|
| 5 |
+
- allenai/Bolmo-7B
|
| 6 |
+
---
|
| 7 |
+
# Bolmo-7B-nvfp4
|
| 8 |
+
|
| 9 |
+
**Format:** NVFP4 — weights & activations quantized to FP4 with dual scaling.
|
| 10 |
+
**Base model:** `allenai/Bolmo-7B`
|
| 11 |
+
**How it was made:** One-shot calibration with LLM Compressor (NVFP4 recipe), long-seq calibration with HuggingFaceH4/ultrachat_200k.
|
| 12 |
+
|
| 13 |
+
> Notes: Keep `lm_head` in high precision; calibrate on long, domain-relevant sequences.
|
| 14 |
+
|
| 15 |
+
Check the original model card for information about this model.
|
| 16 |
+
|
| 17 |
+
# Running the model with VLLM in Docker
|
| 18 |
+
```sh
|
| 19 |
+
sudo docker run --runtime nvidia --gpus all -p 8000:8000 --ipc=host vllm/vllm-openai:nightly --model Firworks/Bolmo-7B-nvfp4 --dtype auto --max-model-len 32768
|
| 20 |
+
```
|
| 21 |
+
This was tested on an RTX Pro 6000 Blackwell cloud instance.
|
| 22 |
+
|
| 23 |
+
If there are other models you're interested in seeing quantized to NVFP4 for use on the DGX Spark, or other modern Blackwell (or newer) cards let me know. I'm trying to make more NVFP4 models available to allow more people to try them out.
|
config.json
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_expanded_embeddings": true,
|
| 3 |
+
"architectures": [
|
| 4 |
+
"BolmoForCausalLM"
|
| 5 |
+
],
|
| 6 |
+
"attention_bias": false,
|
| 7 |
+
"attention_dropout": 0.0,
|
| 8 |
+
"auto_map": {
|
| 9 |
+
"AutoConfig": "configuration_bolmo.BolmoConfig",
|
| 10 |
+
"AutoModelForCausalLM": "modeling_bolmo.BolmoForCausalLM"
|
| 11 |
+
},
|
| 12 |
+
"bos_token_id": 1,
|
| 13 |
+
"boundary_predictor_lookahead": 1,
|
| 14 |
+
"boundary_threshold": "sample:0",
|
| 15 |
+
"dtype": "float32",
|
| 16 |
+
"eos_token_id": 1,
|
| 17 |
+
"hidden_act": "silu",
|
| 18 |
+
"hidden_size": 4096,
|
| 19 |
+
"initializer_range": 0.02,
|
| 20 |
+
"intermediate_size": 11008,
|
| 21 |
+
"layer_types": [
|
| 22 |
+
"sliding_attention",
|
| 23 |
+
"sliding_attention",
|
| 24 |
+
"sliding_attention",
|
| 25 |
+
"full_attention",
|
| 26 |
+
"sliding_attention",
|
| 27 |
+
"sliding_attention",
|
| 28 |
+
"sliding_attention",
|
| 29 |
+
"full_attention",
|
| 30 |
+
"sliding_attention",
|
| 31 |
+
"sliding_attention",
|
| 32 |
+
"sliding_attention",
|
| 33 |
+
"full_attention",
|
| 34 |
+
"sliding_attention",
|
| 35 |
+
"sliding_attention",
|
| 36 |
+
"sliding_attention",
|
| 37 |
+
"full_attention",
|
| 38 |
+
"sliding_attention",
|
| 39 |
+
"sliding_attention",
|
| 40 |
+
"sliding_attention",
|
| 41 |
+
"full_attention",
|
| 42 |
+
"sliding_attention",
|
| 43 |
+
"sliding_attention",
|
| 44 |
+
"sliding_attention",
|
| 45 |
+
"full_attention",
|
| 46 |
+
"sliding_attention",
|
| 47 |
+
"sliding_attention",
|
| 48 |
+
"sliding_attention",
|
| 49 |
+
"full_attention",
|
| 50 |
+
"sliding_attention",
|
| 51 |
+
"sliding_attention",
|
| 52 |
+
"sliding_attention",
|
| 53 |
+
"full_attention"
|
| 54 |
+
],
|
| 55 |
+
"local_intermediate_size": 5504,
|
| 56 |
+
"local_rms_norm_eps": 1e-05,
|
| 57 |
+
"max_position_embeddings": 65536,
|
| 58 |
+
"model_type": "bolmo",
|
| 59 |
+
"num_attention_heads": 32,
|
| 60 |
+
"num_hidden_layers": 32,
|
| 61 |
+
"num_key_value_heads": 32,
|
| 62 |
+
"num_local_decoder_layers": 4,
|
| 63 |
+
"num_local_encoder_layers": 1,
|
| 64 |
+
"num_local_heads": 16,
|
| 65 |
+
"pad_token_id": 0,
|
| 66 |
+
"quantization_config": {
|
| 67 |
+
"config_groups": {
|
| 68 |
+
"group_0": {
|
| 69 |
+
"format": "nvfp4-pack-quantized",
|
| 70 |
+
"input_activations": {
|
| 71 |
+
"actorder": null,
|
| 72 |
+
"block_structure": null,
|
| 73 |
+
"dynamic": "local",
|
| 74 |
+
"group_size": 16,
|
| 75 |
+
"num_bits": 4,
|
| 76 |
+
"observer": "minmax",
|
| 77 |
+
"observer_kwargs": {},
|
| 78 |
+
"strategy": "tensor_group",
|
| 79 |
+
"symmetric": true,
|
| 80 |
+
"type": "float"
|
| 81 |
+
},
|
| 82 |
+
"output_activations": null,
|
| 83 |
+
"targets": [
|
| 84 |
+
"Linear"
|
| 85 |
+
],
|
| 86 |
+
"weights": {
|
| 87 |
+
"actorder": null,
|
| 88 |
+
"block_structure": null,
|
| 89 |
+
"dynamic": false,
|
| 90 |
+
"group_size": 16,
|
| 91 |
+
"num_bits": 4,
|
| 92 |
+
"observer": "minmax",
|
| 93 |
+
"observer_kwargs": {},
|
| 94 |
+
"strategy": "tensor_group",
|
| 95 |
+
"symmetric": true,
|
| 96 |
+
"type": "float"
|
| 97 |
+
}
|
| 98 |
+
}
|
| 99 |
+
},
|
| 100 |
+
"format": "nvfp4-pack-quantized",
|
| 101 |
+
"global_compression_ratio": null,
|
| 102 |
+
"ignore": [
|
| 103 |
+
"lm_head"
|
| 104 |
+
],
|
| 105 |
+
"kv_cache_scheme": null,
|
| 106 |
+
"quant_method": "compressed-tensors",
|
| 107 |
+
"quantization_status": "compressed",
|
| 108 |
+
"sparsity_config": {},
|
| 109 |
+
"transform_config": {},
|
| 110 |
+
"version": "0.12.2"
|
| 111 |
+
},
|
| 112 |
+
"rms_norm_eps": 1e-06,
|
| 113 |
+
"rope_scaling": {
|
| 114 |
+
"attention_factor": 1.2079441541679836,
|
| 115 |
+
"beta_fast": 32,
|
| 116 |
+
"beta_slow": 1,
|
| 117 |
+
"factor": 8.0,
|
| 118 |
+
"original_max_position_embeddings": 8192,
|
| 119 |
+
"rope_type": "yarn"
|
| 120 |
+
},
|
| 121 |
+
"rope_theta": 500000,
|
| 122 |
+
"sliding_window": 4096,
|
| 123 |
+
"subword_vocab_size": 100278,
|
| 124 |
+
"tie_word_embeddings": false,
|
| 125 |
+
"tokenizer_config": {
|
| 126 |
+
"bos_token_id": 1,
|
| 127 |
+
"bpe_token_end_id": 3,
|
| 128 |
+
"eos_token_id": 1,
|
| 129 |
+
"original_identifier": "allenai/dolma2-tokenizer",
|
| 130 |
+
"pad_token_id": 0,
|
| 131 |
+
"special_tokens": [
|
| 132 |
+
"<pad>",
|
| 133 |
+
"<bos>",
|
| 134 |
+
"<eos>",
|
| 135 |
+
"<bpe_token_end>"
|
| 136 |
+
],
|
| 137 |
+
"special_tokens_first": true,
|
| 138 |
+
"vocab_size": 520
|
| 139 |
+
},
|
| 140 |
+
"transformers_version": "4.57.3",
|
| 141 |
+
"use_cache": true,
|
| 142 |
+
"vocab_size": 520
|
| 143 |
+
}
|
configuration_bolmo.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import asdict
|
| 2 |
+
from typing import Any
|
| 3 |
+
|
| 4 |
+
from transformers.configuration_utils import PretrainedConfig, layer_type_validation
|
| 5 |
+
from transformers.modeling_rope_utils import rope_config_validation
|
| 6 |
+
from .tokenization_bolmo import BolmoTokenizerConfig
|
| 7 |
+
|
| 8 |
+
class BolmoConfig(PretrainedConfig):
|
| 9 |
+
r"""
|
| 10 |
+
This is the configuration class to store the configuration of a [`Olmo3Model`]. It is used to instantiate an OLMo3
|
| 11 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 12 |
+
defaults will yield a similar configuration to that of the [allenai/OLMo-3-0725-1B](https://huggingface.co/allenai/OLMo-3-0725-1B).
|
| 13 |
+
|
| 14 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 15 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
vocab_size (`int`, *optional*, defaults to 50304):
|
| 20 |
+
Vocabulary size of the Olmo3 model. Defines the number of different tokens that can be represented by the
|
| 21 |
+
`inputs_ids` passed when calling [`Olmo3Model`]
|
| 22 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
| 23 |
+
Dimension of the hidden representations.
|
| 24 |
+
intermediate_size (`int`, *optional*, defaults to 11008):
|
| 25 |
+
Dimension of the MLP representations.
|
| 26 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
| 27 |
+
Number of hidden layers in the Transformer decoder.
|
| 28 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
| 29 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
| 30 |
+
num_key_value_heads (`int`, *optional*):
|
| 31 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
| 32 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
| 33 |
+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
| 34 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
| 35 |
+
by meanpooling all the original heads within that group. For more details, check out [this
|
| 36 |
+
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
|
| 37 |
+
`num_attention_heads`.
|
| 38 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 39 |
+
The non-linear activation function (function or string) in the decoder.
|
| 40 |
+
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
| 41 |
+
The maximum sequence length that this model might ever be used with.
|
| 42 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 43 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 44 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 45 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 46 |
+
relevant if `config.is_decoder=True`.
|
| 47 |
+
pad_token_id (`int`, *optional*, defaults to 1):
|
| 48 |
+
Padding token id.
|
| 49 |
+
bos_token_id (`int`, *optional*):
|
| 50 |
+
Beginning of stream token id.
|
| 51 |
+
eos_token_id (`int`, *optional*, defaults to 50279):
|
| 52 |
+
End of stream token id.
|
| 53 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 54 |
+
Whether to tie weight embeddings
|
| 55 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
| 56 |
+
The base period of the RoPE embeddings.
|
| 57 |
+
rope_scaling (`Dict`, *optional*):
|
| 58 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
| 59 |
+
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
| 60 |
+
accordingly.
|
| 61 |
+
Expected contents:
|
| 62 |
+
`rope_type` (`str`):
|
| 63 |
+
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
| 64 |
+
'llama3'], with 'default' being the original RoPE implementation.
|
| 65 |
+
`factor` (`float`, *optional*):
|
| 66 |
+
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
| 67 |
+
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
| 68 |
+
original maximum pre-trained length.
|
| 69 |
+
`original_max_position_embeddings` (`int`, *optional*):
|
| 70 |
+
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
| 71 |
+
pretraining.
|
| 72 |
+
`attention_factor` (`float`, *optional*):
|
| 73 |
+
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
| 74 |
+
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
| 75 |
+
`factor` field to infer the suggested value.
|
| 76 |
+
`beta_fast` (`float`, *optional*):
|
| 77 |
+
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
| 78 |
+
ramp function. If unspecified, it defaults to 32.
|
| 79 |
+
`beta_slow` (`float`, *optional*):
|
| 80 |
+
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
| 81 |
+
ramp function. If unspecified, it defaults to 1.
|
| 82 |
+
`short_factor` (`list[float]`, *optional*):
|
| 83 |
+
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
| 84 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 85 |
+
size divided by the number of attention heads divided by 2
|
| 86 |
+
`long_factor` (`list[float]`, *optional*):
|
| 87 |
+
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
| 88 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 89 |
+
size divided by the number of attention heads divided by 2
|
| 90 |
+
`low_freq_factor` (`float`, *optional*):
|
| 91 |
+
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
| 92 |
+
`high_freq_factor` (`float`, *optional*):
|
| 93 |
+
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
| 94 |
+
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
| 95 |
+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
| 96 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 97 |
+
The dropout ratio for the attention probabilities.
|
| 98 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 99 |
+
The epsilon used by the rms normalization layers.
|
| 100 |
+
sliding_window (`int`, *optional*, defaults to 4096):
|
| 101 |
+
Size of the sliding window for sliding window attention.
|
| 102 |
+
layer_types (`list`, *optional*):
|
| 103 |
+
Attention pattern for each layer. Defaults to sliding window attention
|
| 104 |
+
for 3 out of 4 layers, and full attention for every 4th layer.
|
| 105 |
+
|
| 106 |
+
```python
|
| 107 |
+
>>> from transformers import Olmo3Model, Olmo3Config
|
| 108 |
+
|
| 109 |
+
>>> # Initializing a Olmo3 7B style configuration
|
| 110 |
+
>>> configuration = Olmo3Config()
|
| 111 |
+
|
| 112 |
+
>>> # Initializing a model from the Olmo3 7B style configuration
|
| 113 |
+
>>> model = Olmo3Model(configuration)
|
| 114 |
+
|
| 115 |
+
>>> # Accessing the model configuration
|
| 116 |
+
>>> configuration = model.config
|
| 117 |
+
```
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
model_type = "bolmo"
|
| 121 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 122 |
+
base_model_tp_plan = {
|
| 123 |
+
"layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
|
| 124 |
+
"layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
|
| 125 |
+
"layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
|
| 126 |
+
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
|
| 127 |
+
"layers.*.mlp.gate_proj": "colwise",
|
| 128 |
+
"layers.*.mlp.up_proj": "colwise",
|
| 129 |
+
"layers.*.mlp.down_proj": "rowwise",
|
| 130 |
+
}
|
| 131 |
+
base_model_pp_plan = {
|
| 132 |
+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
| 133 |
+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
| 134 |
+
"norm": (["hidden_states"], ["hidden_states"]),
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
def __init__(
|
| 138 |
+
self,
|
| 139 |
+
vocab_size=520,
|
| 140 |
+
hidden_size=4096,
|
| 141 |
+
intermediate_size=11008,
|
| 142 |
+
num_hidden_layers=32,
|
| 143 |
+
num_attention_heads=32,
|
| 144 |
+
num_key_value_heads=None,
|
| 145 |
+
hidden_act="silu",
|
| 146 |
+
max_position_embeddings=2048,
|
| 147 |
+
initializer_range=0.02,
|
| 148 |
+
use_cache=True,
|
| 149 |
+
pad_token_id=1,
|
| 150 |
+
bos_token_id=None,
|
| 151 |
+
eos_token_id=50279,
|
| 152 |
+
tie_word_embeddings=False,
|
| 153 |
+
rope_theta=10000.0,
|
| 154 |
+
rope_scaling=None,
|
| 155 |
+
attention_bias=False,
|
| 156 |
+
attention_dropout=0.0,
|
| 157 |
+
rms_norm_eps=1e-5,
|
| 158 |
+
sliding_window=4096,
|
| 159 |
+
layer_types=None,
|
| 160 |
+
# bolmo config
|
| 161 |
+
add_expanded_embeddings: bool = True,
|
| 162 |
+
boundary_predictor_lookahead: int = 1,
|
| 163 |
+
boundary_threshold: str = "sample:0",
|
| 164 |
+
num_local_encoder_layers: int = 1,
|
| 165 |
+
num_local_decoder_layers: int = 4,
|
| 166 |
+
num_local_heads: int = 16,
|
| 167 |
+
local_intermediate_size: int = 5504,
|
| 168 |
+
local_rms_norm_eps=1e-5,
|
| 169 |
+
subword_vocab_size: int = 100278, # dolma2_tokenizer subword vocab size
|
| 170 |
+
tokenizer_config: BolmoTokenizerConfig | dict[str, Any] | None = None,
|
| 171 |
+
**kwargs,
|
| 172 |
+
):
|
| 173 |
+
super().__init__(
|
| 174 |
+
pad_token_id=pad_token_id,
|
| 175 |
+
bos_token_id=bos_token_id,
|
| 176 |
+
eos_token_id=eos_token_id,
|
| 177 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 178 |
+
**kwargs,
|
| 179 |
+
)
|
| 180 |
+
self.vocab_size = vocab_size
|
| 181 |
+
self.max_position_embeddings = max_position_embeddings
|
| 182 |
+
self.hidden_size = hidden_size
|
| 183 |
+
self.intermediate_size = intermediate_size
|
| 184 |
+
self.num_hidden_layers = num_hidden_layers
|
| 185 |
+
self.num_attention_heads = num_attention_heads
|
| 186 |
+
|
| 187 |
+
# for backward compatibility
|
| 188 |
+
if num_key_value_heads is None:
|
| 189 |
+
num_key_value_heads = num_attention_heads
|
| 190 |
+
|
| 191 |
+
self.num_key_value_heads = num_key_value_heads
|
| 192 |
+
self.hidden_act = hidden_act
|
| 193 |
+
self.initializer_range = initializer_range
|
| 194 |
+
self.use_cache = use_cache
|
| 195 |
+
self.rope_theta = rope_theta
|
| 196 |
+
self.rope_scaling = rope_scaling
|
| 197 |
+
self._rope_scaling_validation()
|
| 198 |
+
self.attention_bias = attention_bias
|
| 199 |
+
self.attention_dropout = attention_dropout
|
| 200 |
+
|
| 201 |
+
self.rms_norm_eps = rms_norm_eps
|
| 202 |
+
|
| 203 |
+
self.sliding_window = sliding_window
|
| 204 |
+
self.layer_types = layer_types
|
| 205 |
+
if self.layer_types is None:
|
| 206 |
+
self.layer_types = [
|
| 207 |
+
"sliding_attention" if (i + 1) % 4 != 0 else "full_attention" for i in range(self.num_hidden_layers)
|
| 208 |
+
]
|
| 209 |
+
layer_type_validation(self.layer_types)
|
| 210 |
+
|
| 211 |
+
# bolmo configuration
|
| 212 |
+
self.add_expanded_embeddings = add_expanded_embeddings
|
| 213 |
+
self.boundary_predictor_lookahead = boundary_predictor_lookahead
|
| 214 |
+
self.boundary_threshold = boundary_threshold
|
| 215 |
+
self.num_local_encoder_layers = num_local_encoder_layers
|
| 216 |
+
self.num_local_decoder_layers = num_local_decoder_layers
|
| 217 |
+
self.num_local_heads = num_local_heads
|
| 218 |
+
self.local_intermediate_size = local_intermediate_size
|
| 219 |
+
self.local_rms_norm_eps = local_rms_norm_eps
|
| 220 |
+
self.subword_vocab_size = subword_vocab_size
|
| 221 |
+
|
| 222 |
+
if tokenizer_config is None:
|
| 223 |
+
self.tokenizer_config = asdict(BolmoTokenizerConfig.bolmo())
|
| 224 |
+
elif isinstance(tokenizer_config, BolmoTokenizerConfig):
|
| 225 |
+
self.tokenizer_config = asdict(tokenizer_config)
|
| 226 |
+
else:
|
| 227 |
+
self.tokenizer_config = tokenizer_config
|
| 228 |
+
|
| 229 |
+
def _rope_scaling_validation(self):
|
| 230 |
+
"""
|
| 231 |
+
Validate the `rope_scaling` configuration.
|
| 232 |
+
"""
|
| 233 |
+
rope_config_validation(self)
|
| 234 |
+
|
| 235 |
+
__all__ = ["BolmoConfig"]
|
generation_config.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"do_sample": true,
|
| 4 |
+
"eos_token_id": 50279,
|
| 5 |
+
"pad_token_id": 1,
|
| 6 |
+
"transformers_version": "4.57.3"
|
| 7 |
+
}
|
model-00001-of-00002.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:99b1a3015951f3907501cc07ed492c6e60e5ade94b0fe62faf09f15d8ef58c87
|
| 3 |
+
size 4979678176
|
model-00002-of-00002.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ff6107aa16c14af67f8169fb536c4c94b70e8d0a48b11bb67f6d6e2ce5d829a2
|
| 3 |
+
size 742727128
|
model.safetensors.index.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
modeling_bolmo.py
ADDED
|
@@ -0,0 +1,1351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from typing import Callable, Optional, Union, cast
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
|
| 9 |
+
from transformers.utils.generic import TransformersKwargs
|
| 10 |
+
|
| 11 |
+
from transformers.activations import ACT2FN
|
| 12 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 13 |
+
from transformers.generation import GenerationMixin, GenerationConfig, LogitsProcessorList, StoppingCriteriaList
|
| 14 |
+
from transformers.generation.utils import GenerateOutput
|
| 15 |
+
from transformers.integrations import use_kernel_forward_from_hub
|
| 16 |
+
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
| 17 |
+
from transformers.modeling_layers import GradientCheckpointingLayer
|
| 18 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 19 |
+
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 20 |
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 21 |
+
from transformers.processing_utils import Unpack
|
| 22 |
+
from transformers.utils import can_return_tuple
|
| 23 |
+
from transformers.utils.deprecation import deprecate_kwarg
|
| 24 |
+
from transformers.utils.generic import check_model_inputs
|
| 25 |
+
|
| 26 |
+
from .configuration_bolmo import BolmoConfig
|
| 27 |
+
from .tokenization_bolmo import BolmoTokenizerConfig
|
| 28 |
+
from .utils_bolmo import compute_boundary_mask, pad_right, pad_left, MaskState
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
from xlstm.xlstm_large.model import mLSTMLayer, mLSTMLayerConfig, mLSTMLayerStateType, soft_cap, mLSTMBackendConfig
|
| 32 |
+
except ImportError:
|
| 33 |
+
raise ImportError("The `xlstm` package is required to use Bolmo. Please install it via `pip install xlstm`.")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@use_kernel_forward_from_hub("RMSNorm")
|
| 37 |
+
class BolmoRMSNorm(nn.Module):
|
| 38 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 39 |
+
"""
|
| 40 |
+
BolmoRMSNorm is equivalent to T5LayerNorm
|
| 41 |
+
"""
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 44 |
+
self.variance_epsilon = eps
|
| 45 |
+
|
| 46 |
+
def forward(self, hidden_states):
|
| 47 |
+
input_dtype = hidden_states.dtype
|
| 48 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 49 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 50 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 51 |
+
return (self.weight * hidden_states).to(input_dtype)
|
| 52 |
+
|
| 53 |
+
def extra_repr(self):
|
| 54 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 58 |
+
"""
|
| 59 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 60 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 61 |
+
"""
|
| 62 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 63 |
+
if n_rep == 1:
|
| 64 |
+
return hidden_states
|
| 65 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 66 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def eager_attention_forward(
|
| 70 |
+
module: nn.Module,
|
| 71 |
+
query: torch.Tensor,
|
| 72 |
+
key: torch.Tensor,
|
| 73 |
+
value: torch.Tensor,
|
| 74 |
+
attention_mask: Optional[torch.Tensor],
|
| 75 |
+
scaling: float,
|
| 76 |
+
dropout: float = 0.0,
|
| 77 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 78 |
+
):
|
| 79 |
+
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 80 |
+
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 81 |
+
|
| 82 |
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 83 |
+
if attention_mask is not None:
|
| 84 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 85 |
+
attn_weights = attn_weights + causal_mask
|
| 86 |
+
|
| 87 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 88 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 89 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 90 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 91 |
+
|
| 92 |
+
return attn_output, attn_weights
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 96 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
q (`torch.Tensor`): The query tensor.
|
| 100 |
+
k (`torch.Tensor`): The key tensor.
|
| 101 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 102 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 103 |
+
position_ids (`torch.Tensor`, *optional*):
|
| 104 |
+
Deprecated and unused.
|
| 105 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 106 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 107 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 108 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 109 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 110 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 111 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 112 |
+
Returns:
|
| 113 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 114 |
+
"""
|
| 115 |
+
q_type, k_type = q.dtype, k.dtype
|
| 116 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 117 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 118 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 119 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 120 |
+
return q_embed.to(q_type), k_embed.to(k_type)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def rotate_half(x):
|
| 124 |
+
"""Rotates half the hidden dims of the input."""
|
| 125 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 126 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 127 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class BolmoAttention(nn.Module):
|
| 131 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 132 |
+
|
| 133 |
+
def __init__(self, config: BolmoConfig, layer_idx: int):
|
| 134 |
+
super().__init__()
|
| 135 |
+
self.config = config
|
| 136 |
+
self.layer_idx = layer_idx
|
| 137 |
+
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 138 |
+
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
| 139 |
+
self.scaling = self.head_dim**-0.5
|
| 140 |
+
self.attention_dropout = config.attention_dropout
|
| 141 |
+
self.is_causal = True
|
| 142 |
+
|
| 143 |
+
self.q_proj = nn.Linear(
|
| 144 |
+
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
| 145 |
+
)
|
| 146 |
+
self.k_proj = nn.Linear(
|
| 147 |
+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 148 |
+
)
|
| 149 |
+
self.v_proj = nn.Linear(
|
| 150 |
+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 151 |
+
)
|
| 152 |
+
self.o_proj = nn.Linear(
|
| 153 |
+
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
| 154 |
+
)
|
| 155 |
+
self.q_norm = BolmoRMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps)
|
| 156 |
+
self.k_norm = BolmoRMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps)
|
| 157 |
+
assert config.layer_types is not None
|
| 158 |
+
self.attention_type = config.layer_types[layer_idx]
|
| 159 |
+
self.sliding_window = config.sliding_window if self.attention_type == "sliding_attention" else None
|
| 160 |
+
|
| 161 |
+
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
|
| 162 |
+
def forward(
|
| 163 |
+
self,
|
| 164 |
+
hidden_states: torch.Tensor,
|
| 165 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 166 |
+
attention_mask: Optional[torch.Tensor],
|
| 167 |
+
past_key_values: Optional[Cache] = None,
|
| 168 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 169 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 170 |
+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 171 |
+
input_shape = hidden_states.shape[:-1]
|
| 172 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 173 |
+
|
| 174 |
+
query_states = self.q_norm(self.q_proj(hidden_states))
|
| 175 |
+
key_states = self.k_norm(self.k_proj(hidden_states))
|
| 176 |
+
value_states = self.v_proj(hidden_states)
|
| 177 |
+
|
| 178 |
+
query_states = query_states.view(hidden_shape).transpose(1, 2)
|
| 179 |
+
key_states = key_states.view(hidden_shape).transpose(1, 2)
|
| 180 |
+
value_states = value_states.view(hidden_shape).transpose(1, 2)
|
| 181 |
+
|
| 182 |
+
cos, sin = position_embeddings
|
| 183 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 184 |
+
|
| 185 |
+
if past_key_values is not None:
|
| 186 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 187 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 188 |
+
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 189 |
+
|
| 190 |
+
attention_interface: Callable = eager_attention_forward
|
| 191 |
+
if self.config._attn_implementation != "eager":
|
| 192 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 193 |
+
|
| 194 |
+
attn_output, attn_weights = attention_interface(
|
| 195 |
+
self,
|
| 196 |
+
query_states,
|
| 197 |
+
key_states,
|
| 198 |
+
value_states,
|
| 199 |
+
attention_mask,
|
| 200 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 201 |
+
scaling=self.scaling,
|
| 202 |
+
sliding_window=self.sliding_window,
|
| 203 |
+
**kwargs,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 207 |
+
attn_output = self.o_proj(attn_output)
|
| 208 |
+
return attn_output, attn_weights
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class BolmoMLP(nn.Module):
|
| 212 |
+
def __init__(self, config):
|
| 213 |
+
super().__init__()
|
| 214 |
+
self.config = config
|
| 215 |
+
self.hidden_size = config.hidden_size
|
| 216 |
+
self.intermediate_size = config.intermediate_size
|
| 217 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 218 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 219 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 220 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 221 |
+
|
| 222 |
+
def forward(self, x):
|
| 223 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 224 |
+
return down_proj
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class BolmoDecoderLayer(GradientCheckpointingLayer):
|
| 228 |
+
def __init__(self, config: BolmoConfig, layer_idx: int):
|
| 229 |
+
super().__init__()
|
| 230 |
+
self.hidden_size = config.hidden_size
|
| 231 |
+
self.self_attn = BolmoAttention(config=config, layer_idx=layer_idx)
|
| 232 |
+
|
| 233 |
+
self.mlp = BolmoMLP(config)
|
| 234 |
+
self.post_attention_layernorm = BolmoRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 235 |
+
self.post_feedforward_layernorm = BolmoRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 236 |
+
|
| 237 |
+
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
|
| 238 |
+
def forward(
|
| 239 |
+
self,
|
| 240 |
+
hidden_states: torch.Tensor,
|
| 241 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 242 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 243 |
+
past_key_values: Optional[Cache] = None,
|
| 244 |
+
use_cache: Optional[bool] = False,
|
| 245 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 246 |
+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
| 247 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 248 |
+
) -> torch.Tensor:
|
| 249 |
+
residual = hidden_states
|
| 250 |
+
attn_out, _ = self.self_attn(
|
| 251 |
+
hidden_states=hidden_states,
|
| 252 |
+
attention_mask=attention_mask,
|
| 253 |
+
position_ids=position_ids,
|
| 254 |
+
past_key_values=past_key_values,
|
| 255 |
+
use_cache=use_cache,
|
| 256 |
+
cache_position=cache_position,
|
| 257 |
+
position_embeddings=position_embeddings,
|
| 258 |
+
**kwargs,
|
| 259 |
+
)
|
| 260 |
+
hidden_states = self.post_attention_layernorm(attn_out)
|
| 261 |
+
hidden_states = residual + hidden_states
|
| 262 |
+
|
| 263 |
+
# Fully Connected
|
| 264 |
+
residual = hidden_states
|
| 265 |
+
mlp_out = self.mlp(hidden_states)
|
| 266 |
+
hidden_states = self.post_feedforward_layernorm(mlp_out)
|
| 267 |
+
hidden_states = residual + hidden_states
|
| 268 |
+
|
| 269 |
+
return hidden_states
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
class BolmoBoundaryPredictor(nn.Module):
|
| 273 |
+
def __init__(self, config: BolmoConfig):
|
| 274 |
+
super().__init__()
|
| 275 |
+
|
| 276 |
+
self.d_model = config.hidden_size
|
| 277 |
+
self.boundary_threshold = config.boundary_threshold
|
| 278 |
+
self.boundary_predictor_lookahead = config.boundary_predictor_lookahead
|
| 279 |
+
self.q_proj_layer = nn.Linear(self.d_model, self.d_model, bias=False)
|
| 280 |
+
self.k_proj_layer = nn.Linear(self.d_model, self.d_model, bias=False)
|
| 281 |
+
|
| 282 |
+
def forward(
|
| 283 |
+
self,
|
| 284 |
+
hidden_states: torch.Tensor,
|
| 285 |
+
sequence_start_indices: Optional[torch.Tensor] = None,
|
| 286 |
+
epsilon: float = 1e-3,
|
| 287 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 288 |
+
if self.boundary_predictor_lookahead == 0:
|
| 289 |
+
# do not use the same rep for k and v, use current and one before as in H-Net + pad with negative to the left
|
| 290 |
+
cos_sim = torch.cat([
|
| 291 |
+
torch.ones((hidden_states.shape[0], 1), device=hidden_states.device, dtype=hidden_states.dtype) * -1,
|
| 292 |
+
torch.einsum(
|
| 293 |
+
"b l d, b l d -> b l",
|
| 294 |
+
F.normalize(self.q_proj_layer(hidden_states[:, :-1]), dim=-1),
|
| 295 |
+
F.normalize(self.k_proj_layer(hidden_states[:, 1:]), dim=-1),
|
| 296 |
+
)
|
| 297 |
+
], dim=1)
|
| 298 |
+
else:
|
| 299 |
+
cos_sim = torch.einsum(
|
| 300 |
+
"b l d, b l d -> b l",
|
| 301 |
+
F.normalize(self.q_proj_layer(hidden_states[:, :-self.boundary_predictor_lookahead]), dim=-1),
|
| 302 |
+
F.normalize(self.k_proj_layer(hidden_states[:, self.boundary_predictor_lookahead:]), dim=-1),
|
| 303 |
+
)
|
| 304 |
+
boundary_logprobs = torch.log1p(-cos_sim.float().clip(max=1.0 - epsilon)) - math.log(2)
|
| 305 |
+
POSITIVE_LOGPROB = 0.0
|
| 306 |
+
NEGATIVE_LOGPROB = -100_000
|
| 307 |
+
if sequence_start_indices is None:
|
| 308 |
+
boundary_logprobs[:, 0] = POSITIVE_LOGPROB
|
| 309 |
+
else:
|
| 310 |
+
pad_mask = torch.arange(boundary_logprobs.shape[1], device=boundary_logprobs.device)[None, :] < sequence_start_indices[:, None]
|
| 311 |
+
boundary_logprobs = boundary_logprobs.masked_fill(pad_mask, NEGATIVE_LOGPROB)
|
| 312 |
+
boundary_logprobs[torch.arange(len(boundary_logprobs), device=boundary_logprobs.device), sequence_start_indices] = POSITIVE_LOGPROB
|
| 313 |
+
|
| 314 |
+
boundary_logprobs = F.pad(boundary_logprobs, (0, self.boundary_predictor_lookahead), "constant", NEGATIVE_LOGPROB)
|
| 315 |
+
boundary_mask = compute_boundary_mask(boundary_logprobs, self.boundary_threshold)
|
| 316 |
+
|
| 317 |
+
return boundary_logprobs, boundary_mask
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
class BolmoXLSTMLayer(mLSTMLayer):
|
| 321 |
+
def __init__(self, config: BolmoConfig):
|
| 322 |
+
super().__init__(mLSTMLayerConfig(
|
| 323 |
+
embedding_dim=config.hidden_size,
|
| 324 |
+
num_heads=config.num_local_heads,
|
| 325 |
+
mlstm_backend=mLSTMBackendConfig(
|
| 326 |
+
chunkwise_kernel="chunkwise--triton_limit_chunk",
|
| 327 |
+
sequence_kernel="native_sequence__triton",
|
| 328 |
+
step_kernel="triton",
|
| 329 |
+
mode="train",
|
| 330 |
+
return_last_states=True,
|
| 331 |
+
autocast_kernel_dtype="float32",
|
| 332 |
+
)
|
| 333 |
+
))
|
| 334 |
+
|
| 335 |
+
# original forward adapted to support sequence_start_indices
|
| 336 |
+
# i.e. set the forget gate to zero at the start of sequence
|
| 337 |
+
def _original_forward(
|
| 338 |
+
self, x: torch.Tensor,
|
| 339 |
+
state: mLSTMLayerStateType | None = None,
|
| 340 |
+
sequence_start_indices: Optional[torch.Tensor] = None,
|
| 341 |
+
) -> tuple[torch.Tensor, mLSTMLayerStateType | None]:
|
| 342 |
+
assert x.ndim == 3, f"Input must have shape [B, S, D], got {x.shape}"
|
| 343 |
+
B, S, _ = x.shape
|
| 344 |
+
if self.config.weight_mode == "single":
|
| 345 |
+
q = self.q(x)
|
| 346 |
+
k = self.k(x)
|
| 347 |
+
v = self.v(x)
|
| 348 |
+
o_preact = self.ogate_preact(x)
|
| 349 |
+
i_preact = soft_cap(
|
| 350 |
+
self.igate_preact(x), cap_value=self.config.gate_soft_cap
|
| 351 |
+
)
|
| 352 |
+
f_preact = soft_cap(
|
| 353 |
+
self.fgate_preact(x), cap_value=self.config.gate_soft_cap
|
| 354 |
+
)
|
| 355 |
+
elif self.config.weight_mode == "fused":
|
| 356 |
+
qkv_opreact = self.qkv_opreact(x)
|
| 357 |
+
q, k, v, o_preact = torch.tensor_split(
|
| 358 |
+
qkv_opreact,
|
| 359 |
+
(
|
| 360 |
+
self.qk_dim,
|
| 361 |
+
2 * self.qk_dim,
|
| 362 |
+
2 * self.qk_dim + self.v_dim,
|
| 363 |
+
),
|
| 364 |
+
dim=-1,
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
if_preact = soft_cap(
|
| 368 |
+
self.ifgate_preact(x), cap_value=self.config.gate_soft_cap
|
| 369 |
+
)
|
| 370 |
+
i_preact, f_preact = torch.tensor_split(
|
| 371 |
+
if_preact, (self.config.num_heads,), dim=-1
|
| 372 |
+
)
|
| 373 |
+
else:
|
| 374 |
+
raise ValueError(f"Unknown weight_mode: {self.config.weight_mode}")
|
| 375 |
+
|
| 376 |
+
q = q.reshape(B, S, self.config.num_heads, -1).transpose(1, 2)
|
| 377 |
+
k = k.reshape(B, S, self.config.num_heads, -1).transpose(1, 2)
|
| 378 |
+
v = v.reshape(B, S, self.config.num_heads, -1).transpose(1, 2)
|
| 379 |
+
|
| 380 |
+
if sequence_start_indices is not None:
|
| 381 |
+
f_preact[torch.arange(B, device=f_preact.device), sequence_start_indices] = -100_000
|
| 382 |
+
|
| 383 |
+
i_preact = i_preact.transpose(1, 2)
|
| 384 |
+
f_preact = f_preact.transpose(1, 2)
|
| 385 |
+
if state is None:
|
| 386 |
+
c_initial, n_initial, m_initial = None, None, None
|
| 387 |
+
else:
|
| 388 |
+
c_initial, n_initial, m_initial = state
|
| 389 |
+
|
| 390 |
+
h, state = self.mlstm_backend(
|
| 391 |
+
q=q,
|
| 392 |
+
k=k,
|
| 393 |
+
v=v,
|
| 394 |
+
i=i_preact,
|
| 395 |
+
f=f_preact,
|
| 396 |
+
c_initial=c_initial,
|
| 397 |
+
n_initial=n_initial,
|
| 398 |
+
m_initial=m_initial,
|
| 399 |
+
)
|
| 400 |
+
expected_h_shape = (
|
| 401 |
+
B,
|
| 402 |
+
self.config.num_heads,
|
| 403 |
+
S,
|
| 404 |
+
self.v_dim // self.config.num_heads,
|
| 405 |
+
)
|
| 406 |
+
assert (
|
| 407 |
+
h.shape == expected_h_shape
|
| 408 |
+
), f"Got {h.shape}, expected {expected_h_shape}"
|
| 409 |
+
|
| 410 |
+
h = h.transpose(1, 2)
|
| 411 |
+
h_norm = self.multihead_norm(h)
|
| 412 |
+
h_norm = h_norm.reshape(B, S, -1)
|
| 413 |
+
|
| 414 |
+
h_out = self.ogate_act_fn(o_preact) * h_norm
|
| 415 |
+
|
| 416 |
+
y = self.out_proj(h_out)
|
| 417 |
+
return y, state
|
| 418 |
+
|
| 419 |
+
def forward( # type: ignore
|
| 420 |
+
self,
|
| 421 |
+
x: torch.Tensor,
|
| 422 |
+
past_key_values: Optional[dict] = None,
|
| 423 |
+
use_cache: bool = False,
|
| 424 |
+
sequence_start_indices: Optional[torch.Tensor] = None,
|
| 425 |
+
cache_mask: Optional[MaskState] = None
|
| 426 |
+
):
|
| 427 |
+
if self.training:
|
| 428 |
+
self.mlstm_backend.config.mode = "train"
|
| 429 |
+
else:
|
| 430 |
+
self.mlstm_backend.config.mode = "inference"
|
| 431 |
+
|
| 432 |
+
if use_cache:
|
| 433 |
+
assert past_key_values is not None
|
| 434 |
+
|
| 435 |
+
prev_mode = self.mlstm_backend.config.mode
|
| 436 |
+
state = past_key_values.get("state", None)
|
| 437 |
+
|
| 438 |
+
if cache_mask is not None:
|
| 439 |
+
state_for_model = cast(mLSTMLayerStateType, tuple(cache_mask.selective_get(x, inv=True) for x in state) if state is not None else None)
|
| 440 |
+
else:
|
| 441 |
+
state_for_model = state
|
| 442 |
+
|
| 443 |
+
h, new_state = self._original_forward(
|
| 444 |
+
x,
|
| 445 |
+
state=state_for_model,
|
| 446 |
+
sequence_start_indices=sequence_start_indices
|
| 447 |
+
)
|
| 448 |
+
assert new_state is not None
|
| 449 |
+
|
| 450 |
+
if state is None or cache_mask is None:
|
| 451 |
+
state = new_state
|
| 452 |
+
else:
|
| 453 |
+
if cache_mask is not None:
|
| 454 |
+
for i in range(len(state)):
|
| 455 |
+
cache_mask.selective_put(new_state[i], state[i], inv=True)
|
| 456 |
+
|
| 457 |
+
past_key_values["state"] = state
|
| 458 |
+
self.mlstm_backend.config.mode = prev_mode
|
| 459 |
+
|
| 460 |
+
return h
|
| 461 |
+
else:
|
| 462 |
+
h, _ = super().forward(x)
|
| 463 |
+
return h
|
| 464 |
+
|
| 465 |
+
class BolmoLocalLayer(nn.Module):
|
| 466 |
+
def __init__(self, config: BolmoConfig):
|
| 467 |
+
super().__init__()
|
| 468 |
+
self.config = config
|
| 469 |
+
self.hidden_size = config.hidden_size
|
| 470 |
+
|
| 471 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 472 |
+
|
| 473 |
+
self.xlstm = BolmoXLSTMLayer(config)
|
| 474 |
+
|
| 475 |
+
local_mlp_config = copy.deepcopy(config)
|
| 476 |
+
local_mlp_config.intermediate_size = config.local_intermediate_size
|
| 477 |
+
self.mlp = BolmoMLP(local_mlp_config)
|
| 478 |
+
|
| 479 |
+
self.pre_xlstm_layernorm = BolmoRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 480 |
+
self.pre_feedforward_layernorm = BolmoRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 481 |
+
|
| 482 |
+
def forward(
|
| 483 |
+
self,
|
| 484 |
+
hidden_states: torch.Tensor,
|
| 485 |
+
sequence_start_indices: Optional[torch.Tensor] = None,
|
| 486 |
+
past_key_values: Optional[dict] = None,
|
| 487 |
+
use_cache: Optional[bool] = False,
|
| 488 |
+
cache_mask: Optional[MaskState] = None,
|
| 489 |
+
) -> torch.Tensor:
|
| 490 |
+
residual = hidden_states
|
| 491 |
+
xlstm_out = self.xlstm(self.pre_xlstm_layernorm(hidden_states), sequence_start_indices=sequence_start_indices, past_key_values=past_key_values["xlstm"] if past_key_values is not None else None, use_cache=use_cache, cache_mask=cache_mask)
|
| 492 |
+
hidden_states = residual + xlstm_out
|
| 493 |
+
|
| 494 |
+
# Fully Connected
|
| 495 |
+
residual = hidden_states
|
| 496 |
+
ffn_out = self.mlp(self.pre_feedforward_layernorm(hidden_states))
|
| 497 |
+
hidden_states = residual + ffn_out
|
| 498 |
+
|
| 499 |
+
return hidden_states
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
class BolmoLocalEncoder(nn.Module):
|
| 503 |
+
def __init__(self, config: BolmoConfig):
|
| 504 |
+
super().__init__()
|
| 505 |
+
self.config = config
|
| 506 |
+
self.hidden_size = config.hidden_size
|
| 507 |
+
self.add_expanded_embeddings = config.add_expanded_embeddings
|
| 508 |
+
|
| 509 |
+
self.byte_embedding = nn.Embedding(
|
| 510 |
+
config.vocab_size,
|
| 511 |
+
self.hidden_size,
|
| 512 |
+
)
|
| 513 |
+
if self.add_expanded_embeddings:
|
| 514 |
+
self.subword_embedding = nn.Embedding(
|
| 515 |
+
config.subword_vocab_size,
|
| 516 |
+
self.hidden_size,
|
| 517 |
+
)
|
| 518 |
+
else:
|
| 519 |
+
self.subword_embedding = None
|
| 520 |
+
|
| 521 |
+
self.layers = nn.ModuleList(
|
| 522 |
+
[BolmoLocalLayer(config) for _ in range(config.num_local_encoder_layers)]
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
self.post_last_block_norm = BolmoRMSNorm(
|
| 526 |
+
self.hidden_size,
|
| 527 |
+
config.local_rms_norm_eps,
|
| 528 |
+
)
|
| 529 |
+
self.out_projection = nn.Linear(
|
| 530 |
+
self.hidden_size,
|
| 531 |
+
self.hidden_size,
|
| 532 |
+
bias=True,
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
self.boundary_predictor_module = BolmoBoundaryPredictor(config)
|
| 536 |
+
|
| 537 |
+
self.has_cache = False
|
| 538 |
+
|
| 539 |
+
def prepare_inference_cache(self, batch_size: int):
|
| 540 |
+
device = next(self.parameters()).device
|
| 541 |
+
self.has_cache = True
|
| 542 |
+
|
| 543 |
+
self.cache_seqlens = 0
|
| 544 |
+
self.last_h = torch.zeros((batch_size, self.hidden_size), dtype=self.out_projection.weight.dtype, device=device)
|
| 545 |
+
self.layer_states = [{"xlstm": {}} for _ in range(len(self.layers))]
|
| 546 |
+
|
| 547 |
+
def free_inference_cache(self):
|
| 548 |
+
self.has_cache = False
|
| 549 |
+
if hasattr(self, "cache_seqlens"):
|
| 550 |
+
del self.cache_seqlens
|
| 551 |
+
if hasattr(self, "last_h"):
|
| 552 |
+
del self.last_h
|
| 553 |
+
if hasattr(self, "layer_states"):
|
| 554 |
+
del self.layer_states
|
| 555 |
+
|
| 556 |
+
def _embed(self, tokens, expanded_input_ids: Optional[torch.Tensor] = None):
|
| 557 |
+
embeddings = self.byte_embedding(tokens)
|
| 558 |
+
if self.add_expanded_embeddings:
|
| 559 |
+
assert expanded_input_ids is not None and self.subword_embedding is not None
|
| 560 |
+
embeddings = embeddings + self.subword_embedding(expanded_input_ids)
|
| 561 |
+
|
| 562 |
+
return embeddings
|
| 563 |
+
|
| 564 |
+
def _pool(
|
| 565 |
+
self,
|
| 566 |
+
h: torch.Tensor,
|
| 567 |
+
boundary_mask: torch.Tensor | None,
|
| 568 |
+
n_patches: int,
|
| 569 |
+
boundary_state: Optional[MaskState] = None,
|
| 570 |
+
):
|
| 571 |
+
if self.has_cache and self.cache_seqlens > 0:
|
| 572 |
+
assert boundary_state is not None
|
| 573 |
+
if boundary_state.all():
|
| 574 |
+
assert h.shape[1] == 1
|
| 575 |
+
reduced_h = h
|
| 576 |
+
else:
|
| 577 |
+
reduced_h = h[[], :, :]
|
| 578 |
+
else:
|
| 579 |
+
assert boundary_mask is not None
|
| 580 |
+
|
| 581 |
+
L = h.shape[1]
|
| 582 |
+
token_idx = (
|
| 583 |
+
torch.arange(L, device=h.device)[None, :] + (~boundary_mask).long() * L # type: ignore
|
| 584 |
+
)
|
| 585 |
+
seq_sorted_indices = torch.argsort(token_idx, dim=1)
|
| 586 |
+
index = seq_sorted_indices[:, :n_patches, None].expand(
|
| 587 |
+
-1, -1, h.shape[-1]
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
+
reduced_h = torch.gather(
|
| 591 |
+
h,
|
| 592 |
+
dim=1,
|
| 593 |
+
index=index,
|
| 594 |
+
)
|
| 595 |
+
|
| 596 |
+
return reduced_h
|
| 597 |
+
|
| 598 |
+
def forward(
|
| 599 |
+
self,
|
| 600 |
+
input_ids,
|
| 601 |
+
true_boundary_mask: Optional[torch.Tensor] = None,
|
| 602 |
+
boundary_state: Optional[MaskState] = None,
|
| 603 |
+
pad_state: Optional[MaskState] = None,
|
| 604 |
+
expanded_input_ids: Optional[torch.Tensor] = None,
|
| 605 |
+
sequence_start_indices: Optional[torch.Tensor] = None,
|
| 606 |
+
):
|
| 607 |
+
embeddings = self._embed(input_ids, expanded_input_ids)
|
| 608 |
+
|
| 609 |
+
# pass through encoder layers
|
| 610 |
+
if self.has_cache and self.cache_seqlens > 0:
|
| 611 |
+
assert pad_state is not None
|
| 612 |
+
|
| 613 |
+
# step those batch positions which are not currently idle (i.e. at a boundary position)
|
| 614 |
+
# if all batch positions are idle, skip the step entirely
|
| 615 |
+
# all positions being idle only happens if fuse_boundaries=False. In this case, the step where we
|
| 616 |
+
# obtain a new representation from the global model will have all positions for the local encoder being idle.
|
| 617 |
+
if not pad_state.all():
|
| 618 |
+
h = pad_state.selective_get(embeddings, inv=True)
|
| 619 |
+
|
| 620 |
+
for i, block in enumerate(self.layers):
|
| 621 |
+
h = block(h, past_key_values=self.layer_states[i], use_cache=True, cache_mask=pad_state)
|
| 622 |
+
|
| 623 |
+
if self.post_last_block_norm is not None:
|
| 624 |
+
h = self.post_last_block_norm(h)
|
| 625 |
+
|
| 626 |
+
pad_state.selective_put(h[:, -1, :], self.last_h, inv=True)
|
| 627 |
+
|
| 628 |
+
h = self.last_h.unsqueeze(1)
|
| 629 |
+
else:
|
| 630 |
+
h = embeddings
|
| 631 |
+
for i, block in enumerate(self.layers):
|
| 632 |
+
if self.has_cache:
|
| 633 |
+
use_cache = True
|
| 634 |
+
past_key_values = self.layer_states[i]
|
| 635 |
+
else:
|
| 636 |
+
use_cache = False
|
| 637 |
+
past_key_values = None
|
| 638 |
+
|
| 639 |
+
h = block(h, past_key_values=past_key_values, use_cache=use_cache, sequence_start_indices=sequence_start_indices)
|
| 640 |
+
|
| 641 |
+
if self.post_last_block_norm is not None:
|
| 642 |
+
h = self.post_last_block_norm(h)
|
| 643 |
+
|
| 644 |
+
if self.has_cache:
|
| 645 |
+
self.last_h.copy_(h[:, -1, :])
|
| 646 |
+
|
| 647 |
+
if not self.has_cache or self.cache_seqlens == 0: # only used for prefill
|
| 648 |
+
boundary_logprobs, boundary_mask = self.boundary_predictor_module(
|
| 649 |
+
h,
|
| 650 |
+
sequence_start_indices=sequence_start_indices,
|
| 651 |
+
)
|
| 652 |
+
if boundary_state is not None:
|
| 653 |
+
# can't predict through encoder - must be through prev local decoder step
|
| 654 |
+
boundary_mask[:, -1] = boundary_state.mask
|
| 655 |
+
else:
|
| 656 |
+
boundary_logprobs = boundary_mask = None
|
| 657 |
+
|
| 658 |
+
# overwrite with true boundaries
|
| 659 |
+
if true_boundary_mask is not None:
|
| 660 |
+
boundary_mask = true_boundary_mask
|
| 661 |
+
|
| 662 |
+
patch_embeddings = self._pool(
|
| 663 |
+
h=h,
|
| 664 |
+
boundary_mask=boundary_mask,
|
| 665 |
+
n_patches=int(cast(torch.Tensor, boundary_mask).sum(-1).max().item()) if boundary_mask is not None else 1,
|
| 666 |
+
boundary_state=boundary_state,
|
| 667 |
+
)
|
| 668 |
+
patch_embeddings = self.out_projection(patch_embeddings)
|
| 669 |
+
|
| 670 |
+
if self.has_cache:
|
| 671 |
+
self.cache_seqlens += input_ids.shape[1]
|
| 672 |
+
|
| 673 |
+
return h, patch_embeddings, boundary_logprobs, boundary_mask
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
class BolmoLocalDecoder(nn.Module):
|
| 677 |
+
def __init__(self, config: BolmoConfig):
|
| 678 |
+
super().__init__()
|
| 679 |
+
self.config = config
|
| 680 |
+
self.hidden_size = config.hidden_size
|
| 681 |
+
|
| 682 |
+
self.initial_norm = BolmoRMSNorm(
|
| 683 |
+
self.hidden_size,
|
| 684 |
+
eps=config.local_rms_norm_eps,
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
self.in_projection = nn.Linear(
|
| 688 |
+
self.hidden_size,
|
| 689 |
+
self.hidden_size,
|
| 690 |
+
bias=True,
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
self.layers = nn.ModuleList(
|
| 694 |
+
[BolmoLocalLayer(config) for _ in range(config.num_local_decoder_layers)]
|
| 695 |
+
)
|
| 696 |
+
|
| 697 |
+
self.has_cache = False
|
| 698 |
+
|
| 699 |
+
def prepare_inference_cache(self, batch_size: int):
|
| 700 |
+
device = next(self.parameters()).device
|
| 701 |
+
self.has_cache = True
|
| 702 |
+
|
| 703 |
+
self.cache_seqlens = 0
|
| 704 |
+
self.last_value = torch.zeros((batch_size, self.hidden_size), dtype=self.in_projection.weight.dtype, device=device)
|
| 705 |
+
self.layer_states = [{"xlstm": {}} for _ in range(len(self.layers))]
|
| 706 |
+
|
| 707 |
+
def free_inference_cache(self):
|
| 708 |
+
self.has_cache = False
|
| 709 |
+
if hasattr(self, "cache_seqlens"):
|
| 710 |
+
del self.cache_seqlens
|
| 711 |
+
if hasattr(self, "last_value"):
|
| 712 |
+
del self.last_value
|
| 713 |
+
if hasattr(self, "layer_states"):
|
| 714 |
+
del self.layer_states
|
| 715 |
+
|
| 716 |
+
def _depool(
|
| 717 |
+
self,
|
| 718 |
+
embeds: torch.Tensor,
|
| 719 |
+
patch_embeds: torch.Tensor,
|
| 720 |
+
boundary_mask: Optional[torch.Tensor],
|
| 721 |
+
boundary_state: Optional[MaskState] = None,
|
| 722 |
+
sequence_start_indices: Optional[torch.Tensor] = None,
|
| 723 |
+
) -> torch.Tensor:
|
| 724 |
+
if self.has_cache and self.cache_seqlens > 0:
|
| 725 |
+
assert boundary_state is not None
|
| 726 |
+
|
| 727 |
+
if patch_embeds.numel() > 0:
|
| 728 |
+
# we got a new value from the global model, so must be at boundary position
|
| 729 |
+
h_patch = patch_embeds[:, -1:, :]
|
| 730 |
+
h = embeds + h_patch
|
| 731 |
+
|
| 732 |
+
self.last_value.copy_(h_patch[:, -1])
|
| 733 |
+
else:
|
| 734 |
+
h = embeds + self.last_value.unsqueeze(1)
|
| 735 |
+
|
| 736 |
+
# skip pad positions until we get a new value from the global model
|
| 737 |
+
if patch_embeds.numel() == 0:
|
| 738 |
+
h = boundary_state.selective_get(h, inv=True)
|
| 739 |
+
else:
|
| 740 |
+
boundary_state = None
|
| 741 |
+
|
| 742 |
+
if h.shape[0] > 0:
|
| 743 |
+
for i, layer in enumerate(self.layers):
|
| 744 |
+
h = layer(h, past_key_values=self.layer_states[i], use_cache=True, cache_mask=boundary_state)
|
| 745 |
+
|
| 746 |
+
self.cache_seqlens += h.shape[1]
|
| 747 |
+
|
| 748 |
+
return h
|
| 749 |
+
else:
|
| 750 |
+
assert boundary_mask is not None
|
| 751 |
+
|
| 752 |
+
h_patch = patch_embeds
|
| 753 |
+
prepool_out = h_patch
|
| 754 |
+
|
| 755 |
+
# TODO(benjaminm): clipping is problematic if it happens too much; track clip %.
|
| 756 |
+
plug_back_idx = (torch.cumsum(boundary_mask, dim=1) - 1).clip(min=0, max=prepool_out.shape[1] - 1)
|
| 757 |
+
depool_out = torch.gather(
|
| 758 |
+
prepool_out,
|
| 759 |
+
dim=1,
|
| 760 |
+
index=plug_back_idx.unsqueeze(-1).expand(-1, -1, self.hidden_size),
|
| 761 |
+
)
|
| 762 |
+
|
| 763 |
+
depool_out_modulated = depool_out
|
| 764 |
+
h = depool_out_modulated + embeds
|
| 765 |
+
|
| 766 |
+
for i, layer in enumerate(self.layers):
|
| 767 |
+
if self.has_cache:
|
| 768 |
+
use_cache = True
|
| 769 |
+
past_key_values = self.layer_states[i]
|
| 770 |
+
else:
|
| 771 |
+
use_cache = False
|
| 772 |
+
past_key_values = None
|
| 773 |
+
|
| 774 |
+
h = layer(h, past_key_values=past_key_values, use_cache=use_cache, sequence_start_indices=sequence_start_indices)
|
| 775 |
+
|
| 776 |
+
if self.has_cache:
|
| 777 |
+
self.last_value.copy_(prepool_out[:, -1])
|
| 778 |
+
self.cache_seqlens += h.shape[1]
|
| 779 |
+
|
| 780 |
+
return h
|
| 781 |
+
|
| 782 |
+
def forward(
|
| 783 |
+
self,
|
| 784 |
+
embeds: torch.Tensor,
|
| 785 |
+
patch_embeds: torch.Tensor,
|
| 786 |
+
boundary_state: Optional[MaskState],
|
| 787 |
+
boundary_mask: torch.Tensor | None,
|
| 788 |
+
sequence_start_indices: Optional[torch.Tensor] = None,
|
| 789 |
+
) -> torch.Tensor:
|
| 790 |
+
h = self.in_projection(embeds)
|
| 791 |
+
h_patch = self.initial_norm(patch_embeds)
|
| 792 |
+
|
| 793 |
+
return self._depool(
|
| 794 |
+
embeds=h,
|
| 795 |
+
patch_embeds=h_patch,
|
| 796 |
+
boundary_mask=boundary_mask,
|
| 797 |
+
boundary_state=boundary_state,
|
| 798 |
+
sequence_start_indices=sequence_start_indices,
|
| 799 |
+
)
|
| 800 |
+
|
| 801 |
+
|
| 802 |
+
class BolmoRotaryEmbedding(nn.Module):
|
| 803 |
+
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
| 804 |
+
|
| 805 |
+
def __init__(self, config: BolmoConfig, device=None, rope_type: Optional[str] = None):
|
| 806 |
+
super().__init__()
|
| 807 |
+
if rope_type is not None:
|
| 808 |
+
self.rope_type = rope_type
|
| 809 |
+
elif hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
|
| 810 |
+
# BC: "rope_type" was originally "type"
|
| 811 |
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 812 |
+
else:
|
| 813 |
+
self.rope_type = "default"
|
| 814 |
+
assert self.rope_type is not None
|
| 815 |
+
|
| 816 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 817 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 818 |
+
|
| 819 |
+
self.config = config
|
| 820 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 821 |
+
|
| 822 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 823 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 824 |
+
self.original_inv_freq = self.inv_freq
|
| 825 |
+
|
| 826 |
+
@torch.no_grad()
|
| 827 |
+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 828 |
+
def forward(self, x, position_ids):
|
| 829 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 830 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 831 |
+
|
| 832 |
+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 833 |
+
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
| 834 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 835 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 836 |
+
cos = emb.cos() * self.attention_scaling
|
| 837 |
+
sin = emb.sin() * self.attention_scaling
|
| 838 |
+
return cos, sin
|
| 839 |
+
|
| 840 |
+
|
| 841 |
+
class BolmoPreTrainedModel(PreTrainedModel):
|
| 842 |
+
config: BolmoConfig
|
| 843 |
+
base_model_prefix = "model"
|
| 844 |
+
supports_gradient_checkpointing = True
|
| 845 |
+
_no_split_modules = ["BolmoDecoderLayer"]
|
| 846 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 847 |
+
_supports_flash_attn = True
|
| 848 |
+
_supports_sdpa = True
|
| 849 |
+
_supports_flex_attn = True
|
| 850 |
+
|
| 851 |
+
_can_compile_fullgraph = True
|
| 852 |
+
_supports_attention_backend = True
|
| 853 |
+
_can_record_outputs = {
|
| 854 |
+
"hidden_states": BolmoDecoderLayer,
|
| 855 |
+
"attentions": BolmoAttention,
|
| 856 |
+
}
|
| 857 |
+
|
| 858 |
+
|
| 859 |
+
class BolmoModel(BolmoPreTrainedModel):
|
| 860 |
+
def __init__(self, config: BolmoConfig):
|
| 861 |
+
super().__init__(config)
|
| 862 |
+
self.padding_idx = config.pad_token_id
|
| 863 |
+
self.vocab_size = config.vocab_size
|
| 864 |
+
|
| 865 |
+
self.local_encoder = BolmoLocalEncoder(config)
|
| 866 |
+
self.local_decoder = BolmoLocalDecoder(config)
|
| 867 |
+
|
| 868 |
+
self.layers = nn.ModuleList(
|
| 869 |
+
[BolmoDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 870 |
+
)
|
| 871 |
+
self.norm = BolmoRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 872 |
+
self.gradient_checkpointing = False
|
| 873 |
+
self.rotary_embs = nn.ModuleDict(
|
| 874 |
+
{
|
| 875 |
+
"sliding_attention": BolmoRotaryEmbedding(config=config, rope_type="default"),
|
| 876 |
+
"full_attention": BolmoRotaryEmbedding(config=config),
|
| 877 |
+
}
|
| 878 |
+
)
|
| 879 |
+
|
| 880 |
+
self.tokenizer_config = BolmoTokenizerConfig(**config.tokenizer_config)
|
| 881 |
+
self._tokenizer = None
|
| 882 |
+
|
| 883 |
+
# Initialize weights and apply final processing
|
| 884 |
+
self.post_init()
|
| 885 |
+
|
| 886 |
+
def get_input_embeddings(self):
|
| 887 |
+
return self.local_encoder.byte_embedding
|
| 888 |
+
|
| 889 |
+
def set_input_embeddings(self, value: nn.Embedding): # type: ignore
|
| 890 |
+
self.local_encoder.byte_embedding = value
|
| 891 |
+
|
| 892 |
+
@property
|
| 893 |
+
def tokenizer(self):
|
| 894 |
+
if self._tokenizer is None:
|
| 895 |
+
self._tokenizer = self.tokenizer_config.build()
|
| 896 |
+
|
| 897 |
+
return self._tokenizer
|
| 898 |
+
|
| 899 |
+
def prefill_boundary_prediction_forward(
|
| 900 |
+
self,
|
| 901 |
+
input_ids: torch.Tensor,
|
| 902 |
+
expanded_input_ids: Optional[torch.Tensor] = None,
|
| 903 |
+
sequence_start_indices: Optional[torch.Tensor] = None,
|
| 904 |
+
last_token_is_boundary: bool = False,
|
| 905 |
+
**kwargs,
|
| 906 |
+
) -> torch.Tensor:
|
| 907 |
+
_, _, _, boundary_mask = self.local_encoder.forward( # type: ignore
|
| 908 |
+
input_ids,
|
| 909 |
+
expanded_input_ids=expanded_input_ids,
|
| 910 |
+
boundary_state=MaskState(torch.full((input_ids.shape[0],), fill_value=last_token_is_boundary, device=input_ids.device, dtype=torch.bool)),
|
| 911 |
+
pad_state=MaskState(torch.zeros((input_ids.shape[0],), device=input_ids.device, dtype=torch.bool)),
|
| 912 |
+
sequence_start_indices=sequence_start_indices,
|
| 913 |
+
)
|
| 914 |
+
|
| 915 |
+
return cast(torch.Tensor, boundary_mask)
|
| 916 |
+
|
| 917 |
+
@check_model_inputs()
|
| 918 |
+
def forward(
|
| 919 |
+
self,
|
| 920 |
+
input_ids: torch.Tensor,
|
| 921 |
+
expanded_input_ids: Optional[torch.Tensor] = None,
|
| 922 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 923 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 924 |
+
past_key_values: Optional[Cache] = None,
|
| 925 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 926 |
+
use_cache: Optional[bool] = None,
|
| 927 |
+
boundary_mask: Optional[torch.Tensor] = None,
|
| 928 |
+
boundary_state: Optional[MaskState] = None,
|
| 929 |
+
pad_state: Optional[MaskState] = None,
|
| 930 |
+
sequence_start_indices: Optional[torch.Tensor] = None,
|
| 931 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 932 |
+
) -> BaseModelOutputWithPast:
|
| 933 |
+
batch_size = input_ids.shape[0]
|
| 934 |
+
device = input_ids.device
|
| 935 |
+
|
| 936 |
+
if self.local_encoder.add_expanded_embeddings and expanded_input_ids is None and input_ids is not None:
|
| 937 |
+
# not optimized
|
| 938 |
+
expanded_input_ids_list: list[torch.Tensor] = []
|
| 939 |
+
for example_idx in range(batch_size):
|
| 940 |
+
expanded_input_ids_list.append(torch.tensor(self.tokenizer.expand_byte_ids(input_ids[example_idx].tolist()), dtype=torch.long, device=device))
|
| 941 |
+
expanded_input_ids = pad_right(expanded_input_ids_list, value=self.tokenizer.pad_token_id, multiple_of=1) # type: ignore
|
| 942 |
+
|
| 943 |
+
h_byte, h_patch, _, boundary_mask = self.local_encoder(
|
| 944 |
+
input_ids=input_ids,
|
| 945 |
+
expanded_input_ids=expanded_input_ids,
|
| 946 |
+
true_boundary_mask=boundary_mask,
|
| 947 |
+
boundary_state=boundary_state,
|
| 948 |
+
pad_state=pad_state,
|
| 949 |
+
)
|
| 950 |
+
|
| 951 |
+
if use_cache and past_key_values is None:
|
| 952 |
+
past_key_values = DynamicCache(config=self.config)
|
| 953 |
+
|
| 954 |
+
if cache_position is None:
|
| 955 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 956 |
+
cache_position: torch.Tensor = torch.arange(
|
| 957 |
+
past_seen_tokens, past_seen_tokens + h_patch.shape[1], device=device
|
| 958 |
+
)
|
| 959 |
+
|
| 960 |
+
if position_ids is None:
|
| 961 |
+
position_ids = cache_position.unsqueeze(0) # type: ignore
|
| 962 |
+
|
| 963 |
+
# It may already have been prepared by e.g. `generate`
|
| 964 |
+
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
| 965 |
+
# Prepare mask arguments
|
| 966 |
+
mask_kwargs = {
|
| 967 |
+
"config": self.config,
|
| 968 |
+
"input_embeds": h_patch,
|
| 969 |
+
"attention_mask": attention_mask,
|
| 970 |
+
"cache_position": cache_position,
|
| 971 |
+
"past_key_values": past_key_values,
|
| 972 |
+
"position_ids": position_ids,
|
| 973 |
+
}
|
| 974 |
+
# Create the masks
|
| 975 |
+
causal_mask_mapping = {
|
| 976 |
+
"full_attention": create_causal_mask(**mask_kwargs),
|
| 977 |
+
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
|
| 978 |
+
}
|
| 979 |
+
|
| 980 |
+
position_embeddings_mapping = {
|
| 981 |
+
"sliding_attention": self.rotary_embs["sliding_attention"](h_byte, position_ids),
|
| 982 |
+
"full_attention": self.rotary_embs["full_attention"](h_byte, position_ids),
|
| 983 |
+
}
|
| 984 |
+
|
| 985 |
+
if h_patch.numel() > 0:
|
| 986 |
+
# we need to convert from right-pad to left-pad and back for prefill
|
| 987 |
+
# since flash attention expects left-pad and local/enc dec expect right-pad global tokens
|
| 988 |
+
# should add better left-pad support but this only affects prefill so OK for now
|
| 989 |
+
# although super inefficient!
|
| 990 |
+
if boundary_mask is not None: # prefill
|
| 991 |
+
n_boundaries = boundary_mask.sum(-1)
|
| 992 |
+
|
| 993 |
+
for i, current_n_boundaries in enumerate(n_boundaries):
|
| 994 |
+
h_patch[i, -current_n_boundaries:] = h_patch[i, :current_n_boundaries].clone()
|
| 995 |
+
|
| 996 |
+
h_patch_after_global = h_patch
|
| 997 |
+
|
| 998 |
+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
| 999 |
+
h_patch_after_global = decoder_layer(
|
| 1000 |
+
h_patch_after_global,
|
| 1001 |
+
attention_mask=causal_mask_mapping[decoder_layer.self_attn.attention_type],
|
| 1002 |
+
position_ids=position_ids,
|
| 1003 |
+
past_key_values=past_key_values,
|
| 1004 |
+
cache_position=cache_position,
|
| 1005 |
+
position_embeddings=position_embeddings_mapping[decoder_layer.self_attn.attention_type],
|
| 1006 |
+
**kwargs,
|
| 1007 |
+
)
|
| 1008 |
+
|
| 1009 |
+
if boundary_mask is not None: # prefill
|
| 1010 |
+
n_boundaries = boundary_mask.sum(-1)
|
| 1011 |
+
|
| 1012 |
+
for i, current_n_boundaries in enumerate(n_boundaries):
|
| 1013 |
+
h_patch_after_global[i, :current_n_boundaries] = h_patch_after_global[i, -current_n_boundaries:].clone()
|
| 1014 |
+
else:
|
| 1015 |
+
h_patch_after_global = h_patch
|
| 1016 |
+
|
| 1017 |
+
h_out = self.local_decoder.forward( # type: ignore
|
| 1018 |
+
embeds=h_byte,
|
| 1019 |
+
patch_embeds=h_patch_after_global,
|
| 1020 |
+
boundary_mask=boundary_mask,
|
| 1021 |
+
boundary_state=boundary_state,
|
| 1022 |
+
sequence_start_indices=sequence_start_indices,
|
| 1023 |
+
)
|
| 1024 |
+
h_out = self.norm(h_out)
|
| 1025 |
+
|
| 1026 |
+
return BaseModelOutputWithPast(
|
| 1027 |
+
last_hidden_state=h_out,
|
| 1028 |
+
past_key_values=past_key_values,
|
| 1029 |
+
)
|
| 1030 |
+
|
| 1031 |
+
|
| 1032 |
+
class BolmoForCausalLM(BolmoPreTrainedModel, GenerationMixin):
|
| 1033 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 1034 |
+
_tp_plan = {"lm_head": "colwise_rep"}
|
| 1035 |
+
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 1036 |
+
|
| 1037 |
+
def __init__(self, config):
|
| 1038 |
+
super().__init__(config)
|
| 1039 |
+
self.model = BolmoModel(config)
|
| 1040 |
+
self.vocab_size = config.vocab_size
|
| 1041 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 1042 |
+
|
| 1043 |
+
# Initialize weights and apply final processing
|
| 1044 |
+
self.post_init()
|
| 1045 |
+
|
| 1046 |
+
def get_output_embeddings(self):
|
| 1047 |
+
return self.lm_head
|
| 1048 |
+
|
| 1049 |
+
def set_output_embeddings(self, new_embeddings: nn.Linear):
|
| 1050 |
+
self.lm_head = new_embeddings
|
| 1051 |
+
|
| 1052 |
+
@can_return_tuple
|
| 1053 |
+
def forward(
|
| 1054 |
+
self,
|
| 1055 |
+
input_ids: torch.Tensor,
|
| 1056 |
+
expanded_input_ids: Optional[torch.Tensor] = None,
|
| 1057 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1058 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1059 |
+
past_key_values: Optional[Cache] = None,
|
| 1060 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1061 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 1062 |
+
use_cache: Optional[bool] = None,
|
| 1063 |
+
boundary_mask: Optional[torch.Tensor] = None,
|
| 1064 |
+
boundary_state: Optional[MaskState] = None,
|
| 1065 |
+
pad_state: Optional[MaskState] = None,
|
| 1066 |
+
sequence_start_indices: Optional[torch.Tensor] = None,
|
| 1067 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 1068 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 1069 |
+
) -> CausalLMOutputWithPast:
|
| 1070 |
+
r"""
|
| 1071 |
+
Example:
|
| 1072 |
+
|
| 1073 |
+
```python
|
| 1074 |
+
>>> from transformers import AutoTokenizer, BolmoForCausalLM
|
| 1075 |
+
|
| 1076 |
+
>>> model = BolmoForCausalLM.from_pretrained("meta-olmo3/Bolmo-2-7b-hf")
|
| 1077 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("meta-olmo3/Bolmo-2-7b-hf")
|
| 1078 |
+
|
| 1079 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
| 1080 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 1081 |
+
|
| 1082 |
+
>>> # Generate
|
| 1083 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 1084 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 1085 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
| 1086 |
+
```"""
|
| 1087 |
+
outputs: BaseModelOutputWithPast = self.model(
|
| 1088 |
+
input_ids=input_ids,
|
| 1089 |
+
expanded_input_ids=expanded_input_ids,
|
| 1090 |
+
attention_mask=attention_mask,
|
| 1091 |
+
position_ids=position_ids,
|
| 1092 |
+
past_key_values=past_key_values,
|
| 1093 |
+
inputs_embeds=inputs_embeds,
|
| 1094 |
+
cache_position=cache_position,
|
| 1095 |
+
use_cache=use_cache,
|
| 1096 |
+
boundary_mask=boundary_mask,
|
| 1097 |
+
boundary_state=boundary_state,
|
| 1098 |
+
pad_state=pad_state,
|
| 1099 |
+
sequence_start_indices=sequence_start_indices,
|
| 1100 |
+
**kwargs,
|
| 1101 |
+
)
|
| 1102 |
+
|
| 1103 |
+
hidden_states = cast(torch.Tensor, outputs.last_hidden_state)
|
| 1104 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 1105 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 1106 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 1107 |
+
|
| 1108 |
+
return CausalLMOutputWithPast(
|
| 1109 |
+
logits=logits,
|
| 1110 |
+
past_key_values=outputs.past_key_values,
|
| 1111 |
+
hidden_states=outputs.hidden_states,
|
| 1112 |
+
attentions=outputs.attentions,
|
| 1113 |
+
)
|
| 1114 |
+
|
| 1115 |
+
@torch.no_grad()
|
| 1116 |
+
def generate( # type: ignore
|
| 1117 |
+
self,
|
| 1118 |
+
inputs: torch.Tensor,
|
| 1119 |
+
generation_config: Optional[GenerationConfig] = None,
|
| 1120 |
+
logits_processor: Optional[LogitsProcessorList] = None,
|
| 1121 |
+
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
| 1122 |
+
use_model_defaults: Optional[bool] = None,
|
| 1123 |
+
**kwargs,
|
| 1124 |
+
) -> Union[GenerateOutput, torch.Tensor]:
|
| 1125 |
+
# generic preprocessing
|
| 1126 |
+
|
| 1127 |
+
generation_config, model_kwargs = self._prepare_generation_config(
|
| 1128 |
+
generation_config, use_model_defaults, **kwargs
|
| 1129 |
+
)
|
| 1130 |
+
self._prepare_special_tokens(generation_config, device=self.model.device)
|
| 1131 |
+
|
| 1132 |
+
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
| 1133 |
+
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
| 1134 |
+
|
| 1135 |
+
# start of custom generate
|
| 1136 |
+
|
| 1137 |
+
expand_input_ids = self.model.local_encoder.add_expanded_embeddings
|
| 1138 |
+
batch_size = len(inputs)
|
| 1139 |
+
|
| 1140 |
+
if expand_input_ids:
|
| 1141 |
+
expanded_input_ids = []
|
| 1142 |
+
|
| 1143 |
+
for i in range(len(inputs)):
|
| 1144 |
+
expanded_input_ids.append(torch.tensor(self.model.tokenizer.expand_byte_ids(inputs[i].tolist()), device=self.device, dtype=torch.long))
|
| 1145 |
+
|
| 1146 |
+
expanded_input_ids = pad_left(expanded_input_ids, value=self.model.tokenizer.pad_token_id, multiple_of=1) # type: ignore
|
| 1147 |
+
else:
|
| 1148 |
+
expanded_input_ids = None
|
| 1149 |
+
|
| 1150 |
+
byte_input_ids = inputs
|
| 1151 |
+
sequence_start_indices = (byte_input_ids == self.model.tokenizer.pad_token_id).sum(-1)
|
| 1152 |
+
batch_size, prompt_len = byte_input_ids.shape
|
| 1153 |
+
finished = torch.zeros(batch_size, dtype=torch.bool, device=self.device)
|
| 1154 |
+
|
| 1155 |
+
boundary_offset = self.model.tokenizer.offset + 256
|
| 1156 |
+
eos = self.model.tokenizer.eos_token_id
|
| 1157 |
+
|
| 1158 |
+
self.model.local_encoder.free_inference_cache()
|
| 1159 |
+
self.model.local_decoder.free_inference_cache()
|
| 1160 |
+
|
| 1161 |
+
boundary_mask = self.model.prefill_boundary_prediction_forward( # type: ignore
|
| 1162 |
+
byte_input_ids,
|
| 1163 |
+
expanded_input_ids=expanded_input_ids,
|
| 1164 |
+
sequence_start_indices=sequence_start_indices,
|
| 1165 |
+
)
|
| 1166 |
+
|
| 1167 |
+
self.model.local_encoder.prepare_inference_cache(batch_size)
|
| 1168 |
+
self.model.local_decoder.prepare_inference_cache(batch_size)
|
| 1169 |
+
|
| 1170 |
+
# roll back by one and force decoding to account for lookahead
|
| 1171 |
+
boundary_mask = boundary_mask[:, :-1]
|
| 1172 |
+
# need to roll one byte back and force decoding to detect whether the last byte is a boundary
|
| 1173 |
+
forced_decoding_ids = byte_input_ids[:, -1].cpu().tolist()
|
| 1174 |
+
byte_input_ids = byte_input_ids[:, :-1]
|
| 1175 |
+
expanded_input_ids = expanded_input_ids[:, :-1] if expanded_input_ids is not None else None
|
| 1176 |
+
# stays the same unless last token is pad.
|
| 1177 |
+
sequence_start_indices = (byte_input_ids == self.model.tokenizer.pad_token_id).sum(-1)
|
| 1178 |
+
|
| 1179 |
+
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
| 1180 |
+
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
|
| 1181 |
+
generation_config = self._prepare_generated_length(
|
| 1182 |
+
generation_config=generation_config,
|
| 1183 |
+
has_default_max_length=has_default_max_length,
|
| 1184 |
+
has_default_min_length=has_default_min_length,
|
| 1185 |
+
model_input_name="input_ids",
|
| 1186 |
+
inputs_tensor=byte_input_ids,
|
| 1187 |
+
input_ids_length=byte_input_ids.shape[1],
|
| 1188 |
+
)
|
| 1189 |
+
|
| 1190 |
+
logits_processor = self._get_logits_processor(
|
| 1191 |
+
generation_config=generation_config, # type: ignore
|
| 1192 |
+
input_ids_seq_length=byte_input_ids.shape[1],
|
| 1193 |
+
encoder_input_ids=byte_input_ids, # type: ignore
|
| 1194 |
+
logits_processor=logits_processor,
|
| 1195 |
+
device=byte_input_ids.device, # type: ignore
|
| 1196 |
+
model_kwargs=model_kwargs,
|
| 1197 |
+
)
|
| 1198 |
+
stopping_criteria = self._get_stopping_criteria(
|
| 1199 |
+
generation_config=generation_config, # type: ignore
|
| 1200 |
+
stopping_criteria=stopping_criteria,
|
| 1201 |
+
tokenizer=self.model.tokenizer,
|
| 1202 |
+
)
|
| 1203 |
+
|
| 1204 |
+
# output container
|
| 1205 |
+
generated = byte_input_ids
|
| 1206 |
+
|
| 1207 |
+
max_n_prefill_patches = boundary_mask.sum(-1).max().item()
|
| 1208 |
+
tokens_generated_plus_prefilled = max_n_prefill_patches
|
| 1209 |
+
bytes_generated = 0
|
| 1210 |
+
|
| 1211 |
+
# generation state
|
| 1212 |
+
boundary_state = MaskState(boundary_mask[:, -1].clone())
|
| 1213 |
+
pad_state = MaskState(torch.zeros(batch_size, dtype=torch.bool, device=self.device))
|
| 1214 |
+
next_tokens = torch.full((batch_size,), self.model.tokenizer.bpe_token_end_id, device=self.device, dtype=torch.long) # type: ignore
|
| 1215 |
+
non_boundary_generated_tokens = [[byte_input_ids[example_idx, -1].item()] for example_idx in range(batch_size)]
|
| 1216 |
+
bytes_since_boundary = (boundary_mask.flip(1).cumsum(-1) == 0).sum(-1)
|
| 1217 |
+
is_first_forward = True
|
| 1218 |
+
global_past_key_values = None
|
| 1219 |
+
|
| 1220 |
+
while not finished.all():
|
| 1221 |
+
input_ids_for_model = (
|
| 1222 |
+
generated
|
| 1223 |
+
if is_first_forward
|
| 1224 |
+
else torch.tensor([x[-1] for x in non_boundary_generated_tokens], device=generated.device, dtype=generated.dtype).unsqueeze(1)
|
| 1225 |
+
)
|
| 1226 |
+
assert not (
|
| 1227 |
+
(input_ids_for_model == self.model.tokenizer.bpe_token_end_id) |
|
| 1228 |
+
(input_ids_for_model >= boundary_offset)
|
| 1229 |
+
).any().item() # type: ignore
|
| 1230 |
+
if expand_input_ids:
|
| 1231 |
+
expanded_input_ids_for_model = torch.zeros_like(input_ids_for_model)
|
| 1232 |
+
for i in range(input_ids_for_model.shape[0]):
|
| 1233 |
+
expanded_input_ids_for_model[i, :] = torch.tensor(self.model.tokenizer.expand_byte_ids(
|
| 1234 |
+
generated[i, :].tolist(),
|
| 1235 |
+
n_last=input_ids_for_model.shape[1],
|
| 1236 |
+
), device=expanded_input_ids_for_model.device, dtype=expanded_input_ids_for_model.dtype)
|
| 1237 |
+
else:
|
| 1238 |
+
expanded_input_ids_for_model = None
|
| 1239 |
+
|
| 1240 |
+
out = self.forward( # type: ignore
|
| 1241 |
+
input_ids_for_model,
|
| 1242 |
+
expanded_input_ids=expanded_input_ids_for_model,
|
| 1243 |
+
boundary_mask=boundary_mask if is_first_forward else None,
|
| 1244 |
+
boundary_state=boundary_state,
|
| 1245 |
+
pad_state=pad_state,
|
| 1246 |
+
sequence_start_indices=sequence_start_indices,
|
| 1247 |
+
logits_to_keep=1,
|
| 1248 |
+
use_cache=True,
|
| 1249 |
+
past_key_values=global_past_key_values,
|
| 1250 |
+
)
|
| 1251 |
+
next_token_logits = cast(torch.Tensor, out.logits)
|
| 1252 |
+
global_past_key_values = out.past_key_values
|
| 1253 |
+
|
| 1254 |
+
if boundary_state.all():
|
| 1255 |
+
# new token, must not be boundary
|
| 1256 |
+
bytes_since_boundary[:] = 0
|
| 1257 |
+
else:
|
| 1258 |
+
boundary_state.selective_add(1, bytes_since_boundary, inv=True)
|
| 1259 |
+
|
| 1260 |
+
if any(x is not None for x in forced_decoding_ids):
|
| 1261 |
+
# only supported for the first token atm, so len(next_token_logits) == batch_size
|
| 1262 |
+
assert len(next_token_logits) == batch_size and is_first_forward
|
| 1263 |
+
for example_idx in range(batch_size):
|
| 1264 |
+
forced_decoding_id = forced_decoding_ids[example_idx]
|
| 1265 |
+
|
| 1266 |
+
if forced_decoding_id is not None:
|
| 1267 |
+
no_boundary_logit = next_token_logits[example_idx, 0, forced_decoding_id].item()
|
| 1268 |
+
boundary_logit = next_token_logits[example_idx, 0, forced_decoding_id + boundary_offset].item()
|
| 1269 |
+
|
| 1270 |
+
next_token_logits[example_idx, 0, :] = -100_000
|
| 1271 |
+
next_token_logits[example_idx, 0, forced_decoding_id] = no_boundary_logit
|
| 1272 |
+
next_token_logits[example_idx, 0, forced_decoding_id + boundary_offset] = boundary_logit
|
| 1273 |
+
|
| 1274 |
+
forced_decoding_ids[example_idx] = None # only force once
|
| 1275 |
+
|
| 1276 |
+
# passing input_ids to logit processor not implemented
|
| 1277 |
+
next_token_scores = logits_processor(None, next_token_logits[:, -1]) # type: ignore
|
| 1278 |
+
|
| 1279 |
+
if generation_config is not None and generation_config.do_sample:
|
| 1280 |
+
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
| 1281 |
+
new_next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
| 1282 |
+
else:
|
| 1283 |
+
new_next_tokens = torch.argmax(next_token_scores, dim=-1)
|
| 1284 |
+
|
| 1285 |
+
if boundary_state.all() or is_first_forward:
|
| 1286 |
+
tokens_generated_plus_prefilled += 1
|
| 1287 |
+
|
| 1288 |
+
next_tokens = new_next_tokens
|
| 1289 |
+
next_tokens_cpu = next_tokens.cpu()
|
| 1290 |
+
for example_idx in range(batch_size):
|
| 1291 |
+
if finished[example_idx].item():
|
| 1292 |
+
continue
|
| 1293 |
+
|
| 1294 |
+
next_token_cpu = next_tokens_cpu[example_idx].item()
|
| 1295 |
+
|
| 1296 |
+
if next_token_cpu >= boundary_offset:
|
| 1297 |
+
next_token_cpu -= boundary_offset
|
| 1298 |
+
|
| 1299 |
+
non_boundary_generated_tokens[example_idx].append(next_token_cpu)
|
| 1300 |
+
else:
|
| 1301 |
+
next_tokens[:] = self.model.tokenizer.bpe_token_end_id # type: ignore
|
| 1302 |
+
boundary_state.selective_put(new_next_tokens, next_tokens, inv=True)
|
| 1303 |
+
next_tokens_cpu = next_tokens.cpu()
|
| 1304 |
+
|
| 1305 |
+
for example_idx in range(batch_size):
|
| 1306 |
+
if finished[example_idx].item():
|
| 1307 |
+
continue
|
| 1308 |
+
|
| 1309 |
+
next_token_cpu = next_tokens_cpu[example_idx].item()
|
| 1310 |
+
|
| 1311 |
+
if not boundary_state.cpu_mask[example_idx].item():
|
| 1312 |
+
if next_token_cpu >= boundary_offset:
|
| 1313 |
+
next_token_cpu -= boundary_offset
|
| 1314 |
+
|
| 1315 |
+
non_boundary_generated_tokens[example_idx].append(next_token_cpu)
|
| 1316 |
+
|
| 1317 |
+
is_first_forward = False
|
| 1318 |
+
|
| 1319 |
+
boundary_state = MaskState(
|
| 1320 |
+
(next_tokens == self.model.tokenizer.bpe_token_end_id) |
|
| 1321 |
+
(next_tokens >= boundary_offset) |
|
| 1322 |
+
finished
|
| 1323 |
+
) # type: ignore
|
| 1324 |
+
pad_state = MaskState(
|
| 1325 |
+
(next_tokens == self.model.tokenizer.bpe_token_end_id) |
|
| 1326 |
+
finished
|
| 1327 |
+
)
|
| 1328 |
+
|
| 1329 |
+
# Force EOS for (previously) finished sequences
|
| 1330 |
+
next_tokens = torch.where(finished, torch.full_like(next_tokens, eos), next_tokens)
|
| 1331 |
+
|
| 1332 |
+
# Append next tokens
|
| 1333 |
+
generated = torch.cat([generated, next_tokens.unsqueeze(-1)], dim=1)
|
| 1334 |
+
|
| 1335 |
+
# Handle finished sequences
|
| 1336 |
+
stop_hit = next_tokens.eq(eos) | next_tokens.eq(eos + boundary_offset)
|
| 1337 |
+
|
| 1338 |
+
for i in range(batch_size):
|
| 1339 |
+
# passing `scores` to stopping criteria not implemented
|
| 1340 |
+
if stopping_criteria(torch.tensor(non_boundary_generated_tokens[i], dtype=torch.long).unsqueeze(0), None).squeeze(0).item(): # type: ignore
|
| 1341 |
+
stop_hit[i] = True
|
| 1342 |
+
|
| 1343 |
+
finished |= stop_hit
|
| 1344 |
+
bytes_generated += 1
|
| 1345 |
+
|
| 1346 |
+
return pad_left([
|
| 1347 |
+
torch.cat([byte_input_ids[i, :-1], torch.tensor(x, dtype=torch.long, device=byte_input_ids.device)])
|
| 1348 |
+
for i, x in enumerate(non_boundary_generated_tokens)
|
| 1349 |
+
], value=self.model.tokenizer.pad_token_id, multiple_of=1) # type: ignore
|
| 1350 |
+
|
| 1351 |
+
__all__ = ["BolmoForCausalLM", "BolmoModel", "BolmoPreTrainedModel"]
|
recipe.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
default_stage:
|
| 2 |
+
default_modifiers:
|
| 3 |
+
QuantizationModifier:
|
| 4 |
+
targets: [Linear]
|
| 5 |
+
ignore: [lm_head, 're:visual.*', 're:.*vision_tower.*', 're:.*video_tower.*', 're:.*audio_tower.*',
|
| 6 |
+
're:.*multi_modal_projector.*']
|
| 7 |
+
scheme: NVFP4
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": "<bos>",
|
| 3 |
+
"eos_token": "<bos>",
|
| 4 |
+
"pad_token": "<pad>"
|
| 5 |
+
}
|
tokenization_bolmo.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
from functools import lru_cache
|
| 3 |
+
from typing import Optional, Union
|
| 4 |
+
from transformers import AutoTokenizer
|
| 5 |
+
from transformers.tokenization_utils import PreTrainedTokenizer
|
| 6 |
+
|
| 7 |
+
# Source: https://github.com/openai/gpt-2/blob/master/src/encoder.py#L9
|
| 8 |
+
# Also implemented in https://docs.rs/tokenizers/latest/src/tokenizers/pre_tokenizers/byte_level.rs.html#13-39
|
| 9 |
+
_CHARS_TO_BYTES = {
|
| 10 |
+
"Ā": 0, "ā": 1, "Ă": 2, "ă": 3, "Ą": 4, "ą": 5, "Ć": 6, "ć": 7, "Ĉ": 8,
|
| 11 |
+
"ĉ": 9, "Ċ": 10, "ċ": 11, "Č": 12, "č": 13, "Ď": 14, "ď": 15, "Đ": 16,
|
| 12 |
+
"đ": 17, "Ē": 18, "ē": 19, "Ĕ": 20, "ĕ": 21, "Ė": 22, "ė": 23, "Ę": 24,
|
| 13 |
+
"ę": 25, "Ě": 26, "ě": 27, "Ĝ": 28, "ĝ": 29, "Ğ": 30, "ğ": 31, "Ġ": 32,
|
| 14 |
+
"!": 33, '"': 34, "#": 35, "$": 36, "%": 37, "&": 38, "'": 39, "(": 40,
|
| 15 |
+
")": 41, "*": 42, "+": 43, ",": 44, "-": 45, ".": 46, "/": 47, "0": 48,
|
| 16 |
+
"1": 49, "2": 50, "3": 51, "4": 52, "5": 53, "6": 54, "7": 55, "8": 56,
|
| 17 |
+
"9": 57, ":": 58, ";": 59, "<": 60, "=": 61, ">": 62, "?": 63, "@": 64,
|
| 18 |
+
"A": 65, "B": 66, "C": 67, "D": 68, "E": 69, "F": 70, "G": 71, "H": 72,
|
| 19 |
+
"I": 73, "J": 74, "K": 75, "L": 76, "M": 77, "N": 78, "O": 79, "P": 80,
|
| 20 |
+
"Q": 81, "R": 82, "S": 83, "T": 84, "U": 85, "V": 86, "W": 87, "X": 88,
|
| 21 |
+
"Y": 89, "Z": 90, "[": 91, "\\": 92, "]": 93, "^": 94, "_": 95, "`": 96,
|
| 22 |
+
"a": 97, "b": 98, "c": 99, "d": 100, "e": 101, "f": 102, "g": 103,
|
| 23 |
+
"h": 104, "i": 105, "j": 106, "k": 107, "l": 108, "m": 109, "n": 110,
|
| 24 |
+
"o": 111, "p": 112, "q": 113, "r": 114, "s": 115, "t": 116, "u": 117,
|
| 25 |
+
"v": 118, "w": 119, "x": 120, "y": 121, "z": 122, "{": 123, "|": 124,
|
| 26 |
+
"}": 125, "~": 126, "ġ": 127, "Ģ": 128, "ģ": 129, "Ĥ": 130, "ĥ": 131,
|
| 27 |
+
"Ħ": 132, "ħ": 133, "Ĩ": 134, "ĩ": 135, "Ī": 136, "ī": 137, "Ĭ": 138,
|
| 28 |
+
"ĭ": 139, "Į": 140, "į": 141, "İ": 142, "ı": 143, "IJ": 144, "ij": 145,
|
| 29 |
+
"Ĵ": 146, "ĵ": 147, "Ķ": 148, "ķ": 149, "ĸ": 150, "Ĺ": 151, "ĺ": 152,
|
| 30 |
+
"Ļ": 153, "ļ": 154, "Ľ": 155, "ľ": 156, "Ŀ": 157, "ŀ": 158, "Ł": 159,
|
| 31 |
+
"ł": 160, "¡": 161, "¢": 162, "£": 163, "¤": 164, "¥": 165, "¦": 166,
|
| 32 |
+
"§": 167, "¨": 168, "©": 169, "ª": 170, "«": 171, "¬": 172, "Ń": 173,
|
| 33 |
+
"®": 174, "¯": 175, "°": 176, "±": 177, "²": 178, "³": 179, "´": 180,
|
| 34 |
+
"µ": 181, "¶": 182, "·": 183, "¸": 184, "¹": 185, "º": 186, "»": 187,
|
| 35 |
+
"¼": 188, "½": 189, "¾": 190, "¿": 191, "À": 192, "Á": 193, "Â": 194,
|
| 36 |
+
"Ã": 195, "Ä": 196, "Å": 197, "Æ": 198, "Ç": 199, "È": 200, "É": 201,
|
| 37 |
+
"Ê": 202, "Ë": 203, "Ì": 204, "Í": 205, "Î": 206, "Ï": 207, "Ð": 208,
|
| 38 |
+
"Ñ": 209, "Ò": 210, "Ó": 211, "Ô": 212, "Õ": 213, "Ö": 214, "×": 215,
|
| 39 |
+
"Ø": 216, "Ù": 217, "Ú": 218, "Û": 219, "Ü": 220, "Ý": 221, "Þ": 222,
|
| 40 |
+
"ß": 223, "à": 224, "á": 225, "â": 226, "ã": 227, "ä": 228, "å": 229,
|
| 41 |
+
"æ": 230, "ç": 231, "è": 232, "é": 233, "ê": 234, "ë": 235, "ì": 236,
|
| 42 |
+
"í": 237, "î": 238, "ï": 239, "ð": 240, "ñ": 241, "ò": 242, "ó": 243,
|
| 43 |
+
"ô": 244, "õ": 245, "ö": 246, "÷": 247, "ø": 248, "ù": 249, "ú": 250,
|
| 44 |
+
"û": 251, "ü": 252, "ý": 253, "þ": 254, "ÿ": 255,
|
| 45 |
+
}
|
| 46 |
+
_BYTES_TO_CHARS = {v: k for k, v in _CHARS_TO_BYTES.items()}
|
| 47 |
+
|
| 48 |
+
def _bytes_to_chars(byte_sequence: bytes) -> str:
|
| 49 |
+
return "".join(_BYTES_TO_CHARS[byte] for byte in byte_sequence)
|
| 50 |
+
|
| 51 |
+
def _chars_to_bytes(char_sequence: str) -> list:
|
| 52 |
+
return list(bytes(_CHARS_TO_BYTES[char] for char in char_sequence))
|
| 53 |
+
|
| 54 |
+
@dataclass
|
| 55 |
+
class BolmoTokenizerConfig:
|
| 56 |
+
vocab_size: int
|
| 57 |
+
bos_token_id: int
|
| 58 |
+
pad_token_id: int
|
| 59 |
+
eos_token_id: int
|
| 60 |
+
bpe_token_end_id: int
|
| 61 |
+
special_tokens: list[str] = field(default_factory=lambda: [])
|
| 62 |
+
special_tokens_first: bool = True
|
| 63 |
+
original_identifier: Optional[str] = None
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@classmethod
|
| 67 |
+
def bolmo(cls) -> "BolmoTokenizerConfig":
|
| 68 |
+
special_tokens = [
|
| 69 |
+
"<pad>",
|
| 70 |
+
"<bos>",
|
| 71 |
+
"<eos>",
|
| 72 |
+
"<bpe_token_end>",
|
| 73 |
+
]
|
| 74 |
+
|
| 75 |
+
return cls(
|
| 76 |
+
# *2 to accomodate fused boundary tokens
|
| 77 |
+
vocab_size=(len(special_tokens) + 256) * 2,
|
| 78 |
+
special_tokens=special_tokens,
|
| 79 |
+
bos_token_id=special_tokens.index("<bos>"),
|
| 80 |
+
pad_token_id=special_tokens.index("<pad>"),
|
| 81 |
+
eos_token_id=special_tokens.index("<bos>"),
|
| 82 |
+
bpe_token_end_id=special_tokens.index("<bpe_token_end>"),
|
| 83 |
+
original_identifier="allenai/dolma2-tokenizer",
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
def build(self):
|
| 87 |
+
return BolmoTokenizer(tokenizer_config=self)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class BolmoTokenizer(PreTrainedTokenizer):
|
| 91 |
+
TOKEN_ID_KEY = -1
|
| 92 |
+
|
| 93 |
+
def __init__(self, **kwargs):
|
| 94 |
+
tokenizer_config = kwargs.pop("tokenizer_config", BolmoTokenizerConfig.bolmo())
|
| 95 |
+
|
| 96 |
+
self.config = tokenizer_config
|
| 97 |
+
self.hf_tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.original_identifier)
|
| 98 |
+
if self.config.special_tokens_first:
|
| 99 |
+
self.offset = len(tokenizer_config.special_tokens)
|
| 100 |
+
self.special_tokens_offset = 0
|
| 101 |
+
else:
|
| 102 |
+
self.offset = 0
|
| 103 |
+
self.special_tokens_offset = self.config.vocab_size - len(tokenizer_config.special_tokens)
|
| 104 |
+
|
| 105 |
+
self.byte_sequences = {}
|
| 106 |
+
|
| 107 |
+
for key, value in self.hf_tokenizer.get_vocab().items():
|
| 108 |
+
if key in self.config.special_tokens:
|
| 109 |
+
byte_sequence = [self.special_tokens_offset + self.config.special_tokens.index(key)]
|
| 110 |
+
elif value == self.hf_tokenizer.eos_token_id and self.eos_token_id is not None:
|
| 111 |
+
byte_sequence = [self.eos_token_id]
|
| 112 |
+
elif value == self.hf_tokenizer.bos_token_id and self.bos_token_id is not None:
|
| 113 |
+
byte_sequence = [self.bos_token_id]
|
| 114 |
+
elif value == self.hf_tokenizer.pad_token_id and self.pad_token_id is not None:
|
| 115 |
+
byte_sequence = [self.pad_token_id]
|
| 116 |
+
else:
|
| 117 |
+
byte_sequence = [self.offset + i for i in _chars_to_bytes(key)]
|
| 118 |
+
|
| 119 |
+
assert self.byte_sequences.get(value) is None
|
| 120 |
+
self.byte_sequences[value] = byte_sequence
|
| 121 |
+
|
| 122 |
+
self.byte_trie = {}
|
| 123 |
+
|
| 124 |
+
for token_id, byte_sequence in self.byte_sequences.items():
|
| 125 |
+
current_dict = self.byte_trie
|
| 126 |
+
for byte in byte_sequence[::-1]: # retrieved from the back so store in reverse order
|
| 127 |
+
if byte not in current_dict:
|
| 128 |
+
current_dict[byte] = {}
|
| 129 |
+
current_dict = current_dict[byte]
|
| 130 |
+
current_dict[BolmoTokenizer.TOKEN_ID_KEY] = token_id
|
| 131 |
+
|
| 132 |
+
self.add_bos_token = True
|
| 133 |
+
self.add_eos_token = False
|
| 134 |
+
self.padding_side = "left" # for generate
|
| 135 |
+
|
| 136 |
+
super().__init__(
|
| 137 |
+
bos_token=self.config.special_tokens[self.config.bos_token_id],
|
| 138 |
+
eos_token=self.config.special_tokens[self.config.eos_token_id],
|
| 139 |
+
pad_token=self.config.special_tokens[self.config.pad_token_id],
|
| 140 |
+
extra_ids=0,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
@property
|
| 144 |
+
def bos_token_id(self):
|
| 145 |
+
return self.config.bos_token_id
|
| 146 |
+
|
| 147 |
+
@property
|
| 148 |
+
def eos_token_id(self):
|
| 149 |
+
return self.config.eos_token_id
|
| 150 |
+
|
| 151 |
+
@property
|
| 152 |
+
def pad_token_id(self):
|
| 153 |
+
return self.config.pad_token_id
|
| 154 |
+
|
| 155 |
+
@property
|
| 156 |
+
def bpe_token_end_id(self):
|
| 157 |
+
return self.config.bpe_token_end_id
|
| 158 |
+
|
| 159 |
+
@property
|
| 160 |
+
def vocab_size(self):
|
| 161 |
+
return self.config.vocab_size
|
| 162 |
+
|
| 163 |
+
def _convert_id_to_token(self, index):
|
| 164 |
+
if index < self.offset:
|
| 165 |
+
return self.config.special_tokens[index - self.special_tokens_offset]
|
| 166 |
+
|
| 167 |
+
if index >= self.offset + 256 and index < self.offset * 2 + 256:
|
| 168 |
+
# special token with fused boundary
|
| 169 |
+
return self.config.special_tokens[index - self.offset - 256] + "b"
|
| 170 |
+
|
| 171 |
+
return _BYTES_TO_CHARS[index - self.offset - 256 - self.offset] + "b" if index >= self.offset + 256 else _BYTES_TO_CHARS[index - self.offset]
|
| 172 |
+
|
| 173 |
+
def _convert_token_to_id(self, token):
|
| 174 |
+
if token in self.config.special_tokens:
|
| 175 |
+
return self.config.special_tokens.index(token)
|
| 176 |
+
|
| 177 |
+
if token in [x + "b" for x in self.config.special_tokens]:
|
| 178 |
+
# special token with fused boundary
|
| 179 |
+
return 256 + self.config.special_tokens.index(token[:-1])
|
| 180 |
+
|
| 181 |
+
if len(token) > 1 and token[-1] == "b":
|
| 182 |
+
return self.offset + 256 + _CHARS_TO_BYTES[token[0]]
|
| 183 |
+
else:
|
| 184 |
+
return self.offset + _CHARS_TO_BYTES[token]
|
| 185 |
+
|
| 186 |
+
def get_vocab(self):
|
| 187 |
+
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
| 188 |
+
return vocab
|
| 189 |
+
|
| 190 |
+
def expand_byte_ids(self, byte_ids: list[int], n_last: Optional[int] = None) -> list[int]:
|
| 191 |
+
# search in the byte tree for the longest matching token at every byte position
|
| 192 |
+
expanded_ids = []
|
| 193 |
+
for i in range(len(byte_ids)):
|
| 194 |
+
if n_last is not None and i < len(byte_ids) - n_last:
|
| 195 |
+
continue
|
| 196 |
+
|
| 197 |
+
current_dict = self.byte_trie
|
| 198 |
+
current_expansion = None
|
| 199 |
+
|
| 200 |
+
for i in range(i, -1, -1):
|
| 201 |
+
byte = byte_ids[i]
|
| 202 |
+
|
| 203 |
+
if byte == self.bpe_token_end_id:
|
| 204 |
+
# skip bpe token end markers, needed for generation
|
| 205 |
+
continue
|
| 206 |
+
|
| 207 |
+
if byte >= self.offset + 256:
|
| 208 |
+
# ignore fused boundary
|
| 209 |
+
byte -= self.offset + 256
|
| 210 |
+
|
| 211 |
+
try:
|
| 212 |
+
current_dict = current_dict[byte]
|
| 213 |
+
if BolmoTokenizer.TOKEN_ID_KEY in current_dict:
|
| 214 |
+
current_expansion = current_dict[BolmoTokenizer.TOKEN_ID_KEY]
|
| 215 |
+
except KeyError:
|
| 216 |
+
assert current_expansion is not None
|
| 217 |
+
break
|
| 218 |
+
|
| 219 |
+
expanded_ids.append(current_expansion)
|
| 220 |
+
|
| 221 |
+
return expanded_ids
|
| 222 |
+
|
| 223 |
+
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens
|
| 224 |
+
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
| 225 |
+
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
|
| 226 |
+
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
| 227 |
+
|
| 228 |
+
output = bos_token_id + token_ids_0 + eos_token_id
|
| 229 |
+
|
| 230 |
+
if token_ids_1 is not None:
|
| 231 |
+
output = output + bos_token_id + token_ids_1 + eos_token_id
|
| 232 |
+
|
| 233 |
+
return output
|
| 234 |
+
|
| 235 |
+
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_special_tokens_mask
|
| 236 |
+
def get_special_tokens_mask(
|
| 237 |
+
self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
|
| 238 |
+
) -> list[int]:
|
| 239 |
+
"""
|
| 240 |
+
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 241 |
+
special tokens using the tokenizer `prepare_for_model` method.
|
| 242 |
+
Args:
|
| 243 |
+
token_ids_0 (`List[int]`):
|
| 244 |
+
List of IDs.
|
| 245 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 246 |
+
Optional second list of IDs for sequence pairs.
|
| 247 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
| 248 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
| 249 |
+
Returns:
|
| 250 |
+
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 251 |
+
"""
|
| 252 |
+
if already_has_special_tokens:
|
| 253 |
+
return super().get_special_tokens_mask(
|
| 254 |
+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
bos_token_id = [1] if self.add_bos_token else []
|
| 258 |
+
eos_token_id = [1] if self.add_eos_token else []
|
| 259 |
+
|
| 260 |
+
if token_ids_1 is None:
|
| 261 |
+
return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
|
| 262 |
+
return (
|
| 263 |
+
bos_token_id
|
| 264 |
+
+ ([0] * len(token_ids_0))
|
| 265 |
+
+ eos_token_id
|
| 266 |
+
+ bos_token_id
|
| 267 |
+
+ ([0] * len(token_ids_1))
|
| 268 |
+
+ eos_token_id
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.create_token_type_ids_from_sequences
|
| 272 |
+
def create_token_type_ids_from_sequences(
|
| 273 |
+
self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
|
| 274 |
+
) -> list[int]:
|
| 275 |
+
"""
|
| 276 |
+
Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
|
| 277 |
+
sequence pair mask has the following format:
|
| 278 |
+
```
|
| 279 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
| 280 |
+
| first sequence | second sequence |
|
| 281 |
+
```
|
| 282 |
+
if token_ids_1 is None, only returns the first portion of the mask (0s).
|
| 283 |
+
Args:
|
| 284 |
+
token_ids_0 (`List[int]`):
|
| 285 |
+
List of ids.
|
| 286 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 287 |
+
Optional second list of IDs for sequence pairs.
|
| 288 |
+
Returns:
|
| 289 |
+
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
|
| 290 |
+
"""
|
| 291 |
+
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
|
| 292 |
+
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
| 293 |
+
|
| 294 |
+
output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
|
| 295 |
+
|
| 296 |
+
if token_ids_1 is not None:
|
| 297 |
+
output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
|
| 298 |
+
|
| 299 |
+
return output
|
| 300 |
+
|
| 301 |
+
def _tokenize(self, text: str, **kwargs) -> list[str]:
|
| 302 |
+
"""Take as input a string and return a list of strings (tokens) for words/sub-words"""
|
| 303 |
+
tokens = self.convert_ids_to_tokens(self._bolmo_encode(text))
|
| 304 |
+
return tokens
|
| 305 |
+
|
| 306 |
+
def _patch_ids_to_byte_ids(self, input_ids: list[int]):
|
| 307 |
+
return [byte_token_id for token_id in input_ids for byte_token_id in self.byte_sequences[token_id]]
|
| 308 |
+
|
| 309 |
+
def _bolmo_encode(self, string: str, add_special_tokens=False):
|
| 310 |
+
input_ids = self.hf_tokenizer.encode(string, add_special_tokens=add_special_tokens)
|
| 311 |
+
return self._patch_ids_to_byte_ids(input_ids)
|
| 312 |
+
|
| 313 |
+
def _bolmo_decode(self, tokens: list[int], skip_special_tokens: bool = False) -> str:
|
| 314 |
+
return self._decode_to_bytes(tokens, skip_special_tokens=skip_special_tokens).decode("utf-8", errors="replace")
|
| 315 |
+
|
| 316 |
+
def _decode_to_bytes(self, tokens: list[int], skip_special_tokens: bool = False) -> bytes:
|
| 317 |
+
tokens_without_boundary = []
|
| 318 |
+
for token in tokens:
|
| 319 |
+
if token >= (self.offset + 256):
|
| 320 |
+
token -= self.offset + 256
|
| 321 |
+
|
| 322 |
+
tokens_without_boundary.append(token)
|
| 323 |
+
|
| 324 |
+
utf8_bytes = []
|
| 325 |
+
|
| 326 |
+
for token in tokens_without_boundary:
|
| 327 |
+
if token < self.offset:
|
| 328 |
+
if skip_special_tokens:
|
| 329 |
+
continue
|
| 330 |
+
else:
|
| 331 |
+
utf8_bytes.extend(self.config.special_tokens[token].encode("utf-8"))
|
| 332 |
+
else:
|
| 333 |
+
utf8_bytes.append(min(token - self.offset, 255))
|
| 334 |
+
|
| 335 |
+
return bytes(utf8_bytes)
|
| 336 |
+
|
| 337 |
+
def get_tokens_and_patch_lengths(self, original_input_ids: list[int], add_bos=False, strip_pad=False, skip_last=False):
|
| 338 |
+
if add_bos and self.bos_token_id is not None:
|
| 339 |
+
byte_tokens = [self.bos_token_id]
|
| 340 |
+
patch_lengths = [1]
|
| 341 |
+
else:
|
| 342 |
+
byte_tokens = []
|
| 343 |
+
patch_lengths = []
|
| 344 |
+
|
| 345 |
+
for idx, token in enumerate(original_input_ids):
|
| 346 |
+
# optionally skip last token to keep the length the same if add_bos=True
|
| 347 |
+
if skip_last and idx == len(original_input_ids) - 1:
|
| 348 |
+
break
|
| 349 |
+
|
| 350 |
+
token_byte_tokens = self._patch_ids_to_byte_ids([int(token)])
|
| 351 |
+
|
| 352 |
+
if strip_pad and all(t == self.pad_token_id for t in token_byte_tokens):
|
| 353 |
+
# skip padding tokens
|
| 354 |
+
continue
|
| 355 |
+
|
| 356 |
+
patch_lengths.append(len(token_byte_tokens))
|
| 357 |
+
byte_tokens.extend(token_byte_tokens)
|
| 358 |
+
|
| 359 |
+
return byte_tokens, patch_lengths
|
| 360 |
+
|
| 361 |
+
def convert_tokens_to_string(self, tokens: list[str]) -> str:
|
| 362 |
+
return self._bolmo_decode(self.convert_tokens_to_ids(tokens), skip_special_tokens=False) # type: ignore
|
| 363 |
+
|
| 364 |
+
def _decode(
|
| 365 |
+
self,
|
| 366 |
+
token_ids: Union[int, list[int]],
|
| 367 |
+
skip_special_tokens: bool = False,
|
| 368 |
+
clean_up_tokenization_spaces: Optional[bool] = None,
|
| 369 |
+
spaces_between_special_tokens: bool = True,
|
| 370 |
+
**kwargs,
|
| 371 |
+
) -> str:
|
| 372 |
+
if isinstance(token_ids, int):
|
| 373 |
+
token_ids = [token_ids]
|
| 374 |
+
|
| 375 |
+
return self._bolmo_decode(token_ids, skip_special_tokens=skip_special_tokens)
|
| 376 |
+
|
| 377 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
|
| 378 |
+
return () # type: ignore
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {
|
| 3 |
+
"0": {
|
| 4 |
+
"content": "<pad>",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false,
|
| 9 |
+
"special": true
|
| 10 |
+
},
|
| 11 |
+
"1": {
|
| 12 |
+
"content": "<bos>",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": false,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false,
|
| 17 |
+
"special": true
|
| 18 |
+
}
|
| 19 |
+
},
|
| 20 |
+
"auto_map": {
|
| 21 |
+
"AutoTokenizer": [
|
| 22 |
+
"tokenization_bolmo.BolmoTokenizer",
|
| 23 |
+
null
|
| 24 |
+
]
|
| 25 |
+
},
|
| 26 |
+
"bos_token": "<bos>",
|
| 27 |
+
"clean_up_tokenization_spaces": false,
|
| 28 |
+
"eos_token": "<bos>",
|
| 29 |
+
"extra_ids": 0,
|
| 30 |
+
"extra_special_tokens": {},
|
| 31 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 32 |
+
"pad_token": "<pad>",
|
| 33 |
+
"tokenizer_class": "BolmoTokenizer"
|
| 34 |
+
}
|
utils_bolmo.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def compute_boundary_mask(boundary_logprobs: torch.Tensor, boundary_threshold: str) -> torch.Tensor:
|
| 8 |
+
if boundary_threshold.startswith("sample:"):
|
| 9 |
+
_, temperature = boundary_threshold.split(":")
|
| 10 |
+
temperature = float(temperature)
|
| 11 |
+
|
| 12 |
+
if temperature == 0:
|
| 13 |
+
return (boundary_logprobs > math.log(0.5))
|
| 14 |
+
elif temperature == 1:
|
| 15 |
+
return torch.bernoulli(torch.exp(boundary_logprobs)).to(torch.bool)
|
| 16 |
+
else:
|
| 17 |
+
raise NotImplementedError("Temperatures outside {0,1} are not implemented yet.")
|
| 18 |
+
elif boundary_threshold.startswith("topk:"):
|
| 19 |
+
_, topk = boundary_threshold.split(":")
|
| 20 |
+
topk = int(topk)
|
| 21 |
+
thresholds = torch.quantile(boundary_logprobs, dim=1, q=1 - (topk / boundary_logprobs.shape[1]))
|
| 22 |
+
return (boundary_logprobs >= thresholds.unsqueeze(-1))
|
| 23 |
+
elif boundary_threshold.startswith("topk_percent:"):
|
| 24 |
+
_, topk_percent = boundary_threshold.split(":")
|
| 25 |
+
topk_percent = float(topk_percent)
|
| 26 |
+
assert 0 <= topk_percent <= 1
|
| 27 |
+
thresholds = torch.quantile(boundary_logprobs, dim=1, q=1 - topk_percent)
|
| 28 |
+
return (boundary_logprobs >= thresholds.unsqueeze(-1))
|
| 29 |
+
else:
|
| 30 |
+
raise ValueError(f"Unknown boundary threshold: {boundary_threshold}")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _pad(tensors: list[torch.Tensor], multiple_of: int, direction: str, value):
|
| 34 |
+
max_len = max(t.size(0) for t in tensors)
|
| 35 |
+
if multiple_of > 1:
|
| 36 |
+
# Round up max_len to the nearest multiple_of
|
| 37 |
+
max_len = ((max_len + multiple_of - 1) // multiple_of) * multiple_of
|
| 38 |
+
padded = []
|
| 39 |
+
for t in tensors:
|
| 40 |
+
if direction == "left":
|
| 41 |
+
pad_shape = (max_len - t.size(0), 0)
|
| 42 |
+
elif direction == "right":
|
| 43 |
+
pad_shape = (0, max_len - t.size(0))
|
| 44 |
+
else:
|
| 45 |
+
raise ValueError(f"Unknown direction: {direction}. Must be 'left' or 'right'.")
|
| 46 |
+
padded.append(F.pad(t, pad_shape, value=value))
|
| 47 |
+
return torch.stack(padded, dim=0)
|
| 48 |
+
|
| 49 |
+
def pad_right(
|
| 50 |
+
tensors: list[torch.Tensor],
|
| 51 |
+
multiple_of: int = 128,
|
| 52 |
+
value=0,
|
| 53 |
+
):
|
| 54 |
+
return _pad(tensors, multiple_of, direction="right", value=value)
|
| 55 |
+
|
| 56 |
+
def pad_left(
|
| 57 |
+
tensors: list[torch.Tensor],
|
| 58 |
+
multiple_of: int = 128,
|
| 59 |
+
value=0,
|
| 60 |
+
):
|
| 61 |
+
return _pad(tensors, multiple_of, direction="left", value=value)
|
| 62 |
+
|
| 63 |
+
class MaskState:
|
| 64 |
+
def __init__(self, mask):
|
| 65 |
+
self.cpu_mask = mask.cpu()
|
| 66 |
+
|
| 67 |
+
self.mask = mask
|
| 68 |
+
self.inv_mask = ~mask
|
| 69 |
+
self._all = self.cpu_mask.all().item()
|
| 70 |
+
self._any = self.cpu_mask.any().item()
|
| 71 |
+
|
| 72 |
+
def any(self):
|
| 73 |
+
return self._any
|
| 74 |
+
|
| 75 |
+
def all(self):
|
| 76 |
+
return self._all
|
| 77 |
+
|
| 78 |
+
def selective_get(self, x, inv=False):
|
| 79 |
+
# try to avoid sync through nonzero on index
|
| 80 |
+
if inv:
|
| 81 |
+
if self.all():
|
| 82 |
+
return x[[]]
|
| 83 |
+
elif not self.any():
|
| 84 |
+
return x
|
| 85 |
+
else:
|
| 86 |
+
return x[self.inv_mask]
|
| 87 |
+
else:
|
| 88 |
+
if self.all():
|
| 89 |
+
return x
|
| 90 |
+
elif not self.any():
|
| 91 |
+
return x[[]]
|
| 92 |
+
else:
|
| 93 |
+
return x[self.mask]
|
| 94 |
+
|
| 95 |
+
def selective_put(self, x, out, inv=False):
|
| 96 |
+
# try to avoid sync through nonzero on index
|
| 97 |
+
if inv:
|
| 98 |
+
if self.all():
|
| 99 |
+
return
|
| 100 |
+
elif not self.any():
|
| 101 |
+
out.copy_(x)
|
| 102 |
+
else:
|
| 103 |
+
out[self.inv_mask] = x
|
| 104 |
+
else:
|
| 105 |
+
if self.all():
|
| 106 |
+
out.copy_(x)
|
| 107 |
+
elif not self.any():
|
| 108 |
+
return
|
| 109 |
+
else:
|
| 110 |
+
out[self.mask] = x
|
| 111 |
+
|
| 112 |
+
def selective_add(self, x, out, inv=False):
|
| 113 |
+
# try to avoid sync through nonzero on index
|
| 114 |
+
if inv:
|
| 115 |
+
if self.all():
|
| 116 |
+
return
|
| 117 |
+
elif not self.any():
|
| 118 |
+
out.add_(x)
|
| 119 |
+
else:
|
| 120 |
+
out[self.inv_mask] += x
|
| 121 |
+
else:
|
| 122 |
+
if self.all():
|
| 123 |
+
out.add_(x)
|
| 124 |
+
elif not self.any():
|
| 125 |
+
return
|
| 126 |
+
else:
|
| 127 |
+
out[self.mask] += x
|