Elzzzz commited on
Commit
6ac88c4
·
verified ·
1 Parent(s): d5586a6

Update pi3/models/pi3.py

Browse files
Files changed (1) hide show
  1. pi3/models/pi3.py +1 -1
pi3/models/pi3.py CHANGED
@@ -189,7 +189,7 @@ class Pi3(nn.Module, PyTorchModelHubMixin):
189
  conf_hidden = self.conf_decoder(hidden, xpos=pos)
190
  camera_hidden = self.camera_decoder(hidden, xpos=pos)
191
 
192
- with torch.amp.autocast(device_type='cuda', enabled=False):
193
  # local points
194
  point_hidden = point_hidden.float()
195
  ret = self.point_head([point_hidden[:, self.patch_start_idx:]], (H, W)).reshape(B, N, H, W, -1)
 
189
  conf_hidden = self.conf_decoder(hidden, xpos=pos)
190
  camera_hidden = self.camera_decoder(hidden, xpos=pos)
191
 
192
+ with torch.amp.autocast(device_type='cpu', enabled=False):
193
  # local points
194
  point_hidden = point_hidden.float()
195
  ret = self.point_head([point_hidden[:, self.patch_start_idx:]], (H, W)).reshape(B, N, H, W, -1)