Upload modeling_moment.py
Browse files- modeling_moment.py +1 -1
modeling_moment.py
CHANGED
|
@@ -455,7 +455,7 @@ class MomentEmbeddingModel(MomentPreTrainedModel):
|
|
| 455 |
|
| 456 |
# For Mists model
|
| 457 |
# [batch_size, n_channels x n_patches, d_model]
|
| 458 |
-
# hidden_states
|
| 459 |
hidden_states = hidden_states.reshape(batch_size, n_channels, n_patches, self.config.d_model).transpose(1, 2).reshape(batch_size, -1, self.config.d_model)
|
| 460 |
|
| 461 |
if reduction == "mean":
|
|
|
|
| 455 |
|
| 456 |
# For Mists model
|
| 457 |
# [batch_size, n_channels x n_patches, d_model]
|
| 458 |
+
# Ensure hidden_states are consistent for both short and long inputs with input_mask specified
|
| 459 |
hidden_states = hidden_states.reshape(batch_size, n_channels, n_patches, self.config.d_model).transpose(1, 2).reshape(batch_size, -1, self.config.d_model)
|
| 460 |
|
| 461 |
if reduction == "mean":
|