WinstonDeng commited on
Commit
7805a18
·
verified ·
1 Parent(s): c171c6a

add step-3.7-flash bf16 model libs

Browse files
Files changed (1) hide show
  1. modeling_step3p7.py +1395 -0
modeling_step3p7.py CHANGED
@@ -0,0 +1,1395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import copy
16
+ import inspect
17
+ from dataclasses import dataclass
18
+ from typing import Callable, Literal, Optional, Tuple, TypedDict, Union
19
+
20
+ from PIL import Image
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+ from transformers.activations import ACT2FN
26
+ from transformers.cache_utils import Cache, DynamicCache
27
+ from transformers.generation import GenerationMixin
28
+ from transformers.masking_utils import (
29
+ create_causal_mask,
30
+ create_sliding_window_causal_mask,
31
+ )
32
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
33
+ from transformers.modeling_layers import GradientCheckpointingLayer
34
+ from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
35
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
36
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
37
+ from transformers.processing_utils import Unpack
38
+ from transformers.utils import TransformersKwargs, can_return_tuple, logging
39
+ from .configuration_step3p7 import Step3p7Config, Step3p7TextConfig
40
+ from .vision_encoder import StepRoboticsVisionEncoder
41
+
42
+
43
+ logger = logging.get_logger(__name__)
44
+ _MASK_INPUT_EMBEDS_ARG = (
45
+ "inputs_embeds"
46
+ if "inputs_embeds" in inspect.signature(create_causal_mask).parameters
47
+ else "input_embeds"
48
+ )
49
+
50
+ __all__ = [
51
+ "Step3p7Model",
52
+ ]
53
+
54
+
55
+ class StepVLImagePixelInputs(TypedDict):
56
+ type: Literal["pixel_values"]
57
+ pixel_values: torch.Tensor
58
+ patch_pixel_values: Optional[torch.Tensor]
59
+ num_patches: list[int]
60
+
61
+
62
+ class StepVLImageEmbeddingInputs(TypedDict):
63
+ type: Literal["image_embeds"]
64
+ image_embeds: torch.Tensor
65
+
66
+
67
+ StepVLImageInputs = Union[StepVLImagePixelInputs, StepVLImageEmbeddingInputs]
68
+
69
+
70
+ def _flatten_embeddings(embeddings) -> torch.Tensor:
71
+ """
72
+ Recursively flattens and concatenates NestedTensors on all but the last
73
+ dimension.
74
+ """
75
+
76
+ if isinstance(embeddings, torch.Tensor):
77
+ # Flatten all but the last dimension.
78
+ return embeddings.flatten(0, -2)
79
+
80
+ return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))
81
+
82
+ def _embedding_count_expression(embeddings) -> str:
83
+ """
84
+ Constructs a debugging representation of the number of embeddings in the
85
+ NestedTensors.
86
+ """
87
+
88
+ if isinstance(embeddings, torch.Tensor):
89
+ return " x ".join([str(dim) for dim in embeddings.shape[:-1]])
90
+
91
+ return " + ".join(_embedding_count_expression(inner) for inner in embeddings)
92
+
93
+
94
+ def _merge_multimodal_embeddings(
95
+ inputs_embeds: torch.Tensor,
96
+ is_multimodal: torch.Tensor,
97
+ multimodal_embeddings,
98
+ ) -> torch.Tensor:
99
+ """
100
+ Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
101
+ positions in ``inputs_embeds`` corresponding to placeholder tokens in
102
+ ``input_ids``.
103
+ Note:
104
+ This updates ``inputs_embeds`` in place.
105
+ """
106
+ num_expected_tokens = is_multimodal.sum().item()
107
+ assert isinstance(num_expected_tokens, int)
108
+
109
+ flattened = _flatten_embeddings(multimodal_embeddings)
110
+ if flattened.shape[0] != num_expected_tokens:
111
+ expr = _embedding_count_expression(multimodal_embeddings)
112
+ raise ValueError(
113
+ f"Attempted to assign {expr} = {flattened.shape[0]} "
114
+ f"multimodal tokens to {num_expected_tokens} placeholders"
115
+ )
116
+
117
+ is_multimodal = is_multimodal.to(inputs_embeds.device)
118
+ flattened = flattened.to(inputs_embeds.device)
119
+ inputs_embeds[is_multimodal] = flattened
120
+ return inputs_embeds
121
+
122
+ def merge_multimodal_embeddings(
123
+ input_ids: torch.Tensor,
124
+ inputs_embeds: torch.Tensor,
125
+ multimodal_embeddings,
126
+ placeholder_token_id: Union[int, list[int]],
127
+ ) -> torch.Tensor:
128
+ """
129
+ Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
130
+ positions in ``inputs_embeds`` corresponding to placeholder tokens in
131
+ ``input_ids``.
132
+
133
+ ``placeholder_token_id`` can be a list of token ids (e.g, token ids
134
+ of img_start, img_break, and img_end tokens) when needed: This means
135
+ the order of these tokens in the ``input_ids`` MUST MATCH the order of
136
+ their embeddings in ``multimodal_embeddings`` since we need to
137
+ slice-merge instead of individually scattering.
138
+ For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
139
+ - T is text token
140
+ - S is image start token
141
+ - I is image embedding token
142
+ - B is image break token
143
+ - E is image end token.
144
+
145
+ Then the image embeddings (that correspond to I's) from vision encoder
146
+ must be padded with embeddings of S, B, and E in the same order of
147
+ input_ids for a correct embedding merge.
148
+ Note:
149
+ This updates ``inputs_embeds`` in place.
150
+ """
151
+ if isinstance(placeholder_token_id, list):
152
+ placeholder_token_id = torch.tensor(
153
+ placeholder_token_id, device=input_ids.device
154
+ )
155
+ return _merge_multimodal_embeddings(
156
+ inputs_embeds,
157
+ torch.isin(input_ids, placeholder_token_id),
158
+ multimodal_embeddings,
159
+ )
160
+
161
+ return _merge_multimodal_embeddings(
162
+ inputs_embeds,
163
+ (input_ids == placeholder_token_id),
164
+ multimodal_embeddings,
165
+ )
166
+
167
+
168
+ class Step3p7PreTrainedModel(PreTrainedModel):
169
+ # Link this model family to its configuration class so PreTrainedModel.from_pretrained
170
+ # can load the config instead of failing with a NoneType error.
171
+ config_class = Step3p7Config
172
+ supports_gradient_checkpointing = True
173
+ _skip_keys_device_placement = ["past_key_values"]
174
+ _keys_to_ignore_on_load_unexpected = [
175
+ r"model\.layers\.45\.*",
176
+ r"model\.layers\.46\.*",
177
+ r"model\.layers\.47\.*",
178
+ ]
179
+ _supports_flash_attn = False
180
+ _supports_sdpa = True
181
+ _supports_flex_attn = True
182
+ _supports_static_cache = True
183
+ _supports_attention_backend = True
184
+
185
+ @classmethod
186
+ def from_pretrained(
187
+ cls, pretrained_model_name_or_path, *model_args, **kwargs
188
+ ):
189
+ key_mapping = getattr(cls, "_checkpoint_conversion_mapping", None)
190
+ if key_mapping is not None and kwargs.get("key_mapping") is None:
191
+ # Transformers only applies checkpoint renaming when key_mapping is
192
+ # passed explicitly; inheriting the class attribute alone is not enough.
193
+ kwargs["key_mapping"] = copy.deepcopy(key_mapping)
194
+ return super().from_pretrained(
195
+ pretrained_model_name_or_path, *model_args, **kwargs
196
+ )
197
+
198
+
199
+ class Step3p7RotaryEmbedding(nn.Module):
200
+ def __init__(self, config: Step3p7TextConfig, device=None, layer_idx=None):
201
+ super().__init__()
202
+ # BC: "rope_type" was originally "type"
203
+ self.layer_idx = layer_idx
204
+ self.original_rope_parameters = None
205
+ if config.rope_parameters is not None:
206
+ self.original_rope_parameters = config.rope_parameters
207
+ config.rope_parameters = dict(config.rope_parameters)
208
+ self.rope_type = config.rope_parameters.get(
209
+ "rope_type", config.rope_parameters.get("type")
210
+ )
211
+ else:
212
+ self.rope_type = "default"
213
+ self.max_seq_len_cached = config.max_position_embeddings
214
+ self.original_max_seq_len = config.max_position_embeddings
215
+
216
+ partial_rotary_factors = getattr(
217
+ config, "partial_rotary_factors", None
218
+ )
219
+ if partial_rotary_factors is not None:
220
+ config.partial_rotary_factor = partial_rotary_factors[self.layer_idx]
221
+ else:
222
+ config.partial_rotary_factor = 1.0
223
+
224
+ self.rope_theta = config.rope_theta
225
+ if isinstance(config.rope_theta, list):
226
+ self.rope_theta = config.rope_theta.copy()
227
+ config.rope_theta = self.rope_theta[self.layer_idx]
228
+
229
+ self.config = copy.copy(config)
230
+ if config.rope_parameters is not None:
231
+ self.config.rope_parameters = dict(config.rope_parameters)
232
+ self.rope_init_fn = self.compute_default_rope_parameters
233
+ if self.rope_type != "default":
234
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
235
+ inv_freq, self.attention_scaling = self.rope_init_fn(
236
+ self.config, device
237
+ )
238
+
239
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
240
+ self.original_inv_freq = self.inv_freq
241
+ config.rope_theta = self.rope_theta
242
+ config.rope_parameters = self.original_rope_parameters
243
+
244
+ @torch.no_grad()
245
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
246
+ def forward(self, x, position_ids):
247
+ inv_freq_expanded = (
248
+ self.inv_freq[None, :, None]
249
+ .float()
250
+ .expand(position_ids.shape[0], -1, 1)
251
+ .to(x.device)
252
+ )
253
+ position_ids_expanded = position_ids[:, None, :].float().to(x.device)
254
+
255
+ device_type = (
256
+ x.device.type
257
+ if isinstance(x.device.type, str) and x.device.type != "mps"
258
+ else "cpu"
259
+ )
260
+ with torch.autocast(
261
+ device_type=device_type, enabled=False
262
+ ): # Force float32
263
+ freqs = (
264
+ inv_freq_expanded.float() @ position_ids_expanded.float()
265
+ ).transpose(1, 2)
266
+ emb = torch.cat((freqs, freqs), dim=-1)
267
+ cos = emb.cos() * self.attention_scaling
268
+ sin = emb.sin() * self.attention_scaling
269
+
270
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
271
+
272
+ @staticmethod
273
+ def compute_default_rope_parameters(
274
+ config: Step3p7TextConfig | None = None,
275
+ device: Optional["torch.device"] = None,
276
+ ) -> tuple["torch.Tensor", float]:
277
+ """
278
+ Computes the inverse frequencies according to the original RoPE implementation
279
+ Args:
280
+ config ([`~transformers.PreTrainedConfig`]):
281
+ The model configuration.
282
+ device (`torch.device`):
283
+ The device to use for initialization of the inverse frequencies.
284
+ seq_len (`int`, *optional*):
285
+ The current sequence length. Unused for this type of RoPE.
286
+ Returns:
287
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
288
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
289
+ """
290
+ base = config.rope_theta
291
+ dim = (
292
+ getattr(config, "head_dim", None)
293
+ or config.hidden_size // config.num_attention_heads
294
+ )
295
+
296
+ attention_factor = 1.0 # Unused in this type of RoPE
297
+
298
+ # Compute the inverse frequencies
299
+ inv_freq = 1.0 / (
300
+ base
301
+ ** (
302
+ torch.arange(0, dim, 2, dtype=torch.int64).to(
303
+ device=device, dtype=torch.float
304
+ )
305
+ / dim
306
+ )
307
+ )
308
+ return inv_freq, attention_factor
309
+
310
+ def rotate_half(x):
311
+ """Rotates half the hidden dims of the input."""
312
+ x1 = x[..., :x.shape[-1] // 2]
313
+ x2 = x[..., x.shape[-1] // 2:]
314
+ return torch.cat((-x2, x1), dim=-1)
315
+
316
+
317
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
318
+ """Applies Rotary Position Embedding to the query and key tensors.
319
+
320
+ Args:
321
+ q (`torch.Tensor`): The query tensor.
322
+ k (`torch.Tensor`): The key tensor.
323
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
324
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
325
+ position_ids (`torch.Tensor`, *optional*):
326
+ Deprecated and unused.
327
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
328
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
329
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
330
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
331
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
332
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
333
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
334
+ Returns:
335
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
336
+ """
337
+ rotary_dim = cos.shape[-1]
338
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
339
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
340
+
341
+ # Apply rotary embeddings on the first half or full tensor
342
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
343
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
344
+
345
+ # Concatenate back to full shape
346
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
347
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
348
+ return q_embed, k_embed
349
+
350
+
351
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
352
+ """
353
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
354
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
355
+ """
356
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
357
+ if n_rep == 1:
358
+ return hidden_states
359
+ hidden_states = hidden_states[:, :, None, :, :].expand(
360
+ batch, num_key_value_heads, n_rep, slen, head_dim
361
+ )
362
+ return hidden_states.reshape(
363
+ batch, num_key_value_heads * n_rep, slen, head_dim
364
+ )
365
+
366
+
367
+ # Adapted from transformers.models.llama.modeling_llama.eager_attention_forward.
368
+ # Llama4 does not cast attention weights to fp32 here.
369
+ def eager_attention_forward(
370
+ module: nn.Module,
371
+ query: torch.Tensor,
372
+ key: torch.Tensor,
373
+ value: torch.Tensor,
374
+ attention_mask: Optional[torch.Tensor],
375
+ scaling: float,
376
+ dropout: float = 0.0,
377
+ **kwargs,
378
+ ):
379
+ key_states = repeat_kv(key, module.num_key_value_groups)
380
+ value_states = repeat_kv(value, module.num_key_value_groups)
381
+ # breakpoint()
382
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
383
+ if attention_mask is not None:
384
+ causal_mask = attention_mask[:, :, :, :key_states.shape[-2]]
385
+ attn_weights = attn_weights + causal_mask
386
+
387
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
388
+ attn_weights = nn.functional.dropout(
389
+ attn_weights, p=dropout, training=module.training
390
+ )
391
+ attn_output = torch.matmul(attn_weights, value_states)
392
+ attn_output = attn_output.transpose(1, 2).contiguous()
393
+
394
+ return attn_output, attn_weights
395
+
396
+
397
+ @dataclass
398
+ class Step3p7CausalLMOutputWithPast(ModelOutput):
399
+ r"""
400
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
401
+ Language modeling loss (for next-token prediction).
402
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
403
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
404
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
405
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
406
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
407
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
408
+ `past_key_values` input) to speed up sequential decoding.
409
+ """
410
+
411
+ loss: Optional[torch.FloatTensor] = None
412
+ last_hidden_state: Optional[torch.FloatTensor] = None
413
+ logits: torch.FloatTensor = None
414
+ past_key_values: Optional[list[torch.FloatTensor]] = None
415
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
416
+ attentions: Optional[tuple[torch.FloatTensor]] = None
417
+
418
+
419
+ class Step3p7MLP(nn.Module):
420
+ def __init__(self, config, intermediate_size=None, swiglu_limit=None):
421
+ super().__init__()
422
+ self.config = config
423
+ self.hidden_size = config.hidden_size
424
+ self.intermediate_size = (
425
+ intermediate_size
426
+ if intermediate_size is not None
427
+ else config.intermediate_size
428
+ )
429
+ self.gate_proj = nn.Linear(self.hidden_size,
430
+ self.intermediate_size,
431
+ bias=False)
432
+ self.up_proj = nn.Linear(self.hidden_size,
433
+ self.intermediate_size,
434
+ bias=False)
435
+ self.down_proj = nn.Linear(self.intermediate_size,
436
+ self.hidden_size,
437
+ bias=False)
438
+ self.act_fn = ACT2FN["silu"]
439
+ self.limit = swiglu_limit
440
+
441
+ def forward(self, x):
442
+ up = self.up_proj(x)
443
+ gate = self.act_fn(self.gate_proj(x))
444
+ if self.limit is not None:
445
+ gate = gate.clamp(min=None, max=self.limit)
446
+ up = up.clamp(min=-self.limit, max=self.limit)
447
+
448
+ return self.down_proj(gate * up)
449
+
450
+
451
+ def sigmoid_routing_function(gating_output: torch.Tensor, topk: int,
452
+ renormalize: bool):
453
+ gating_output = gating_output.float()
454
+ gate_prob = torch.sigmoid(gating_output)
455
+ gate_prob = gate_prob / gate_prob.sum(dim=-1, keepdim=True)
456
+ topk_prob, indices = torch.topk(gate_prob, k=topk, dim=1)
457
+ expert_topk_weight = topk_prob
458
+ if renormalize:
459
+ expert_topk_weight = expert_topk_weight / torch.sum(
460
+ expert_topk_weight, dim=-1, keepdim=True)
461
+ return expert_topk_weight, indices
462
+
463
+
464
+ def softmax_routing_function(gating_output: torch.Tensor, top_k: int,
465
+ renormalize: bool):
466
+ gating_output = gating_output.float()
467
+ gate_prob = torch.softmax(gating_output, dim=-1)
468
+ gate_prob = gate_prob / gate_prob.sum(dim=-1, keepdim=True)
469
+ topk_prob, indices = torch.topk(gate_prob, k=top_k, dim=1)
470
+ expert_topk_weight = topk_prob
471
+ if renormalize:
472
+ expert_topk_weight = expert_topk_weight / torch.sum(
473
+ expert_topk_weight, dim=-1, keepdim=True)
474
+ return expert_topk_weight, indices.to(torch.int32)
475
+
476
+
477
+ class MoELinear(nn.Module):
478
+
479
+ def __init__(self, num_experts, in_features, out_features):
480
+ super().__init__()
481
+ self.num_experts = num_experts
482
+ self.in_features = in_features
483
+ self.out_features = out_features
484
+ self.weight = nn.Parameter(
485
+ torch.empty(num_experts, out_features, in_features))
486
+
487
+ def forward(self, x, expert_id):
488
+ x = F.linear(x.float(), self.weight[expert_id].float())
489
+ return x
490
+
491
+
492
+ class Step3p7MoEMLP(nn.Module):
493
+
494
+ def __init__(self, config, swiglu_limit=None):
495
+ super().__init__()
496
+ self.num_experts = config.moe_num_experts
497
+ self.top_k = config.moe_top_k
498
+ self.hidden_size = config.hidden_size
499
+ self.moe_intermediate_size = config.moe_intermediate_size
500
+
501
+ self.use_moe_router_bias = config.use_moe_router_bias
502
+ if self.use_moe_router_bias:
503
+ self.router_bias = nn.Parameter(torch.zeros(config.moe_num_experts,
504
+ dtype=torch.float32),
505
+ requires_grad=False)
506
+ self.custom_routing_function = self.router_bias_func
507
+ elif config.moe_router_activation == "sigmoid":
508
+ self.custom_routing_function = sigmoid_routing_function
509
+ else:
510
+ self.custom_routing_function = None
511
+ self.need_fp32_gate = config.need_fp32_gate
512
+ self.routed_scaling_factor = getattr(config,
513
+ "moe_router_scaling_factor", 1.0)
514
+
515
+ # gating
516
+ self.gate = nn.Linear(self.hidden_size, self.num_experts, bias=False)
517
+
518
+ self.act_fn = ACT2FN["silu"]
519
+ self.limit = swiglu_limit
520
+
521
+ self.up_proj = MoELinear(self.num_experts, self.hidden_size,
522
+ self.moe_intermediate_size)
523
+ self.gate_proj = MoELinear(self.num_experts, self.hidden_size,
524
+ self.moe_intermediate_size)
525
+ self.down_proj = MoELinear(self.num_experts,
526
+ self.moe_intermediate_size,
527
+ self.hidden_size)
528
+
529
+ def router_bias_func(self, gating_output: torch.Tensor, topk: int,
530
+ renormalize: bool):
531
+ gate_prob = torch.sigmoid(gating_output.float())
532
+ gate_prob_with_bias = gate_prob + self.router_bias.unsqueeze(0)
533
+ _, indices = torch.topk(gate_prob_with_bias, k=topk, dim=1)
534
+ topk_prob = torch.gather(gate_prob, 1, indices)
535
+ expert_topk_weight = topk_prob
536
+ if renormalize:
537
+ expert_topk_weight = expert_topk_weight / (
538
+ torch.sum(expert_topk_weight, dim=-1, keepdim=True) + 1e-20)
539
+ return expert_topk_weight, indices
540
+
541
+ def get_expert_output(self, inputs: torch.Tensor, expert_id):
542
+ #if self.limit is None:
543
+ up = self.up_proj(inputs, expert_id)
544
+ gate = self.act_fn(self.gate_proj(inputs, expert_id))
545
+ if self.limit is not None:
546
+ gate = gate.clamp(min=None, max=self.limit)
547
+ up = up.clamp(min=-self.limit, max=self.limit)
548
+
549
+ return self.down_proj(gate * up, expert_id)
550
+
551
+ def forward(self, hidden_states):
552
+ """ """
553
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
554
+ hidden_states = hidden_states.view(-1, hidden_dim)
555
+ if self.need_fp32_gate:
556
+ router_logits = torch.matmul(
557
+ hidden_states.to(torch.float32),
558
+ self.gate.weight.t().to(torch.float32),
559
+ )
560
+ else:
561
+ # router_logits: (batch * sequence_length, n_experts)
562
+ router_logits = self.gate(hidden_states)
563
+
564
+ if self.custom_routing_function:
565
+ routing_weights, selected_experts = self.custom_routing_function(
566
+ router_logits, self.top_k, renormalize=True)
567
+ else:
568
+ routing_weights = F.softmax(router_logits,
569
+ dim=1,
570
+ dtype=torch.float)
571
+ routing_weights, selected_experts = torch.topk(routing_weights,
572
+ self.top_k,
573
+ dim=-1)
574
+
575
+ routing_weights = routing_weights * self.routed_scaling_factor
576
+
577
+ final_hidden_states = torch.zeros(
578
+ (batch_size * sequence_length, hidden_dim),
579
+ dtype=hidden_states.dtype,
580
+ device=hidden_states.device)
581
+
582
+ # One hot encode the selected experts to create an expert mask
583
+ # this will be used to easily index which expert is going to be sollicitated
584
+ expert_mask = torch.nn.functional.one_hot(
585
+ selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
586
+
587
+ # Loop over all available experts in the model and perform the computation on each expert
588
+ for expert_idx in range(self.num_experts):
589
+ idx, top_x = torch.where(expert_mask[expert_idx])
590
+
591
+ # Index the correct hidden states and compute the expert hidden state for
592
+ # the current expert. We need to make sure to multiply the output hidden
593
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
594
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
595
+ current_hidden_states = (
596
+ self.get_expert_output(current_state, expert_idx) *
597
+ routing_weights[top_x, idx, None])
598
+
599
+ # However `index_add_` only support torch tensors for indexing so we'll use
600
+ # the `top_x` tensor here.
601
+ final_hidden_states.index_add_(
602
+ 0, top_x, current_hidden_states.to(hidden_states.dtype))
603
+ final_hidden_states = final_hidden_states.reshape(
604
+ batch_size, sequence_length, hidden_dim)
605
+ return final_hidden_states
606
+
607
+
608
+ class Step3p7RMSNorm(nn.Module):
609
+
610
+ def __init__(
611
+ self,
612
+ hidden_size: int,
613
+ eps: float = 1e-5,
614
+ ) -> None:
615
+ super().__init__()
616
+ self.weight = nn.Parameter(torch.ones(hidden_size))
617
+ self.variance_epsilon = eps
618
+
619
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
620
+ dtype = x.dtype
621
+ x = x.float()
622
+ variance = x.pow(2).mean(dim=-1, keepdim=True)
623
+ normed = x * torch.rsqrt(variance + self.variance_epsilon)
624
+ normed = normed * (self.weight.float() + 1)
625
+ return normed.to(dtype)
626
+ class Step3p7Attention(nn.Module):
627
+
628
+ def __init__(self, config: Step3p7TextConfig, layer_idx):
629
+ super().__init__()
630
+ self.config = config
631
+ self.layer_idx = layer_idx
632
+ self.num_attention_heads = config.num_attention_heads
633
+ self.num_key_value_heads = config.num_attention_groups
634
+
635
+ layer_types = getattr(config, "layer_types", [])
636
+ if layer_types:
637
+ enable_sliding_window = layer_types[
638
+ self.layer_idx] == "sliding_attention"
639
+ else:
640
+ enable_sliding_window = self.layer_idx % 2 == 0
641
+
642
+ yarn_only_types = getattr(config, "yarn_only_types", None)
643
+ if yarn_only_types and layer_types[
644
+ self.layer_idx] not in yarn_only_types:
645
+ config.rope_parameters = None
646
+ else:
647
+ config.rope_parameters = getattr(config, "rope_scaling", None)
648
+
649
+ self.sliding_window = config.sliding_window
650
+ if enable_sliding_window:
651
+ self.num_attention_heads = config.attention_other_setting[
652
+ "num_attention_heads"]
653
+ self.num_key_value_heads = config.attention_other_setting[
654
+ "num_attention_groups"]
655
+
656
+ if self.sliding_window is not None and enable_sliding_window:
657
+ self.sliding_window = (self.sliding_window)
658
+ else:
659
+ self.sliding_window = None
660
+ self.head_dim = getattr(config, "head_dim",
661
+ config.hidden_size // self.num_attention_heads)
662
+ self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
663
+
664
+ self.rotary_emb = Step3p7RotaryEmbedding(config, layer_idx=layer_idx)
665
+
666
+ self.q_size = self.num_attention_heads * self.head_dim
667
+ self.kv_size = self.num_key_value_heads * self.head_dim
668
+ self.scaling = self.head_dim**-0.5
669
+
670
+ self.q_proj = nn.Linear(config.hidden_size, self.q_size, bias=False)
671
+ self.k_proj = nn.Linear(config.hidden_size, self.kv_size, bias=False)
672
+ self.v_proj = nn.Linear(config.hidden_size, self.kv_size, bias=False)
673
+ self.o_proj = nn.Linear(self.q_size, config.hidden_size, bias=False)
674
+ self.attention_dropout = getattr(config, "attention_dropout", 0.0)
675
+ self.q_norm = Step3p7RMSNorm(self.head_dim,
676
+ eps=config.rms_norm_eps)
677
+ self.k_norm = Step3p7RMSNorm(self.head_dim,
678
+ eps=config.rms_norm_eps)
679
+
680
+ self.use_head_wise_attn_gate = config.use_head_wise_attn_gate
681
+ if self.use_head_wise_attn_gate:
682
+ self.g_proj = nn.Linear(config.hidden_size,
683
+ self.num_attention_heads,
684
+ bias=False)
685
+
686
+ self.use_rope = True
687
+ use_rope_layers = getattr(config, "use_rope_layers", None)
688
+ if use_rope_layers:
689
+ self.use_rope = use_rope_layers[self.layer_idx]
690
+
691
+ def forward(
692
+ self,
693
+ hidden_states: torch.Tensor,
694
+ attention_mask: Optional[torch.Tensor],
695
+ past_key_value: Optional[Cache] = None,
696
+ cache_position: Optional[torch.LongTensor] = None,
697
+ position_ids: Optional[torch.LongTensor] = None,
698
+ **kwargs: Unpack[FlashAttentionKwargs],
699
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
700
+ Optional[Tuple[torch.Tensor]]]:
701
+ input_shape = hidden_states.shape[:-1]
702
+ hidden_shape = (*input_shape, -1, self.head_dim)
703
+
704
+ query_states = self.q_norm(
705
+ self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
706
+ key_states = self.k_norm(
707
+ self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
708
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(
709
+ 1, 2)
710
+ if self.use_head_wise_attn_gate:
711
+ gate_states = self.g_proj(hidden_states)
712
+ cos, sin = self.rotary_emb(hidden_states, position_ids)
713
+
714
+ # cos, sin = position_embeddings
715
+ query_states, key_states = apply_rotary_pos_emb(
716
+ query_states, key_states, cos, sin)
717
+
718
+ # query_states, key_states = apply_rotary_pos_emb(query_norm_states, key_norm_states, cos, sin)
719
+ if past_key_value is not None:
720
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
721
+ cache_kwargs = {
722
+ "sin": sin,
723
+ "cos": cos,
724
+ "cache_position": cache_position
725
+ }
726
+ key_states, value_states = past_key_value.update(
727
+ key_states, value_states, self.layer_idx, cache_kwargs)
728
+
729
+ attention_interface: Callable = eager_attention_forward
730
+ # TODO: considering FP8;
731
+ # RuntimeError: Expected attn_mask dtype to be bool or float or to match query dtype,
732
+ # but got attn_mask.dtype: long int and query.dtype: c10::BFloat16 instead.
733
+ if self.config._attn_implementation != "eager":
734
+ attention_interface = ALL_ATTENTION_FUNCTIONS[
735
+ self.config._attn_implementation]
736
+
737
+ attn_output, attn_weights = attention_interface(
738
+ self,
739
+ query_states,
740
+ key_states,
741
+ value_states,
742
+ attention_mask,
743
+ dropout=0.0 if not self.training else self.attention_dropout,
744
+ scaling=self.scaling,
745
+ sliding_window=self.sliding_window, # main diff with Llama
746
+ **kwargs,
747
+ )
748
+ attn_output = attn_output.reshape(*input_shape, -1)
749
+ if self.use_head_wise_attn_gate:
750
+ output = attn_output.view(
751
+ *attn_output.shape[:-1], self.num_attention_heads,
752
+ self.head_dim) * gate_states.unsqueeze(-1).sigmoid()
753
+ attn_output = output.view(*attn_output.shape)
754
+ attn_output = self.o_proj(attn_output)
755
+
756
+ return attn_output, attn_weights
757
+
758
+
759
+ class Step3p7DecoderLayer(GradientCheckpointingLayer):
760
+
761
+ def __init__(self, config, layer_idx):
762
+ super().__init__()
763
+ self.hidden_size = config.hidden_size
764
+ self.layer_idx = layer_idx
765
+ self.self_attn = Step3p7Attention(config, layer_idx)
766
+ layer_types = getattr(config, "layer_types", None) or []
767
+ if layer_types:
768
+ self.attention_type = layer_types[layer_idx]
769
+ else:
770
+ self.attention_type = (
771
+ "sliding_attention" if layer_idx % 2 == 0 else "full_attention"
772
+ )
773
+
774
+ moe_layers_enum = getattr(config, "moe_layers_enum", None)
775
+ if moe_layers_enum is not None:
776
+ if isinstance(moe_layers_enum, str):
777
+ moe_layers_idx = [
778
+ int(i) for i in moe_layers_enum.split(',') if i.strip()
779
+ ]
780
+ else:
781
+ moe_layers_idx = [int(i) for i in moe_layers_enum]
782
+ else:
783
+ moe_layers_idx = [i for i in range(1, config.num_hidden_layers)]
784
+ self.is_moe_layer = layer_idx in moe_layers_idx
785
+ self.use_moe = False
786
+
787
+ if config.swiglu_limits_shared and config.swiglu_limits_shared[
788
+ layer_idx] is not None and config.swiglu_limits_shared[
789
+ layer_idx] != 0:
790
+ swiglu_limit_shared = config.swiglu_limits_shared[layer_idx]
791
+ else:
792
+ swiglu_limit_shared = None
793
+ if config.swiglu_limits and config.swiglu_limits[
794
+ layer_idx] is not None and config.swiglu_limits[layer_idx] != 0:
795
+ swiglu_limit = config.swiglu_limits[layer_idx]
796
+ else:
797
+ swiglu_limit = None
798
+ if self.is_moe_layer:
799
+ self.moe = Step3p7MoEMLP(config, swiglu_limit=swiglu_limit) #
800
+ self.share_expert = Step3p7MLP(
801
+ config,
802
+ intermediate_size=config.share_expert_dim,
803
+ swiglu_limit=swiglu_limit_shared)
804
+ self.use_moe = True
805
+ else:
806
+ self.mlp = Step3p7MLP(config,
807
+ intermediate_size=config.intermediate_size,
808
+ swiglu_limit=swiglu_limit_shared)
809
+
810
+ self.input_layernorm = Step3p7RMSNorm(
811
+ config.hidden_size,
812
+ eps=config.rms_norm_eps)
813
+ self.post_attention_layernorm = Step3p7RMSNorm(
814
+ config.hidden_size,
815
+ eps=config.rms_norm_eps)
816
+
817
+ def forward(
818
+ self,
819
+ hidden_states: torch.Tensor,
820
+ attention_mask: Optional[torch.Tensor] = None,
821
+ position_ids: Optional[torch.LongTensor] = None,
822
+ past_key_value: Optional[tuple[torch.Tensor]] = None,
823
+ cache_position: Optional[torch.LongTensor] = None,
824
+ **kwargs: Unpack[FlashAttentionKwargs],
825
+ ) -> torch.FloatTensor:
826
+ residual = hidden_states
827
+ hidden_states = self.input_layernorm(hidden_states)
828
+ hidden_states, _ = self.self_attn(
829
+ hidden_states=hidden_states,
830
+ attention_mask=attention_mask,
831
+ position_ids=position_ids,
832
+ past_key_value=past_key_value,
833
+ cache_position=cache_position,
834
+ **kwargs,
835
+ )
836
+ hidden_states = residual + hidden_states
837
+
838
+ # Fully Connected
839
+ residual = hidden_states
840
+ hidden_states = self.post_attention_layernorm(hidden_states)
841
+ if self.use_moe:
842
+ share_output = self.share_expert(hidden_states)
843
+ moe_output = self.moe(hidden_states)
844
+ ffn_output = moe_output + share_output
845
+ else:
846
+ ffn_output = self.mlp(hidden_states)
847
+ if isinstance(ffn_output, tuple):
848
+ hidden_states, _ = ffn_output
849
+ else:
850
+ hidden_states = ffn_output
851
+
852
+ hidden_states = residual + hidden_states
853
+ return hidden_states
854
+
855
+
856
+ class Step3p7TextPreTrainedModel(PreTrainedModel):
857
+ # Link this model family to its configuration class so PreTrainedModel.from_pretrained
858
+ # can load the config instead of failing with a NoneType error.
859
+ config_class = Step3p7TextConfig
860
+ supports_gradient_checkpointing = True
861
+ _skip_keys_device_placement = ["past_key_values"]
862
+ _keys_to_ignore_on_load_unexpected = [
863
+ r"model\.layers\.45\.*",
864
+ r"model\.layers\.46\.*",
865
+ r"model\.layers\.47\.*",
866
+ ]
867
+ _supports_flash_attn = False
868
+ _supports_sdpa = True
869
+ _supports_flex_attn = True
870
+ _supports_static_cache = True
871
+ _supports_attention_backend = True
872
+
873
+
874
+ class Step3p7TextModel(Step3p7TextPreTrainedModel, GenerationMixin):
875
+ _no_split_modules = ["Step3p7DecoderLayer"]
876
+ base_model_prefix = "model"
877
+ _tied_weights_keys = ["lm_head.weight"]
878
+ config: Step3p7TextConfig
879
+
880
+ def __init__(self, config: Step3p7TextConfig):
881
+ super().__init__(config)
882
+ self.padding_idx = config.pad_token_id
883
+ self.vocab_size = config.vocab_size
884
+
885
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size,
886
+ self.padding_idx)
887
+ self.layers = nn.ModuleList([
888
+ Step3p7DecoderLayer(config, layer_idx)
889
+ for layer_idx in range(config.num_hidden_layers)
890
+ ])
891
+ self.norm = Step3p7RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
892
+ self.gradient_checkpointing = False
893
+ layer_types = self.config.layer_types or []
894
+ self.has_sliding_layers = (not layer_types or
895
+ "sliding_attention" in layer_types)
896
+
897
+ # Initialize weights and apply final processing
898
+ self.post_init()
899
+
900
+
901
+ def get_input_embeddings(self, input_ids):
902
+ return self.embed_tokens(input_ids)
903
+
904
+ @can_return_tuple
905
+ def forward(
906
+ self,
907
+ input_ids: torch.LongTensor = None,
908
+ attention_mask: Optional[torch.Tensor] = None,
909
+ position_ids: Optional[torch.LongTensor] = None,
910
+ past_key_values: Optional[Cache] = None,
911
+ inputs_embeds: Optional[torch.FloatTensor] = None,
912
+ use_cache: Optional[bool] = None,
913
+ output_attentions: Optional[bool] = None,
914
+ output_hidden_states: Optional[bool] = None,
915
+ return_dict: Optional[bool] = None,
916
+ cache_position: Optional[torch.LongTensor] = None,
917
+ **kwargs: Unpack[TransformersKwargs],
918
+ ) -> Union[tuple, BaseModelOutputWithPast]:
919
+ output_attentions = (
920
+ output_attentions
921
+ if output_attentions is not None
922
+ else self.config.output_attentions
923
+ )
924
+ output_hidden_states = (
925
+ output_hidden_states
926
+ if output_hidden_states is not None
927
+ else self.config.output_hidden_states
928
+ )
929
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
930
+ return_dict = (
931
+ return_dict
932
+ if return_dict is not None
933
+ else getattr(self.config, "return_dict", True)
934
+ )
935
+ if (input_ids is None) ^ (inputs_embeds is not None):
936
+ raise ValueError(
937
+ "You must specify exactly one of input_ids or inputs_embeds")
938
+
939
+ if self.gradient_checkpointing and self.training and use_cache:
940
+ logger.warning_once(
941
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
942
+ )
943
+ use_cache = False
944
+
945
+ if inputs_embeds is None:
946
+ inputs_embeds = self.embed_tokens(
947
+ input_ids.to(self.embed_tokens.weight.device))
948
+
949
+ if use_cache and past_key_values is None:
950
+ past_key_values = DynamicCache()
951
+
952
+ if cache_position is None:
953
+ past_seen_tokens = past_key_values.get_seq_length(
954
+ ) if past_key_values is not None else 0
955
+ cache_position = torch.arange(past_seen_tokens,
956
+ past_seen_tokens +
957
+ inputs_embeds.shape[1],
958
+ device=inputs_embeds.device)
959
+
960
+ if position_ids is None:
961
+ position_ids = cache_position.unsqueeze(0)
962
+
963
+ hidden_states = inputs_embeds
964
+
965
+ # It may already have been prepared by e.g. `generate`
966
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
967
+ # Prepare mask arguments
968
+ mask_kwargs = {
969
+ "config": self.config,
970
+ "attention_mask": attention_mask,
971
+ "cache_position": cache_position,
972
+ "past_key_values": past_key_values,
973
+ "position_ids": position_ids,
974
+ }
975
+ mask_kwargs[_MASK_INPUT_EMBEDS_ARG] = inputs_embeds
976
+ # Create the masks
977
+ causal_mask_mapping = {
978
+ "full_attention": create_causal_mask(**mask_kwargs),
979
+ }
980
+
981
+ # The sliding window alternating layers are not always activated depending on the config
982
+ if self.has_sliding_layers:
983
+ causal_mask_mapping[
984
+ "sliding_attention"] = create_sliding_window_causal_mask(
985
+ **mask_kwargs)
986
+
987
+ # # create position embeddings to be shared across the decoder layers
988
+ # decoder layers
989
+ all_hidden_states = () if output_hidden_states else None
990
+ all_self_attns = () if output_attentions else None
991
+ for decoder_layer in self.layers[:self.config.num_hidden_layers]:
992
+ if output_hidden_states:
993
+ all_hidden_states += (hidden_states, )
994
+
995
+ layer_outputs = decoder_layer(
996
+ hidden_states,
997
+ attention_mask=causal_mask_mapping[
998
+ decoder_layer.attention_type],
999
+ position_ids=position_ids,
1000
+ past_key_value=past_key_values,
1001
+ output_attentions=output_attentions,
1002
+ use_cache=use_cache,
1003
+ cache_position=cache_position,
1004
+ **kwargs,
1005
+ )
1006
+
1007
+ hidden_states = layer_outputs
1008
+
1009
+ hidden_states = self.norm(hidden_states)
1010
+
1011
+ return BaseModelOutputWithPast(
1012
+ last_hidden_state=hidden_states,
1013
+ past_key_values=past_key_values if use_cache else None,
1014
+ hidden_states=all_hidden_states,
1015
+ attentions=all_self_attns,
1016
+ )
1017
+
1018
+
1019
+ class Step3p7Model(Step3p7PreTrainedModel, GenerationMixin):
1020
+ config: Step3p7Config
1021
+ _tied_weights_keys = ["lm_head.weight"]
1022
+ base_model_prefix = ""
1023
+
1024
+ def __init__(self, config: Step3p7Config):
1025
+ super().__init__(config)
1026
+ self.vision_model = StepRoboticsVisionEncoder(config.vision_config)
1027
+ self.language_model = Step3p7TextModel(config.text_config)
1028
+ self.vocab_size = config.text_config.vocab_size
1029
+ self.vit_large_projector = nn.Linear(
1030
+ config.vision_config.width * 4,
1031
+ config.text_config.hidden_size,
1032
+ bias=config.projector_bias)
1033
+ self.image_placeholder_token_id = config.image_token_id
1034
+
1035
+ # Initialize weights and apply final processing
1036
+ self.post_init()
1037
+
1038
+ def get_input_embeddings(
1039
+ self,
1040
+ input_ids: torch.Tensor,
1041
+ multimodal_embeddings = None,
1042
+ ) -> torch.Tensor:
1043
+ # breakpoint()
1044
+ input_ids = input_ids.squeeze(0)
1045
+ if multimodal_embeddings is None:
1046
+ inputs_embeds = self.language_model.get_input_embeddings(input_ids)
1047
+ else:
1048
+ is_text = input_ids != self.config.image_token_id
1049
+ text_ids = input_ids[is_text]
1050
+ text_embeds = self.language_model.get_input_embeddings(text_ids)
1051
+
1052
+ inputs_embeds = torch.empty(input_ids.shape[0],
1053
+ text_embeds.shape[-1],
1054
+ dtype=text_embeds.dtype,
1055
+ device=text_embeds.device)
1056
+ inputs_embeds[is_text] = text_embeds
1057
+ inputs_embeds = merge_multimodal_embeddings(
1058
+ input_ids, inputs_embeds, multimodal_embeddings,
1059
+ self.config.image_token_id)
1060
+ inputs_embeds = inputs_embeds.unsqueeze(0)
1061
+ return inputs_embeds
1062
+
1063
+
1064
+ def set_input_embeddings(self, value):
1065
+ return self.language_model.set_input_embeddings(value)
1066
+
1067
+ def set_decoder(self, decoder):
1068
+ self.language_model = decoder
1069
+
1070
+ def get_decoder(self):
1071
+ return self.language_model
1072
+
1073
+ def _parse_and_validate_image_input(
1074
+ self, **kwargs: object) -> Optional[StepVLImageInputs]:
1075
+ pixel_values = kwargs.pop("pixel_values", None)
1076
+ patch_pixel_values = kwargs.pop("patch_pixel_values", None)
1077
+ num_patches = kwargs.pop("num_patches", None)
1078
+ image_embeds = kwargs.pop("image_embeds", None)
1079
+
1080
+ if pixel_values is None and image_embeds is None:
1081
+ return None
1082
+
1083
+ if pixel_values is not None:
1084
+ # pixel_values = flatten_bn(pixel_values, concat=True)
1085
+ if pixel_values.dim() >= 3:
1086
+ pixel_values = pixel_values.view(-1, *pixel_values.shape[-3:])
1087
+ if patch_pixel_values is not None:
1088
+ # patch_pixel_values = flatten_bn(patch_pixel_values,
1089
+ # concat=True)
1090
+ patch_pixel_values = patch_pixel_values.view(
1091
+ -1, *patch_pixel_values.shape[-3:])
1092
+ # Handle empty patch_pixel_values by setting to None
1093
+ if patch_pixel_values.shape[0] == 0:
1094
+ patch_pixel_values = None
1095
+
1096
+ return StepVLImagePixelInputs(
1097
+ type="pixel_values",
1098
+ pixel_values=pixel_values.to(self.dtype).to(self.device),
1099
+ patch_pixel_values=patch_pixel_values.to(self.dtype).to(
1100
+ self.device) if patch_pixel_values is not None else None,
1101
+ num_patches=num_patches,
1102
+ )
1103
+
1104
+ if image_embeds is not None:
1105
+ if image_embeds.dim() == 2 or image_embeds.dim() >= 3:
1106
+ image_embeds = image_embeds.view(-1, image_embeds.shape[-1])
1107
+ else:
1108
+ raise ValueError(
1109
+ f"Unexpected shape for image_embeds: {image_embeds.shape}")
1110
+
1111
+ return StepVLImageEmbeddingInputs(
1112
+ type="image_embeds",
1113
+ image_embeds=image_embeds.to(self.dtype).to(self.device),
1114
+ )
1115
+ return None
1116
+
1117
+ def _process_image_features(self,
1118
+ image_features: torch.Tensor) -> torch.Tensor:
1119
+ B, P = image_features.shape[:2]
1120
+ HW = int(P ** 0.5)
1121
+ image_features = image_features.permute(0, 2, 1).view(B, -1, HW, HW)
1122
+ image_features = self.vision_model.vit_downsampler1(image_features)
1123
+ image_features = self.vision_model.vit_downsampler2(image_features)
1124
+
1125
+ B, C, HW, HW = image_features.shape
1126
+ image_features = image_features.view(B, -1, HW * HW).permute(0, 2, 1)
1127
+ image_features = self.vit_large_projector(image_features)
1128
+ return image_features
1129
+
1130
+ def _get_vision_model_output(self,
1131
+ input_tensor: torch.Tensor) -> torch.Tensor:
1132
+ return self.vision_model(input_tensor)
1133
+
1134
+ def _process_image_input(
1135
+ self, image_input: StepVLImageInputs) -> tuple[torch.Tensor, ...]:
1136
+
1137
+ if image_input["type"] == "image_embeds":
1138
+ image_features = image_input["image_embeds"]
1139
+ else:
1140
+ image_features = self._get_vision_model_output(
1141
+ image_input["pixel_values"])
1142
+ patch_image_features = self._get_vision_model_output(
1143
+ image_input["patch_pixel_values"]
1144
+ ) if image_input["patch_pixel_values"] is not None else None
1145
+ num_patches = image_input["num_patches"]
1146
+
1147
+ image_features = self._process_image_features(image_features)
1148
+ patch_image_features = self._process_image_features(
1149
+ patch_image_features) if patch_image_features is not None else None
1150
+
1151
+ merged_image_features = []
1152
+ cur_patch_idx = 0
1153
+ for i, num_patch in enumerate(num_patches):
1154
+ cur_feature = []
1155
+ if num_patch > 0:
1156
+ patch_slice = patch_image_features[
1157
+ cur_patch_idx:cur_patch_idx + num_patch]
1158
+ cur_feature.append(patch_slice.view(-1, patch_slice.shape[-1]))
1159
+ cur_feature.append(image_features[i].view(
1160
+ -1, image_features.shape[-1]))
1161
+ cur_patch_idx += num_patch
1162
+ merged_image_features.append(
1163
+ torch.cat(cur_feature) if len(cur_feature) >
1164
+ 1 else cur_feature[0])
1165
+
1166
+ return merged_image_features
1167
+
1168
+ def get_multimodal_embeddings(self, **kwargs):
1169
+ # breakpoint()
1170
+ image_input = self._parse_and_validate_image_input(**kwargs)
1171
+ if image_input is None:
1172
+ return None
1173
+ vision_embeddings = self._process_image_input(image_input)
1174
+ return vision_embeddings
1175
+
1176
+ @can_return_tuple
1177
+ def forward(
1178
+ self,
1179
+ input_ids: torch.LongTensor = None,
1180
+ attention_mask: Optional[torch.Tensor] = None,
1181
+ position_ids: Optional[torch.LongTensor] = None,
1182
+ past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
1183
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1184
+ labels: Optional[torch.LongTensor] = None,
1185
+ use_cache: Optional[bool] = None,
1186
+ output_attentions: Optional[bool] = None,
1187
+ output_hidden_states: Optional[bool] = None,
1188
+ return_dict: Optional[bool] = None,
1189
+ cache_position: Optional[torch.LongTensor] = None,
1190
+ logits_to_keep: Union[int, torch.Tensor] = 0,
1191
+ images: Optional[list[Image.Image]] = None,
1192
+ **kwargs: Unpack[TransformersKwargs],
1193
+ ) -> Union[tuple, Step3p7CausalLMOutputWithPast]:
1194
+ r"""
1195
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1196
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1197
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1198
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1199
+ Example:
1200
+ ```python
1201
+ >>> from transformers import AutoTokenizer, Llama4ForCausalLM
1202
+ >>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf")
1203
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf")
1204
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1205
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1206
+ >>> # Generate
1207
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1208
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1209
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1210
+ ```"""
1211
+ output_attentions = (
1212
+ output_attentions
1213
+ if output_attentions is not None
1214
+ else self.config.output_attentions
1215
+ )
1216
+ output_hidden_states = (
1217
+ output_hidden_states
1218
+ if output_hidden_states is not None
1219
+ else self.config.output_hidden_states
1220
+ )
1221
+ return_dict = (
1222
+ return_dict if return_dict is not None else self.config.use_return_dict
1223
+ )
1224
+
1225
+ if inputs_embeds is None:
1226
+ input_ids = input_ids
1227
+ vision_embeddings = self.get_multimodal_embeddings(**kwargs)
1228
+ inputs_embeds = self.get_input_embeddings(input_ids,
1229
+ vision_embeddings)
1230
+ input_ids = None
1231
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1232
+ outputs = self.language_model(
1233
+ input_ids=None,
1234
+ position_ids=position_ids,
1235
+ attention_mask=attention_mask,
1236
+ past_key_values=past_key_values,
1237
+ inputs_embeds=inputs_embeds,
1238
+ use_cache=use_cache,
1239
+ output_attentions=output_attentions,
1240
+ output_hidden_states=output_hidden_states,
1241
+ return_dict=True,
1242
+ cache_position=cache_position,
1243
+ **kwargs,
1244
+ )
1245
+
1246
+ output = Step3p7CausalLMOutputWithPast(
1247
+ last_hidden_state=outputs.last_hidden_state,
1248
+ past_key_values=outputs.past_key_values,
1249
+ attentions=outputs.attentions,
1250
+ )
1251
+ return output if return_dict else output.to_tuple()
1252
+
1253
+
1254
+ class Step3p7ForConditionalGeneration(Step3p7PreTrainedModel, GenerationMixin):
1255
+ _checkpoint_conversion_mapping = {
1256
+ "^vision_model": "model.vision_model",
1257
+ r"^model(?!\.(language_model|vision_model))": "model.language_model",
1258
+ "^vit_large_projector": "model.vit_large_projector",
1259
+ }
1260
+ _tied_weights_keys = ["lm_head.weight"]
1261
+ config: Step3p7Config
1262
+
1263
+ def __init__(self, config: Step3p7Config):
1264
+ super().__init__(config)
1265
+ self.model = Step3p7Model(config)
1266
+ self.lm_head = nn.Linear(config.hidden_size,
1267
+ config.text_config.vocab_size,
1268
+ bias=False)
1269
+
1270
+ self.post_init()
1271
+
1272
+ def get_input_embeddings(self):
1273
+ return self.model.get_input_embeddings()
1274
+
1275
+ def set_input_embeddings(self, value):
1276
+ self.model.set_input_embeddings(value)
1277
+
1278
+ def get_output_embeddings(self):
1279
+ return self.model.get_output_embeddings()
1280
+
1281
+ def set_output_embeddings(self, new_embeddings):
1282
+ self.model.set_output_embeddings(new_embeddings)
1283
+
1284
+ def set_decoder(self, decoder):
1285
+ self.model.set_decoder(decoder)
1286
+
1287
+ def get_decoder(self):
1288
+ return self.model.get_decoder()
1289
+
1290
+ @property
1291
+ def language_model(self):
1292
+ return self.model.language_model
1293
+
1294
+ @property
1295
+ def visual(self):
1296
+ return self.model.vision_model
1297
+
1298
+ def forward(
1299
+ self,
1300
+ input_ids: torch.LongTensor = None,
1301
+ pixel_values: Optional[torch.Tensor] = None,
1302
+ num_patches=None,
1303
+ patch_pixel_values=None,
1304
+ patch_newline_mask=None,
1305
+ image_embeds: Optional[torch.FloatTensor] = None,
1306
+ attention_mask: Optional[torch.Tensor] = None,
1307
+ position_ids: Optional[torch.LongTensor] = None,
1308
+ past_key_values: Optional[Cache] = None,
1309
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1310
+ labels: Optional[torch.LongTensor] = None,
1311
+ use_cache: Optional[bool] = None,
1312
+ output_attentions: Optional[bool] = None,
1313
+ output_hidden_states: Optional[bool] = None,
1314
+ return_dict: Optional[bool] = None,
1315
+ cache_position: Optional[torch.LongTensor] = None,
1316
+ **kwargs: Unpack[TransformersKwargs],
1317
+ ) -> Union[tuple, Step3p7CausalLMOutputWithPast]:
1318
+ output_attentions = (
1319
+ output_attentions
1320
+ if output_attentions is not None
1321
+ else self.config.output_attentions
1322
+ )
1323
+ output_hidden_states = (
1324
+ output_hidden_states
1325
+ if output_hidden_states is not None
1326
+ else self.config.output_hidden_states
1327
+ )
1328
+
1329
+ outputs = self.model(
1330
+ input_ids=input_ids,
1331
+ num_patches=num_patches,
1332
+ patch_pixel_values=patch_pixel_values,
1333
+ patch_newline_mask=patch_newline_mask,
1334
+ position_ids=position_ids,
1335
+ attention_mask=attention_mask,
1336
+ past_key_values=past_key_values,
1337
+ inputs_embeds=inputs_embeds,
1338
+ use_cache=use_cache,
1339
+ output_attentions=output_attentions,
1340
+ output_hidden_states=output_hidden_states,
1341
+ return_dict=return_dict,
1342
+ cache_position=cache_position,
1343
+ **kwargs,
1344
+ )
1345
+
1346
+ hidden_states = outputs.last_hidden_state
1347
+ logits = self.lm_head(hidden_states)
1348
+
1349
+ los = None
1350
+ if labels is not None:
1351
+ loss = self.loss_function(
1352
+ logits=logits, labels=labels, vocab_size=self.config.vocab_size
1353
+ )
1354
+
1355
+ return Step3p7CausalLMOutputWithPast(
1356
+ logits=logits,
1357
+ )
1358
+
1359
+
1360
+ def prepare_inputs_for_generation(
1361
+ self,
1362
+ input_ids,
1363
+ past_key_values=None,
1364
+ inputs_embeds=None,
1365
+ pixel_values=None,
1366
+ patch_pixel_values=None,
1367
+ num_patches=None,
1368
+ image_embeds=None,
1369
+ attention_mask=None,
1370
+ cache_position=None,
1371
+ logits_to_keep=None,
1372
+ **kwargs,
1373
+ ):
1374
+ model_inputs = super().prepare_inputs_for_generation(
1375
+ input_ids,
1376
+ past_key_values=past_key_values,
1377
+ inputs_embeds=inputs_embeds,
1378
+ attention_mask=attention_mask,
1379
+ cache_position=cache_position,
1380
+ logits_to_keep=logits_to_keep,
1381
+ **kwargs,
1382
+ )
1383
+
1384
+ if cache_position[0] == 0:
1385
+ # During cached decoding, input ids no longer contain image tokens,
1386
+ # so pixel values should only be passed at the first step.
1387
+ model_inputs["pixel_values"] = pixel_values
1388
+
1389
+ return model_inputs
1390
+
1391
+ def _fix_state_dict_key_on_load(self, key: str) -> tuple[str, bool]:
1392
+ if key.startswith("language_model."):
1393
+ return key[len("language_model.") :], True
1394
+
1395
+ return key, False