Text Generation
Transformers
Safetensors
English
bolmo
custom_code
benjamin commited on
Commit
247323b
·
verified ·
1 Parent(s): b5b00b8

Upload folder using huggingface_hub

Browse files
configuration_bolmo.py CHANGED
@@ -3,7 +3,7 @@ from typing import Any
3
 
4
  from transformers.configuration_utils import PretrainedConfig, layer_type_validation
5
  from transformers.modeling_rope_utils import rope_config_validation
6
- from .tokenization_bolmo import ByteTokenizerConfig
7
 
8
  class BolmoConfig(PretrainedConfig):
9
  r"""
@@ -167,7 +167,7 @@ class BolmoConfig(PretrainedConfig):
167
  local_intermediate_size: int = 5504,
168
  local_rms_norm_eps=1e-5,
169
  subword_vocab_size: int = 100278, # dolma2_tokenizer subword vocab size
170
- tokenizer_config: ByteTokenizerConfig | dict[str, Any] | None = None,
171
  **kwargs,
172
  ):
173
  super().__init__(
@@ -220,8 +220,8 @@ class BolmoConfig(PretrainedConfig):
220
  self.subword_vocab_size = subword_vocab_size
221
 
222
  if tokenizer_config is None:
223
- self.tokenizer_config = asdict(ByteTokenizerConfig.bolmo())
224
- elif isinstance(tokenizer_config, ByteTokenizerConfig):
225
  self.tokenizer_config = asdict(tokenizer_config)
226
  else:
227
  self.tokenizer_config = tokenizer_config
 
3
 
4
  from transformers.configuration_utils import PretrainedConfig, layer_type_validation
5
  from transformers.modeling_rope_utils import rope_config_validation
6
+ from .tokenization_bolmo import BolmoTokenizerConfig
7
 
8
  class BolmoConfig(PretrainedConfig):
9
  r"""
 
167
  local_intermediate_size: int = 5504,
168
  local_rms_norm_eps=1e-5,
169
  subword_vocab_size: int = 100278, # dolma2_tokenizer subword vocab size
170
+ tokenizer_config: BolmoTokenizerConfig | dict[str, Any] | None = None,
171
  **kwargs,
172
  ):
173
  super().__init__(
 
220
  self.subword_vocab_size = subword_vocab_size
221
 
222
  if tokenizer_config is None:
223
+ self.tokenizer_config = asdict(BolmoTokenizerConfig.bolmo())
224
+ elif isinstance(tokenizer_config, BolmoTokenizerConfig):
225
  self.tokenizer_config = asdict(tokenizer_config)
226
  else:
227
  self.tokenizer_config = tokenizer_config
modeling_bolmo.py CHANGED
@@ -10,7 +10,8 @@ from transformers.utils.generic import TransformersKwargs
10
 
11
  from transformers.activations import ACT2FN
12
  from transformers.cache_utils import Cache, DynamicCache
13
- from transformers.generation import GenerationMixin
 
14
  from transformers.integrations import use_kernel_forward_from_hub
15
  from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
16
  from transformers.modeling_layers import GradientCheckpointingLayer
@@ -18,15 +19,18 @@ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutpu
18
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
19
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
20
  from transformers.processing_utils import Unpack
21
- from transformers.utils import auto_docstring, can_return_tuple
22
  from transformers.utils.deprecation import deprecate_kwarg
23
  from transformers.utils.generic import check_model_inputs
24
 
25
  from .configuration_bolmo import BolmoConfig
26
- from .tokenization_bolmo import ByteTokenizerConfig
27
  from .utils_bolmo import compute_boundary_mask, pad_right, pad_left, MaskState
28
 
29
- from xlstm.xlstm_large.model import mLSTMLayer, mLSTMLayerConfig, mLSTMLayerStateType, soft_cap, mLSTMBackendConfig
 
 
 
30
 
31
 
32
  @use_kernel_forward_from_hub("RMSNorm")
@@ -161,7 +165,7 @@ class BolmoAttention(nn.Module):
161
  position_embeddings: tuple[torch.Tensor, torch.Tensor],
162
  attention_mask: Optional[torch.Tensor],
163
  past_key_values: Optional[Cache] = None,
164
- cache_position: Optional[torch.LongTensor] = None,
165
  **kwargs: Unpack[TransformersKwargs],
166
  ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
167
  input_shape = hidden_states.shape[:-1]
@@ -235,10 +239,10 @@ class BolmoDecoderLayer(GradientCheckpointingLayer):
235
  self,
236
  hidden_states: torch.Tensor,
237
  attention_mask: Optional[torch.Tensor] = None,
238
- position_ids: Optional[torch.LongTensor] = None,
239
  past_key_values: Optional[Cache] = None,
240
  use_cache: Optional[bool] = False,
241
- cache_position: Optional[torch.LongTensor] = None,
242
  position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
243
  **kwargs: Unpack[TransformersKwargs],
244
  ) -> torch.Tensor:
@@ -834,7 +838,6 @@ class BolmoRotaryEmbedding(nn.Module):
834
  return cos, sin
835
 
836
 
837
- @auto_docstring
838
  class BolmoPreTrainedModel(PreTrainedModel):
839
  config: BolmoConfig
840
  base_model_prefix = "model"
@@ -853,7 +856,6 @@ class BolmoPreTrainedModel(PreTrainedModel):
853
  }
854
 
855
 
856
- @auto_docstring
857
  class BolmoModel(BolmoPreTrainedModel):
858
  def __init__(self, config: BolmoConfig):
859
  super().__init__(config)
@@ -875,7 +877,7 @@ class BolmoModel(BolmoPreTrainedModel):
875
  }
876
  )
877
 
878
- self.tokenizer_config = ByteTokenizerConfig(**config.tokenizer_config)
879
  self._tokenizer = None
880
 
881
  # Initialize weights and apply final processing
@@ -897,7 +899,7 @@ class BolmoModel(BolmoPreTrainedModel):
897
  def prefill_boundary_prediction_forward(
898
  self,
899
  input_ids: torch.Tensor,
900
- expanded_input_ids: Optional[torch.LongTensor] = None,
901
  sequence_start_indices: Optional[torch.Tensor] = None,
902
  last_token_is_boundary: bool = False,
903
  **kwargs,
@@ -913,16 +915,14 @@ class BolmoModel(BolmoPreTrainedModel):
913
  return cast(torch.Tensor, boundary_mask)
914
 
915
  @check_model_inputs()
916
- @auto_docstring
917
  def forward(
918
  self,
919
- input_ids: torch.LongTensor,
920
- expanded_input_ids: Optional[torch.LongTensor] = None,
921
  attention_mask: Optional[torch.Tensor] = None,
922
- position_ids: Optional[torch.LongTensor] = None,
923
  past_key_values: Optional[Cache] = None,
924
- inputs_embeds: Optional[torch.FloatTensor] = None,
925
- cache_position: Optional[torch.LongTensor] = None,
926
  use_cache: Optional[bool] = None,
927
  boundary_mask: Optional[torch.Tensor] = None,
928
  boundary_state: Optional[MaskState] = None,
@@ -1029,7 +1029,6 @@ class BolmoModel(BolmoPreTrainedModel):
1029
  )
1030
 
1031
 
1032
- @auto_docstring
1033
  class BolmoForCausalLM(BolmoPreTrainedModel, GenerationMixin):
1034
  _tied_weights_keys = ["lm_head.weight"]
1035
  _tp_plan = {"lm_head": "colwise_rep"}
@@ -1051,16 +1050,15 @@ class BolmoForCausalLM(BolmoPreTrainedModel, GenerationMixin):
1051
  self.lm_head = new_embeddings
1052
 
1053
  @can_return_tuple
1054
- @auto_docstring
1055
  def forward(
1056
  self,
1057
- input_ids: torch.LongTensor,
1058
- expanded_input_ids: Optional[torch.LongTensor] = None,
1059
  attention_mask: Optional[torch.Tensor] = None,
1060
- position_ids: Optional[torch.LongTensor] = None,
1061
  past_key_values: Optional[Cache] = None,
1062
  inputs_embeds: Optional[torch.FloatTensor] = None,
1063
- cache_position: Optional[torch.LongTensor] = None,
1064
  use_cache: Optional[bool] = None,
1065
  boundary_mask: Optional[torch.Tensor] = None,
1066
  boundary_state: Optional[MaskState] = None,
@@ -1114,22 +1112,42 @@ class BolmoForCausalLM(BolmoPreTrainedModel, GenerationMixin):
1114
  attentions=outputs.attentions,
1115
  )
1116
 
1117
- def generate(self, input_ids: list[list[int]], max_new_tokens: int = 20):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1118
  expand_input_ids = self.model.local_encoder.add_expanded_embeddings
1119
- batch_size = len(input_ids)
1120
 
1121
  if expand_input_ids:
1122
  expanded_input_ids = []
1123
 
1124
- for i in range(len(input_ids)):
1125
- expanded_input_ids.append(torch.tensor(self.model.tokenizer.expand_byte_ids(input_ids[i]), device=self.device, dtype=torch.long))
1126
 
1127
  expanded_input_ids = pad_left(expanded_input_ids, value=self.model.tokenizer.pad_token_id, multiple_of=1) # type: ignore
1128
  else:
1129
  expanded_input_ids = None
1130
 
1131
- byte_input_ids: torch.Tensor = pad_left([torch.tensor(x, device=self.device, dtype=torch.long) for x in input_ids], value=self.model.tokenizer.pad_token_id, multiple_of=1)
1132
-
1133
  sequence_start_indices = (byte_input_ids == self.model.tokenizer.pad_token_id).sum(-1)
1134
  batch_size, prompt_len = byte_input_ids.shape
1135
  finished = torch.zeros(batch_size, dtype=torch.bool, device=self.device)
@@ -1155,6 +1173,31 @@ class BolmoForCausalLM(BolmoPreTrainedModel, GenerationMixin):
1155
  # stays the same unless last token is pad.
1156
  sequence_start_indices = (byte_input_ids == self.model.tokenizer.pad_token_id).sum(-1)
1157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1158
  # output container
1159
  generated = byte_input_ids
1160
 
@@ -1162,8 +1205,6 @@ class BolmoForCausalLM(BolmoPreTrainedModel, GenerationMixin):
1162
  tokens_generated_plus_prefilled = max_n_prefill_patches
1163
  bytes_generated = 0
1164
 
1165
- max_length = max_n_prefill_patches + max_new_tokens
1166
-
1167
  # generation state
1168
  boundary_state = MaskState(boundary_mask[:, -1].clone())
1169
  pad_state = MaskState(torch.zeros(batch_size, dtype=torch.bool, device=self.device))
@@ -1173,10 +1214,7 @@ class BolmoForCausalLM(BolmoPreTrainedModel, GenerationMixin):
1173
  is_first_forward = True
1174
  global_past_key_values = None
1175
 
1176
- # TODO: impl
1177
- stop_token_sequences = []
1178
-
1179
- while not ((max_length is not None and tokens_generated_plus_prefilled >= max_length) or finished.all()):
1180
  input_ids_for_model = (
1181
  generated
1182
  if is_first_forward
@@ -1232,15 +1270,24 @@ class BolmoForCausalLM(BolmoPreTrainedModel, GenerationMixin):
1232
 
1233
  forced_decoding_ids[example_idx] = None # only force once
1234
 
1235
- # TODO: impl non-greedy
1236
- new_next_tokens = next_token_logits.squeeze(1).argmax(dim=-1)
1237
 
1238
- if boundary_state.all():
 
 
 
 
 
 
1239
  tokens_generated_plus_prefilled += 1
1240
 
1241
  next_tokens = new_next_tokens
1242
  next_tokens_cpu = next_tokens.cpu()
1243
  for example_idx in range(batch_size):
 
 
 
1244
  next_token_cpu = next_tokens_cpu[example_idx].item()
1245
 
1246
  if next_token_cpu >= boundary_offset:
@@ -1253,6 +1300,9 @@ class BolmoForCausalLM(BolmoPreTrainedModel, GenerationMixin):
1253
  next_tokens_cpu = next_tokens.cpu()
1254
 
1255
  for example_idx in range(batch_size):
 
 
 
1256
  next_token_cpu = next_tokens_cpu[example_idx].item()
1257
 
1258
  if not boundary_state.cpu_mask[example_idx].item():
@@ -1282,14 +1332,17 @@ class BolmoForCausalLM(BolmoPreTrainedModel, GenerationMixin):
1282
  # Handle finished sequences
1283
  stop_hit = next_tokens.eq(eos) | next_tokens.eq(eos + boundary_offset)
1284
 
1285
- # Also check for stop tokens if provided
1286
- # TODO(benjaminm): this is very annoying due to the boundaries
1287
- # make better
1288
- if len(stop_token_sequences) > 0:
1289
- # TODO: implement
1290
- raise NotImplementedError("stop_token_sequences not implemented yet for Bolmo generation.")
1291
 
1292
  finished |= stop_hit
1293
  bytes_generated += 1
1294
 
 
 
 
 
 
1295
  __all__ = ["BolmoForCausalLM", "BolmoModel", "BolmoPreTrainedModel"]
 
10
 
11
  from transformers.activations import ACT2FN
12
  from transformers.cache_utils import Cache, DynamicCache
13
+ from transformers.generation import GenerationMixin, GenerationConfig, LogitsProcessorList, StoppingCriteriaList
14
+ from transformers.generation.utils import GenerateOutput
15
  from transformers.integrations import use_kernel_forward_from_hub
16
  from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
17
  from transformers.modeling_layers import GradientCheckpointingLayer
 
19
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
20
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
21
  from transformers.processing_utils import Unpack
22
+ from transformers.utils import can_return_tuple
23
  from transformers.utils.deprecation import deprecate_kwarg
24
  from transformers.utils.generic import check_model_inputs
25
 
26
  from .configuration_bolmo import BolmoConfig
27
+ from .tokenization_bolmo import BolmoTokenizerConfig
28
  from .utils_bolmo import compute_boundary_mask, pad_right, pad_left, MaskState
29
 
30
+ try:
31
+ from xlstm.xlstm_large.model import mLSTMLayer, mLSTMLayerConfig, mLSTMLayerStateType, soft_cap, mLSTMBackendConfig
32
+ except ImportError:
33
+ raise ImportError("The `xlstm` package is required to use Bolmo. Please install it via `pip install xlstm`.")
34
 
35
 
36
  @use_kernel_forward_from_hub("RMSNorm")
 
165
  position_embeddings: tuple[torch.Tensor, torch.Tensor],
166
  attention_mask: Optional[torch.Tensor],
167
  past_key_values: Optional[Cache] = None,
168
+ cache_position: Optional[torch.Tensor] = None,
169
  **kwargs: Unpack[TransformersKwargs],
170
  ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
171
  input_shape = hidden_states.shape[:-1]
 
239
  self,
240
  hidden_states: torch.Tensor,
241
  attention_mask: Optional[torch.Tensor] = None,
242
+ position_ids: Optional[torch.Tensor] = None,
243
  past_key_values: Optional[Cache] = None,
244
  use_cache: Optional[bool] = False,
245
+ cache_position: Optional[torch.Tensor] = None,
246
  position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
247
  **kwargs: Unpack[TransformersKwargs],
248
  ) -> torch.Tensor:
 
838
  return cos, sin
839
 
840
 
 
841
  class BolmoPreTrainedModel(PreTrainedModel):
842
  config: BolmoConfig
843
  base_model_prefix = "model"
 
856
  }
857
 
858
 
 
859
  class BolmoModel(BolmoPreTrainedModel):
860
  def __init__(self, config: BolmoConfig):
861
  super().__init__(config)
 
877
  }
878
  )
879
 
880
+ self.tokenizer_config = BolmoTokenizerConfig(**config.tokenizer_config)
881
  self._tokenizer = None
882
 
883
  # Initialize weights and apply final processing
 
899
  def prefill_boundary_prediction_forward(
900
  self,
901
  input_ids: torch.Tensor,
902
+ expanded_input_ids: Optional[torch.Tensor] = None,
903
  sequence_start_indices: Optional[torch.Tensor] = None,
904
  last_token_is_boundary: bool = False,
905
  **kwargs,
 
915
  return cast(torch.Tensor, boundary_mask)
916
 
917
  @check_model_inputs()
 
918
  def forward(
919
  self,
920
+ input_ids: torch.Tensor,
921
+ expanded_input_ids: Optional[torch.Tensor] = None,
922
  attention_mask: Optional[torch.Tensor] = None,
923
+ position_ids: Optional[torch.Tensor] = None,
924
  past_key_values: Optional[Cache] = None,
925
+ cache_position: Optional[torch.Tensor] = None,
 
926
  use_cache: Optional[bool] = None,
927
  boundary_mask: Optional[torch.Tensor] = None,
928
  boundary_state: Optional[MaskState] = None,
 
1029
  )
1030
 
1031
 
 
1032
  class BolmoForCausalLM(BolmoPreTrainedModel, GenerationMixin):
1033
  _tied_weights_keys = ["lm_head.weight"]
1034
  _tp_plan = {"lm_head": "colwise_rep"}
 
1050
  self.lm_head = new_embeddings
1051
 
1052
  @can_return_tuple
 
1053
  def forward(
1054
  self,
1055
+ input_ids: torch.Tensor,
1056
+ expanded_input_ids: Optional[torch.Tensor] = None,
1057
  attention_mask: Optional[torch.Tensor] = None,
1058
+ position_ids: Optional[torch.Tensor] = None,
1059
  past_key_values: Optional[Cache] = None,
1060
  inputs_embeds: Optional[torch.FloatTensor] = None,
1061
+ cache_position: Optional[torch.Tensor] = None,
1062
  use_cache: Optional[bool] = None,
1063
  boundary_mask: Optional[torch.Tensor] = None,
1064
  boundary_state: Optional[MaskState] = None,
 
1112
  attentions=outputs.attentions,
1113
  )
1114
 
1115
+ @torch.no_grad()
1116
+ def generate( # type: ignore
1117
+ self,
1118
+ inputs: torch.Tensor,
1119
+ generation_config: Optional[GenerationConfig] = None,
1120
+ logits_processor: Optional[LogitsProcessorList] = None,
1121
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
1122
+ use_model_defaults: Optional[bool] = None,
1123
+ **kwargs,
1124
+ ) -> Union[GenerateOutput, torch.Tensor]:
1125
+ # generic preprocessing
1126
+
1127
+ generation_config, model_kwargs = self._prepare_generation_config(
1128
+ generation_config, use_model_defaults, **kwargs
1129
+ )
1130
+ self._prepare_special_tokens(generation_config, device=self.model.device)
1131
+
1132
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
1133
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
1134
+
1135
+ # start of custom generate
1136
+
1137
  expand_input_ids = self.model.local_encoder.add_expanded_embeddings
1138
+ batch_size = len(inputs)
1139
 
1140
  if expand_input_ids:
1141
  expanded_input_ids = []
1142
 
1143
+ for i in range(len(inputs)):
1144
+ expanded_input_ids.append(torch.tensor(self.model.tokenizer.expand_byte_ids(inputs[i].tolist()), device=self.device, dtype=torch.long))
1145
 
1146
  expanded_input_ids = pad_left(expanded_input_ids, value=self.model.tokenizer.pad_token_id, multiple_of=1) # type: ignore
1147
  else:
1148
  expanded_input_ids = None
1149
 
1150
+ byte_input_ids = inputs
 
1151
  sequence_start_indices = (byte_input_ids == self.model.tokenizer.pad_token_id).sum(-1)
1152
  batch_size, prompt_len = byte_input_ids.shape
1153
  finished = torch.zeros(batch_size, dtype=torch.bool, device=self.device)
 
1173
  # stays the same unless last token is pad.
1174
  sequence_start_indices = (byte_input_ids == self.model.tokenizer.pad_token_id).sum(-1)
1175
 
1176
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
1177
+ has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
1178
+ generation_config = self._prepare_generated_length(
1179
+ generation_config=generation_config,
1180
+ has_default_max_length=has_default_max_length,
1181
+ has_default_min_length=has_default_min_length,
1182
+ model_input_name="input_ids",
1183
+ inputs_tensor=byte_input_ids,
1184
+ input_ids_length=byte_input_ids.shape[1],
1185
+ )
1186
+
1187
+ logits_processor = self._get_logits_processor(
1188
+ generation_config=generation_config, # type: ignore
1189
+ input_ids_seq_length=byte_input_ids.shape[1],
1190
+ encoder_input_ids=byte_input_ids, # type: ignore
1191
+ logits_processor=logits_processor,
1192
+ device=byte_input_ids.device, # type: ignore
1193
+ model_kwargs=model_kwargs,
1194
+ )
1195
+ stopping_criteria = self._get_stopping_criteria(
1196
+ generation_config=generation_config, # type: ignore
1197
+ stopping_criteria=stopping_criteria,
1198
+ tokenizer=self.model.tokenizer,
1199
+ )
1200
+
1201
  # output container
1202
  generated = byte_input_ids
1203
 
 
1205
  tokens_generated_plus_prefilled = max_n_prefill_patches
1206
  bytes_generated = 0
1207
 
 
 
1208
  # generation state
1209
  boundary_state = MaskState(boundary_mask[:, -1].clone())
1210
  pad_state = MaskState(torch.zeros(batch_size, dtype=torch.bool, device=self.device))
 
1214
  is_first_forward = True
1215
  global_past_key_values = None
1216
 
1217
+ while not finished.all():
 
 
 
1218
  input_ids_for_model = (
1219
  generated
1220
  if is_first_forward
 
1270
 
1271
  forced_decoding_ids[example_idx] = None # only force once
1272
 
1273
+ # passing input_ids to logit processor not implemented
1274
+ next_token_scores = logits_processor(None, next_token_logits[:, -1]) # type: ignore
1275
 
1276
+ if generation_config is not None and generation_config.do_sample:
1277
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
1278
+ new_next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
1279
+ else:
1280
+ new_next_tokens = torch.argmax(next_token_scores, dim=-1)
1281
+
1282
+ if boundary_state.all() or is_first_forward:
1283
  tokens_generated_plus_prefilled += 1
1284
 
1285
  next_tokens = new_next_tokens
1286
  next_tokens_cpu = next_tokens.cpu()
1287
  for example_idx in range(batch_size):
1288
+ if finished[example_idx].item():
1289
+ continue
1290
+
1291
  next_token_cpu = next_tokens_cpu[example_idx].item()
1292
 
1293
  if next_token_cpu >= boundary_offset:
 
1300
  next_tokens_cpu = next_tokens.cpu()
1301
 
1302
  for example_idx in range(batch_size):
1303
+ if finished[example_idx].item():
1304
+ continue
1305
+
1306
  next_token_cpu = next_tokens_cpu[example_idx].item()
1307
 
1308
  if not boundary_state.cpu_mask[example_idx].item():
 
1332
  # Handle finished sequences
1333
  stop_hit = next_tokens.eq(eos) | next_tokens.eq(eos + boundary_offset)
1334
 
1335
+ for i in range(batch_size):
1336
+ # passing `scores` to stopping criteria not implemented
1337
+ if stopping_criteria(torch.tensor(non_boundary_generated_tokens[i], dtype=torch.long).unsqueeze(0), None).squeeze(0).item(): # type: ignore
1338
+ stop_hit[i] = True
 
 
1339
 
1340
  finished |= stop_hit
1341
  bytes_generated += 1
1342
 
1343
+ return pad_left([
1344
+ torch.cat([byte_input_ids[i, :-1], torch.tensor(x, dtype=torch.long, device=byte_input_ids.device)])
1345
+ for i, x in enumerate(non_boundary_generated_tokens)
1346
+ ], value=self.model.tokenizer.pad_token_id, multiple_of=1) # type: ignore
1347
+
1348
  __all__ = ["BolmoForCausalLM", "BolmoModel", "BolmoPreTrainedModel"]
special_tokens_map.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<bos>",
3
+ "eos_token": "<bos>",
4
+ "pad_token": "<pad>"
5
+ }
tokenization_bolmo.py CHANGED
@@ -1,7 +1,8 @@
1
  from dataclasses import dataclass, field
2
  from functools import lru_cache
3
- from typing import Optional
4
  from transformers import AutoTokenizer
 
5
 
6
  # Source: https://github.com/openai/gpt-2/blob/master/src/encoder.py#L9
7
  # Also implemented in https://docs.rs/tokenizers/latest/src/tokenizers/pre_tokenizers/byte_level.rs.html#13-39
@@ -51,7 +52,7 @@ def _chars_to_bytes(char_sequence: str) -> list:
51
  return list(bytes(_CHARS_TO_BYTES[char] for char in char_sequence))
52
 
53
  @dataclass
54
- class ByteTokenizerConfig:
55
  vocab_size: int
56
  bos_token_id: int
57
  pad_token_id: int
@@ -63,7 +64,7 @@ class ByteTokenizerConfig:
63
 
64
 
65
  @classmethod
66
- def bolmo(cls) -> "ByteTokenizerConfig":
67
  special_tokens = [
68
  "<pad>",
69
  "<bos>",
@@ -83,13 +84,15 @@ class ByteTokenizerConfig:
83
  )
84
 
85
  def build(self):
86
- return ByteTokenizer(self)
87
 
88
 
89
- class ByteTokenizer:
90
  TOKEN_ID_KEY = -1
91
 
92
- def __init__(self, tokenizer_config: ByteTokenizerConfig):
 
 
93
  self.config = tokenizer_config
94
  self.hf_tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.original_identifier)
95
  if self.config.special_tokens_first:
@@ -124,7 +127,18 @@ class ByteTokenizer:
124
  if byte not in current_dict:
125
  current_dict[byte] = {}
126
  current_dict = current_dict[byte]
127
- current_dict[ByteTokenizer.TOKEN_ID_KEY] = token_id
 
 
 
 
 
 
 
 
 
 
 
128
 
129
  @property
130
  def bos_token_id(self):
@@ -142,6 +156,37 @@ class ByteTokenizer:
142
  def bpe_token_end_id(self):
143
  return self.config.bpe_token_end_id
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  def expand_byte_ids(self, byte_ids: list[int], n_last: Optional[int] = None) -> list[int]:
146
  # search in the byte tree for the longest matching token at every byte position
147
  expanded_ids = []
@@ -165,8 +210,8 @@ class ByteTokenizer:
165
 
166
  try:
167
  current_dict = current_dict[byte]
168
- if ByteTokenizer.TOKEN_ID_KEY in current_dict:
169
- current_expansion = current_dict[ByteTokenizer.TOKEN_ID_KEY]
170
  except KeyError:
171
  assert current_expansion is not None
172
  break
@@ -175,17 +220,100 @@ class ByteTokenizer:
175
 
176
  return expanded_ids
177
 
178
- def patch_ids_to_byte_ids(self, input_ids: list[int]):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  return [byte_token_id for token_id in input_ids for byte_token_id in self.byte_sequences[token_id]]
180
 
181
- def encode(self, string: str, add_special_tokens=False):
182
  input_ids = self.hf_tokenizer.encode(string, add_special_tokens=add_special_tokens)
183
- return self.patch_ids_to_byte_ids(input_ids)
184
 
185
- def decode(self, tokens: list[int]) -> str:
186
- return self.decode_to_bytes(tokens).decode("utf-8", errors="replace")
187
 
188
- def decode_to_bytes(self, tokens: list[int]) -> bytes:
189
  tokens_without_boundary = []
190
  for token in tokens:
191
  if token >= (self.offset + 256):
@@ -193,7 +321,17 @@ class ByteTokenizer:
193
 
194
  tokens_without_boundary.append(token)
195
 
196
- utf8_bytes = [min(token - self.offset, 255) for token in tokens_without_boundary if token >= self.offset]
 
 
 
 
 
 
 
 
 
 
197
  return bytes(utf8_bytes)
198
 
199
  def get_tokens_and_patch_lengths(self, original_input_ids: list[int], add_bos=False, strip_pad=False, skip_last=False):
@@ -209,7 +347,7 @@ class ByteTokenizer:
209
  if skip_last and idx == len(original_input_ids) - 1:
210
  break
211
 
212
- token_byte_tokens = self.patch_ids_to_byte_ids([int(token)])
213
 
214
  if strip_pad and all(t == self.pad_token_id for t in token_byte_tokens):
215
  # skip padding tokens
@@ -220,82 +358,21 @@ class ByteTokenizer:
220
 
221
  return byte_tokens, patch_lengths
222
 
223
- @lru_cache(maxsize=1024)
224
- def _is_spacelike(self, token_id: int) -> bool:
225
- """
226
- Check if a token ID is spacelike.
227
- """
228
- byte = token_id - self.offset
229
- # see https://github.com/kjslag/spacebyte/blob/321111315c92bce0bc2f9f1630cb0bc82b897c57/spacebyte.py#L137-L145.
230
- is_spacelike = (
231
- (byte < ord('0')) |
232
- ((ord('9') < byte) & (byte < ord('A'))) |
233
- ((ord('Z') < byte) & (byte < ord('a'))) |
234
- ((ord('z') < byte) & (byte < 0b1000_0000)) |
235
- (0b1100_0000 <= byte)
236
- )
237
- return is_spacelike
238
-
239
- @lru_cache(maxsize=1024)
240
- def _is_strict_spacelike(self, token_id: int) -> bool:
241
- """
242
- Check if a token ID is strictly spacelike (only space, tab, newline, carriage return).
243
- """
244
- byte = token_id - self.offset
245
- return byte in {ord(' '), ord('\t'), ord('\n'), ord('\r')}
246
-
247
- def get_space_patch_lengths(self, input_ids: list[int], max_patch_length: int = 16, kind: str = "strict_end_before_space") -> list[int]:
248
- patch_lengths = []
249
- current_length = 0
250
-
251
- special_tokens = {self.bos_token_id, self.eos_token_id, self.pad_token_id}
252
-
253
- all_spacelike = [self._is_spacelike(token) for token in input_ids]
254
-
255
- if kind == "spacebyte":
256
- for token_idx, token in enumerate(input_ids):
257
- current_length += 1
258
-
259
- spacelike = all_spacelike[token_idx]
260
- previous_spacelike = all_spacelike[token_idx - 1] if token_idx > 0 else False
261
-
262
- if (not previous_spacelike and spacelike) or current_length >= max_patch_length or token in special_tokens:
263
- patch_lengths.append(current_length)
264
- current_length = 0
265
- elif kind == "spacebyte_end_before_space":
266
- for token_idx, token in enumerate(input_ids):
267
- current_length += 1
268
-
269
- spacelike = all_spacelike[token_idx]
270
- next_spacelike = all_spacelike[token_idx + 1] if token_idx < len(input_ids) - 1 else True
271
-
272
- if (not spacelike and next_spacelike) or current_length >= max_patch_length or token in special_tokens:
273
- patch_lengths.append(current_length)
274
- current_length = 0
275
- elif kind == "strict_end_before_space":
276
- all_strict_spacelike = [self._is_strict_spacelike(token) for token in input_ids]
277
- in_strict_prefix = True
278
-
279
- for token_idx, token in enumerate(input_ids):
280
- current_length += 1
281
-
282
- spacelike = all_spacelike[token_idx]
283
- strict_spacelike = all_strict_spacelike[token_idx]
284
- next_spacelike = all_spacelike[token_idx + 1] if token_idx < len(input_ids) - 1 else True
285
- next_strict_spacelike = all_strict_spacelike[token_idx + 1] if token_idx < len(input_ids) - 1 else True
286
-
287
- if not strict_spacelike:
288
- in_strict_prefix = False
289
-
290
- if in_strict_prefix:
291
- continue
292
 
293
- if (spacelike != next_spacelike) or (strict_spacelike != next_strict_spacelike) or current_length >= max_patch_length or token in special_tokens:
294
- patch_lengths.append(current_length)
295
- in_strict_prefix = True
296
- current_length = 0
 
 
 
 
 
 
297
 
298
- if current_length > 0:
299
- patch_lengths.append(current_length)
300
 
301
- return patch_lengths
 
 
1
  from dataclasses import dataclass, field
2
  from functools import lru_cache
3
+ from typing import Optional, Union
4
  from transformers import AutoTokenizer
5
+ from transformers.tokenization_utils import PreTrainedTokenizer
6
 
7
  # Source: https://github.com/openai/gpt-2/blob/master/src/encoder.py#L9
8
  # Also implemented in https://docs.rs/tokenizers/latest/src/tokenizers/pre_tokenizers/byte_level.rs.html#13-39
 
52
  return list(bytes(_CHARS_TO_BYTES[char] for char in char_sequence))
53
 
54
  @dataclass
55
+ class BolmoTokenizerConfig:
56
  vocab_size: int
57
  bos_token_id: int
58
  pad_token_id: int
 
64
 
65
 
66
  @classmethod
67
+ def bolmo(cls) -> "BolmoTokenizerConfig":
68
  special_tokens = [
69
  "<pad>",
70
  "<bos>",
 
84
  )
85
 
86
  def build(self):
87
+ return BolmoTokenizer(tokenizer_config=self)
88
 
89
 
90
+ class BolmoTokenizer(PreTrainedTokenizer):
91
  TOKEN_ID_KEY = -1
92
 
93
+ def __init__(self, **kwargs):
94
+ tokenizer_config = kwargs.pop("tokenizer_config", BolmoTokenizerConfig.bolmo())
95
+
96
  self.config = tokenizer_config
97
  self.hf_tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.original_identifier)
98
  if self.config.special_tokens_first:
 
127
  if byte not in current_dict:
128
  current_dict[byte] = {}
129
  current_dict = current_dict[byte]
130
+ current_dict[BolmoTokenizer.TOKEN_ID_KEY] = token_id
131
+
132
+ self.add_bos_token = True
133
+ self.add_eos_token = False
134
+ self.padding_side = "left" # for generate
135
+
136
+ super().__init__(
137
+ bos_token=self.config.special_tokens[self.config.bos_token_id],
138
+ eos_token=self.config.special_tokens[self.config.eos_token_id],
139
+ pad_token=self.config.special_tokens[self.config.pad_token_id],
140
+ extra_ids=0,
141
+ )
142
 
143
  @property
144
  def bos_token_id(self):
 
156
  def bpe_token_end_id(self):
157
  return self.config.bpe_token_end_id
158
 
159
+ @property
160
+ def vocab_size(self):
161
+ return self.config.vocab_size
162
+
163
+ def _convert_id_to_token(self, index):
164
+ if index < self.offset:
165
+ return self.config.special_tokens[index - self.special_tokens_offset]
166
+
167
+ if index >= self.offset + 256 and index < self.offset * 2 + 256:
168
+ # special token with fused boundary
169
+ return self.config.special_tokens[index - self.offset - 256] + "b"
170
+
171
+ return _BYTES_TO_CHARS[index - self.offset - 256 - self.offset] + "b" if index >= self.offset + 256 else _BYTES_TO_CHARS[index - self.offset]
172
+
173
+ def _convert_token_to_id(self, token):
174
+ if token in self.config.special_tokens:
175
+ return self.config.special_tokens.index(token)
176
+
177
+ if token in [x + "b" for x in self.config.special_tokens]:
178
+ # special token with fused boundary
179
+ return 256 + self.config.special_tokens.index(token[:-1])
180
+
181
+ if len(token) > 1 and token[-1] == "b":
182
+ return self.offset + 256 + _CHARS_TO_BYTES[token[0]]
183
+ else:
184
+ return self.offset + _CHARS_TO_BYTES[token]
185
+
186
+ def get_vocab(self):
187
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
188
+ return vocab
189
+
190
  def expand_byte_ids(self, byte_ids: list[int], n_last: Optional[int] = None) -> list[int]:
191
  # search in the byte tree for the longest matching token at every byte position
192
  expanded_ids = []
 
210
 
211
  try:
212
  current_dict = current_dict[byte]
213
+ if BolmoTokenizer.TOKEN_ID_KEY in current_dict:
214
+ current_expansion = current_dict[BolmoTokenizer.TOKEN_ID_KEY]
215
  except KeyError:
216
  assert current_expansion is not None
217
  break
 
220
 
221
  return expanded_ids
222
 
223
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens
224
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
225
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
226
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
227
+
228
+ output = bos_token_id + token_ids_0 + eos_token_id
229
+
230
+ if token_ids_1 is not None:
231
+ output = output + bos_token_id + token_ids_1 + eos_token_id
232
+
233
+ return output
234
+
235
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_special_tokens_mask
236
+ def get_special_tokens_mask(
237
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
238
+ ) -> list[int]:
239
+ """
240
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
241
+ special tokens using the tokenizer `prepare_for_model` method.
242
+ Args:
243
+ token_ids_0 (`List[int]`):
244
+ List of IDs.
245
+ token_ids_1 (`List[int]`, *optional*):
246
+ Optional second list of IDs for sequence pairs.
247
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
248
+ Whether or not the token list is already formatted with special tokens for the model.
249
+ Returns:
250
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
251
+ """
252
+ if already_has_special_tokens:
253
+ return super().get_special_tokens_mask(
254
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
255
+ )
256
+
257
+ bos_token_id = [1] if self.add_bos_token else []
258
+ eos_token_id = [1] if self.add_eos_token else []
259
+
260
+ if token_ids_1 is None:
261
+ return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
262
+ return (
263
+ bos_token_id
264
+ + ([0] * len(token_ids_0))
265
+ + eos_token_id
266
+ + bos_token_id
267
+ + ([0] * len(token_ids_1))
268
+ + eos_token_id
269
+ )
270
+
271
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.create_token_type_ids_from_sequences
272
+ def create_token_type_ids_from_sequences(
273
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
274
+ ) -> list[int]:
275
+ """
276
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
277
+ sequence pair mask has the following format:
278
+ ```
279
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
280
+ | first sequence | second sequence |
281
+ ```
282
+ if token_ids_1 is None, only returns the first portion of the mask (0s).
283
+ Args:
284
+ token_ids_0 (`List[int]`):
285
+ List of ids.
286
+ token_ids_1 (`List[int]`, *optional*):
287
+ Optional second list of IDs for sequence pairs.
288
+ Returns:
289
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
290
+ """
291
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
292
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
293
+
294
+ output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
295
+
296
+ if token_ids_1 is not None:
297
+ output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
298
+
299
+ return output
300
+
301
+ def _tokenize(self, text: str, **kwargs) -> list[str]:
302
+ """Take as input a string and return a list of strings (tokens) for words/sub-words"""
303
+ tokens = self.convert_ids_to_tokens(self._bolmo_encode(text))
304
+ return tokens
305
+
306
+ def _patch_ids_to_byte_ids(self, input_ids: list[int]):
307
  return [byte_token_id for token_id in input_ids for byte_token_id in self.byte_sequences[token_id]]
308
 
309
+ def _bolmo_encode(self, string: str, add_special_tokens=False):
310
  input_ids = self.hf_tokenizer.encode(string, add_special_tokens=add_special_tokens)
311
+ return self._patch_ids_to_byte_ids(input_ids)
312
 
313
+ def _bolmo_decode(self, tokens: list[int], skip_special_tokens: bool = False) -> str:
314
+ return self._decode_to_bytes(tokens, skip_special_tokens=skip_special_tokens).decode("utf-8", errors="replace")
315
 
316
+ def _decode_to_bytes(self, tokens: list[int], skip_special_tokens: bool = False) -> bytes:
317
  tokens_without_boundary = []
318
  for token in tokens:
319
  if token >= (self.offset + 256):
 
321
 
322
  tokens_without_boundary.append(token)
323
 
324
+ utf8_bytes = []
325
+
326
+ for token in tokens_without_boundary:
327
+ if token < self.offset:
328
+ if skip_special_tokens:
329
+ continue
330
+ else:
331
+ utf8_bytes.extend(self.config.special_tokens[token].encode("utf-8"))
332
+ else:
333
+ utf8_bytes.append(min(token - self.offset, 255))
334
+
335
  return bytes(utf8_bytes)
336
 
337
  def get_tokens_and_patch_lengths(self, original_input_ids: list[int], add_bos=False, strip_pad=False, skip_last=False):
 
347
  if skip_last and idx == len(original_input_ids) - 1:
348
  break
349
 
350
+ token_byte_tokens = self._patch_ids_to_byte_ids([int(token)])
351
 
352
  if strip_pad and all(t == self.pad_token_id for t in token_byte_tokens):
353
  # skip padding tokens
 
358
 
359
  return byte_tokens, patch_lengths
360
 
361
+ def convert_tokens_to_string(self, tokens: list[str]) -> str:
362
+ return self._bolmo_decode(self.convert_tokens_to_ids(tokens), skip_special_tokens=False) # type: ignore
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
 
364
+ def _decode(
365
+ self,
366
+ token_ids: Union[int, list[int]],
367
+ skip_special_tokens: bool = False,
368
+ clean_up_tokenization_spaces: Optional[bool] = None,
369
+ spaces_between_special_tokens: bool = True,
370
+ **kwargs,
371
+ ) -> str:
372
+ if isinstance(token_ids, int):
373
+ token_ids = [token_ids]
374
 
375
+ return self._bolmo_decode(token_ids, skip_special_tokens=skip_special_tokens)
 
376
 
377
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
378
+ return () # type: ignore
tokenizer_config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<pad>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<bos>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ }
19
+ },
20
+ "auto_map": {
21
+ "AutoTokenizer": [
22
+ "tokenization_bolmo.BolmoTokenizer",
23
+ null
24
+ ]
25
+ },
26
+ "bos_token": "<bos>",
27
+ "clean_up_tokenization_spaces": false,
28
+ "eos_token": "<bos>",
29
+ "extra_ids": 0,
30
+ "extra_special_tokens": {},
31
+ "model_max_length": 1000000000000000019884624838656,
32
+ "pad_token": "<pad>",
33
+ "tokenizer_class": "BolmoTokenizer"
34
+ }