Upload modeling_moment.py
Browse files- modeling_moment.py +2 -3
modeling_moment.py
CHANGED
|
@@ -470,15 +470,14 @@ class MomentEmbeddingModel(MomentPreTrainedModel):
|
|
| 470 |
# [batch_size, n_channels x n_patches, d_model]
|
| 471 |
# Ensure hidden_states are consistent for both short and long inputs with input_mask specified
|
| 472 |
# hidden_states = hidden_states.reshape(batch_size, n_channels, n_patches, self.config.d_model).transpose(1, 2).reshape(batch_size, -1, self.config.d_model)
|
| 473 |
-
# [batch_size x n_channels x n_patches x d_model]
|
| 474 |
-
hidden_states = hidden_states.reshape(batch_size, n_channels, n_patches, self.config.d_model)
|
| 475 |
# [batch_size x n_patches]
|
| 476 |
input_mask_patch_view_for_hidden_states = Masking.convert_seq_to_patch_view(input_mask, self.patch_len)
|
| 477 |
# [batch_size x n_channels x n_patches x d_model]
|
| 478 |
input_mask_patch_view_for_hidden_states = input_mask_patch_view_for_hidden_states.unsqueeze(-1).repeat(
|
| 479 |
1, n_channels, 1, self.config.d_model
|
| 480 |
)
|
| 481 |
-
# [batch_size
|
|
|
|
| 482 |
hidden_states = input_mask_patch_view_for_hidden_states * hidden_states
|
| 483 |
# [batch_size, n_channels x n_patches, d_model]
|
| 484 |
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.config.d_model)
|
|
|
|
| 470 |
# [batch_size, n_channels x n_patches, d_model]
|
| 471 |
# Ensure hidden_states are consistent for both short and long inputs with input_mask specified
|
| 472 |
# hidden_states = hidden_states.reshape(batch_size, n_channels, n_patches, self.config.d_model).transpose(1, 2).reshape(batch_size, -1, self.config.d_model)
|
|
|
|
|
|
|
| 473 |
# [batch_size x n_patches]
|
| 474 |
input_mask_patch_view_for_hidden_states = Masking.convert_seq_to_patch_view(input_mask, self.patch_len)
|
| 475 |
# [batch_size x n_channels x n_patches x d_model]
|
| 476 |
input_mask_patch_view_for_hidden_states = input_mask_patch_view_for_hidden_states.unsqueeze(-1).repeat(
|
| 477 |
1, n_channels, 1, self.config.d_model
|
| 478 |
)
|
| 479 |
+
# [batch_size x n_channels x n_patches x d_model]
|
| 480 |
+
hidden_states = hidden_states.reshape(batch_size, n_channels, n_patches, self.config.d_model)
|
| 481 |
hidden_states = input_mask_patch_view_for_hidden_states * hidden_states
|
| 482 |
# [batch_size, n_channels x n_patches, d_model]
|
| 483 |
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.config.d_model)
|