Upload modeling_moment.py
Browse files- modeling_moment.py +4 -2
modeling_moment.py
CHANGED
|
@@ -481,10 +481,12 @@ class MomentEmbeddingModel(MomentPreTrainedModel):
|
|
| 481 |
hidden_states = input_mask_patch_view_for_hidden_states * hidden_states
|
| 482 |
# [batch_size, n_channels x n_patches, d_model]
|
| 483 |
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.config.d_model)
|
| 484 |
-
|
|
|
|
|
|
|
| 485 |
|
| 486 |
return TimeseriesOutputs(
|
| 487 |
-
embeddings=enc_out, input_mask=input_mask, metadata=reduction, hidden_states=hidden_states, input_mask_patch_view=
|
| 488 |
)
|
| 489 |
|
| 490 |
def forward(
|
|
|
|
| 481 |
hidden_states = input_mask_patch_view_for_hidden_states * hidden_states
|
| 482 |
# [batch_size, n_channels x n_patches, d_model]
|
| 483 |
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.config.d_model)
|
| 484 |
+
|
| 485 |
+
# [batch_size x n_patches]
|
| 486 |
+
input_mask_patch_view_for_mists = Masking.convert_seq_to_patch_view(input_mask, self.patch_len)
|
| 487 |
|
| 488 |
return TimeseriesOutputs(
|
| 489 |
+
embeddings=enc_out, input_mask=input_mask, metadata=reduction, hidden_states=hidden_states, input_mask_patch_view=input_mask_patch_view_for_mists
|
| 490 |
)
|
| 491 |
|
| 492 |
def forward(
|