Update vision_encoder.py
Browse files- vision_encoder.py +1 -1
vision_encoder.py
CHANGED
|
@@ -18,7 +18,7 @@ class VisionEncoder(nn.Module):
|
|
| 18 |
])
|
| 19 |
|
| 20 |
def forward(self, image, device):
|
| 21 |
-
|
| 22 |
with torch.no_grad():
|
| 23 |
pixel_values = self.vision_model(processed_image)
|
| 24 |
image_features = pixel_values.last_hidden_state
|
|
|
|
| 18 |
])
|
| 19 |
|
| 20 |
def forward(self, image, device):
|
| 21 |
+
processed_images = torch.stack([self.image_transform(image) for image in images]).to(device)
|
| 22 |
with torch.no_grad():
|
| 23 |
pixel_values = self.vision_model(processed_image)
|
| 24 |
image_features = pixel_values.last_hidden_state
|