VibhuJawa commited on
Commit
0f79b28
·
unverified ·
1 Parent(s): b903fe1

move pre processing to device

Browse files

Signed-off-by: Vibhu Jawa <vjawa@nvidia.com>

nemotron_graphic_elements_v1/model.py CHANGED
@@ -141,6 +141,7 @@ class YoloXWrapper(nn.Module):
141
  """
142
  if not isinstance(image, torch.Tensor):
143
  image = torch.from_numpy(image)
 
144
  image = image.permute(2, 0, 1) # [H, W, 3] -> [3, H, W]
145
  image = resize_pad(image, self.img_size)
146
  return image.float()
 
141
  """
142
  if not isinstance(image, torch.Tensor):
143
  image = torch.from_numpy(image)
144
+ image = image.to(self.device)
145
  image = image.permute(2, 0, 1) # [H, W, 3] -> [3, H, W]
146
  image = resize_pad(image, self.img_size)
147
  return image.float()