huangyu-nv commited on
Commit
4e84267
·
1 Parent(s): 1584e8c

Sync Step3.7 remote code and processor config

Browse files
config.json CHANGED
@@ -4,6 +4,7 @@
4
  ],
5
  "auto_map": {
6
  "AutoConfig": "configuration_step3p7.Step3p7Config",
 
7
  "AutoModelForCausalLM": "modeling_step3p7.Step3p7ForConditionalGeneration"
8
  },
9
  "dtype": "bfloat16",
 
4
  ],
5
  "auto_map": {
6
  "AutoConfig": "configuration_step3p7.Step3p7Config",
7
+ "AutoProcessor": "processing_step3.Step3VLProcessor",
8
  "AutoModelForCausalLM": "modeling_step3p7.Step3p7ForConditionalGeneration"
9
  },
10
  "dtype": "bfloat16",
configuration_step3p6.py DELETED
@@ -1,216 +0,0 @@
1
- from typing import Any, Optional, Sequence, Union
2
-
3
- from transformers.configuration_utils import PretrainedConfig
4
-
5
- class StepRoboticsVisionEncoderConfig(PretrainedConfig):
6
- model_type = "perception_encoder"
7
-
8
- def __init__(
9
- self,
10
- width=1536,
11
- layers=47,
12
- heads=16,
13
- num_channels=3,
14
- image_size=728,
15
- mlp_ratio = 8960/1536,
16
- patch_size=14,
17
- hidden_act="quick_gelu",
18
- layer_norm_eps=1e-5,
19
- ues_cls_token=False,
20
- use_cls_token: Optional[bool] = None,
21
- use_ln_pre=True,
22
- use_ln_post=False,
23
- use_abs_posemb=True,
24
- use_rope2d=True,
25
- ls_init_value=0.1,
26
- **kwargs,
27
- ):
28
- self.width = width
29
- self.layers = layers
30
- self.heads = heads
31
- self.num_channels = num_channels
32
- self.patch_size = patch_size
33
- self.image_size = image_size
34
- self.mlp_ratio = mlp_ratio
35
- self.layer_norm_eps = layer_norm_eps
36
- self.hidden_act = hidden_act
37
- if use_cls_token is None:
38
- use_cls_token = ues_cls_token
39
- self.ues_cls_token = use_cls_token
40
- self.use_cls_token = use_cls_token
41
- self.use_ln_pre = use_ln_pre
42
- self.ls_init_value = ls_init_value
43
- self.use_ln_post = use_ln_post
44
- self.use_abs_posemb = use_abs_posemb
45
- self.use_rope2d = use_rope2d
46
- super().__init__(**kwargs)
47
-
48
-
49
- class Step3p6TextConfig(PretrainedConfig):
50
- model_type = "step3p5"
51
- architectures = ["Step3p5ForCausalLM"]
52
-
53
- def __init__(
54
- self,
55
- hidden_size: int = 4096,
56
- intermediate_size: int = 11264,
57
- num_attention_heads: int = 64,
58
- num_attention_groups: int = 8,
59
- num_hidden_layers: int = 45,
60
- max_seq_len: int = 128000,
61
- vocab_size: int = 128815,
62
- rms_norm_eps: float = 1e-5,
63
- moe_intermediate_size: int = 1280,
64
- moe_num_experts: int = 288,
65
- moe_top_k: int = 8,
66
- rope_theta: float = 10000,
67
- rope_scaling: Optional[dict[str, Any]] = None,
68
- max_position_embeddings: int = 128000,
69
- share_expert_dims: int = 1280,
70
- share_expert_dim: Optional[int] = None,
71
- head_dim: int = 128,
72
- norm_expert_weight: bool = True,
73
- layer_types: list[str] = None,
74
- sliding_window: Optional[int] = None,
75
- pad_token_id: int = 1,
76
- attention_dropout: float = 0.0,
77
- use_head_wise_attn_gate: bool = False,
78
- use_moe_router_bias: bool = False,
79
- moe_router_activation: str = "softmax",
80
- moe_router_scaling_factor: float = 1.0,
81
- need_fp32_gate: bool = False,
82
- attention_other_setting: Optional[dict[str, Any]] = None,
83
- swiglu_limits: Optional[list[Optional[float]]] = None,
84
- swiglu_limits_shared: Optional[list[Optional[float]]] = None,
85
- use_rope_layers: Optional[list[bool]] = None,
86
- yarn_only_types: Optional[list[str]] = None,
87
- moe_layers_enum: tuple[int] = (3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
88
- 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
89
- 25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
90
- 35, 36, 37, 38, 39, 40, 41, 42, 43, 44),
91
- **kwargs,
92
- ) -> None:
93
- torch_dtype = kwargs.get("torch_dtype")
94
- layer_types = _normalize_per_layer_values(layer_types,
95
- num_hidden_layers)
96
- swiglu_limits = _normalize_per_layer_values(swiglu_limits,
97
- num_hidden_layers)
98
- swiglu_limits_shared = _normalize_per_layer_values(
99
- swiglu_limits_shared, num_hidden_layers)
100
- partial_rotary_factors = kwargs.get("partial_rotary_factors")
101
- kwargs["partial_rotary_factors"] = _normalize_per_layer_values(
102
- partial_rotary_factors, num_hidden_layers)
103
- if isinstance(rope_theta, list):
104
- rope_theta = _normalize_per_layer_values(rope_theta,
105
- num_hidden_layers)
106
- if isinstance(rope_scaling, dict):
107
- rope_scaling = dict(rope_scaling)
108
- if use_rope_layers:
109
- use_rope_layers = _normalize_per_layer_values(
110
- use_rope_layers, num_hidden_layers)
111
- if share_expert_dim is None:
112
- share_expert_dim = share_expert_dims
113
- self.hidden_size = hidden_size
114
- self.intermediate_size = intermediate_size
115
- self.num_attention_heads = num_attention_heads
116
- self.num_attention_groups = num_attention_groups
117
- self.num_hidden_layers = num_hidden_layers
118
- self.max_seq_len = max_seq_len
119
- self.vocab_size = vocab_size
120
- self.rms_norm_eps = rms_norm_eps
121
- self.moe_intermediate_size = moe_intermediate_size
122
- self.moe_num_experts = moe_num_experts
123
- self.moe_top_k = moe_top_k
124
- self.rope_theta = rope_theta
125
- self.rope_scaling = rope_scaling
126
- self.max_position_embeddings = max_position_embeddings
127
- self.share_expert_dim = share_expert_dim
128
- self.head_dim = head_dim
129
- self.norm_expert_weight = norm_expert_weight
130
- self.moe_layers_enum = moe_layers_enum
131
- self.layer_types = layer_types
132
- self.sliding_window = sliding_window
133
- self.pad_token_id = pad_token_id
134
- self.attention_dropout = attention_dropout
135
- self.use_head_wise_attn_gate = use_head_wise_attn_gate
136
- self.use_moe_router_bias = use_moe_router_bias
137
- self.moe_router_activation = moe_router_activation
138
- self.moe_router_scaling_factor = moe_router_scaling_factor
139
- self.need_fp32_gate = need_fp32_gate
140
- self.attention_other_setting = attention_other_setting
141
- self.swiglu_limits = swiglu_limits
142
- self.swiglu_limits_shared = swiglu_limits_shared
143
- self.use_rope_layers = use_rope_layers
144
- self.yarn_only_types = yarn_only_types
145
- super().__init__(**kwargs)
146
- if torch_dtype is not None:
147
- self.torch_dtype = torch_dtype
148
-
149
- def to_dict(self):
150
- output = super().to_dict()
151
- torch_dtype = getattr(self, "torch_dtype", None)
152
- if torch_dtype is not None:
153
- output["torch_dtype"] = torch_dtype
154
- return output
155
-
156
-
157
- def _normalize_per_layer_values(
158
- values: Optional[Sequence[Any]],
159
- num_hidden_layers: int,
160
- ) -> Optional[list[Any]]:
161
- if values is None:
162
- return None
163
- normalized = list(values)
164
- if not normalized:
165
- return normalized
166
- if len(normalized) < num_hidden_layers:
167
- normalized.extend([normalized[-1]] *
168
- (num_hidden_layers - len(normalized)))
169
- return normalized
170
-
171
- class Step3p6Config(PretrainedConfig):
172
- # This loader is a compatibility shim for original Step VL checkpoints
173
- # whose top-level config model_type is `step3p5v`.
174
- model_type = "step3p5v"
175
-
176
- def __init__(
177
- self,
178
- vision_config: Optional[Union[dict, StepRoboticsVisionEncoderConfig]] = None,
179
- text_config: Optional[Union[dict, Step3p6TextConfig]] = None,
180
- understand_projector_stride: int = 2,
181
- projector_bias: bool = False,
182
- image_token_id: int = 151679,
183
- **kwargs,
184
- ) -> None:
185
- shared_rope_scaling = kwargs.get("rope_scaling")
186
- if isinstance(shared_rope_scaling, dict):
187
- shared_rope_scaling = dict(shared_rope_scaling)
188
-
189
- if vision_config is None:
190
- vision_config = StepRoboticsVisionEncoderConfig()
191
- elif isinstance(vision_config, dict):
192
- vision_config = StepRoboticsVisionEncoderConfig(**vision_config)
193
- self.vision_config = vision_config
194
-
195
- if text_config is None:
196
- text_config = Step3p6TextConfig(rope_scaling=shared_rope_scaling)
197
- elif isinstance(text_config, dict):
198
- text_config = dict(text_config)
199
- if shared_rope_scaling is not None and "rope_scaling" not in text_config:
200
- text_config["rope_scaling"] = shared_rope_scaling
201
- text_config = Step3p6TextConfig(**text_config)
202
- elif shared_rope_scaling is not None and text_config.rope_scaling is None:
203
- text_config.rope_scaling = dict(shared_rope_scaling)
204
- self.text_config = text_config
205
-
206
- rope_scaling = kwargs.get("rope_scaling")
207
- if isinstance(rope_scaling, dict):
208
- kwargs["rope_scaling"] = dict(rope_scaling)
209
-
210
- self.understand_projector_stride = understand_projector_stride
211
- self.projector_bias = projector_bias
212
- self.hidden_size = text_config.hidden_size
213
- self.max_position_embeddings = text_config.max_position_embeddings
214
- self.image_token_id = image_token_id
215
- # Help Auto classes find the correct implementation when saving/loading.
216
- super().__init__(**kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configuration_step3p7.py CHANGED
@@ -91,23 +91,10 @@ class Step3p7TextConfig(PretrainedConfig):
91
  **kwargs,
92
  ) -> None:
93
  torch_dtype = kwargs.get("torch_dtype")
94
- layer_types = _normalize_per_layer_values(layer_types,
95
  num_hidden_layers)
96
- swiglu_limits = _normalize_per_layer_values(swiglu_limits,
97
- num_hidden_layers)
98
- swiglu_limits_shared = _normalize_per_layer_values(
99
- swiglu_limits_shared, num_hidden_layers)
100
- partial_rotary_factors = kwargs.get("partial_rotary_factors")
101
- kwargs["partial_rotary_factors"] = _normalize_per_layer_values(
102
- partial_rotary_factors, num_hidden_layers)
103
- if isinstance(rope_theta, list):
104
- rope_theta = _normalize_per_layer_values(rope_theta,
105
- num_hidden_layers)
106
  if isinstance(rope_scaling, dict):
107
  rope_scaling = dict(rope_scaling)
108
- if use_rope_layers:
109
- use_rope_layers = _normalize_per_layer_values(
110
- use_rope_layers, num_hidden_layers)
111
  if share_expert_dim is None:
112
  share_expert_dim = share_expert_dims
113
  self.hidden_size = hidden_size
@@ -128,7 +115,7 @@ class Step3p7TextConfig(PretrainedConfig):
128
  self.head_dim = head_dim
129
  self.norm_expert_weight = norm_expert_weight
130
  self.moe_layers_enum = moe_layers_enum
131
- self.layer_types = layer_types
132
  self.sliding_window = sliding_window
133
  self.pad_token_id = pad_token_id
134
  self.attention_dropout = attention_dropout
@@ -145,6 +132,7 @@ class Step3p7TextConfig(PretrainedConfig):
145
  super().__init__(**kwargs)
146
  if torch_dtype is not None:
147
  self.torch_dtype = torch_dtype
 
148
 
149
  def to_dict(self):
150
  output = super().to_dict()
 
91
  **kwargs,
92
  ) -> None:
93
  torch_dtype = kwargs.get("torch_dtype")
94
+ trim_layer_types = _normalize_per_layer_values(layer_types,
95
  num_hidden_layers)
 
 
 
 
 
 
 
 
 
 
96
  if isinstance(rope_scaling, dict):
97
  rope_scaling = dict(rope_scaling)
 
 
 
98
  if share_expert_dim is None:
99
  share_expert_dim = share_expert_dims
100
  self.hidden_size = hidden_size
 
115
  self.head_dim = head_dim
116
  self.norm_expert_weight = norm_expert_weight
117
  self.moe_layers_enum = moe_layers_enum
118
+ self.layer_types = trim_layer_types
119
  self.sliding_window = sliding_window
120
  self.pad_token_id = pad_token_id
121
  self.attention_dropout = attention_dropout
 
132
  super().__init__(**kwargs)
133
  if torch_dtype is not None:
134
  self.torch_dtype = torch_dtype
135
+ self.layer_types = layer_types
136
 
137
  def to_dict(self):
138
  output = super().to_dict()
modeling_step3p6.py DELETED
@@ -1,1324 +0,0 @@
1
- # Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved.
2
- #
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- import copy
16
- import inspect
17
- from dataclasses import dataclass
18
- from typing import Callable, Optional, Tuple, Union, TypedDict, Literal
19
- from dataclasses import dataclass
20
- from PIL import Image
21
-
22
- import torch
23
- import torch.nn as nn
24
- import torch.nn.functional as F
25
- from transformers.activations import ACT2FN
26
- from transformers.cache_utils import Cache, DynamicCache
27
- from transformers.generation import GenerationMixin
28
- from transformers.masking_utils import (create_causal_mask,
29
- create_sliding_window_causal_mask)
30
- from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
31
- from transformers.modeling_layers import GradientCheckpointingLayer
32
- from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
33
- from transformers.modeling_rope_utils import (ROPE_INIT_FUNCTIONS,
34
- dynamic_rope_update)
35
- from transformers.modeling_utils import (ALL_ATTENTION_FUNCTIONS,
36
- PreTrainedModel)
37
- from transformers.processing_utils import Unpack
38
- from transformers.utils import TransformersKwargs, can_return_tuple, logging
39
- from .configuration_step3p6 import Step3p6Config, Step3p6TextConfig
40
- from .vision_encoder import StepRoboticsVisionEncoder
41
-
42
-
43
- logger = logging.get_logger(__name__)
44
- _MASK_INPUT_EMBEDS_ARG = ("inputs_embeds" if "inputs_embeds" in
45
- inspect.signature(create_causal_mask).parameters else
46
- "input_embeds")
47
-
48
- __all__ = [
49
- "Step3p6Model",
50
- ]
51
-
52
-
53
- class StepVLImagePixelInputs(TypedDict):
54
- type: Literal["pixel_values"]
55
- pixel_values: torch.Tensor
56
- patch_pixel_values: Optional[torch.Tensor]
57
- num_patches: list[int]
58
-
59
-
60
- class StepVLImageEmbeddingInputs(TypedDict):
61
- type: Literal["image_embeds"]
62
- image_embeds: torch.Tensor
63
-
64
-
65
- StepVLImageInputs = Union[StepVLImagePixelInputs,
66
- StepVLImageEmbeddingInputs]
67
-
68
-
69
- @dataclass
70
- class Step3p6CausalLMOutputWithPast(ModelOutput):
71
- r"""
72
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
73
- Language modeling loss (for next-token prediction).
74
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
75
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
76
- past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
77
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
78
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
79
- Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
80
- `past_key_values` input) to speed up sequential decoding.
81
- """
82
-
83
- loss: Optional[torch.FloatTensor] = None
84
- last_hidden_state: Optional[torch.FloatTensor] = None
85
- logits: torch.FloatTensor = None
86
- past_key_values: Optional[list[torch.FloatTensor]] = None
87
- hidden_states: Optional[tuple[torch.FloatTensor]] = None
88
- attentions: Optional[tuple[torch.FloatTensor]] = None
89
- image_hidden_states: Optional[torch.FloatTensor] = None
90
-
91
- def _flatten_embeddings(embeddings) -> torch.Tensor:
92
- """
93
- Recursively flattens and concatenates NestedTensors on all but the last
94
- dimension.
95
- """
96
-
97
- if isinstance(embeddings, torch.Tensor):
98
- # Flatten all but the last dimension.
99
- return embeddings.flatten(0, -2)
100
-
101
- return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))
102
-
103
- def _embedding_count_expression(embeddings) -> str:
104
- """
105
- Constructs a debugging representation of the number of embeddings in the
106
- NestedTensors.
107
- """
108
-
109
- if isinstance(embeddings, torch.Tensor):
110
- return " x ".join([str(dim) for dim in embeddings.shape[:-1]])
111
-
112
- return " + ".join(
113
- _embedding_count_expression(inner) for inner in embeddings)
114
-
115
- def _merge_multimodal_embeddings(
116
- inputs_embeds: torch.Tensor,
117
- is_multimodal: torch.Tensor,
118
- multimodal_embeddings,
119
- ) -> torch.Tensor:
120
- """
121
- Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
122
- positions in ``inputs_embeds`` corresponding to placeholder tokens in
123
- ``input_ids``.
124
- Note:
125
- This updates ``inputs_embeds`` in place.
126
- """
127
- num_expected_tokens = is_multimodal.sum().item()
128
- assert isinstance(num_expected_tokens, int)
129
-
130
- flattened = _flatten_embeddings(multimodal_embeddings)
131
- if flattened.shape[0] != num_expected_tokens:
132
- expr = _embedding_count_expression(multimodal_embeddings)
133
- raise ValueError(
134
- f"Attempted to assign {expr} = {flattened.shape[0]} "
135
- f"multimodal tokens to {num_expected_tokens} placeholders")
136
-
137
- is_multimodal = is_multimodal.to(inputs_embeds.device)
138
- flattened = flattened.to(inputs_embeds.device)
139
- inputs_embeds[is_multimodal] = flattened
140
- return inputs_embeds
141
-
142
- def merge_multimodal_embeddings(
143
- input_ids: torch.Tensor,
144
- inputs_embeds: torch.Tensor,
145
- multimodal_embeddings,
146
- placeholder_token_id: Union[int, list[int]],
147
- ) -> torch.Tensor:
148
- """
149
- Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
150
- positions in ``inputs_embeds`` corresponding to placeholder tokens in
151
- ``input_ids``.
152
-
153
- ``placeholder_token_id`` can be a list of token ids (e.g, token ids
154
- of img_start, img_break, and img_end tokens) when needed: This means
155
- the order of these tokens in the ``input_ids`` MUST MATCH the order of
156
- their embeddings in ``multimodal_embeddings`` since we need to
157
- slice-merge instead of individually scattering.
158
- For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
159
- - T is text token
160
- - S is image start token
161
- - I is image embedding token
162
- - B is image break token
163
- - E is image end token.
164
-
165
- Then the image embeddings (that correspond to I's) from vision encoder
166
- must be padded with embeddings of S, B, and E in the same order of
167
- input_ids for a correct embedding merge.
168
- Note:
169
- This updates ``inputs_embeds`` in place.
170
- """
171
- if isinstance(placeholder_token_id, list):
172
- placeholder_token_id = torch.tensor(placeholder_token_id,
173
- device=input_ids.device)
174
- return _merge_multimodal_embeddings(
175
- inputs_embeds,
176
- torch.isin(input_ids, placeholder_token_id),
177
- multimodal_embeddings,
178
- )
179
-
180
- return _merge_multimodal_embeddings(
181
- inputs_embeds,
182
- (input_ids == placeholder_token_id),
183
- multimodal_embeddings,
184
- )
185
-
186
- class Step3p6PreTrainedModel(PreTrainedModel):
187
- # Link this model family to its configuration class so PreTrainedModel.from_pretrained
188
- # can load the config instead of failing with a NoneType error.
189
- config_class = Step3p6Config
190
- supports_gradient_checkpointing = True
191
- _skip_keys_device_placement = ["past_key_values"]
192
- _supports_flash_attn = False
193
- _supports_sdpa = True
194
- _supports_flex_attn = True
195
- _supports_static_cache = True
196
- _supports_attention_backend = True
197
-
198
- @classmethod
199
- def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
200
- **kwargs):
201
- key_mapping = getattr(cls, "_checkpoint_conversion_mapping", None)
202
- if key_mapping is not None and kwargs.get("key_mapping") is None:
203
- # Transformers only applies checkpoint renaming when key_mapping is
204
- # passed explicitly; inheriting the class attribute alone is not enough.
205
- kwargs["key_mapping"] = copy.deepcopy(key_mapping)
206
- return super().from_pretrained(pretrained_model_name_or_path,
207
- *model_args, **kwargs)
208
-
209
-
210
-
211
- class Step3p6RotaryEmbedding(nn.Module):
212
-
213
- def __init__(self, config: Step3p6TextConfig, device=None, layer_idx=None):
214
- super().__init__()
215
- # BC: "rope_type" was originally "type"
216
- self.layer_idx = layer_idx
217
- self.original_rope_parameters = None
218
- if config.rope_parameters is not None:
219
- self.original_rope_parameters = config.rope_parameters
220
- config.rope_parameters = dict(config.rope_parameters)
221
- self.rope_type = config.rope_parameters.get(
222
- "rope_type", config.rope_parameters.get("type"))
223
- else:
224
- self.rope_type = "default"
225
- self.max_seq_len_cached = config.max_position_embeddings
226
- self.original_max_seq_len = config.max_position_embeddings
227
-
228
- partial_rotary_factors = getattr(config, "partial_rotary_factors",
229
- None)
230
- if partial_rotary_factors is not None:
231
- config.partial_rotary_factor = partial_rotary_factors[
232
- self.layer_idx]
233
- else:
234
- config.partial_rotary_factor = 1.0
235
-
236
- self.rope_theta = config.rope_theta
237
- if isinstance(config.rope_theta, list):
238
- self.rope_theta = config.rope_theta.copy()
239
- config.rope_theta = self.rope_theta[self.layer_idx]
240
-
241
- self.config = copy.copy(config)
242
- if config.rope_parameters is not None:
243
- self.config.rope_parameters = dict(config.rope_parameters)
244
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
245
- inv_freq, self.attention_scaling = self.rope_init_fn(
246
- self.config, device)
247
-
248
- self.register_buffer("inv_freq", inv_freq, persistent=False)
249
- self.original_inv_freq = self.inv_freq
250
- config.rope_theta = self.rope_theta
251
- config.rope_parameters = self.original_rope_parameters
252
-
253
- @torch.no_grad()
254
- @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
255
- def forward(self, x, position_ids):
256
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(
257
- position_ids.shape[0], -1, 1).to(x.device)
258
- position_ids_expanded = position_ids[:, None, :].float().to(x.device)
259
-
260
- device_type = x.device.type if isinstance(
261
- x.device.type, str) and x.device.type != "mps" else "cpu"
262
- with torch.autocast(device_type=device_type,
263
- enabled=False): # Force float32
264
- freqs = (inv_freq_expanded.float()
265
- @ position_ids_expanded.float()).transpose(1, 2)
266
- emb = torch.cat((freqs, freqs), dim=-1)
267
- cos = emb.cos() * self.attention_scaling
268
- sin = emb.sin() * self.attention_scaling
269
-
270
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
271
-
272
-
273
- def rotate_half(x):
274
- """Rotates half the hidden dims of the input."""
275
- x1 = x[..., :x.shape[-1] // 2]
276
- x2 = x[..., x.shape[-1] // 2:]
277
- return torch.cat((-x2, x1), dim=-1)
278
-
279
-
280
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
281
- """Applies Rotary Position Embedding to the query and key tensors.
282
-
283
- Args:
284
- q (`torch.Tensor`): The query tensor.
285
- k (`torch.Tensor`): The key tensor.
286
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
287
- sin (`torch.Tensor`): The sine part of the rotary embedding.
288
- position_ids (`torch.Tensor`, *optional*):
289
- Deprecated and unused.
290
- unsqueeze_dim (`int`, *optional*, defaults to 1):
291
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
292
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
293
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
294
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
295
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
296
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
297
- Returns:
298
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
299
- """
300
- rotary_dim = cos.shape[-1]
301
- q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
302
- k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
303
-
304
- # Apply rotary embeddings on the first half or full tensor
305
- q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
306
- k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
307
-
308
- # Concatenate back to full shape
309
- q_embed = torch.cat([q_embed, q_pass], dim=-1)
310
- k_embed = torch.cat([k_embed, k_pass], dim=-1)
311
- return q_embed, k_embed
312
-
313
-
314
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
315
- """
316
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
317
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
318
- """
319
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
320
- if n_rep == 1:
321
- return hidden_states
322
- hidden_states = hidden_states[:, :,
323
- None, :, :].expand(batch,
324
- num_key_value_heads,
325
- n_rep, slen, head_dim)
326
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
327
- head_dim)
328
-
329
-
330
- # Adapted from transformers.models.llama.modeling_llama.eager_attention_forward -> llama4 doesn't cast attn weights to fp32
331
- def eager_attention_forward(
332
- module: nn.Module,
333
- query: torch.Tensor,
334
- key: torch.Tensor,
335
- value: torch.Tensor,
336
- attention_mask: Optional[torch.Tensor],
337
- scaling: float,
338
- dropout: float = 0.0,
339
- **kwargs,
340
- ):
341
- key_states = repeat_kv(key, module.num_key_value_groups)
342
- value_states = repeat_kv(value, module.num_key_value_groups)
343
- # breakpoint()
344
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
345
- if attention_mask is not None:
346
- causal_mask = attention_mask[:, :, :, :key_states.shape[-2]]
347
- attn_weights = attn_weights + causal_mask
348
-
349
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
350
- attn_weights = nn.functional.dropout(attn_weights,
351
- p=dropout,
352
- training=module.training)
353
- attn_output = torch.matmul(attn_weights, value_states)
354
- attn_output = attn_output.transpose(1, 2).contiguous()
355
-
356
- return attn_output, attn_weights
357
-
358
- @dataclass
359
- class Step3p6CausalLMOutputWithPast(ModelOutput):
360
- r"""
361
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
362
- Language modeling loss (for next-token prediction).
363
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
364
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
365
- past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
366
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
367
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
368
- Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
369
- `past_key_values` input) to speed up sequential decoding.
370
- """
371
-
372
- loss: Optional[torch.FloatTensor] = None
373
- last_hidden_state: Optional[torch.FloatTensor] = None
374
- logits: torch.FloatTensor = None
375
- past_key_values: Optional[list[torch.FloatTensor]] = None
376
- hidden_states: Optional[tuple[torch.FloatTensor]] = None
377
- attentions: Optional[tuple[torch.FloatTensor]] = None
378
-
379
-
380
- class Step3p6MLP(nn.Module):
381
-
382
- def __init__(self, config, intermediate_size=None, swiglu_limit=None):
383
- super().__init__()
384
- self.config = config
385
- self.hidden_size = config.hidden_size
386
- self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
387
- self.gate_proj = nn.Linear(self.hidden_size,
388
- self.intermediate_size,
389
- bias=False)
390
- self.up_proj = nn.Linear(self.hidden_size,
391
- self.intermediate_size,
392
- bias=False)
393
- self.down_proj = nn.Linear(self.intermediate_size,
394
- self.hidden_size,
395
- bias=False)
396
- self.act_fn = ACT2FN["silu"]
397
- self.limit = swiglu_limit
398
-
399
- def forward(self, x):
400
- up = self.up_proj(x)
401
- gate = self.act_fn(self.gate_proj(x))
402
- if self.limit is not None:
403
- gate = gate.clamp(min=None, max=self.limit)
404
- up = up.clamp(min=-self.limit, max=self.limit)
405
-
406
- return self.down_proj(gate * up)
407
-
408
-
409
- def sigmoid_routing_function(gating_output: torch.Tensor, topk: int,
410
- renormalize: bool):
411
- gating_output = gating_output.float()
412
- gate_prob = torch.sigmoid(gating_output)
413
- gate_prob = gate_prob / gate_prob.sum(dim=-1, keepdim=True)
414
- topk_prob, indices = torch.topk(gate_prob, k=topk, dim=1)
415
- expert_topk_weight = topk_prob
416
- if renormalize:
417
- expert_topk_weight = expert_topk_weight / torch.sum(
418
- expert_topk_weight, dim=-1, keepdim=True)
419
- return expert_topk_weight, indices
420
-
421
-
422
- def softmax_routing_function(gating_output: torch.Tensor, top_k: int,
423
- renormalize: bool):
424
- gating_output = gating_output.float()
425
- gate_prob = torch.softmax(gating_output, dim=-1)
426
- gate_prob = gate_prob / gate_prob.sum(dim=-1, keepdim=True)
427
- topk_prob, indices = torch.topk(gate_prob, k=top_k, dim=1)
428
- expert_topk_weight = topk_prob
429
- if renormalize:
430
- expert_topk_weight = expert_topk_weight / torch.sum(
431
- expert_topk_weight, dim=-1, keepdim=True)
432
- return expert_topk_weight, indices.to(torch.int32)
433
-
434
-
435
- class MoELinear(nn.Module):
436
-
437
- def __init__(self, num_experts, in_features, out_features):
438
- super().__init__()
439
- self.num_experts = num_experts
440
- self.in_features = in_features
441
- self.out_features = out_features
442
- self.weight = nn.Parameter(
443
- torch.empty(num_experts, out_features, in_features))
444
-
445
- def forward(self, x, expert_id):
446
- x = F.linear(x.float(), self.weight[expert_id].float())
447
- return x
448
-
449
-
450
- class Step3p6MoEMLP(nn.Module):
451
-
452
- def __init__(self, config, swiglu_limit=None):
453
- super().__init__()
454
- self.num_experts = config.moe_num_experts
455
- self.top_k = config.moe_top_k
456
- self.hidden_size = config.hidden_size
457
- self.moe_intermediate_size = config.moe_intermediate_size
458
-
459
- self.use_moe_router_bias = config.use_moe_router_bias
460
- if self.use_moe_router_bias:
461
- self.router_bias = nn.Parameter(torch.zeros(config.moe_num_experts,
462
- dtype=torch.float32),
463
- requires_grad=False)
464
- self.custom_routing_function = self.router_bias_func
465
- elif config.moe_router_activation == "sigmoid":
466
- self.custom_routing_function = sigmoid_routing_function
467
- else:
468
- self.custom_routing_function = None
469
- self.need_fp32_gate = config.need_fp32_gate
470
- self.routed_scaling_factor = getattr(config,
471
- "moe_router_scaling_factor", 1.0)
472
-
473
- # gating
474
- self.gate = nn.Linear(self.hidden_size, self.num_experts, bias=False)
475
-
476
- self.act_fn = ACT2FN["silu"]
477
- self.limit = swiglu_limit
478
-
479
- self.up_proj = MoELinear(self.num_experts, self.hidden_size,
480
- self.moe_intermediate_size)
481
- self.gate_proj = MoELinear(self.num_experts, self.hidden_size,
482
- self.moe_intermediate_size)
483
- self.down_proj = MoELinear(self.num_experts,
484
- self.moe_intermediate_size,
485
- self.hidden_size)
486
-
487
- def router_bias_func(self, gating_output: torch.Tensor, topk: int,
488
- renormalize: bool):
489
- gate_prob = torch.sigmoid(gating_output.float())
490
- gate_prob_with_bias = gate_prob + self.router_bias.unsqueeze(0)
491
- _, indices = torch.topk(gate_prob_with_bias, k=topk, dim=1)
492
- topk_prob = torch.gather(gate_prob, 1, indices)
493
- expert_topk_weight = topk_prob
494
- if renormalize:
495
- expert_topk_weight = expert_topk_weight / (
496
- torch.sum(expert_topk_weight, dim=-1, keepdim=True) + 1e-20)
497
- return expert_topk_weight, indices
498
-
499
- def get_expert_output(self, inputs: torch.Tensor, expert_id):
500
- #if self.limit is None:
501
- up = self.up_proj(inputs, expert_id)
502
- gate = self.act_fn(self.gate_proj(inputs, expert_id))
503
- if self.limit is not None:
504
- gate = gate.clamp(min=None, max=self.limit)
505
- up = up.clamp(min=-self.limit, max=self.limit)
506
-
507
- return self.down_proj(gate * up, expert_id)
508
-
509
- def forward(self, hidden_states):
510
- """ """
511
- batch_size, sequence_length, hidden_dim = hidden_states.shape
512
- hidden_states = hidden_states.view(-1, hidden_dim)
513
- if self.need_fp32_gate:
514
- router_logits = torch.matmul(hidden_states.to(torch.float32), self.gate.weight.t().to(torch.float32))
515
- else:
516
- # router_logits: (batch * sequence_length, n_experts)
517
- router_logits = self.gate(hidden_states)
518
-
519
- if self.custom_routing_function:
520
- routing_weights, selected_experts = self.custom_routing_function(
521
- router_logits, self.top_k, renormalize=True)
522
- else:
523
- routing_weights = F.softmax(router_logits,
524
- dim=1,
525
- dtype=torch.float)
526
- routing_weights, selected_experts = torch.topk(routing_weights,
527
- self.top_k,
528
- dim=-1)
529
-
530
- routing_weights = routing_weights * self.routed_scaling_factor
531
-
532
- final_hidden_states = torch.zeros(
533
- (batch_size * sequence_length, hidden_dim),
534
- dtype=hidden_states.dtype,
535
- device=hidden_states.device)
536
-
537
- # One hot encode the selected experts to create an expert mask
538
- # this will be used to easily index which expert is going to be sollicitated
539
- expert_mask = torch.nn.functional.one_hot(
540
- selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
541
-
542
- # Loop over all available experts in the model and perform the computation on each expert
543
- for expert_idx in range(self.num_experts):
544
- idx, top_x = torch.where(expert_mask[expert_idx])
545
-
546
- # Index the correct hidden states and compute the expert hidden state for
547
- # the current expert. We need to make sure to multiply the output hidden
548
- # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
549
- current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
550
- current_hidden_states = (
551
- self.get_expert_output(current_state, expert_idx) *
552
- routing_weights[top_x, idx, None])
553
-
554
- # However `index_add_` only support torch tensors for indexing so we'll use
555
- # the `top_x` tensor here.
556
- final_hidden_states.index_add_(
557
- 0, top_x, current_hidden_states.to(hidden_states.dtype))
558
- final_hidden_states = final_hidden_states.reshape(
559
- batch_size, sequence_length, hidden_dim)
560
- return final_hidden_states
561
-
562
-
563
- class Step3p6RMSNorm(nn.Module):
564
-
565
- def __init__(
566
- self,
567
- hidden_size: int,
568
- eps: float = 1e-5,
569
- ) -> None:
570
- super().__init__()
571
- self.weight = nn.Parameter(torch.ones(hidden_size))
572
- self.variance_epsilon = eps
573
-
574
- def forward(self, x: torch.Tensor) -> torch.Tensor:
575
- dtype = x.dtype
576
- x = x.float()
577
- variance = x.pow(2).mean(dim=-1, keepdim=True)
578
- normed = x * torch.rsqrt(variance + self.variance_epsilon)
579
- normed = normed * (self.weight.float() + 1)
580
- return normed.to(dtype)
581
- class Step3p6Attention(nn.Module):
582
-
583
- def __init__(self, config: Step3p6TextConfig, layer_idx):
584
- super().__init__()
585
- self.config = config
586
- self.layer_idx = layer_idx
587
- self.num_attention_heads = config.num_attention_heads
588
- self.num_key_value_heads = config.num_attention_groups
589
-
590
- layer_types = getattr(config, "layer_types", [])
591
- if layer_types:
592
- enable_sliding_window = layer_types[
593
- self.layer_idx] == "sliding_attention"
594
- else:
595
- enable_sliding_window = self.layer_idx % 2 == 0
596
-
597
- yarn_only_types = getattr(config, "yarn_only_types", None)
598
- if yarn_only_types and layer_types[
599
- self.layer_idx] not in yarn_only_types:
600
- config.rope_parameters = None
601
- else:
602
- config.rope_parameters = getattr(config, "rope_scaling", None)
603
-
604
- self.sliding_window = config.sliding_window
605
- if enable_sliding_window:
606
- self.num_attention_heads = config.attention_other_setting[
607
- "num_attention_heads"]
608
- self.num_key_value_heads = config.attention_other_setting[
609
- "num_attention_groups"]
610
-
611
- if self.sliding_window is not None and enable_sliding_window:
612
- self.sliding_window = (self.sliding_window)
613
- else:
614
- self.sliding_window = None
615
- self.head_dim = getattr(config, "head_dim",
616
- config.hidden_size // self.num_attention_heads)
617
- self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
618
-
619
- self.rotary_emb = Step3p6RotaryEmbedding(config, layer_idx=layer_idx)
620
-
621
- self.q_size = self.num_attention_heads * self.head_dim
622
- self.kv_size = self.num_key_value_heads * self.head_dim
623
- self.scaling = self.head_dim**-0.5
624
-
625
- self.q_proj = nn.Linear(config.hidden_size, self.q_size, bias=False)
626
- self.k_proj = nn.Linear(config.hidden_size, self.kv_size, bias=False)
627
- self.v_proj = nn.Linear(config.hidden_size, self.kv_size, bias=False)
628
- self.o_proj = nn.Linear(self.q_size, config.hidden_size, bias=False)
629
- self.attention_dropout = getattr(config, "attention_dropout", 0.0)
630
- self.q_norm = Step3p6RMSNorm(self.head_dim,
631
- eps=config.rms_norm_eps)
632
- self.k_norm = Step3p6RMSNorm(self.head_dim,
633
- eps=config.rms_norm_eps)
634
-
635
- self.use_head_wise_attn_gate = config.use_head_wise_attn_gate
636
- if self.use_head_wise_attn_gate:
637
- self.g_proj = nn.Linear(config.hidden_size,
638
- self.num_attention_heads,
639
- bias=False)
640
-
641
- self.use_rope = True
642
- use_rope_layers = getattr(config, "use_rope_layers", None)
643
- if use_rope_layers:
644
- self.use_rope = use_rope_layers[self.layer_idx]
645
-
646
- def forward(
647
- self,
648
- hidden_states: torch.Tensor,
649
- attention_mask: Optional[torch.Tensor],
650
- past_key_value: Optional[Cache] = None,
651
- cache_position: Optional[torch.LongTensor] = None,
652
- position_ids: Optional[torch.LongTensor] = None,
653
- **kwargs: Unpack[FlashAttentionKwargs],
654
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
655
- Optional[Tuple[torch.Tensor]]]:
656
- input_shape = hidden_states.shape[:-1]
657
- hidden_shape = (*input_shape, -1, self.head_dim)
658
-
659
- query_states = self.q_norm(
660
- self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
661
- key_states = self.k_norm(
662
- self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
663
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(
664
- 1, 2)
665
- if self.use_head_wise_attn_gate:
666
- gate_states = self.g_proj(hidden_states)
667
- cos, sin = self.rotary_emb(hidden_states, position_ids)
668
-
669
- # cos, sin = position_embeddings
670
- query_states, key_states = apply_rotary_pos_emb(
671
- query_states, key_states, cos, sin)
672
-
673
- # query_states, key_states = apply_rotary_pos_emb(query_norm_states, key_norm_states, cos, sin)
674
- if past_key_value is not None:
675
- # sin and cos are specific to RoPE models; position_ids needed for the static cache
676
- cache_kwargs = {
677
- "sin": sin,
678
- "cos": cos,
679
- "cache_position": cache_position
680
- }
681
- key_states, value_states = past_key_value.update(
682
- key_states, value_states, self.layer_idx, cache_kwargs)
683
-
684
- attention_interface: Callable = eager_attention_forward
685
- # TODO: considering FP8;
686
- # RuntimeError: Expected attn_mask dtype to be bool or float or to match query dtype,
687
- # but got attn_mask.dtype: long int and query.dtype: c10::BFloat16 instead.
688
- if self.config._attn_implementation != "eager":
689
- attention_interface = ALL_ATTENTION_FUNCTIONS[
690
- self.config._attn_implementation]
691
-
692
- attn_output, attn_weights = attention_interface(
693
- self,
694
- query_states,
695
- key_states,
696
- value_states,
697
- attention_mask,
698
- dropout=0.0 if not self.training else self.attention_dropout,
699
- scaling=self.scaling,
700
- sliding_window=self.sliding_window, # main diff with Llama
701
- **kwargs,
702
- )
703
- attn_output = attn_output.reshape(*input_shape, -1)
704
- if self.use_head_wise_attn_gate:
705
- output = attn_output.view(
706
- *attn_output.shape[:-1], self.num_attention_heads,
707
- self.head_dim) * gate_states.unsqueeze(-1).sigmoid()
708
- attn_output = output.view(*attn_output.shape)
709
- attn_output = self.o_proj(attn_output)
710
-
711
- return attn_output, attn_weights
712
-
713
-
714
- class Step3p6DecoderLayer(GradientCheckpointingLayer):
715
-
716
- def __init__(self, config, layer_idx):
717
- super().__init__()
718
- self.hidden_size = config.hidden_size
719
- self.layer_idx = layer_idx
720
- self.self_attn = Step3p6Attention(config, layer_idx)
721
- layer_types = getattr(config, "layer_types", None) or []
722
- if layer_types:
723
- self.attention_type = layer_types[layer_idx]
724
- else:
725
- self.attention_type = (
726
- "sliding_attention" if layer_idx % 2 == 0 else "full_attention"
727
- )
728
-
729
- moe_layers_enum = getattr(config, "moe_layers_enum", None)
730
- if moe_layers_enum is not None:
731
- if isinstance(moe_layers_enum, str):
732
- moe_layers_idx = [
733
- int(i) for i in moe_layers_enum.split(',') if i.strip()
734
- ]
735
- else:
736
- moe_layers_idx = [int(i) for i in moe_layers_enum]
737
- else:
738
- moe_layers_idx = [i for i in range(1, config.num_hidden_layers)]
739
- self.is_moe_layer = layer_idx in moe_layers_idx
740
- self.use_moe = False
741
-
742
- if config.swiglu_limits_shared and config.swiglu_limits_shared[
743
- layer_idx] is not None and config.swiglu_limits_shared[
744
- layer_idx] != 0:
745
- swiglu_limit_shared = config.swiglu_limits_shared[layer_idx]
746
- else:
747
- swiglu_limit_shared = None
748
- if config.swiglu_limits and config.swiglu_limits[
749
- layer_idx] is not None and config.swiglu_limits[layer_idx] != 0:
750
- swiglu_limit = config.swiglu_limits[layer_idx]
751
- else:
752
- swiglu_limit = None
753
- if self.is_moe_layer:
754
- self.moe = Step3p6MoEMLP(config, swiglu_limit=swiglu_limit) #
755
- self.share_expert = Step3p6MLP(
756
- config,
757
- intermediate_size=config.share_expert_dim,
758
- swiglu_limit=swiglu_limit_shared)
759
- self.use_moe = True
760
- else:
761
- self.mlp = Step3p6MLP(config,
762
- intermediate_size=config.intermediate_size,
763
- swiglu_limit=swiglu_limit_shared)
764
-
765
- self.input_layernorm = Step3p6RMSNorm(
766
- config.hidden_size,
767
- eps=config.rms_norm_eps)
768
- self.post_attention_layernorm = Step3p6RMSNorm(
769
- config.hidden_size,
770
- eps=config.rms_norm_eps)
771
-
772
- def forward(
773
- self,
774
- hidden_states: torch.Tensor,
775
- attention_mask: Optional[torch.Tensor] = None,
776
- position_ids: Optional[torch.LongTensor] = None,
777
- past_key_value: Optional[tuple[torch.Tensor]] = None,
778
- cache_position: Optional[torch.LongTensor] = None,
779
- **kwargs: Unpack[FlashAttentionKwargs],
780
- ) -> torch.FloatTensor:
781
- residual = hidden_states
782
- hidden_states = self.input_layernorm(hidden_states)
783
- hidden_states, _ = self.self_attn(
784
- hidden_states=hidden_states,
785
- attention_mask=attention_mask,
786
- position_ids=position_ids,
787
- past_key_value=past_key_value,
788
- cache_position=cache_position,
789
- **kwargs,
790
- )
791
- hidden_states = residual + hidden_states
792
-
793
- # Fully Connected
794
- residual = hidden_states
795
- hidden_states = self.post_attention_layernorm(hidden_states)
796
- if self.use_moe:
797
- share_output = self.share_expert(hidden_states)
798
- moe_output = self.moe(hidden_states)
799
- ffn_output = moe_output + share_output
800
- else:
801
- ffn_output = self.mlp(hidden_states)
802
- if isinstance(ffn_output, tuple):
803
- hidden_states, _ = ffn_output
804
- else:
805
- hidden_states = ffn_output
806
-
807
- hidden_states = residual + hidden_states
808
- return hidden_states
809
-
810
-
811
- class Step3p6PreTrainedModel(PreTrainedModel):
812
- # Link this model family to its configuration class so PreTrainedModel.from_pretrained
813
- # can load the config instead of failing with a NoneType error.
814
- config_class = Step3p6TextConfig
815
- supports_gradient_checkpointing = True
816
- _skip_keys_device_placement = ["past_key_values"]
817
- _keys_to_ignore_on_load_unexpected = [
818
- r"model\.layers\.45\.*",
819
- r"model\.layers\.46\.*",
820
- r"model\.layers\.47\.*"
821
- ]
822
- _supports_flash_attn = False
823
- _supports_sdpa = True
824
- _supports_flex_attn = True
825
- _supports_static_cache = True
826
- _supports_attention_backend = True
827
-
828
-
829
- class Step3p6TextModel(Step3p6PreTrainedModel, GenerationMixin):
830
- _no_split_modules = ["Step3p6DecoderLayer"]
831
- base_model_prefix = "model"
832
- _tied_weights_keys = ["lm_head.weight"]
833
- config: Step3p6TextConfig
834
- def __init__(self, config: Step3p6TextConfig):
835
- super().__init__(config)
836
- self.padding_idx = config.pad_token_id
837
- self.vocab_size = config.vocab_size
838
-
839
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size,
840
- self.padding_idx)
841
- self.layers = nn.ModuleList([
842
- Step3p6DecoderLayer(config, layer_idx)
843
- for layer_idx in range(config.num_hidden_layers)
844
- ])
845
- self.norm = Step3p6RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
846
- self.gradient_checkpointing = False
847
- layer_types = self.config.layer_types or []
848
- self.has_sliding_layers = (not layer_types or
849
- "sliding_attention" in layer_types)
850
-
851
- # Initialize weights and apply final processing
852
- self.post_init()
853
-
854
-
855
- def get_input_embeddings(self, input_ids):
856
- return self.embed_tokens(input_ids)
857
-
858
- @can_return_tuple
859
- def forward(
860
- self,
861
- input_ids: torch.LongTensor = None,
862
- attention_mask: Optional[torch.Tensor] = None,
863
- position_ids: Optional[torch.LongTensor] = None,
864
- past_key_values: Optional[Cache] = None,
865
- inputs_embeds: Optional[torch.FloatTensor] = None,
866
- use_cache: Optional[bool] = None,
867
- output_attentions: Optional[bool] = None,
868
- output_hidden_states: Optional[bool] = None,
869
- return_dict: Optional[bool] = None,
870
- cache_position: Optional[torch.LongTensor] = None,
871
- **kwargs: Unpack[TransformersKwargs],
872
- ) -> Union[tuple, BaseModelOutputWithPast]:
873
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
874
- output_hidden_states = (output_hidden_states
875
- if output_hidden_states is not None else
876
- self.config.output_hidden_states)
877
- use_cache = use_cache if use_cache is not None else self.config.use_cache
878
- return_dict = return_dict if return_dict is not None else getattr(
879
- self.config, "return_dict", True)
880
- if (input_ids is None) ^ (inputs_embeds is not None):
881
- raise ValueError(
882
- "You must specify exactly one of input_ids or inputs_embeds")
883
-
884
- if self.gradient_checkpointing and self.training and use_cache:
885
- logger.warning_once(
886
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
887
- )
888
- use_cache = False
889
-
890
- if inputs_embeds is None:
891
- inputs_embeds = self.embed_tokens(
892
- input_ids.to(self.embed_tokens.weight.device))
893
-
894
- if use_cache and past_key_values is None:
895
- past_key_values = DynamicCache()
896
-
897
- if cache_position is None:
898
- past_seen_tokens = past_key_values.get_seq_length(
899
- ) if past_key_values is not None else 0
900
- cache_position = torch.arange(past_seen_tokens,
901
- past_seen_tokens +
902
- inputs_embeds.shape[1],
903
- device=inputs_embeds.device)
904
-
905
- if position_ids is None:
906
- position_ids = cache_position.unsqueeze(0)
907
-
908
- hidden_states = inputs_embeds
909
-
910
- # It may already have been prepared by e.g. `generate`
911
- if not isinstance(causal_mask_mapping := attention_mask, dict):
912
- # Prepare mask arguments
913
- mask_kwargs = {
914
- "config": self.config,
915
- "attention_mask": attention_mask,
916
- "cache_position": cache_position,
917
- "past_key_values": past_key_values,
918
- "position_ids": position_ids,
919
- }
920
- mask_kwargs[_MASK_INPUT_EMBEDS_ARG] = inputs_embeds
921
- # Create the masks
922
- causal_mask_mapping = {
923
- "full_attention": create_causal_mask(**mask_kwargs),
924
- }
925
-
926
- # The sliding window alternating layers are not always activated depending on the config
927
- if self.has_sliding_layers:
928
- causal_mask_mapping[
929
- "sliding_attention"] = create_sliding_window_causal_mask(
930
- **mask_kwargs)
931
-
932
- # # create position embeddings to be shared across the decoder layers
933
- # decoder layers
934
- all_hidden_states = () if output_hidden_states else None
935
- all_self_attns = () if output_attentions else None
936
- for decoder_layer in self.layers[:self.config.num_hidden_layers]:
937
- if output_hidden_states:
938
- all_hidden_states += (hidden_states, )
939
-
940
- layer_outputs = decoder_layer(
941
- hidden_states,
942
- attention_mask=causal_mask_mapping[
943
- decoder_layer.attention_type],
944
- position_ids=position_ids,
945
- past_key_value=past_key_values,
946
- output_attentions=output_attentions,
947
- use_cache=use_cache,
948
- cache_position=cache_position,
949
- **kwargs,
950
- )
951
-
952
- hidden_states = layer_outputs
953
-
954
- hidden_states = self.norm(hidden_states)
955
-
956
- return BaseModelOutputWithPast(
957
- last_hidden_state=hidden_states,
958
- past_key_values=past_key_values if use_cache else None,
959
- hidden_states=all_hidden_states,
960
- attentions=all_self_attns,
961
- )
962
-
963
-
964
- class Step3p6Model(Step3p6PreTrainedModel, GenerationMixin):
965
- config: Step3p6Config
966
- _tied_weights_keys = ["lm_head.weight"]
967
- base_model_prefix = ""
968
-
969
- def __init__(self, config: Step3p6Config):
970
- super().__init__(config)
971
- self.vision_model = StepRoboticsVisionEncoder(config.vision_config)
972
- self.language_model = Step3p6TextModel(config.text_config)
973
- self.vocab_size = config.text_config.vocab_size
974
- self.vit_large_projector = nn.Linear(
975
- config.vision_config.width * 4,
976
- config.text_config.hidden_size,
977
- bias=config.projector_bias)
978
- self.image_placeholder_token_id = config.image_token_id
979
-
980
- # Initialize weights and apply final processing
981
- self.post_init()
982
-
983
- def get_input_embeddings(
984
- self,
985
- input_ids: torch.Tensor,
986
- multimodal_embeddings = None,
987
- ) -> torch.Tensor:
988
- # breakpoint()
989
- input_ids = input_ids.squeeze(0)
990
- if multimodal_embeddings is None:
991
- inputs_embeds = self.language_model.get_input_embeddings(input_ids)
992
- else:
993
- is_text = input_ids != self.config.image_token_id
994
- text_ids = input_ids[is_text]
995
- text_embeds = self.language_model.get_input_embeddings(text_ids)
996
-
997
- inputs_embeds = torch.empty(input_ids.shape[0],
998
- text_embeds.shape[-1],
999
- dtype=text_embeds.dtype,
1000
- device=text_embeds.device)
1001
- inputs_embeds[is_text] = text_embeds
1002
- inputs_embeds = merge_multimodal_embeddings(
1003
- input_ids, inputs_embeds, multimodal_embeddings,
1004
- self.config.image_token_id)
1005
- inputs_embeds = inputs_embeds.unsqueeze(0)
1006
- return inputs_embeds
1007
-
1008
-
1009
- def set_input_embeddings(self, value):
1010
- return self.language_model.set_input_embeddings(value)
1011
-
1012
- def set_decoder(self, decoder):
1013
- self.language_model = decoder
1014
-
1015
- def get_decoder(self):
1016
- return self.language_model
1017
-
1018
- def _parse_and_validate_image_input(
1019
- self, **kwargs: object) -> Optional[StepVLImageInputs]:
1020
- pixel_values = kwargs.pop("pixel_values", None)
1021
- patch_pixel_values = kwargs.pop("patch_pixel_values", None)
1022
- num_patches = kwargs.pop("num_patches", None)
1023
- image_embeds = kwargs.pop("image_embeds", None)
1024
-
1025
- if pixel_values is None and image_embeds is None:
1026
- return None
1027
-
1028
- if pixel_values is not None:
1029
- # pixel_values = flatten_bn(pixel_values, concat=True)
1030
- if pixel_values.dim() >= 3:
1031
- pixel_values = pixel_values.view(-1, *pixel_values.shape[-3:])
1032
- if patch_pixel_values is not None:
1033
- # patch_pixel_values = flatten_bn(patch_pixel_values,
1034
- # concat=True)
1035
- patch_pixel_values = patch_pixel_values.view(
1036
- -1, *patch_pixel_values.shape[-3:])
1037
- # Handle empty patch_pixel_values by setting to None
1038
- if patch_pixel_values.shape[0] == 0:
1039
- patch_pixel_values = None
1040
-
1041
- return StepVLImagePixelInputs(
1042
- type="pixel_values",
1043
- pixel_values=pixel_values.to(self.dtype).to(self.device),
1044
- patch_pixel_values=patch_pixel_values.to(self.dtype).to(
1045
- self.device) if patch_pixel_values is not None else None,
1046
- num_patches=num_patches,
1047
- )
1048
-
1049
- if image_embeds is not None:
1050
- if image_embeds.dim() == 2 or image_embeds.dim() >= 3:
1051
- image_embeds = image_embeds.view(-1, image_embeds.shape[-1])
1052
- else:
1053
- raise ValueError(
1054
- f"Unexpected shape for image_embeds: {image_embeds.shape}")
1055
-
1056
- return StepVLImageEmbeddingInputs(
1057
- type="image_embeds",
1058
- image_embeds=image_embeds.to(self.dtype).to(self.device),
1059
- )
1060
- return None
1061
-
1062
- def _process_image_features(self,
1063
- image_features: torch.Tensor) -> torch.Tensor:
1064
- B, P = image_features.shape[:2]
1065
- HW = int(P ** 0.5)
1066
- image_features = image_features.permute(0, 2, 1).view(B, -1, HW, HW)
1067
- image_features = self.vision_model.vit_downsampler1(image_features)
1068
- image_features = self.vision_model.vit_downsampler2(image_features)
1069
-
1070
- B, C, HW, HW = image_features.shape
1071
- image_features = image_features.view(B, -1, HW * HW).permute(0, 2, 1)
1072
- image_features = self.vit_large_projector(image_features)
1073
- return image_features
1074
-
1075
- def _get_vision_model_output(self,
1076
- input_tensor: torch.Tensor) -> torch.Tensor:
1077
- return self.vision_model(input_tensor)
1078
-
1079
- def _process_image_input(
1080
- self, image_input: StepVLImageInputs) -> tuple[torch.Tensor, ...]:
1081
-
1082
- if image_input["type"] == "image_embeds":
1083
- image_features = image_input["image_embeds"]
1084
- else:
1085
- image_features = self._get_vision_model_output(
1086
- image_input["pixel_values"])
1087
- patch_image_features = self._get_vision_model_output(
1088
- image_input["patch_pixel_values"]
1089
- ) if image_input["patch_pixel_values"] is not None else None
1090
- num_patches = image_input["num_patches"]
1091
-
1092
- image_features = self._process_image_features(image_features)
1093
- patch_image_features = self._process_image_features(
1094
- patch_image_features) if patch_image_features is not None else None
1095
-
1096
- merged_image_features = []
1097
- cur_patch_idx = 0
1098
- for i, num_patch in enumerate(num_patches):
1099
- cur_feature = []
1100
- if num_patch > 0:
1101
- patch_slice = patch_image_features[
1102
- cur_patch_idx:cur_patch_idx + num_patch]
1103
- cur_feature.append(patch_slice.view(-1, patch_slice.shape[-1]))
1104
- cur_feature.append(image_features[i].view(
1105
- -1, image_features.shape[-1]))
1106
- cur_patch_idx += num_patch
1107
- merged_image_features.append(
1108
- torch.cat(cur_feature) if len(cur_feature) >
1109
- 1 else cur_feature[0])
1110
-
1111
- return merged_image_features
1112
-
1113
- def get_multimodal_embeddings(self, **kwargs):
1114
- # breakpoint()
1115
- image_input = self._parse_and_validate_image_input(**kwargs)
1116
- if image_input is None:
1117
- return None
1118
- vision_embeddings = self._process_image_input(image_input)
1119
- return vision_embeddings
1120
-
1121
- @can_return_tuple
1122
- def forward(
1123
- self,
1124
- input_ids: torch.LongTensor = None,
1125
- attention_mask: Optional[torch.Tensor] = None,
1126
- position_ids: Optional[torch.LongTensor] = None,
1127
- past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
1128
- inputs_embeds: Optional[torch.FloatTensor] = None,
1129
- labels: Optional[torch.LongTensor] = None,
1130
- use_cache: Optional[bool] = None,
1131
- output_attentions: Optional[bool] = None,
1132
- output_hidden_states: Optional[bool] = None,
1133
- return_dict: Optional[bool] = None,
1134
- cache_position: Optional[torch.LongTensor] = None,
1135
- logits_to_keep: Union[int, torch.Tensor] = 0,
1136
- images: Optional[list[Image.Image]] = None,
1137
- **kwargs: Unpack[TransformersKwargs],
1138
- ) -> Union[tuple, Step3p6CausalLMOutputWithPast]:
1139
- r"""
1140
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1141
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1142
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1143
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1144
- Example:
1145
- ```python
1146
- >>> from transformers import AutoTokenizer, Llama4ForCausalLM
1147
- >>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf")
1148
- >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf")
1149
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
1150
- >>> inputs = tokenizer(prompt, return_tensors="pt")
1151
- >>> # Generate
1152
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1153
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1154
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1155
- ```"""
1156
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1157
- output_hidden_states = (
1158
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1159
- )
1160
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1161
-
1162
- if inputs_embeds is None:
1163
- input_ids = input_ids
1164
- vision_embeddings = self.get_multimodal_embeddings(**kwargs)
1165
- inputs_embeds = self.get_input_embeddings(input_ids,
1166
- vision_embeddings)
1167
- input_ids = None
1168
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1169
- outputs = self.language_model(
1170
- input_ids=None,
1171
- position_ids=position_ids,
1172
- attention_mask=attention_mask,
1173
- past_key_values=past_key_values,
1174
- inputs_embeds=inputs_embeds,
1175
- use_cache=use_cache,
1176
- output_attentions=output_attentions,
1177
- output_hidden_states=output_hidden_states,
1178
- return_dict=True,
1179
- cache_position=cache_position,
1180
- **kwargs,
1181
- )
1182
-
1183
- output = Step3p6CausalLMOutputWithPast(
1184
- last_hidden_state=outputs.last_hidden_state,
1185
- past_key_values=outputs.past_key_values,
1186
- attentions=outputs.attentions,
1187
- )
1188
- return output if return_dict else output.to_tuple()
1189
-
1190
-
1191
- class Step3p6ForConditionalGeneration(Step3p6PreTrainedModel, GenerationMixin):
1192
- _checkpoint_conversion_mapping = {
1193
- "^vision_model": "model.vision_model",
1194
- r"^model(?!\.(language_model|vision_model))": "model.language_model",
1195
- "^vit_large_projector": "model.vit_large_projector"
1196
- }
1197
- _tied_weights_keys = ["lm_head.weight"]
1198
- config: Step3p6Config
1199
-
1200
- def __init__(self, config: Step3p6Config):
1201
- super().__init__(config)
1202
- self.model = Step3p6Model(config)
1203
- self.lm_head = nn.Linear(config.hidden_size,
1204
- config.text_config.vocab_size,
1205
- bias=False)
1206
-
1207
- self.post_init()
1208
-
1209
- def get_input_embeddings(self):
1210
- return self.model.get_input_embeddings()
1211
-
1212
- def set_input_embeddings(self, value):
1213
- self.model.set_input_embeddings(value)
1214
-
1215
- def get_output_embeddings(self):
1216
- return self.model.get_output_embeddings()
1217
-
1218
- def set_output_embeddings(self, new_embeddings):
1219
- self.model.set_output_embeddings(new_embeddings)
1220
-
1221
- def set_decoder(self, decoder):
1222
- self.model.set_decoder(decoder)
1223
-
1224
- def get_decoder(self):
1225
- return self.model.get_decoder()
1226
-
1227
- @property
1228
- def language_model(self):
1229
- return self.model.language_model
1230
-
1231
- @property
1232
- def visual(self):
1233
- return self.model.vision_model
1234
-
1235
- def forward(
1236
- self,
1237
- input_ids: torch.LongTensor = None,
1238
- pixel_values: Optional[torch.Tensor] = None,
1239
- num_patches=None,
1240
- patch_pixel_values=None,
1241
- patch_newline_mask=None,
1242
- image_embeds: Optional[torch.FloatTensor] = None,
1243
- attention_mask: Optional[torch.Tensor] = None,
1244
- position_ids: Optional[torch.LongTensor] = None,
1245
- past_key_values: Optional[Cache] = None,
1246
- inputs_embeds: Optional[torch.FloatTensor] = None,
1247
- labels: Optional[torch.LongTensor] = None,
1248
- use_cache: Optional[bool] = None,
1249
- output_attentions: Optional[bool] = None,
1250
- output_hidden_states: Optional[bool] = None,
1251
- return_dict: Optional[bool] = None,
1252
- cache_position: Optional[torch.LongTensor] = None,
1253
- **kwargs: Unpack[TransformersKwargs],
1254
- ) -> Union[tuple, Step3p6CausalLMOutputWithPast]:
1255
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1256
- output_hidden_states = (
1257
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1258
- )
1259
-
1260
- outputs = self.model(
1261
- input_ids=input_ids,
1262
- num_patches = num_patches,
1263
- patch_pixel_values = patch_pixel_values,
1264
- patch_newline_mask=patch_newline_mask,
1265
- position_ids=position_ids,
1266
- attention_mask=attention_mask,
1267
- past_key_values=past_key_values,
1268
- inputs_embeds=inputs_embeds,
1269
- use_cache=use_cache,
1270
- output_attentions=output_attentions,
1271
- output_hidden_states=output_hidden_states,
1272
- return_dict=return_dict,
1273
- cache_position=cache_position,
1274
- **kwargs,
1275
- )
1276
-
1277
- hidden_states = outputs.last_hidden_state
1278
- logits = self.lm_head(hidden_states)
1279
-
1280
- los = None
1281
- if labels is not None:
1282
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
1283
-
1284
- return Step3p6CausalLMOutputWithPast(
1285
- logits=logits,
1286
- )
1287
-
1288
-
1289
- def prepare_inputs_for_generation(
1290
- self,
1291
- input_ids,
1292
- past_key_values=None,
1293
- inputs_embeds=None,
1294
- pixel_values=None,
1295
- patch_pixel_values=None,
1296
- num_patches=None,
1297
- image_embeds=None,
1298
- attention_mask=None,
1299
- cache_position=None,
1300
- logits_to_keep=None,
1301
- **kwargs,
1302
- ):
1303
- model_inputs = super().prepare_inputs_for_generation(
1304
- input_ids,
1305
- past_key_values=past_key_values,
1306
- inputs_embeds=inputs_embeds,
1307
- attention_mask=attention_mask,
1308
- cache_position=cache_position,
1309
- logits_to_keep=logits_to_keep,
1310
- **kwargs,
1311
- )
1312
-
1313
- if cache_position[0] == 0:
1314
- # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
1315
- # Otherwise we need pixel values to be passed to model
1316
- model_inputs["pixel_values"] = pixel_values
1317
-
1318
- return model_inputs
1319
-
1320
- def _fix_state_dict_key_on_load(self, key: str) -> tuple[str, bool]:
1321
- if key.startswith("language_model."):
1322
- return key[len("language_model."):], True
1323
-
1324
- return key, False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling_step3p7.py CHANGED
@@ -199,36 +199,40 @@ class Step3p7PreTrainedModel(PreTrainedModel):
199
  class Step3p7RotaryEmbedding(nn.Module):
200
  def __init__(self, config: Step3p7TextConfig, device=None, layer_idx=None):
201
  super().__init__()
202
- # BC: "rope_type" was originally "type"
203
  self.layer_idx = layer_idx
204
- self.original_rope_parameters = None
205
- if config.rope_parameters is not None:
206
- self.original_rope_parameters = config.rope_parameters
207
- config.rope_parameters = dict(config.rope_parameters)
208
- self.rope_type = config.rope_parameters.get(
209
- "rope_type", config.rope_parameters.get("type")
210
- )
211
- else:
212
- self.rope_type = "default"
213
  self.max_seq_len_cached = config.max_position_embeddings
214
  self.original_max_seq_len = config.max_position_embeddings
215
 
216
- partial_rotary_factors = getattr(
217
- config, "partial_rotary_factors", None
218
- )
 
 
 
219
  if partial_rotary_factors is not None:
220
- config.partial_rotary_factor = partial_rotary_factors[self.layer_idx]
221
- else:
222
- config.partial_rotary_factor = 1.0
223
 
224
- self.rope_theta = config.rope_theta
225
- if isinstance(config.rope_theta, list):
226
- self.rope_theta = config.rope_theta.copy()
227
- config.rope_theta = self.rope_theta[self.layer_idx]
228
 
229
  self.config = copy.copy(config)
 
 
 
230
  if config.rope_parameters is not None:
231
- self.config.rope_parameters = dict(config.rope_parameters)
 
 
 
 
 
 
 
 
 
 
232
  self.rope_init_fn = self.compute_default_rope_parameters
233
  if self.rope_type != "default":
234
  self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
@@ -238,8 +242,6 @@ class Step3p7RotaryEmbedding(nn.Module):
238
 
239
  self.register_buffer("inv_freq", inv_freq, persistent=False)
240
  self.original_inv_freq = self.inv_freq
241
- config.rope_theta = self.rope_theta
242
- config.rope_parameters = self.original_rope_parameters
243
 
244
  @torch.no_grad()
245
  @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
@@ -288,10 +290,14 @@ class Step3p7RotaryEmbedding(nn.Module):
288
  post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
289
  """
290
  base = config.rope_theta
291
- dim = (
 
 
 
292
  getattr(config, "head_dim", None)
293
  or config.hidden_size // config.num_attention_heads
294
  )
 
295
 
296
  attention_factor = 1.0 # Unused in this type of RoPE
297
 
@@ -968,7 +974,6 @@ class Step3p7TextModel(Step3p7TextPreTrainedModel, GenerationMixin):
968
  mask_kwargs = {
969
  "config": self.config,
970
  "attention_mask": attention_mask,
971
- "cache_position": cache_position,
972
  "past_key_values": past_key_values,
973
  "position_ids": position_ids,
974
  }
@@ -1381,7 +1386,12 @@ class Step3p7ForConditionalGeneration(Step3p7PreTrainedModel, GenerationMixin):
1381
  **kwargs,
1382
  )
1383
 
1384
- if cache_position[0] == 0:
 
 
 
 
 
1385
  # During cached decoding, input ids no longer contain image tokens,
1386
  # so pixel values should only be passed at the first step.
1387
  model_inputs["pixel_values"] = pixel_values
@@ -1392,4 +1402,4 @@ class Step3p7ForConditionalGeneration(Step3p7PreTrainedModel, GenerationMixin):
1392
  if key.startswith("language_model."):
1393
  return key[len("language_model.") :], True
1394
 
1395
- return key, False
 
199
  class Step3p7RotaryEmbedding(nn.Module):
200
  def __init__(self, config: Step3p7TextConfig, device=None, layer_idx=None):
201
  super().__init__()
 
202
  self.layer_idx = layer_idx
 
 
 
 
 
 
 
 
 
203
  self.max_seq_len_cached = config.max_position_embeddings
204
  self.original_max_seq_len = config.max_position_embeddings
205
 
206
+ rope_theta = config.rope_theta
207
+ if isinstance(rope_theta, list):
208
+ rope_theta = rope_theta[0 if layer_idx is None else layer_idx]
209
+
210
+ partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
211
+ partial_rotary_factors = getattr(config, "partial_rotary_factors", None)
212
  if partial_rotary_factors is not None:
213
+ partial_rotary_factor = partial_rotary_factors[
214
+ 0 if layer_idx is None else layer_idx
215
+ ]
216
 
217
+ self.rope_theta = rope_theta
218
+ self.partial_rotary_factor = partial_rotary_factor
 
 
219
 
220
  self.config = copy.copy(config)
221
+ self.config.rope_theta = rope_theta
222
+ self.config.partial_rotary_factor = partial_rotary_factor
223
+
224
  if config.rope_parameters is not None:
225
+ self.config.rope_parameters = copy.deepcopy(config.rope_parameters)
226
+ self.config.rope_parameters["rope_theta"] = rope_theta
227
+ self.config.rope_parameters["partial_rotary_factor"] = (
228
+ partial_rotary_factor
229
+ )
230
+ self.rope_type = self.config.rope_parameters.get(
231
+ "rope_type", self.config.rope_parameters.get("type")
232
+ )
233
+ else:
234
+ self.rope_type = "default"
235
+
236
  self.rope_init_fn = self.compute_default_rope_parameters
237
  if self.rope_type != "default":
238
  self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
 
242
 
243
  self.register_buffer("inv_freq", inv_freq, persistent=False)
244
  self.original_inv_freq = self.inv_freq
 
 
245
 
246
  @torch.no_grad()
247
  @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
 
290
  post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
291
  """
292
  base = config.rope_theta
293
+ partial_rotary_factor = getattr(
294
+ config, "partial_rotary_factor", 1.0
295
+ )
296
+ head_dim = (
297
  getattr(config, "head_dim", None)
298
  or config.hidden_size // config.num_attention_heads
299
  )
300
+ dim = int(head_dim * partial_rotary_factor)
301
 
302
  attention_factor = 1.0 # Unused in this type of RoPE
303
 
 
974
  mask_kwargs = {
975
  "config": self.config,
976
  "attention_mask": attention_mask,
 
977
  "past_key_values": past_key_values,
978
  "position_ids": position_ids,
979
  }
 
1386
  **kwargs,
1387
  )
1388
 
1389
+ generation_cache_position = model_inputs.get("cache_position", cache_position)
1390
+ is_prefill = past_key_values is None
1391
+ if generation_cache_position is not None and generation_cache_position.numel() > 0:
1392
+ is_prefill = generation_cache_position[0].item() == 0
1393
+
1394
+ if is_prefill:
1395
  # During cached decoding, input ids no longer contain image tokens,
1396
  # so pixel values should only be passed at the first step.
1397
  model_inputs["pixel_values"] = pixel_values
 
1402
  if key.startswith("language_model."):
1403
  return key[len("language_model.") :], True
1404
 
1405
+ return key, False
processing_step3.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BaseImageProcessor, ImageProcessingMixin
2
+ from transformers.processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
3
+ import math
4
+ from typing import Iterable, Optional, Tuple, List, TypedDict, Literal, Union, overload
5
+
6
+ from PIL import Image
7
+ import torch
8
+ import numpy as np
9
+ import torchvision
10
+ from torch import nn
11
+ from torch.nn import functional as F, LayerNorm
12
+ from torchvision.transforms.functional import InterpolationMode
13
+ from transformers.activations import ACT2FN
14
+ from torchvision import transforms
15
+ from torchvision.transforms.functional import InterpolationMode
16
+ from transformers.feature_extraction_utils import BatchFeature, TensorType
17
+ from transformers.image_utils import ImageInput
18
+ from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
19
+ from transformers.tokenization_utils_tokenizers import TokenizersBackend
20
+ from math import ceil
21
+ from itertools import product
22
+
23
+
24
+
25
+ MAX_IMAGE_SIZE: int = 3024
26
+
27
+ class Step3VLImagePixelInputs(TypedDict):
28
+ type: Literal["pixel_values"]
29
+ pixel_values: torch.Tensor
30
+ patch_pixel_values: Optional[torch.Tensor]
31
+ num_patches: list[int]
32
+
33
+
34
+ class Step3VLImageEmbeddingInputs(TypedDict):
35
+ type: Literal["image_embeds"]
36
+ image_embeds: torch.Tensor
37
+
38
+
39
+ ImageWithPatches = tuple[Image.Image, list[Image.Image], list[int] | None]
40
+
41
+
42
+ class GPUToTensor(torch.nn.Module):
43
+
44
+ def forward(self, raw_image: Union[np.ndarray,
45
+ Image.Image]) -> torch.Tensor:
46
+ if isinstance(raw_image, Image.Image):
47
+ return transforms.ToTensor()(raw_image)
48
+ if raw_image.ndim == 2:
49
+ raw_image = raw_image[:, :, None].repeat(3, -1)
50
+ if torch.cuda.is_available():
51
+ device = torch.device("cuda")
52
+ else:
53
+ device = torch.device("cpu")
54
+ image_tensor = torch.from_numpy(raw_image).to(device)
55
+ image_tensor = torch.permute(image_tensor, (2, 0, 1)).contiguous()
56
+ if image_tensor.dtype == torch.uint8:
57
+ image_tensor = image_tensor.to(torch.float32).div(255)
58
+ return image_tensor
59
+
60
+ class Step3VisionProcessor(BaseImageProcessor):
61
+
62
+ def __init__(self, size, interpolation_mode="bicubic", patch_size=None):
63
+ mean = [0.48145466, 0.4578275, 0.40821073]
64
+ std = [0.26862954, 0.26130258, 0.27577711]
65
+ patch_size = patch_size if patch_size is not None else size
66
+
67
+ self.transform = transforms.Compose([
68
+ GPUToTensor(),
69
+ transforms.Normalize(mean, std),
70
+ transforms.Resize(
71
+ (size, size),
72
+ interpolation=InterpolationMode.BICUBIC if interpolation_mode
73
+ == "bicubic" else InterpolationMode.BILINEAR,
74
+ antialias=True),
75
+ ])
76
+
77
+ self.patch_transform = transforms.Compose([
78
+ GPUToTensor(),
79
+ transforms.Normalize(mean, std),
80
+ transforms.Resize(
81
+ (patch_size, patch_size),
82
+ interpolation=InterpolationMode.BICUBIC if interpolation_mode
83
+ == "bicubic" else InterpolationMode.BILINEAR,
84
+ antialias=True),
85
+ ]) if patch_size is not None else None
86
+
87
+ def __call__(self, image, is_patch=False):
88
+ if is_patch:
89
+ return {"pixel_values": self.patch_transform(image).unsqueeze(0)}
90
+ else:
91
+ return {"pixel_values": self.transform(image).unsqueeze(0)}
92
+
93
+ class ImagePatcher:
94
+ def determine_window_size(self, long: int, short: int) -> int:
95
+ if long <= 728:
96
+ return short if long / short > 1.5 else 0
97
+ return min(short, 504) if long / short > 4 else 504
98
+ def slide_window(
99
+ self,
100
+ width: int,
101
+ height: int,
102
+ sizes: list[tuple[int, int]],
103
+ steps: list[tuple[int, int]],
104
+ img_rate_thr: float = 0.6,
105
+ ) -> tuple[list[tuple[int, int, int, int]], tuple[int, int]]:
106
+ assert 1 >= img_rate_thr >= 0, "The `in_rate_thr` should lie in 0~1"
107
+ windows = []
108
+ # Sliding windows.
109
+ for size, step in zip(sizes, steps):
110
+ size_w, size_h = size
111
+ step_w, step_h = step
112
+
113
+ x_num = 1 if width <= size_w else ceil((width - size_w) / step_w +
114
+ 1)
115
+ x_start = [step_w * i for i in range(x_num)]
116
+ if len(x_start) > 1 and x_start[-1] + size_w > width:
117
+ x_start[-1] = width - size_w
118
+
119
+ y_num = 1 if height <= size_h else ceil((height - size_h) /
120
+ step_h + 1)
121
+ y_start = [step_h * i for i in range(y_num)]
122
+ if len(y_start) > 1 and y_start[-1] + size_h > height:
123
+ y_start[-1] = height - size_h
124
+
125
+ start = np.array(list(product(y_start, x_start)), dtype=int)
126
+ start[:, [0, 1]] = start[:, [1, 0]]
127
+ windows.append(np.concatenate([start, start + size], axis=1))
128
+ windows = np.concatenate(windows, axis=0)
129
+
130
+ return [(int(box[0]), int(box[1]), int(box[2] - box[0]),
131
+ int(box[3] - box[1])) for box in windows], (x_num, y_num)
132
+
133
+ def square_pad(self, img: Image.Image) -> Image.Image:
134
+ w, h = img.size
135
+ if w == h:
136
+ return img
137
+ size = max(w, h)
138
+ padded = Image.new(img.mode, (size, size), 0)
139
+ padded.paste(img, (0, 0))
140
+ return padded
141
+
142
+ def get_image_size_for_padding(self, img_width: int,
143
+ img_height: int) -> tuple[int, int]:
144
+ ratio = img_width / img_height
145
+ if min(img_height, img_width) < 32 and (ratio > 4 or ratio < 1 / 4):
146
+ new_size = max(img_height, img_width)
147
+ return new_size, new_size
148
+ return img_width, img_height
149
+
150
+ def get_image_size_for_preprocess(self, img_width: int,
151
+ img_height: int) -> tuple[int, int]:
152
+
153
+ if max(img_height, img_width) > MAX_IMAGE_SIZE:
154
+ scale_factor = MAX_IMAGE_SIZE / max(img_height, img_width)
155
+ img_width = int(img_width * scale_factor)
156
+ img_height = int(img_height * scale_factor)
157
+ return img_width, img_height
158
+
159
+ def get_image_size_for_crop(self, img_width: int, img_height: int,
160
+ window_size: int):
161
+ w_ratio = img_width / window_size
162
+ h_ratio = img_height / window_size
163
+
164
+ if w_ratio < 1:
165
+ width_new = img_width
166
+ else:
167
+ decimal_w = w_ratio - img_width // window_size
168
+ w_ratio = int(w_ratio) + 1 if decimal_w > 0.2 else int(w_ratio)
169
+ width_new = window_size * w_ratio
170
+ if h_ratio < 1:
171
+ height_new = img_height
172
+ else:
173
+ decimal_h = h_ratio - img_height // window_size
174
+ h_ratio = int(h_ratio) + 1 if decimal_h > 0.2 else int(h_ratio)
175
+ height_new = window_size * h_ratio
176
+ return int(width_new), int(height_new)
177
+
178
+ def patch_crop(self, img: Image.Image, i: int, j: int, th: int, tw: int):
179
+ target = img.crop((j, i, j + tw, i + th))
180
+ return target
181
+
182
+ def get_num_patches(self, img_width: int,
183
+ img_height: int) -> tuple[int, int]:
184
+ img_width, img_height = self.get_image_size_for_padding(
185
+ img_width, img_height)
186
+ img_width, img_height = self.get_image_size_for_preprocess(
187
+ img_width, img_height)
188
+ window_size = self.determine_window_size(max(img_height, img_width),
189
+ min(img_height, img_width))
190
+ if window_size == 0:
191
+ return 0, 0
192
+ else:
193
+ img_width, img_height = self.get_image_size_for_crop(
194
+ img_width, img_height, window_size)
195
+ center_list, (x_num, y_num) = self.slide_window(
196
+ img_width, img_height, [(window_size, window_size)],
197
+ [(window_size, window_size)])
198
+ full_rows = (len(center_list) - 1) // x_num + 1
199
+ if len(center_list) > 0 and len(center_list) % x_num == 0:
200
+ full_rows -= 1
201
+ return len(center_list), full_rows
202
+
203
+ def __call__(
204
+ self, img: Image.Image
205
+ ) -> tuple[Image.Image, list[Image.Image], list[bool] | None]:
206
+ img_width, img_height = img.size
207
+ new_img_width, new_img_height = self.get_image_size_for_padding(
208
+ img_width, img_height)
209
+ if new_img_width != img_width or new_img_height != img_height:
210
+ img = self.square_pad(img)
211
+ img_width, img_height = img.size
212
+
213
+ new_img_width, new_img_height = self.get_image_size_for_preprocess(
214
+ img_width, img_height)
215
+ img = img.resize((new_img_width, new_img_height),
216
+ Image.Resampling.BILINEAR)
217
+ window_size = self.determine_window_size(
218
+ max(new_img_height, new_img_width),
219
+ min(new_img_height, new_img_width))
220
+ # return img, [], None
221
+ if window_size == 0:
222
+ return img, [], None
223
+ else:
224
+ new_img_width, new_img_height = self.get_image_size_for_crop(
225
+ new_img_width, new_img_height, window_size)
226
+ if (new_img_width, new_img_height) != (img_width, img_height):
227
+ img_for_crop = img.resize((new_img_width, new_img_height),
228
+ Image.Resampling.BILINEAR)
229
+ else:
230
+ img_for_crop = img
231
+
232
+ patches = []
233
+ newlines = []
234
+ center_list, (x_num, y_num) = self.slide_window(
235
+ new_img_width, new_img_height, [(window_size, window_size)],
236
+ [(window_size, window_size)])
237
+ for patch_id, center_lf_point in enumerate(center_list):
238
+ x, y, patch_w, patch_h = center_lf_point
239
+ big_patch = self.patch_crop(img_for_crop, y, x, patch_h,
240
+ patch_w)
241
+ patches.append(big_patch)
242
+ if (patch_id + 1) % x_num == 0:
243
+ newlines.append(patch_id)
244
+
245
+ if newlines and newlines[-1] == len(patches) - 1:
246
+ newlines.pop()
247
+
248
+ return img, patches, [i in newlines for i in range(len(patches))] if len(patches) > 0 else None
249
+
250
+
251
+
252
+
253
+ class Step3VLProcessor(ProcessorMixin):
254
+ # Align ProcessorMixin with our custom components.
255
+ # We only have an image processor (not a feature extractor) plus a tokenizer.
256
+ attributes = ["tokenizer"]
257
+ tokenizer_class = "AutoTokenizer"
258
+
259
+ @classmethod
260
+ def _load_tokenizer_from_pretrained(
261
+ cls, sub_processor_type, pretrained_model_name_or_path, subfolder="", **kwargs
262
+ ):
263
+ return TokenizersBackend.from_pretrained(
264
+ pretrained_model_name_or_path,
265
+ subfolder=subfolder,
266
+ **kwargs,
267
+ )
268
+
269
+ def __init__(
270
+ self,
271
+ tokenizer=None,
272
+ chat_template=None,
273
+ **kwargs
274
+ ) -> None:
275
+ self.image_size = 728
276
+ self.patch_size = 504
277
+
278
+ self.image_preprocessor = Step3VisionProcessor(self.image_size,
279
+ "bilinear",
280
+ self.patch_size)
281
+
282
+ self.num_image_feature_size = 169
283
+ self.num_patch_feature_size = 81
284
+ self.image_token = "<im_patch>"
285
+ self.image_feature_placeholder = (self.image_token *
286
+ self.num_image_feature_size)
287
+ self.patch_feature_placeholder = (self.image_token *
288
+ self.num_patch_feature_size)
289
+ super().__init__(tokenizer=tokenizer, chat_template=chat_template, **kwargs)
290
+ self.patcher = ImagePatcher()
291
+
292
+ @property
293
+ def image_token_id(self) -> int:
294
+ return self.tokenizer.get_vocab()[self.image_token]
295
+
296
+ def get_num_image_tokens(self, img_width: int, img_height: int) -> int:
297
+ num_patches, num_newlines = self.patcher.get_num_patches(
298
+ img_width, img_height)
299
+
300
+ return num_patches * (
301
+ self.num_patch_feature_size +
302
+ 2) + self.num_image_feature_size + 2 + num_newlines
303
+
304
+ def _split_images(self,
305
+ images: list[Image.Image]) -> list[ImageWithPatches]:
306
+ result = []
307
+ for img in images:
308
+ result.append(self.patcher(img))
309
+ return result
310
+
311
+ def _convert_images_to_pixel_values(
312
+ self,
313
+ images: list[Image.Image],
314
+ is_patch: bool = False,
315
+ ) -> list[torch.Tensor]:
316
+ return [
317
+ self.image_preprocessor(img, is_patch=is_patch)["pixel_values"]
318
+ for img in images
319
+ ]
320
+
321
+ def _get_patch_repl(
322
+ self,
323
+ num_patches: int,
324
+ patch_newline_mask: list[bool] | None,
325
+ ) -> tuple[str, list[int]]:
326
+ text = ""
327
+ token_ids = []
328
+ for i in range(num_patches):
329
+ assert len(patch_newline_mask) == num_patches
330
+ text += f"<patch_start>{self.patch_feature_placeholder}<patch_end>"
331
+ token_ids.extend(
332
+ [self.tokenizer.convert_tokens_to_ids("<patch_start>")] +
333
+ [self.image_token_id] * self.num_patch_feature_size +
334
+ [self.tokenizer.convert_tokens_to_ids("<patch_end>")])
335
+ if patch_newline_mask and patch_newline_mask[i]:
336
+ text += "<patch_newline>"
337
+ token_ids.append(
338
+ self.tokenizer.convert_tokens_to_ids("<patch_newline>"))
339
+ return text, token_ids
340
+
341
+ def _get_image_repl(
342
+ self,
343
+ num_images: int,
344
+ ) -> tuple[str, list[int]]:
345
+ text = f"<im_start>{self.image_feature_placeholder}<im_end>"
346
+ token_ids = [
347
+ self.tokenizer.convert_tokens_to_ids("<im_start>")
348
+ ] + [self.image_token_id] * self.num_image_feature_size + [
349
+ self.tokenizer.convert_tokens_to_ids("<im_end>")
350
+ ]
351
+ return text * num_images, token_ids * num_images
352
+
353
+ def _get_image_repl_features(
354
+ self,
355
+ num_images: int,
356
+ num_patches: int,
357
+ patch_new_line_idx: Optional[list[bool]],
358
+ ) -> tuple[str, list[int]]:
359
+ if num_patches > 0:
360
+ patch_repl, patch_repl_ids = self._get_patch_repl(
361
+ num_patches, patch_new_line_idx)
362
+ else:
363
+ patch_repl = ""
364
+ patch_repl_ids = []
365
+ image_repl, image_repl_ids = self._get_image_repl(num_images)
366
+ return patch_repl + image_repl, patch_repl_ids + image_repl_ids
367
+
368
+ def replace_placeholder(self, text: str, placeholder: str,
369
+ repls: list[str]) -> str:
370
+ parts = text.split(placeholder)
371
+
372
+ if len(parts) - 1 != len(repls):
373
+ raise ValueError(
374
+ "The number of placeholders does not match the number of replacements." # noqa: E501
375
+ )
376
+
377
+ result = [parts[0]]
378
+ for i, repl in enumerate(repls):
379
+ result.append(repl)
380
+ result.append(parts[i + 1])
381
+
382
+ return "".join(result)
383
+
384
+ def __call__(
385
+ self,
386
+ text: Optional[Union[str, list[str]]] = None,
387
+ images: ImageInput | None = None,
388
+ return_tensors: Optional[Union[str, TensorType]] = None,
389
+ **kwargs,
390
+ ) -> BatchFeature:
391
+
392
+ if images is not None:
393
+ images = self.image_preprocessor.fetch_images(images)
394
+ if text is None:
395
+ text = []
396
+ if not isinstance(text, list):
397
+ text = [text]
398
+ if images is None:
399
+ images = []
400
+ elif not isinstance(images, list):
401
+ images = [images]
402
+ elif isinstance(images[0], list):
403
+ images = images[0]
404
+
405
+ if len(images) == 0:
406
+ image_inputs = {}
407
+ text_inputs = self.tokenizer(text)
408
+ else:
409
+ splitted_images_data = self._split_images(images)
410
+ pixel_values_lst = []
411
+ patch_pixel_values_lst = []
412
+ patch_newline_mask_lst = []
413
+ image_repl_str_lst = []
414
+ image_repl_ids_lst = []
415
+ num_patches = []
416
+ for raw_img, img_patches, patch_newline_mask in splitted_images_data: # noqa: E501
417
+ pixel_values_lst.extend(
418
+ self._convert_images_to_pixel_values([raw_img]))
419
+
420
+ if len(img_patches) > 0:
421
+ patch_pixel_values_lst.extend(
422
+ self._convert_images_to_pixel_values(img_patches,
423
+ is_patch=True))
424
+ num_patches.append(len(img_patches))
425
+
426
+ image_repl_str, image_repl_ids = self._get_image_repl_features(
427
+ 1, len(img_patches), patch_newline_mask)
428
+ image_repl_str_lst.append(image_repl_str)
429
+ image_repl_ids_lst.extend(image_repl_ids)
430
+
431
+ if patch_newline_mask is not None:
432
+ patch_newline_mask_lst.extend(patch_newline_mask)
433
+
434
+ image_inputs = {
435
+ "pixel_values": torch.cat(pixel_values_lst),
436
+ "num_patches": num_patches,
437
+ }
438
+ if patch_pixel_values_lst:
439
+ image_inputs["patch_pixel_values"] = torch.cat(
440
+ patch_pixel_values_lst)
441
+ if patch_newline_mask_lst:
442
+ image_inputs["patch_newline_mask"] = torch.tensor(
443
+ patch_newline_mask_lst, dtype=torch.bool)
444
+
445
+ text = [
446
+ self.replace_placeholder(t, self.image_token,
447
+ image_repl_str_lst) for t in text
448
+ ]
449
+ text_inputs = self.tokenizer(text)
450
+
451
+ return BatchFeature(
452
+ {
453
+ **text_inputs,
454
+ **image_inputs,
455
+ },
456
+ tensor_type=return_tensors,
457
+ )
458
+
459
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
460
+ def batch_decode(self, *args, **kwargs):
461
+ """
462
+ This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
463
+ refer to the docstring of this method for more information.
464
+ """
465
+ return self.tokenizer.batch_decode(*args, **kwargs)
466
+
467
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma
468
+ def decode(self, *args, **kwargs):
469
+ """
470
+ This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
471
+ the docstring of this method for more information.
472
+ """
473
+ return self.tokenizer.decode(*args, **kwargs)
474
+
475
+ __all__ = ["Step3VLProcessor"]