SajayR commited on
Commit
8948365
·
verified ·
1 Parent(s): ce081c7

Added batch support for images

Browse files
Files changed (1) hide show
  1. hf_model.py +33 -9
hf_model.py CHANGED
@@ -186,16 +186,40 @@ class Triad(nn.Module):
186
  if text_list is not None:
187
  assert isinstance(text_list, list) and len(text_list) == 1, "Text list must be a list of strings of length 1"
188
  if image is not None:
189
- image = Image.open(image).convert('RGB')
190
- transform = transforms.Compose([
191
- transforms.Resize((224, 224)),
192
- transforms.ToTensor(),
193
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
194
- std=[0.229, 0.224, 0.225])
195
- ])
196
- image = transform(image)
197
  device = next(self.parameters()).device
198
- image = image.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  embeddings = {}
200
  if image is not None:
201
  embeddings['visual_feats'] = self.visual_embedder(image)
 
186
  if text_list is not None:
187
  assert isinstance(text_list, list) and len(text_list) == 1, "Text list must be a list of strings of length 1"
188
  if image is not None:
 
 
 
 
 
 
 
 
189
  device = next(self.parameters()).device
190
+
191
+ # Handle batch of file paths
192
+ if isinstance(image, list):
193
+ # Process a list of image paths
194
+ processed_images = []
195
+ for img_path in image:
196
+ img = Image.open(img_path).convert('RGB')
197
+ transform = transforms.Compose([
198
+ transforms.Resize((224, 224)),
199
+ transforms.ToTensor(),
200
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
201
+ ])
202
+ processed_img = transform(img).to(device)
203
+ processed_images.append(processed_img)
204
+ image = torch.stack(processed_images, dim=0) # [B, 3, 224, 224]
205
+
206
+ # Handle single file path
207
+ elif isinstance(image, str):
208
+ img = Image.open(image).convert('RGB')
209
+ transform = transforms.Compose([
210
+ transforms.Resize((224, 224)),
211
+ transforms.ToTensor(),
212
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
213
+ ])
214
+ image = transform(img).to(device).unsqueeze(0) # Add batch dimension [1, 3, 224, 224]
215
+
216
+ # Handle tensor input (assume it's already processed but may need device transfer)
217
+ elif isinstance(image, torch.Tensor):
218
+ # If single image without batch dimension
219
+ if image.dim() == 3:
220
+ image = image.unsqueeze(0) # Add batch dimension
221
+ image = image.to(device)
222
+
223
  embeddings = {}
224
  if image is not None:
225
  embeddings['visual_feats'] = self.visual_embedder(image)