damerajee commited on
Commit
96e263a
·
verified ·
1 Parent(s): fedec36

Update vision_encoder.py

Browse files
Files changed (1) hide show
  1. vision_encoder.py +1 -1
vision_encoder.py CHANGED
@@ -18,7 +18,7 @@ class VisionEncoder(nn.Module):
18
  ])
19
 
20
  def forward(self, image, device):
21
- processed_image = self.image_transform(image).unsqueeze(0).to(device)
22
  with torch.no_grad():
23
  pixel_values = self.vision_model(processed_image)
24
  image_features = pixel_values.last_hidden_state
 
18
  ])
19
 
20
  def forward(self, image, device):
21
+ processed_images = torch.stack([self.image_transform(image) for image in images]).to(device)
22
  with torch.no_grad():
23
  pixel_values = self.vision_model(processed_image)
24
  image_features = pixel_values.last_hidden_state