SajayR commited on
Commit
ce081c7
·
verified ·
1 Parent(s): 4d81923

Added device shift for image

Browse files
Files changed (1) hide show
  1. hf_model.py +2 -0
hf_model.py CHANGED
@@ -194,6 +194,8 @@ class Triad(nn.Module):
194
  std=[0.229, 0.224, 0.225])
195
  ])
196
  image = transform(image)
 
 
197
  embeddings = {}
198
  if image is not None:
199
  embeddings['visual_feats'] = self.visual_embedder(image)
 
194
  std=[0.229, 0.224, 0.225])
195
  ])
196
  image = transform(image)
197
+ device = next(self.parameters()).device
198
+ image = image.to(device)
199
  embeddings = {}
200
  if image is not None:
201
  embeddings['visual_feats'] = self.visual_embedder(image)