WinstonDeng commited on
Commit
c171c6a
·
verified ·
1 Parent(s): 457483d

add step-3.7-flash bf16 model libs

Browse files
configuration_step3p7.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional, Sequence, Union
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+
5
+ class StepRoboticsVisionEncoderConfig(PretrainedConfig):
6
+ model_type = "perception_encoder"
7
+
8
+ def __init__(
9
+ self,
10
+ width=1536,
11
+ layers=47,
12
+ heads=16,
13
+ num_channels=3,
14
+ image_size=728,
15
+ mlp_ratio = 8960/1536,
16
+ patch_size=14,
17
+ hidden_act="quick_gelu",
18
+ layer_norm_eps=1e-5,
19
+ ues_cls_token=False,
20
+ use_cls_token: Optional[bool] = None,
21
+ use_ln_pre=True,
22
+ use_ln_post=False,
23
+ use_abs_posemb=True,
24
+ use_rope2d=True,
25
+ ls_init_value=0.1,
26
+ **kwargs,
27
+ ):
28
+ self.width = width
29
+ self.layers = layers
30
+ self.heads = heads
31
+ self.num_channels = num_channels
32
+ self.patch_size = patch_size
33
+ self.image_size = image_size
34
+ self.mlp_ratio = mlp_ratio
35
+ self.layer_norm_eps = layer_norm_eps
36
+ self.hidden_act = hidden_act
37
+ if use_cls_token is None:
38
+ use_cls_token = ues_cls_token
39
+ self.ues_cls_token = use_cls_token
40
+ self.use_cls_token = use_cls_token
41
+ self.use_ln_pre = use_ln_pre
42
+ self.ls_init_value = ls_init_value
43
+ self.use_ln_post = use_ln_post
44
+ self.use_abs_posemb = use_abs_posemb
45
+ self.use_rope2d = use_rope2d
46
+ super().__init__(**kwargs)
47
+
48
+
49
+ class Step3p7TextConfig(PretrainedConfig):
50
+ model_type = "step3p5"
51
+ architectures = ["Step3p5ForCausalLM"]
52
+
53
+ def __init__(
54
+ self,
55
+ hidden_size: int = 4096,
56
+ intermediate_size: int = 11264,
57
+ num_attention_heads: int = 64,
58
+ num_attention_groups: int = 8,
59
+ num_hidden_layers: int = 45,
60
+ max_seq_len: int = 128000,
61
+ vocab_size: int = 128815,
62
+ rms_norm_eps: float = 1e-5,
63
+ moe_intermediate_size: int = 1280,
64
+ moe_num_experts: int = 288,
65
+ moe_top_k: int = 8,
66
+ rope_theta: float = 10000,
67
+ rope_scaling: Optional[dict[str, Any]] = None,
68
+ max_position_embeddings: int = 128000,
69
+ share_expert_dims: int = 1280,
70
+ share_expert_dim: Optional[int] = None,
71
+ head_dim: int = 128,
72
+ norm_expert_weight: bool = True,
73
+ layer_types: list[str] = None,
74
+ sliding_window: Optional[int] = None,
75
+ pad_token_id: int = 1,
76
+ attention_dropout: float = 0.0,
77
+ use_head_wise_attn_gate: bool = False,
78
+ use_moe_router_bias: bool = False,
79
+ moe_router_activation: str = "softmax",
80
+ moe_router_scaling_factor: float = 1.0,
81
+ need_fp32_gate: bool = False,
82
+ attention_other_setting: Optional[dict[str, Any]] = None,
83
+ swiglu_limits: Optional[list[Optional[float]]] = None,
84
+ swiglu_limits_shared: Optional[list[Optional[float]]] = None,
85
+ use_rope_layers: Optional[list[bool]] = None,
86
+ yarn_only_types: Optional[list[str]] = None,
87
+ moe_layers_enum: tuple[int] = (3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
88
+ 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
89
+ 25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
90
+ 35, 36, 37, 38, 39, 40, 41, 42, 43, 44),
91
+ **kwargs,
92
+ ) -> None:
93
+ torch_dtype = kwargs.get("torch_dtype")
94
+ layer_types = _normalize_per_layer_values(layer_types,
95
+ num_hidden_layers)
96
+ swiglu_limits = _normalize_per_layer_values(swiglu_limits,
97
+ num_hidden_layers)
98
+ swiglu_limits_shared = _normalize_per_layer_values(
99
+ swiglu_limits_shared, num_hidden_layers)
100
+ partial_rotary_factors = kwargs.get("partial_rotary_factors")
101
+ kwargs["partial_rotary_factors"] = _normalize_per_layer_values(
102
+ partial_rotary_factors, num_hidden_layers)
103
+ if isinstance(rope_theta, list):
104
+ rope_theta = _normalize_per_layer_values(rope_theta,
105
+ num_hidden_layers)
106
+ if isinstance(rope_scaling, dict):
107
+ rope_scaling = dict(rope_scaling)
108
+ if use_rope_layers:
109
+ use_rope_layers = _normalize_per_layer_values(
110
+ use_rope_layers, num_hidden_layers)
111
+ if share_expert_dim is None:
112
+ share_expert_dim = share_expert_dims
113
+ self.hidden_size = hidden_size
114
+ self.intermediate_size = intermediate_size
115
+ self.num_attention_heads = num_attention_heads
116
+ self.num_attention_groups = num_attention_groups
117
+ self.num_hidden_layers = num_hidden_layers
118
+ self.max_seq_len = max_seq_len
119
+ self.vocab_size = vocab_size
120
+ self.rms_norm_eps = rms_norm_eps
121
+ self.moe_intermediate_size = moe_intermediate_size
122
+ self.moe_num_experts = moe_num_experts
123
+ self.moe_top_k = moe_top_k
124
+ self.rope_theta = rope_theta
125
+ self.rope_scaling = rope_scaling
126
+ self.max_position_embeddings = max_position_embeddings
127
+ self.share_expert_dim = share_expert_dim
128
+ self.head_dim = head_dim
129
+ self.norm_expert_weight = norm_expert_weight
130
+ self.moe_layers_enum = moe_layers_enum
131
+ self.layer_types = layer_types
132
+ self.sliding_window = sliding_window
133
+ self.pad_token_id = pad_token_id
134
+ self.attention_dropout = attention_dropout
135
+ self.use_head_wise_attn_gate = use_head_wise_attn_gate
136
+ self.use_moe_router_bias = use_moe_router_bias
137
+ self.moe_router_activation = moe_router_activation
138
+ self.moe_router_scaling_factor = moe_router_scaling_factor
139
+ self.need_fp32_gate = need_fp32_gate
140
+ self.attention_other_setting = attention_other_setting
141
+ self.swiglu_limits = swiglu_limits
142
+ self.swiglu_limits_shared = swiglu_limits_shared
143
+ self.use_rope_layers = use_rope_layers
144
+ self.yarn_only_types = yarn_only_types
145
+ super().__init__(**kwargs)
146
+ if torch_dtype is not None:
147
+ self.torch_dtype = torch_dtype
148
+
149
+ def to_dict(self):
150
+ output = super().to_dict()
151
+ torch_dtype = getattr(self, "torch_dtype", None)
152
+ if torch_dtype is not None:
153
+ output["torch_dtype"] = torch_dtype
154
+ return output
155
+
156
+
157
+ def _normalize_per_layer_values(
158
+ values: Optional[Sequence[Any]],
159
+ num_hidden_layers: int,
160
+ ) -> Optional[list[Any]]:
161
+ if values is None:
162
+ return None
163
+ normalized = list(values)
164
+ if not normalized:
165
+ return normalized
166
+ if len(normalized) < num_hidden_layers:
167
+ normalized.extend([normalized[-1]] *
168
+ (num_hidden_layers - len(normalized)))
169
+ # Some checkpoints keep MTP/spec layer entries after the decoder layers.
170
+ # This config only builds num_hidden_layers decoder layers, and HF strict
171
+ # validation requires per-layer fields to match that decoder count.
172
+ return normalized[:num_hidden_layers]
173
+
174
+ class Step3p7Config(PretrainedConfig):
175
+ # This loader is a compatibility shim for original Step VL checkpoints
176
+ # whose top-level config model_type is `step3p7`.
177
+ model_type = "step3p7"
178
+
179
+ def __init__(
180
+ self,
181
+ vision_config: Optional[Union[dict, StepRoboticsVisionEncoderConfig]] = None,
182
+ text_config: Optional[Union[dict, Step3p7TextConfig]] = None,
183
+ understand_projector_stride: int = 2,
184
+ projector_bias: bool = False,
185
+ image_token_id: int = 151679,
186
+ **kwargs,
187
+ ) -> None:
188
+ shared_rope_scaling = kwargs.get("rope_scaling")
189
+ if isinstance(shared_rope_scaling, dict):
190
+ shared_rope_scaling = dict(shared_rope_scaling)
191
+
192
+ if vision_config is None:
193
+ vision_config = StepRoboticsVisionEncoderConfig()
194
+ elif isinstance(vision_config, dict):
195
+ vision_config = StepRoboticsVisionEncoderConfig(**vision_config)
196
+ self.vision_config = vision_config
197
+
198
+ if text_config is None:
199
+ text_config = Step3p7TextConfig(rope_scaling=shared_rope_scaling)
200
+ elif isinstance(text_config, dict):
201
+ text_config = dict(text_config)
202
+ if shared_rope_scaling is not None and "rope_scaling" not in text_config:
203
+ text_config["rope_scaling"] = shared_rope_scaling
204
+ text_config = Step3p7TextConfig(**text_config)
205
+ elif shared_rope_scaling is not None and text_config.rope_scaling is None:
206
+ text_config.rope_scaling = dict(shared_rope_scaling)
207
+ self.text_config = text_config
208
+
209
+ rope_scaling = kwargs.get("rope_scaling")
210
+ if isinstance(rope_scaling, dict):
211
+ kwargs["rope_scaling"] = dict(rope_scaling)
212
+
213
+ self.understand_projector_stride = understand_projector_stride
214
+ self.projector_bias = projector_bias
215
+ self.hidden_size = text_config.hidden_size
216
+ self.max_position_embeddings = text_config.max_position_embeddings
217
+ self.image_token_id = image_token_id
218
+ # Help Auto classes find the correct implementation when saving/loading.
219
+ super().__init__(**kwargs)
modeling_step3p7.py ADDED
File without changes
processing_step3.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BaseImageProcessor, ImageProcessingMixin
2
+ from transformers.processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
3
+ import math
4
+ from typing import Iterable, Optional, Tuple, List, TypedDict, Literal, Union, overload
5
+
6
+ from PIL import Image
7
+ import torch
8
+ import numpy as np
9
+ import torchvision
10
+ from torch import nn
11
+ from torch.nn import functional as F, LayerNorm
12
+ from torchvision.transforms.functional import InterpolationMode
13
+ from transformers.activations import ACT2FN
14
+ from torchvision import transforms
15
+ from torchvision.transforms.functional import InterpolationMode
16
+ from transformers.feature_extraction_utils import BatchFeature, TensorType
17
+ from transformers.image_utils import ImageInput
18
+ from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
19
+ from math import ceil
20
+ from itertools import product
21
+
22
+
23
+
24
+ MAX_IMAGE_SIZE: int = 3024
25
+
26
+ class Step3VLImagePixelInputs(TypedDict):
27
+ type: Literal["pixel_values"]
28
+ pixel_values: torch.Tensor
29
+ patch_pixel_values: Optional[torch.Tensor]
30
+ num_patches: list[int]
31
+
32
+
33
+ class Step3VLImageEmbeddingInputs(TypedDict):
34
+ type: Literal["image_embeds"]
35
+ image_embeds: torch.Tensor
36
+
37
+
38
+ ImageWithPatches = tuple[Image.Image, list[Image.Image], list[int] | None]
39
+
40
+
41
+ class GPUToTensor(torch.nn.Module):
42
+
43
+ def forward(self, raw_image: Union[np.ndarray,
44
+ Image.Image]) -> torch.Tensor:
45
+ if isinstance(raw_image, Image.Image):
46
+ return transforms.ToTensor()(raw_image)
47
+ if raw_image.ndim == 2:
48
+ raw_image = raw_image[:, :, None].repeat(3, -1)
49
+ if torch.cuda.is_available():
50
+ device = torch.device("cuda")
51
+ else:
52
+ device = torch.device("cpu")
53
+ image_tensor = torch.from_numpy(raw_image).to(device)
54
+ image_tensor = torch.permute(image_tensor, (2, 0, 1)).contiguous()
55
+ if image_tensor.dtype == torch.uint8:
56
+ image_tensor = image_tensor.to(torch.float32).div(255)
57
+ return image_tensor
58
+
59
+ class Step3VisionProcessor(BaseImageProcessor):
60
+
61
+ def __init__(self, size, interpolation_mode="bicubic", patch_size=None):
62
+ mean = [0.48145466, 0.4578275, 0.40821073]
63
+ std = [0.26862954, 0.26130258, 0.27577711]
64
+ patch_size = patch_size if patch_size is not None else size
65
+
66
+ self.transform = transforms.Compose([
67
+ GPUToTensor(),
68
+ transforms.Normalize(mean, std),
69
+ transforms.Resize(
70
+ (size, size),
71
+ interpolation=InterpolationMode.BICUBIC if interpolation_mode
72
+ == "bicubic" else InterpolationMode.BILINEAR,
73
+ antialias=True),
74
+ ])
75
+
76
+ self.patch_transform = transforms.Compose([
77
+ GPUToTensor(),
78
+ transforms.Normalize(mean, std),
79
+ transforms.Resize(
80
+ (patch_size, patch_size),
81
+ interpolation=InterpolationMode.BICUBIC if interpolation_mode
82
+ == "bicubic" else InterpolationMode.BILINEAR,
83
+ antialias=True),
84
+ ]) if patch_size is not None else None
85
+
86
+ def __call__(self, image, is_patch=False):
87
+ if is_patch:
88
+ return {"pixel_values": self.patch_transform(image).unsqueeze(0)}
89
+ else:
90
+ return {"pixel_values": self.transform(image).unsqueeze(0)}
91
+
92
+ class ImagePatcher:
93
+ def determine_window_size(self, long: int, short: int) -> int:
94
+ if long <= 728:
95
+ return short if long / short > 1.5 else 0
96
+ return min(short, 504) if long / short > 4 else 504
97
+ def slide_window(
98
+ self,
99
+ width: int,
100
+ height: int,
101
+ sizes: list[tuple[int, int]],
102
+ steps: list[tuple[int, int]],
103
+ img_rate_thr: float = 0.6,
104
+ ) -> tuple[list[tuple[int, int, int, int]], tuple[int, int]]:
105
+ assert 1 >= img_rate_thr >= 0, "The `in_rate_thr` should lie in 0~1"
106
+ windows = []
107
+ # Sliding windows.
108
+ for size, step in zip(sizes, steps):
109
+ size_w, size_h = size
110
+ step_w, step_h = step
111
+
112
+ x_num = 1 if width <= size_w else ceil((width - size_w) / step_w +
113
+ 1)
114
+ x_start = [step_w * i for i in range(x_num)]
115
+ if len(x_start) > 1 and x_start[-1] + size_w > width:
116
+ x_start[-1] = width - size_w
117
+
118
+ y_num = 1 if height <= size_h else ceil((height - size_h) /
119
+ step_h + 1)
120
+ y_start = [step_h * i for i in range(y_num)]
121
+ if len(y_start) > 1 and y_start[-1] + size_h > height:
122
+ y_start[-1] = height - size_h
123
+
124
+ start = np.array(list(product(y_start, x_start)), dtype=int)
125
+ start[:, [0, 1]] = start[:, [1, 0]]
126
+ windows.append(np.concatenate([start, start + size], axis=1))
127
+ windows = np.concatenate(windows, axis=0)
128
+
129
+ return [(int(box[0]), int(box[1]), int(box[2] - box[0]),
130
+ int(box[3] - box[1])) for box in windows], (x_num, y_num)
131
+
132
+ def square_pad(self, img: Image.Image) -> Image.Image:
133
+ w, h = img.size
134
+ if w == h:
135
+ return img
136
+ size = max(w, h)
137
+ padded = Image.new(img.mode, (size, size), 0)
138
+ padded.paste(img, (0, 0))
139
+ return padded
140
+
141
+ def get_image_size_for_padding(self, img_width: int,
142
+ img_height: int) -> tuple[int, int]:
143
+ ratio = img_width / img_height
144
+ if min(img_height, img_width) < 32 and (ratio > 4 or ratio < 1 / 4):
145
+ new_size = max(img_height, img_width)
146
+ return new_size, new_size
147
+ return img_width, img_height
148
+
149
+ def get_image_size_for_preprocess(self, img_width: int,
150
+ img_height: int) -> tuple[int, int]:
151
+
152
+ if max(img_height, img_width) > MAX_IMAGE_SIZE:
153
+ scale_factor = MAX_IMAGE_SIZE / max(img_height, img_width)
154
+ img_width = int(img_width * scale_factor)
155
+ img_height = int(img_height * scale_factor)
156
+ return img_width, img_height
157
+
158
+ def get_image_size_for_crop(self, img_width: int, img_height: int,
159
+ window_size: int):
160
+ w_ratio = img_width / window_size
161
+ h_ratio = img_height / window_size
162
+
163
+ if w_ratio < 1:
164
+ width_new = img_width
165
+ else:
166
+ decimal_w = w_ratio - img_width // window_size
167
+ w_ratio = int(w_ratio) + 1 if decimal_w > 0.2 else int(w_ratio)
168
+ width_new = window_size * w_ratio
169
+ if h_ratio < 1:
170
+ height_new = img_height
171
+ else:
172
+ decimal_h = h_ratio - img_height // window_size
173
+ h_ratio = int(h_ratio) + 1 if decimal_h > 0.2 else int(h_ratio)
174
+ height_new = window_size * h_ratio
175
+ return int(width_new), int(height_new)
176
+
177
+ def patch_crop(self, img: Image.Image, i: int, j: int, th: int, tw: int):
178
+ target = img.crop((j, i, j + tw, i + th))
179
+ return target
180
+
181
+ def get_num_patches(self, img_width: int,
182
+ img_height: int) -> tuple[int, int]:
183
+ img_width, img_height = self.get_image_size_for_padding(
184
+ img_width, img_height)
185
+ img_width, img_height = self.get_image_size_for_preprocess(
186
+ img_width, img_height)
187
+ window_size = self.determine_window_size(max(img_height, img_width),
188
+ min(img_height, img_width))
189
+ if window_size == 0:
190
+ return 0, 0
191
+ else:
192
+ img_width, img_height = self.get_image_size_for_crop(
193
+ img_width, img_height, window_size)
194
+ center_list, (x_num, y_num) = self.slide_window(
195
+ img_width, img_height, [(window_size, window_size)],
196
+ [(window_size, window_size)])
197
+ full_rows = (len(center_list) - 1) // x_num + 1
198
+ if len(center_list) > 0 and len(center_list) % x_num == 0:
199
+ full_rows -= 1
200
+ return len(center_list), full_rows
201
+
202
+ def __call__(
203
+ self, img: Image.Image
204
+ ) -> tuple[Image.Image, list[Image.Image], list[bool] | None]:
205
+ img_width, img_height = img.size
206
+ new_img_width, new_img_height = self.get_image_size_for_padding(
207
+ img_width, img_height)
208
+ if new_img_width != img_width or new_img_height != img_height:
209
+ img = self.square_pad(img)
210
+ img_width, img_height = img.size
211
+
212
+ new_img_width, new_img_height = self.get_image_size_for_preprocess(
213
+ img_width, img_height)
214
+ img = img.resize((new_img_width, new_img_height),
215
+ Image.Resampling.BILINEAR)
216
+ window_size = self.determine_window_size(
217
+ max(new_img_height, new_img_width),
218
+ min(new_img_height, new_img_width))
219
+ # return img, [], None
220
+ if window_size == 0:
221
+ return img, [], None
222
+ else:
223
+ new_img_width, new_img_height = self.get_image_size_for_crop(
224
+ new_img_width, new_img_height, window_size)
225
+ if (new_img_width, new_img_height) != (img_width, img_height):
226
+ img_for_crop = img.resize((new_img_width, new_img_height),
227
+ Image.Resampling.BILINEAR)
228
+ else:
229
+ img_for_crop = img
230
+
231
+ patches = []
232
+ newlines = []
233
+ center_list, (x_num, y_num) = self.slide_window(
234
+ new_img_width, new_img_height, [(window_size, window_size)],
235
+ [(window_size, window_size)])
236
+ for patch_id, center_lf_point in enumerate(center_list):
237
+ x, y, patch_w, patch_h = center_lf_point
238
+ big_patch = self.patch_crop(img_for_crop, y, x, patch_h,
239
+ patch_w)
240
+ patches.append(big_patch)
241
+ if (patch_id + 1) % x_num == 0:
242
+ newlines.append(patch_id)
243
+
244
+ if newlines and newlines[-1] == len(patches) - 1:
245
+ newlines.pop()
246
+
247
+ return img, patches, [i in newlines for i in range(len(patches))] if len(patches) > 0 else None
248
+
249
+
250
+
251
+
252
+ class Step3VLProcessor(ProcessorMixin):
253
+ # Align ProcessorMixin with our custom components.
254
+ # We only have an image processor (not a feature extractor) plus a tokenizer.
255
+ attributes = ["tokenizer"]
256
+ tokenizer_class = "AutoTokenizer"
257
+
258
+ def __init__(
259
+ self,
260
+ tokenizer=None,
261
+ chat_template=None,
262
+ **kwargs
263
+ ) -> None:
264
+ self.image_size = 728
265
+ self.patch_size = 504
266
+
267
+ self.image_preprocessor = Step3VisionProcessor(self.image_size,
268
+ "bilinear",
269
+ self.patch_size)
270
+
271
+ self.num_image_feature_size = 169
272
+ self.num_patch_feature_size = 81
273
+ self.image_token = "<im_patch>"
274
+ self.image_feature_placeholder = (self.image_token *
275
+ self.num_image_feature_size)
276
+ self.patch_feature_placeholder = (self.image_token *
277
+ self.num_patch_feature_size)
278
+ super().__init__(tokenizer=tokenizer, chat_template=chat_template, **kwargs)
279
+ self.patcher = ImagePatcher()
280
+
281
+ @property
282
+ def image_token_id(self) -> int:
283
+ return self.tokenizer.get_vocab()[self.image_token]
284
+
285
+ def get_num_image_tokens(self, img_width: int, img_height: int) -> int:
286
+ num_patches, num_newlines = self.patcher.get_num_patches(
287
+ img_width, img_height)
288
+
289
+ return num_patches * (
290
+ self.num_patch_feature_size +
291
+ 2) + self.num_image_feature_size + 2 + num_newlines
292
+
293
+ def _split_images(self,
294
+ images: list[Image.Image]) -> list[ImageWithPatches]:
295
+ result = []
296
+ for img in images:
297
+ result.append(self.patcher(img))
298
+ return result
299
+
300
+ def _convert_images_to_pixel_values(
301
+ self,
302
+ images: list[Image.Image],
303
+ is_patch: bool = False,
304
+ ) -> list[torch.Tensor]:
305
+ return [
306
+ self.image_preprocessor(img, is_patch=is_patch)["pixel_values"]
307
+ for img in images
308
+ ]
309
+
310
+ def _get_patch_repl(
311
+ self,
312
+ num_patches: int,
313
+ patch_newline_mask: list[bool] | None,
314
+ ) -> tuple[str, list[int]]:
315
+ text = ""
316
+ token_ids = []
317
+ for i in range(num_patches):
318
+ assert len(patch_newline_mask) == num_patches
319
+ text += f"<patch_start>{self.patch_feature_placeholder}<patch_end>"
320
+ token_ids.extend(
321
+ [self.tokenizer.convert_tokens_to_ids("<patch_start>")] +
322
+ [self.image_token_id] * self.num_patch_feature_size +
323
+ [self.tokenizer.convert_tokens_to_ids("<patch_end>")])
324
+ if patch_newline_mask and patch_newline_mask[i]:
325
+ text += "<patch_newline>"
326
+ token_ids.append(
327
+ self.tokenizer.convert_tokens_to_ids("<patch_newline>"))
328
+ return text, token_ids
329
+
330
+ def _get_image_repl(
331
+ self,
332
+ num_images: int,
333
+ ) -> tuple[str, list[int]]:
334
+ text = f"<im_start>{self.image_feature_placeholder}<im_end>"
335
+ token_ids = [
336
+ self.tokenizer.convert_tokens_to_ids("<im_start>")
337
+ ] + [self.image_token_id] * self.num_image_feature_size + [
338
+ self.tokenizer.convert_tokens_to_ids("<im_end>")
339
+ ]
340
+ return text * num_images, token_ids * num_images
341
+
342
+ def _get_image_repl_features(
343
+ self,
344
+ num_images: int,
345
+ num_patches: int,
346
+ patch_new_line_idx: Optional[list[bool]],
347
+ ) -> tuple[str, list[int]]:
348
+ if num_patches > 0:
349
+ patch_repl, patch_repl_ids = self._get_patch_repl(
350
+ num_patches, patch_new_line_idx)
351
+ else:
352
+ patch_repl = ""
353
+ patch_repl_ids = []
354
+ image_repl, image_repl_ids = self._get_image_repl(num_images)
355
+ return patch_repl + image_repl, patch_repl_ids + image_repl_ids
356
+
357
+ def replace_placeholder(self, text: str, placeholder: str,
358
+ repls: list[str]) -> str:
359
+ parts = text.split(placeholder)
360
+
361
+ if len(parts) - 1 != len(repls):
362
+ raise ValueError(
363
+ "The number of placeholders does not match the number of replacements." # noqa: E501
364
+ )
365
+
366
+ result = [parts[0]]
367
+ for i, repl in enumerate(repls):
368
+ result.append(repl)
369
+ result.append(parts[i + 1])
370
+
371
+ return "".join(result)
372
+
373
+ def __call__(
374
+ self,
375
+ text: Optional[Union[str, list[str]]] = None,
376
+ images: ImageInput | None = None,
377
+ return_tensors: Optional[Union[str, TensorType]] = None,
378
+ **kwargs,
379
+ ) -> BatchFeature:
380
+
381
+ if images is not None:
382
+ images = self.image_preprocessor.fetch_images(images)
383
+ if text is None:
384
+ text = []
385
+ if not isinstance(text, list):
386
+ text = [text]
387
+ if images is None:
388
+ images = []
389
+ elif not isinstance(images, list):
390
+ images = [images]
391
+ elif isinstance(images[0], list):
392
+ images = images[0]
393
+
394
+ if len(images) == 0:
395
+ image_inputs = {}
396
+ text_inputs = self.tokenizer(text)
397
+ else:
398
+ splitted_images_data = self._split_images(images)
399
+ pixel_values_lst = []
400
+ patch_pixel_values_lst = []
401
+ patch_newline_mask_lst = []
402
+ image_repl_str_lst = []
403
+ image_repl_ids_lst = []
404
+ num_patches = []
405
+ for raw_img, img_patches, patch_newline_mask in splitted_images_data: # noqa: E501
406
+ pixel_values_lst.extend(
407
+ self._convert_images_to_pixel_values([raw_img]))
408
+
409
+ if len(img_patches) > 0:
410
+ patch_pixel_values_lst.extend(
411
+ self._convert_images_to_pixel_values(img_patches,
412
+ is_patch=True))
413
+ num_patches.append(len(img_patches))
414
+
415
+ image_repl_str, image_repl_ids = self._get_image_repl_features(
416
+ 1, len(img_patches), patch_newline_mask)
417
+ image_repl_str_lst.append(image_repl_str)
418
+ image_repl_ids_lst.extend(image_repl_ids)
419
+
420
+ if patch_newline_mask is not None:
421
+ patch_newline_mask_lst.extend(patch_newline_mask)
422
+
423
+ image_inputs = {
424
+ "pixel_values": torch.cat(pixel_values_lst),
425
+ "num_patches": num_patches,
426
+ }
427
+ if patch_pixel_values_lst:
428
+ image_inputs["patch_pixel_values"] = torch.cat(
429
+ patch_pixel_values_lst)
430
+ if patch_newline_mask_lst:
431
+ image_inputs["patch_newline_mask"] = torch.tensor(
432
+ patch_newline_mask_lst, dtype=torch.bool)
433
+
434
+ text = [
435
+ self.replace_placeholder(t, self.image_token,
436
+ image_repl_str_lst) for t in text
437
+ ]
438
+ text_inputs = self.tokenizer(text)
439
+
440
+ return BatchFeature(
441
+ {
442
+ **text_inputs,
443
+ **image_inputs,
444
+ },
445
+ tensor_type=return_tensors,
446
+ )
447
+
448
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
449
+ def batch_decode(self, *args, **kwargs):
450
+ """
451
+ This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
452
+ refer to the docstring of this method for more information.
453
+ """
454
+ return self.tokenizer.batch_decode(*args, **kwargs)
455
+
456
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma
457
+ def decode(self, *args, **kwargs):
458
+ """
459
+ This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
460
+ the docstring of this method for more information.
461
+ """
462
+ return self.tokenizer.decode(*args, **kwargs)
463
+
464
+ __all__ = ["Step3VLProcessor"]
vision_encoder.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from transformers.activations import ACT2FN
7
+
8
+
9
+ from .configuration_step3p7 import StepRoboticsVisionEncoderConfig
10
+
11
+
12
+
13
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
14
+ """Rotate last dimension halves (used by RoPE)."""
15
+ x = x.reshape(*x.shape[:-1], -1, 2)
16
+ x1, x2 = x.unbind(dim=-1)
17
+ x = torch.stack((-x2, x1), dim=-1)
18
+ return x.reshape(*x.shape[:-2], -1)
19
+
20
+
21
+ def apply_rotary_emb(freqs: torch.Tensor,
22
+ t: torch.Tensor,
23
+ start_index: int = 0,
24
+ scale: float = 1.0,
25
+ seq_dim: int = -2) -> torch.Tensor:
26
+ """Apply 2D rotary embeddings to queries / keys."""
27
+ dtype = t.dtype
28
+
29
+ if t.ndim == 3:
30
+ seq_len = t.shape[seq_dim]
31
+ freqs = freqs[-seq_len:]
32
+
33
+ rot_dim = freqs.shape[-1]
34
+ end_index = start_index + rot_dim
35
+ assert rot_dim <= t.shape[-1], (
36
+ f"feature dimension {t.shape[-1]} is too small for rot_dim {rot_dim}")
37
+
38
+ t_left, t, t_right = (
39
+ t[..., :start_index],
40
+ t[..., start_index:end_index],
41
+ t[..., end_index:],
42
+ )
43
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
44
+ out = torch.cat((t_left, t, t_right), dim=-1)
45
+ return out.type(dtype)
46
+
47
+
48
+ class EncoderRope2D(nn.Module):
49
+ """Cacheable 2D rotary positional embedding."""
50
+
51
+ def __init__(
52
+ self,
53
+ dim: int,
54
+ max_grid_height: int,
55
+ max_grid_width: int,
56
+ use_cls_token: bool = False,
57
+ theta: Union[int, float] = 10000,
58
+ max_freq: int = 10,
59
+ num_freqs: int = 1,
60
+ theta_rescale_factor: float = 1.0,
61
+ ):
62
+ super().__init__()
63
+ self.dim = dim
64
+ self.max_grid_height = max_grid_height
65
+ self.max_grid_width = max_grid_width
66
+ self.use_cls_token = use_cls_token
67
+ self.theta = theta * theta_rescale_factor**(dim / (dim - 2))
68
+ self.max_freq = max_freq
69
+ self.num_freqs = num_freqs
70
+ cache = self._compute_2d_freqs()
71
+ self.register_buffer("freqs_cache", cache, persistent=False)
72
+
73
+ def _compute_inv_freq(self, base: Union[int, float],
74
+ dim: int) -> torch.Tensor:
75
+
76
+ freqs = 1.0 / (base**(
77
+ torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
78
+ return freqs
79
+
80
+ def _compute_freqs(self, t: torch.Tensor, inv_freq: torch.Tensor):
81
+ freqs = torch.einsum("..., f -> ... f", t.type(inv_freq.dtype),
82
+ inv_freq)
83
+ freqs = freqs.repeat_interleave(2, dim=-1)
84
+ return freqs
85
+
86
+ def _compute_2d_freqs(self) -> torch.Tensor:
87
+ grid_h_range = torch.arange(self.max_grid_height, dtype=torch.float)
88
+ grid_w_range = torch.arange(self.max_grid_width, dtype=torch.float)
89
+ if self.use_cls_token:
90
+ grid_h_range += 1
91
+ grid_w_range += 1
92
+ inv_freq = self._compute_inv_freq(self.theta, self.dim // 2)
93
+ freqs_h = self._compute_freqs(grid_h_range, inv_freq)[:, None].expand(
94
+ self.max_grid_height, self.max_grid_width, -1)
95
+ freqs_w = self._compute_freqs(grid_w_range, inv_freq)[None, :].expand(
96
+ self.max_grid_height, self.max_grid_width, -1)
97
+ freqs = torch.cat([freqs_w, freqs_h], dim=-1).reshape(
98
+ self.max_grid_height * self.max_grid_width, -1)
99
+ if self.use_cls_token:
100
+ freqs = torch.cat([torch.zeros(1, freqs.shape[-1]), freqs], dim=0)
101
+ freqs = freqs[None, None, ...]
102
+ return freqs
103
+
104
+ def forward(self, q: torch.Tensor, k: torch.Tensor,
105
+ grid_hw: tuple[int, int]):
106
+ # If grid matches cached shape we reuse directly to avoid recomputation.
107
+ if grid_hw[0] != self.max_grid_height or grid_hw[1] != self.max_grid_width:
108
+ rows = torch.arange(grid_hw[0], device=q.device).view(-1, 1)
109
+ cols = torch.arange(grid_hw[1], device=q.device).view(1, -1)
110
+ positions = (rows * self.max_grid_width + cols).reshape(-1).to(
111
+ torch.long)
112
+ if self.use_cls_token:
113
+ positions = torch.cat(
114
+ [torch.zeros(1, device=q.device), positions + 1], dim=0)
115
+ freqs = self.freqs_cache.index_select(2, positions)
116
+ else:
117
+ freqs = self.freqs_cache
118
+ q = apply_rotary_emb(freqs, q)
119
+ k = apply_rotary_emb(freqs, k)
120
+ return q, k
121
+
122
+
123
+ class EncoderLayerScale(nn.Module):
124
+ """Per-channel residual scaling used when ls_init_value is set."""
125
+
126
+ def __init__(self, dim: int, init_values: float):
127
+ super().__init__()
128
+ self.gamma = nn.Parameter(torch.full((dim,), init_values))
129
+
130
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # (B, L, D)
131
+ return hidden_states * self.gamma
132
+
133
+
134
+ class EncoderMLP(nn.Module):
135
+ """Feed-forward network used inside each transformer block."""
136
+
137
+ def __init__(self, hidden_size: int, intermediate_size: int,
138
+ hidden_act: str):
139
+ super().__init__()
140
+ self.c_fc = nn.Linear(hidden_size, intermediate_size, bias=True)
141
+ self.act_fn = ACT2FN[hidden_act]
142
+ self.c_proj = nn.Linear(intermediate_size, hidden_size, bias=True)
143
+
144
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
145
+
146
+ hidden_states = self.c_proj(self.act_fn(self.c_fc(hidden_states)))
147
+ return hidden_states
148
+
149
+
150
+ class EncoderVisionAttention(nn.Module):
151
+ """Multi-head self attention with optional 2D RoPE."""
152
+
153
+ def __init__(
154
+ self,
155
+ hidden_size: int,
156
+ num_heads: int,
157
+ max_grid_height: int,
158
+ max_grid_width: int,
159
+ use_cls_token: bool = False,
160
+ use_rope2d: bool = True,
161
+ rope_theta: Union[int, float] = 10000,
162
+ rope_max_freq: int = 10,
163
+ rope_num_freqs: int = 1,
164
+ rope_theta_rescale_factor: float = 1.0,
165
+ rope_freqs_for: Literal["lang", "pixel", "constant"] = "lang",
166
+ ):
167
+ super().__init__()
168
+ if hidden_size % num_heads != 0:
169
+ raise ValueError(
170
+ f"hidden_size ({hidden_size}) must be divisible by num_heads ({num_heads})."
171
+ )
172
+ self.num_heads = num_heads
173
+ self.head_dim = hidden_size // num_heads
174
+ self.scale = self.head_dim**-0.5
175
+ self.in_proj_weight = nn.Parameter(torch.zeros(hidden_size * 3, hidden_size))
176
+ self.in_proj_bias = nn.Parameter(torch.zeros(hidden_size * 3))
177
+ self.out_proj = nn.Linear(hidden_size, hidden_size, bias=True)
178
+
179
+ self.rope = None
180
+ if use_rope2d:
181
+ self.rope = EncoderRope2D(
182
+ dim=self.head_dim,
183
+ max_grid_height=max_grid_height,
184
+ max_grid_width=max_grid_width,
185
+ use_cls_token=use_cls_token,
186
+ theta=rope_theta,
187
+ max_freq=rope_max_freq,
188
+ num_freqs=rope_num_freqs,
189
+ theta_rescale_factor=rope_theta_rescale_factor,
190
+ )
191
+
192
+ def forward(self, hidden_states: torch.Tensor, grid_hw: tuple[int, int]) -> torch.Tensor:
193
+ bsz, seq_len, _ = hidden_states.shape
194
+ qkv = F.linear(
195
+ hidden_states,
196
+ self.in_proj_weight,
197
+ self.in_proj_bias,
198
+ )
199
+ q, k, v = qkv.chunk(3, dim=-1)
200
+
201
+ q = q.view(bsz, seq_len, self.num_heads,
202
+ self.head_dim).transpose(1, 2)
203
+ k = k.view(bsz, seq_len, self.num_heads,
204
+ self.head_dim).transpose(1, 2)
205
+ if self.rope is not None:
206
+ q, k = self.rope(q, k, grid_hw=grid_hw)
207
+ v = v.view(bsz, seq_len, self.num_heads,
208
+ self.head_dim).transpose(1, 2)
209
+
210
+ attn_output = F.scaled_dot_product_attention(
211
+ q, k, v, is_causal=False, scale=self.scale)
212
+ attn_output = attn_output.transpose(1, 2).reshape(
213
+ bsz, seq_len, self.num_heads * self.head_dim)
214
+ return self.out_proj(attn_output)
215
+
216
+
217
+ class EncoderVisionBlock(nn.Module):
218
+ """A single Vision Transformer block (self-attention + MLP)."""
219
+
220
+ def __init__(
221
+ self,
222
+ hidden_size: int,
223
+ num_heads: int,
224
+ mlp_ratio: float,
225
+ hidden_act: str,
226
+ layer_norm_eps: float,
227
+ ls_init_value: Optional[float] = None,
228
+ max_grid_height: Optional[int] = None,
229
+ max_grid_width: Optional[int] = None,
230
+ use_cls_token: bool = False,
231
+ use_rope2d: bool = True,
232
+ rope_kwargs: Optional[dict] = None,
233
+ ):
234
+ super().__init__()
235
+ rope_kwargs = rope_kwargs or {}
236
+ self.attn = EncoderVisionAttention(
237
+ hidden_size,
238
+ num_heads,
239
+ max_grid_height=max_grid_height,
240
+ max_grid_width=max_grid_width,
241
+ use_cls_token=use_cls_token,
242
+ use_rope2d=use_rope2d,
243
+ **rope_kwargs,
244
+ )
245
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
246
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
247
+
248
+ intermediate = int(hidden_size * mlp_ratio)
249
+ self.mlp = EncoderMLP(hidden_size, intermediate, hidden_act)
250
+
251
+ self.ls_1 = EncoderLayerScale(hidden_size, ls_init_value)
252
+ self.ls_2 = EncoderLayerScale(hidden_size, ls_init_value)
253
+
254
+ def forward(self, hidden_states: torch.Tensor,
255
+ grid_hw: tuple[int, int]) -> torch.Tensor:
256
+ # breakpoint()
257
+ residual = hidden_states
258
+ hidden_states = self.ln_1(hidden_states)
259
+ hidden_states = self.attn(hidden_states, grid_hw=grid_hw)
260
+ hidden_states = residual + self.ls_1(hidden_states)
261
+
262
+ residual = hidden_states
263
+ hidden_states = self.ln_2(hidden_states)
264
+ hidden_states = self.mlp(hidden_states)
265
+ hidden_states = residual + self.ls_2(hidden_states)
266
+ return hidden_states
267
+
268
+
269
+ class EncoderVisionTransformer(nn.Module):
270
+ """Stack of encoder blocks parameterised by Step35VisionEncoderConfig."""
271
+
272
+ def __init__(
273
+ self,
274
+ embed_dim: int,
275
+ depth: int,
276
+ num_heads: int,
277
+ mlp_ratio: float,
278
+ hidden_act: str,
279
+ layer_norm_eps: float,
280
+ ls_init_value: Optional[float] = None,
281
+ max_grid_height: Optional[int] = None,
282
+ max_grid_width: Optional[int] = None,
283
+ use_cls_token: bool = False,
284
+ use_rope2d: bool = True,
285
+ rope_kwargs: Optional[dict] = None,
286
+ ):
287
+ super().__init__()
288
+ self.layers = depth
289
+ rope_kwargs = rope_kwargs or {}
290
+ self.resblocks = nn.ModuleList([
291
+ EncoderVisionBlock(embed_dim, num_heads, mlp_ratio, hidden_act,
292
+ layer_norm_eps,
293
+ max_grid_height=max_grid_height,
294
+ max_grid_width=max_grid_width,
295
+ use_cls_token=use_cls_token,
296
+ use_rope2d=use_rope2d,
297
+ ls_init_value=ls_init_value,
298
+ rope_kwargs=rope_kwargs)
299
+ for _ in range(depth)
300
+ ])
301
+
302
+ def forward(self,
303
+ hidden_states: torch.Tensor,
304
+ grid_hw: tuple[int, int]) -> torch.Tensor:
305
+ for block in self.resblocks:
306
+ hidden_states = block(hidden_states, grid_hw=grid_hw)
307
+ return hidden_states
308
+
309
+
310
+ class StepRoboticsVisionEncoder(nn.Module):
311
+ """
312
+ Vision encoder built from StepRoboticsVisionEncoderConfig.
313
+
314
+ The encoder performs patch embedding followed by a stack of transformer
315
+ blocks. Only the config fields defined in StepRoboticsVisionEncoderConfig (and
316
+ StepRoboticVLConfig.vision_config) are expected.
317
+ """
318
+
319
+ def __init__(self, config: StepRoboticsVisionEncoderConfig):
320
+ super().__init__()
321
+ self.config = config
322
+
323
+ # Align commonly used attributes so downstream code (e.g. StepRoboticVL)
324
+ # can access them without extra renaming.
325
+ self.hidden_size = config.width
326
+ self.num_heads = config.heads
327
+ self.num_hidden_layers = config.layers
328
+ self.patch_size = config.patch_size
329
+ self.image_size = config.image_size
330
+ self.use_cls_token = getattr(config, "use_cls_token", False)
331
+ self.use_rope2d = getattr(config, "use_rope2d", True)
332
+ self.use_abs_posemb = getattr(config, "use_abs_posemb", True)
333
+ self.layer_norm_eps = config.layer_norm_eps
334
+ self.mlp_ratio = getattr(config, "mlp_ratio", 8960 / 1536)
335
+ self.ls_init_value = getattr(config, "ls_init_value", None)
336
+ self.hidden_act = config.hidden_act
337
+ self.use_ln_pre = getattr(config, "use_ln_pre", False)
338
+ self.use_ln_post = getattr(config, "use_ln_post", True)
339
+
340
+ # Patch embedding.
341
+ self.conv1 = nn.Conv2d(in_channels=config.num_channels,
342
+ out_channels=self.hidden_size,
343
+ kernel_size=self.patch_size,
344
+ stride=self.patch_size,
345
+ bias=False)
346
+
347
+ self.ln_pre = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) if self.use_ln_pre else nn.Identity()
348
+ self.ln_post = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) if self.use_ln_post else nn.Identity()
349
+
350
+ grid_size = self.image_size // self.patch_size
351
+ self.base_grid = (grid_size, grid_size)
352
+
353
+ if self.use_cls_token:
354
+ self.class_embedding = nn.Parameter(
355
+ torch.randn(self.hidden_size) * (self.hidden_size**-0.5))
356
+ else:
357
+ self.class_embedding = None
358
+
359
+ if self.use_abs_posemb:
360
+ self.posemb_grid_size = self.image_size // self.patch_size
361
+ self.positional_embedding = nn.Parameter(
362
+ (self.hidden_size**-0.5) * torch.randn(
363
+ int(self.use_cls_token) + self.posemb_grid_size**2,
364
+ self.hidden_size,
365
+ ))
366
+
367
+ self.transformer = EncoderVisionTransformer(
368
+ embed_dim=self.hidden_size,
369
+ depth=self.num_hidden_layers,
370
+ num_heads=self.num_heads,
371
+ mlp_ratio=self.mlp_ratio,
372
+ hidden_act=self.hidden_act,
373
+ layer_norm_eps=self.layer_norm_eps,
374
+ ls_init_value=self.ls_init_value,
375
+ max_grid_height=self.base_grid[0],
376
+ max_grid_width=self.base_grid[1],
377
+ use_cls_token=self.use_cls_token,
378
+ use_rope2d=self.use_rope2d,
379
+ rope_kwargs={
380
+ "rope_theta": getattr(config, "rope_theta", 10000),
381
+ "rope_max_freq": getattr(config, "rope_max_freq", 10),
382
+ "rope_num_freqs": getattr(config, "rope_num_freqs", 1),
383
+ "rope_theta_rescale_factor":
384
+ getattr(config, "rope_theta_rescale_factor", 1.0),
385
+ "rope_freqs_for": getattr(config, "rope_freqs_for", "lang"),
386
+ },
387
+ )
388
+ self.vit_downsampler1 = nn.Conv2d(self.hidden_size,
389
+ self.hidden_size * 2,
390
+ kernel_size=3,
391
+ stride=2,
392
+ padding=1)
393
+ self.vit_downsampler2 = nn.Conv2d(self.hidden_size * 2,
394
+ self.hidden_size * 4,
395
+ kernel_size=3,
396
+ stride=2,
397
+ padding=1)
398
+
399
+
400
+ def sample_abs_posemb(self, grid_h: int, grid_w: int):
401
+ if self.posemb_grid_size == grid_h and self.posemb_grid_size == grid_w:
402
+ return self.positional_embedding[None, ...]
403
+
404
+ pos_embed = self.positional_embedding
405
+ if self.use_cls_token:
406
+ cls_token_embed, pos_embed = pos_embed[:1], pos_embed[1:]
407
+
408
+ pos_embed = (pos_embed.reshape(1, self.posemb_grid_size,
409
+ self.posemb_grid_size,
410
+ -1).permute(0, 3, 1, 2).contiguous())
411
+ pos_embed = F.interpolate(pos_embed,
412
+ size=(grid_h, grid_w),
413
+ mode="bilinear",
414
+ align_corners=False)
415
+ pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(-1, self.hidden_size)
416
+
417
+ if self.use_cls_token:
418
+ pos_embed = torch.cat([cls_token_embed, pos_embed], dim=0)
419
+
420
+ return pos_embed[None, ...]
421
+
422
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
423
+ """
424
+ Args:
425
+ pixel_values: Image tensor of shape (B, C, H, W).
426
+ layer_idx: Negative indices stop after a given block (e.g., -1 uses all blocks).
427
+ strip_cls_token: If True and cls token is used, remove it from output.
428
+ """
429
+ bsz, _, height, width = pixel_values.shape
430
+ grid_h, grid_w = height // self.patch_size, width // self.patch_size
431
+
432
+ hidden_state = self.conv1(pixel_values) # (B, D, Gh, Gw)
433
+ hidden_state = hidden_state.flatten(2).transpose(1, 2) # (B, Gh*Gw, D)
434
+
435
+ if self.use_cls_token:
436
+ cls_token = self.class_embedding.view(1, 1,
437
+ -1).expand(bsz, -1, -1)
438
+ hidden_state = torch.cat([cls_token, hidden_state], dim=1)
439
+
440
+ if self.use_abs_posemb:
441
+ pos_emb = self.sample_abs_posemb(grid_h, grid_w)
442
+ hidden_state = hidden_state + pos_emb
443
+ hidden_state = self.ln_pre(hidden_state)
444
+ hidden_state = self.transformer(hidden_state, grid_hw=(grid_h, grid_w))
445
+
446
+ if self.use_ln_post:
447
+ hidden_state = self.ln_post(hidden_state)
448
+
449
+ if self.use_cls_token:
450
+ hidden_state = hidden_state[:, 1:, :]
451
+
452
+ return hidden_state