Update pi3/models/pi3.py
Browse files- pi3/models/pi3.py +1 -1
pi3/models/pi3.py
CHANGED
|
@@ -189,7 +189,7 @@ class Pi3(nn.Module, PyTorchModelHubMixin):
|
|
| 189 |
conf_hidden = self.conf_decoder(hidden, xpos=pos)
|
| 190 |
camera_hidden = self.camera_decoder(hidden, xpos=pos)
|
| 191 |
|
| 192 |
-
with torch.amp.autocast(device_type='
|
| 193 |
# local points
|
| 194 |
point_hidden = point_hidden.float()
|
| 195 |
ret = self.point_head([point_hidden[:, self.patch_start_idx:]], (H, W)).reshape(B, N, H, W, -1)
|
|
|
|
| 189 |
conf_hidden = self.conf_decoder(hidden, xpos=pos)
|
| 190 |
camera_hidden = self.camera_decoder(hidden, xpos=pos)
|
| 191 |
|
| 192 |
+
with torch.amp.autocast(device_type='cpu', enabled=False):
|
| 193 |
# local points
|
| 194 |
point_hidden = point_hidden.float()
|
| 195 |
ret = self.point_head([point_hidden[:, self.patch_start_idx:]], (H, W)).reshape(B, N, H, W, -1)
|