Added device shift for image
Browse files- hf_model.py +2 -0
hf_model.py
CHANGED
|
@@ -194,6 +194,8 @@ class Triad(nn.Module):
|
|
| 194 |
std=[0.229, 0.224, 0.225])
|
| 195 |
])
|
| 196 |
image = transform(image)
|
|
|
|
|
|
|
| 197 |
embeddings = {}
|
| 198 |
if image is not None:
|
| 199 |
embeddings['visual_feats'] = self.visual_embedder(image)
|
|
|
|
| 194 |
std=[0.229, 0.224, 0.225])
|
| 195 |
])
|
| 196 |
image = transform(image)
|
| 197 |
+
device = next(self.parameters()).device
|
| 198 |
+
image = image.to(device)
|
| 199 |
embeddings = {}
|
| 200 |
if image is not None:
|
| 201 |
embeddings['visual_feats'] = self.visual_embedder(image)
|