luotingdan
commited on
Commit
·
f7bf7c1
1
Parent(s):
82465dc
remove some unuse code
Browse files- modeling_step_vl.py +0 -4
- vision_encoder.py +8 -24
modeling_step_vl.py
CHANGED
|
@@ -290,10 +290,6 @@ class StepRoboticsModel(StepRoboticsPreTrainedModel, GenerationMixin):
|
|
| 290 |
def _get_vision_model_output(self,
|
| 291 |
input_tensor: torch.Tensor) -> torch.Tensor:
|
| 292 |
return self.vision_model(input_tensor)
|
| 293 |
-
|
| 294 |
-
def _get_pooled_vision_model_output(
|
| 295 |
-
self, input_tensor: torch.Tensor) -> torch.Tensor:
|
| 296 |
-
return self.vision_model.pool(input_tensor)
|
| 297 |
|
| 298 |
def _process_image_input(
|
| 299 |
self, image_input: StepVLImageInputs) -> tuple[torch.Tensor, ...]:
|
|
|
|
| 290 |
def _get_vision_model_output(self,
|
| 291 |
input_tensor: torch.Tensor) -> torch.Tensor:
|
| 292 |
return self.vision_model(input_tensor)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
|
| 294 |
def _process_image_input(
|
| 295 |
self, image_input: StepVLImageInputs) -> tuple[torch.Tensor, ...]:
|
vision_encoder.py
CHANGED
|
@@ -53,7 +53,6 @@ class EncoderRope2D(nn.Module):
|
|
| 53 |
max_grid_height: int,
|
| 54 |
max_grid_width: int,
|
| 55 |
use_cls_token: bool = False,
|
| 56 |
-
freqs_for: Literal["lang", "pixel", "constant"] = "lang",
|
| 57 |
theta: Union[int, float] = 10000,
|
| 58 |
max_freq: int = 10,
|
| 59 |
num_freqs: int = 1,
|
|
@@ -65,7 +64,6 @@ class EncoderRope2D(nn.Module):
|
|
| 65 |
self.max_grid_width = max_grid_width
|
| 66 |
self.use_cls_token = use_cls_token
|
| 67 |
self.theta = theta * theta_rescale_factor**(dim / (dim - 2))
|
| 68 |
-
self.freqs_for = freqs_for
|
| 69 |
self.max_freq = max_freq
|
| 70 |
self.num_freqs = num_freqs
|
| 71 |
cache = self._compute_2d_freqs()
|
|
@@ -73,15 +71,9 @@ class EncoderRope2D(nn.Module):
|
|
| 73 |
|
| 74 |
def _compute_inv_freq(self, base: Union[int, float],
|
| 75 |
dim: int) -> torch.Tensor:
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
elif self.freqs_for == "pixel":
|
| 80 |
-
freqs = torch.linspace(1.0, self.max_freq / 2, dim // 2) * torch.pi
|
| 81 |
-
elif self.freqs_for == "constant":
|
| 82 |
-
freqs = torch.ones(self.num_freqs).float()
|
| 83 |
-
else:
|
| 84 |
-
raise ValueError(f"Unsupported freqs_for value: {self.freqs_for}")
|
| 85 |
return freqs
|
| 86 |
|
| 87 |
def _compute_freqs(self, t: torch.Tensor, inv_freq: torch.Tensor):
|
|
@@ -309,14 +301,9 @@ class EncoderVisionTransformer(nn.Module):
|
|
| 309 |
|
| 310 |
def forward(self,
|
| 311 |
hidden_states: torch.Tensor,
|
| 312 |
-
grid_hw: tuple[int, int]
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
stop_idx = (self.layers + layer_idx) % self.layers
|
| 316 |
-
for idx, block in enumerate(self.resblocks):
|
| 317 |
hidden_states = block(hidden_states, grid_hw=grid_hw)
|
| 318 |
-
if idx == stop_idx:
|
| 319 |
-
break
|
| 320 |
return hidden_states
|
| 321 |
|
| 322 |
|
|
@@ -432,10 +419,7 @@ class StepRoboticsVisionEncoder(nn.Module):
|
|
| 432 |
|
| 433 |
return pos_embed[None, ...]
|
| 434 |
|
| 435 |
-
def forward(self,
|
| 436 |
-
pixel_values: torch.Tensor,
|
| 437 |
-
layer_idx: int = -1,
|
| 438 |
-
strip_cls_token: bool = False) -> torch.Tensor:
|
| 439 |
"""
|
| 440 |
Args:
|
| 441 |
pixel_values: Image tensor of shape (B, C, H, W).
|
|
@@ -457,12 +441,12 @@ class StepRoboticsVisionEncoder(nn.Module):
|
|
| 457 |
pos_emb = self.sample_abs_posemb(grid_h, grid_w)
|
| 458 |
hidden_state = hidden_state + pos_emb
|
| 459 |
hidden_state = self.ln_pre(hidden_state)
|
| 460 |
-
hidden_state = self.transformer(hidden_state, grid_hw=(grid_h, grid_w)
|
| 461 |
|
| 462 |
if self.use_ln_post:
|
| 463 |
hidden_state = self.ln_post(hidden_state)
|
| 464 |
|
| 465 |
-
if
|
| 466 |
hidden_state = hidden_state[:, 1:, :]
|
| 467 |
|
| 468 |
return hidden_state
|
|
|
|
| 53 |
max_grid_height: int,
|
| 54 |
max_grid_width: int,
|
| 55 |
use_cls_token: bool = False,
|
|
|
|
| 56 |
theta: Union[int, float] = 10000,
|
| 57 |
max_freq: int = 10,
|
| 58 |
num_freqs: int = 1,
|
|
|
|
| 64 |
self.max_grid_width = max_grid_width
|
| 65 |
self.use_cls_token = use_cls_token
|
| 66 |
self.theta = theta * theta_rescale_factor**(dim / (dim - 2))
|
|
|
|
| 67 |
self.max_freq = max_freq
|
| 68 |
self.num_freqs = num_freqs
|
| 69 |
cache = self._compute_2d_freqs()
|
|
|
|
| 71 |
|
| 72 |
def _compute_inv_freq(self, base: Union[int, float],
|
| 73 |
dim: int) -> torch.Tensor:
|
| 74 |
+
|
| 75 |
+
freqs = 1.0 / (base**(
|
| 76 |
+
torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
return freqs
|
| 78 |
|
| 79 |
def _compute_freqs(self, t: torch.Tensor, inv_freq: torch.Tensor):
|
|
|
|
| 301 |
|
| 302 |
def forward(self,
|
| 303 |
hidden_states: torch.Tensor,
|
| 304 |
+
grid_hw: tuple[int, int]) -> torch.Tensor:
|
| 305 |
+
for block in self.resblocks:
|
|
|
|
|
|
|
|
|
|
| 306 |
hidden_states = block(hidden_states, grid_hw=grid_hw)
|
|
|
|
|
|
|
| 307 |
return hidden_states
|
| 308 |
|
| 309 |
|
|
|
|
| 419 |
|
| 420 |
return pos_embed[None, ...]
|
| 421 |
|
| 422 |
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
| 423 |
"""
|
| 424 |
Args:
|
| 425 |
pixel_values: Image tensor of shape (B, C, H, W).
|
|
|
|
| 441 |
pos_emb = self.sample_abs_posemb(grid_h, grid_w)
|
| 442 |
hidden_state = hidden_state + pos_emb
|
| 443 |
hidden_state = self.ln_pre(hidden_state)
|
| 444 |
+
hidden_state = self.transformer(hidden_state, grid_hw=(grid_h, grid_w))
|
| 445 |
|
| 446 |
if self.use_ln_post:
|
| 447 |
hidden_state = self.ln_post(hidden_state)
|
| 448 |
|
| 449 |
+
if self.use_cls_token:
|
| 450 |
hidden_state = hidden_state[:, 1:, :]
|
| 451 |
|
| 452 |
return hidden_state
|