luotingdan commited on
Commit
f7bf7c1
·
1 Parent(s): 82465dc

remove some unuse code

Browse files
Files changed (2) hide show
  1. modeling_step_vl.py +0 -4
  2. vision_encoder.py +8 -24
modeling_step_vl.py CHANGED
@@ -290,10 +290,6 @@ class StepRoboticsModel(StepRoboticsPreTrainedModel, GenerationMixin):
290
  def _get_vision_model_output(self,
291
  input_tensor: torch.Tensor) -> torch.Tensor:
292
  return self.vision_model(input_tensor)
293
-
294
- def _get_pooled_vision_model_output(
295
- self, input_tensor: torch.Tensor) -> torch.Tensor:
296
- return self.vision_model.pool(input_tensor)
297
 
298
  def _process_image_input(
299
  self, image_input: StepVLImageInputs) -> tuple[torch.Tensor, ...]:
 
290
  def _get_vision_model_output(self,
291
  input_tensor: torch.Tensor) -> torch.Tensor:
292
  return self.vision_model(input_tensor)
 
 
 
 
293
 
294
  def _process_image_input(
295
  self, image_input: StepVLImageInputs) -> tuple[torch.Tensor, ...]:
vision_encoder.py CHANGED
@@ -53,7 +53,6 @@ class EncoderRope2D(nn.Module):
53
  max_grid_height: int,
54
  max_grid_width: int,
55
  use_cls_token: bool = False,
56
- freqs_for: Literal["lang", "pixel", "constant"] = "lang",
57
  theta: Union[int, float] = 10000,
58
  max_freq: int = 10,
59
  num_freqs: int = 1,
@@ -65,7 +64,6 @@ class EncoderRope2D(nn.Module):
65
  self.max_grid_width = max_grid_width
66
  self.use_cls_token = use_cls_token
67
  self.theta = theta * theta_rescale_factor**(dim / (dim - 2))
68
- self.freqs_for = freqs_for
69
  self.max_freq = max_freq
70
  self.num_freqs = num_freqs
71
  cache = self._compute_2d_freqs()
@@ -73,15 +71,9 @@ class EncoderRope2D(nn.Module):
73
 
74
  def _compute_inv_freq(self, base: Union[int, float],
75
  dim: int) -> torch.Tensor:
76
- if self.freqs_for == "lang":
77
- freqs = 1.0 / (base**(
78
- torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
79
- elif self.freqs_for == "pixel":
80
- freqs = torch.linspace(1.0, self.max_freq / 2, dim // 2) * torch.pi
81
- elif self.freqs_for == "constant":
82
- freqs = torch.ones(self.num_freqs).float()
83
- else:
84
- raise ValueError(f"Unsupported freqs_for value: {self.freqs_for}")
85
  return freqs
86
 
87
  def _compute_freqs(self, t: torch.Tensor, inv_freq: torch.Tensor):
@@ -309,14 +301,9 @@ class EncoderVisionTransformer(nn.Module):
309
 
310
  def forward(self,
311
  hidden_states: torch.Tensor,
312
- grid_hw: tuple[int, int],
313
- layer_idx: int = -1) -> torch.Tensor:
314
-
315
- stop_idx = (self.layers + layer_idx) % self.layers
316
- for idx, block in enumerate(self.resblocks):
317
  hidden_states = block(hidden_states, grid_hw=grid_hw)
318
- if idx == stop_idx:
319
- break
320
  return hidden_states
321
 
322
 
@@ -432,10 +419,7 @@ class StepRoboticsVisionEncoder(nn.Module):
432
 
433
  return pos_embed[None, ...]
434
 
435
- def forward(self,
436
- pixel_values: torch.Tensor,
437
- layer_idx: int = -1,
438
- strip_cls_token: bool = False) -> torch.Tensor:
439
  """
440
  Args:
441
  pixel_values: Image tensor of shape (B, C, H, W).
@@ -457,12 +441,12 @@ class StepRoboticsVisionEncoder(nn.Module):
457
  pos_emb = self.sample_abs_posemb(grid_h, grid_w)
458
  hidden_state = hidden_state + pos_emb
459
  hidden_state = self.ln_pre(hidden_state)
460
- hidden_state = self.transformer(hidden_state, grid_hw=(grid_h, grid_w), layer_idx=layer_idx)
461
 
462
  if self.use_ln_post:
463
  hidden_state = self.ln_post(hidden_state)
464
 
465
- if strip_cls_token and self.use_cls_token:
466
  hidden_state = hidden_state[:, 1:, :]
467
 
468
  return hidden_state
 
53
  max_grid_height: int,
54
  max_grid_width: int,
55
  use_cls_token: bool = False,
 
56
  theta: Union[int, float] = 10000,
57
  max_freq: int = 10,
58
  num_freqs: int = 1,
 
64
  self.max_grid_width = max_grid_width
65
  self.use_cls_token = use_cls_token
66
  self.theta = theta * theta_rescale_factor**(dim / (dim - 2))
 
67
  self.max_freq = max_freq
68
  self.num_freqs = num_freqs
69
  cache = self._compute_2d_freqs()
 
71
 
72
  def _compute_inv_freq(self, base: Union[int, float],
73
  dim: int) -> torch.Tensor:
74
+
75
+ freqs = 1.0 / (base**(
76
+ torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
 
 
 
 
 
 
77
  return freqs
78
 
79
  def _compute_freqs(self, t: torch.Tensor, inv_freq: torch.Tensor):
 
301
 
302
  def forward(self,
303
  hidden_states: torch.Tensor,
304
+ grid_hw: tuple[int, int]) -> torch.Tensor:
305
+ for block in self.resblocks:
 
 
 
306
  hidden_states = block(hidden_states, grid_hw=grid_hw)
 
 
307
  return hidden_states
308
 
309
 
 
419
 
420
  return pos_embed[None, ...]
421
 
422
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
 
 
 
423
  """
424
  Args:
425
  pixel_values: Image tensor of shape (B, C, H, W).
 
441
  pos_emb = self.sample_abs_posemb(grid_h, grid_w)
442
  hidden_state = hidden_state + pos_emb
443
  hidden_state = self.ln_pre(hidden_state)
444
+ hidden_state = self.transformer(hidden_state, grid_hw=(grid_h, grid_w))
445
 
446
  if self.use_ln_post:
447
  hidden_state = self.ln_post(hidden_state)
448
 
449
+ if self.use_cls_token:
450
  hidden_state = hidden_state[:, 1:, :]
451
 
452
  return hidden_state