shilinxu commited on
Commit
6c73cc4
·
verified ·
1 Parent(s): 5cb7f5a

Update modeling_moonvit.py

Browse files
Files changed (1) hide show
  1. modeling_moonvit.py +4 -4
modeling_moonvit.py CHANGED
@@ -587,7 +587,7 @@ class MoonVitPretrainedModel(PreTrainedModel):
587
  self.multi_modal_projector = MultiModalProjector(config)
588
 
589
  def forward(
590
- self, pixel_values: torch.Tensor, grid_hws: torch.Tensor
591
  ) -> torch.Tensor:
592
  """
593
  Args:
@@ -596,10 +596,10 @@ class MoonVitPretrainedModel(PreTrainedModel):
596
  Returns:
597
  torch.Tensor: The output tokens.
598
  """
599
- hidden_states = self.patch_embed(pixel_values, grid_hws)
600
- hidden_states = self.encoder(hidden_states, grid_hws)
601
  hidden_states = patch_merger(
602
- hidden_states, grid_hws, merge_kernel_size=self.merge_kernel_size
603
  )
604
  hidden_states = self.multi_modal_projector(hidden_states)
605
  return hidden_states
 
587
  self.multi_modal_projector = MultiModalProjector(config)
588
 
589
  def forward(
590
+ self, pixel_values: torch.Tensor, image_grid_hws: torch.Tensor
591
  ) -> torch.Tensor:
592
  """
593
  Args:
 
596
  Returns:
597
  torch.Tensor: The output tokens.
598
  """
599
+ hidden_states = self.patch_embed(pixel_values, image_grid_hws)
600
+ hidden_states = self.encoder(hidden_states, image_grid_hws)
601
  hidden_states = patch_merger(
602
+ hidden_states, image_grid_hws, merge_kernel_size=self.merge_kernel_size
603
  )
604
  hidden_states = self.multi_modal_projector(hidden_states)
605
  return hidden_states