Update modeling_moonvit.py
Browse files- modeling_moonvit.py +1 -1
modeling_moonvit.py
CHANGED
|
@@ -180,7 +180,7 @@ class Learnable2DInterpPosEmb(nn.Module):
|
|
| 180 |
|
| 181 |
def forward(self, x: torch.Tensor, grid_hws: torch.Tensor) -> torch.Tensor:
|
| 182 |
pos_embs = []
|
| 183 |
-
for shape in grid_hws.tolist():
|
| 184 |
if shape == self.weight.shape[:-1]:
|
| 185 |
pos_embs.append(self.weight.flatten(end_dim=1))
|
| 186 |
else:
|
|
|
|
| 180 |
|
| 181 |
def forward(self, x: torch.Tensor, grid_hws: torch.Tensor) -> torch.Tensor:
|
| 182 |
pos_embs = []
|
| 183 |
+
for shape in grid_hws[:, 1:].tolist():
|
| 184 |
if shape == self.weight.shape[:-1]:
|
| 185 |
pos_embs.append(self.weight.flatten(end_dim=1))
|
| 186 |
else:
|