aharley commited on
Commit
3bb4be2
·
verified ·
1 Parent(s): a045514

Upload net34.py

Browse files
Files changed (1) hide show
  1. nets/net34.py +3 -2
nets/net34.py CHANGED
@@ -194,6 +194,7 @@ class Net(nn.Module):
194
 
195
  def get_fmaps(self, images_, B, T, sw, is_training, nograd_backbone):
196
  _, _, H_pad, W_pad = images_.shape # revised HW
 
197
 
198
  C, H8, W8 = self.dim*2, H_pad//8, W_pad//8
199
  if self.no_split:
@@ -205,7 +206,7 @@ class Net(nn.Module):
205
  fmaps = []
206
  for t in range(0, T, fmaps_chunk_size):
207
  images_chunk = images[:, t : t + fmaps_chunk_size]
208
- images_chunk = images_chunk.cuda()
209
  if self.use_basicencoder:
210
  if self.full_split:
211
  fmaps_chunk1 = self.fnet(images_chunk.reshape(-1, 3, H_pad, W_pad))
@@ -224,7 +225,7 @@ class Net(nn.Module):
224
  else:
225
  if not is_training:
226
  # sometimes we need to move things to cuda here
227
- images_ = images_.cuda()
228
  if self.use_basicencoder:
229
  if self.full_split:
230
  # if self.half_corr:
 
194
 
195
  def get_fmaps(self, images_, B, T, sw, is_training, nograd_backbone):
196
  _, _, H_pad, W_pad = images_.shape # revised HW
197
+ device = images.device
198
 
199
  C, H8, W8 = self.dim*2, H_pad//8, W_pad//8
200
  if self.no_split:
 
206
  fmaps = []
207
  for t in range(0, T, fmaps_chunk_size):
208
  images_chunk = images[:, t : t + fmaps_chunk_size]
209
+ images_chunk = images_chunk.to(device)
210
  if self.use_basicencoder:
211
  if self.full_split:
212
  fmaps_chunk1 = self.fnet(images_chunk.reshape(-1, 3, H_pad, W_pad))
 
225
  else:
226
  if not is_training:
227
  # sometimes we need to move things to cuda here
228
+ images_ = images_.to(device)
229
  if self.use_basicencoder:
230
  if self.full_split:
231
  # if self.half_corr: