shilinxu commited on
Commit
f3d2bc2
·
verified ·
1 Parent(s): 00ec663

Update modeling_moonvit.py

Browse files
Files changed (1) hide show
  1. modeling_moonvit.py +1 -1
modeling_moonvit.py CHANGED
@@ -180,7 +180,7 @@ class Learnable2DInterpPosEmb(nn.Module):
180
 
181
  def forward(self, x: torch.Tensor, grid_hws: torch.Tensor) -> torch.Tensor:
182
  pos_embs = []
183
- for shape in grid_hws.tolist():
184
  if shape == self.weight.shape[:-1]:
185
  pos_embs.append(self.weight.flatten(end_dim=1))
186
  else:
 
180
 
181
  def forward(self, x: torch.Tensor, grid_hws: torch.Tensor) -> torch.Tensor:
182
  pos_embs = []
183
+ for shape in grid_hws[:, 1:].tolist():
184
  if shape == self.weight.shape[:-1]:
185
  pos_embs.append(self.weight.flatten(end_dim=1))
186
  else: