Update modeling_moonvit.py
Browse files- modeling_moonvit.py +4 -4
modeling_moonvit.py
CHANGED
|
@@ -587,7 +587,7 @@ class MoonVitPretrainedModel(PreTrainedModel):
|
|
| 587 |
self.multi_modal_projector = MultiModalProjector(config)
|
| 588 |
|
| 589 |
def forward(
|
| 590 |
-
self, pixel_values: torch.Tensor,
|
| 591 |
) -> torch.Tensor:
|
| 592 |
"""
|
| 593 |
Args:
|
|
@@ -596,10 +596,10 @@ class MoonVitPretrainedModel(PreTrainedModel):
|
|
| 596 |
Returns:
|
| 597 |
torch.Tensor: The output tokens.
|
| 598 |
"""
|
| 599 |
-
hidden_states = self.patch_embed(pixel_values,
|
| 600 |
-
hidden_states = self.encoder(hidden_states,
|
| 601 |
hidden_states = patch_merger(
|
| 602 |
-
hidden_states,
|
| 603 |
)
|
| 604 |
hidden_states = self.multi_modal_projector(hidden_states)
|
| 605 |
return hidden_states
|
|
|
|
| 587 |
self.multi_modal_projector = MultiModalProjector(config)
|
| 588 |
|
| 589 |
def forward(
|
| 590 |
+
self, pixel_values: torch.Tensor, image_grid_hws: torch.Tensor
|
| 591 |
) -> torch.Tensor:
|
| 592 |
"""
|
| 593 |
Args:
|
|
|
|
| 596 |
Returns:
|
| 597 |
torch.Tensor: The output tokens.
|
| 598 |
"""
|
| 599 |
+
hidden_states = self.patch_embed(pixel_values, image_grid_hws)
|
| 600 |
+
hidden_states = self.encoder(hidden_states, image_grid_hws)
|
| 601 |
hidden_states = patch_merger(
|
| 602 |
+
hidden_states, image_grid_hws, merge_kernel_size=self.merge_kernel_size
|
| 603 |
)
|
| 604 |
hidden_states = self.multi_modal_projector(hidden_states)
|
| 605 |
return hidden_states
|