Upload modeling_moment.py
Browse files- modeling_moment.py +1 -4
modeling_moment.py
CHANGED
|
@@ -449,15 +449,12 @@ class MomentEmbeddingModel(MomentPreTrainedModel):
|
|
| 449 |
outputs = self.encoder(inputs_embeds=enc_in, attention_mask=attention_mask)
|
| 450 |
enc_out = outputs.last_hidden_state
|
| 451 |
|
| 452 |
-
# For Mists model
|
| 453 |
-
hidden_states = outputs.last_hidden_state
|
| 454 |
-
|
| 455 |
enc_out = enc_out.reshape((-1, n_channels, n_patches, self.config.d_model))
|
| 456 |
# [batch_size x n_channels x n_patches x d_model]
|
| 457 |
|
| 458 |
# For Mists model
|
| 459 |
# [batch_size, n_channels x n_patches, d_model]
|
| 460 |
-
|
| 461 |
|
| 462 |
if reduction == "mean":
|
| 463 |
enc_out = enc_out.mean(dim=1, keepdim=False) # Mean across channels
|
|
|
|
| 449 |
outputs = self.encoder(inputs_embeds=enc_in, attention_mask=attention_mask)
|
| 450 |
enc_out = outputs.last_hidden_state
|
| 451 |
|
|
|
|
|
|
|
|
|
|
| 452 |
enc_out = enc_out.reshape((-1, n_channels, n_patches, self.config.d_model))
|
| 453 |
# [batch_size x n_channels x n_patches x d_model]
|
| 454 |
|
| 455 |
# For Mists model
|
| 456 |
# [batch_size, n_channels x n_patches, d_model]
|
| 457 |
+
hidden_states = enc_out.reshape(batch_size, n_channels * n_patches, self.config.d_model)
|
| 458 |
|
| 459 |
if reduction == "mean":
|
| 460 |
enc_out = enc_out.mean(dim=1, keepdim=False) # Mean across channels
|