vjawa_move_preprocessing_to_device

#5
nemotron_table_structure_v1/model.py CHANGED
@@ -146,6 +146,7 @@ class YoloXWrapper(nn.Module):
146
  """
147
  if not isinstance(image, torch.Tensor):
148
  image = torch.from_numpy(image)
 
149
  image = image.permute(2, 0, 1) # [H, W, 3] -> [3, H, W]
150
  image = resize_pad(image, self.img_size)
151
  return image.float()
 
146
  """
147
  if not isinstance(image, torch.Tensor):
148
  image = torch.from_numpy(image)
149
+ image = image.to(self.device)
150
  image = image.permute(2, 0, 1) # [H, W, 3] -> [3, H, W]
151
  image = resize_pad(image, self.img_size)
152
  return image.float()