SajayR commited on
Commit
50d0a45
·
verified ·
1 Parent(s): 29c8bf8

Update hf_model.py

Browse files
Files changed (1) hide show
  1. hf_model.py +36 -9
hf_model.py CHANGED
@@ -205,18 +205,44 @@ class Triad(nn.Module):
205
  assert image is not None or audio is not None or text_list is not None, "At least one modality must be provided"
206
  if image is not None: assert image is not str, "Frames should be a path to an image"
207
  if audio is not None:
208
- assert isinstance(audio, torch.Tensor) and audio.shape[0] == 1 and len(audio.shape) == 2, "Audio must be a PyTorch tensor of shape (1, T)"
209
  if text_list is not None:
210
  assert isinstance(text_list, list) and len(text_list) == 1, "Text list must be a list of strings of length 1"
211
  if image is not None:
212
- image = Image.open(image).convert('RGB')
213
- transform = transforms.Compose([
214
- transforms.Resize((224, 224)),
215
- transforms.ToTensor(),
216
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
217
- std=[0.229, 0.224, 0.225])
218
- ])
219
- image = transform(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  embeddings = {}
221
  if image is not None:
222
  embeddings['visual_feats'] = self.visual_embedder(image)
@@ -233,3 +259,4 @@ class Triad(nn.Module):
233
  embeddings['text_audio_sim_matrix'] = self.compute_similarity_matrix(embeddings['text_feats'], embeddings['audio_feats'])
234
  return embeddings
235
 
 
 
205
  assert image is not None or audio is not None or text_list is not None, "At least one modality must be provided"
206
  if image is not None: assert image is not str, "Frames should be a path to an image"
207
  if audio is not None:
208
+ assert isinstance(audio, torch.Tensor) and len(audio.shape) == 2, "Audio must be a PyTorch tensor of shape (B, T)"
209
  if text_list is not None:
210
  assert isinstance(text_list, list) and len(text_list) == 1, "Text list must be a list of strings of length 1"
211
  if image is not None:
212
+ device = next(self.parameters()).device
213
+
214
+ # Handle batch of file paths
215
+ if isinstance(image, list):
216
+ # Process a list of image paths
217
+ processed_images = []
218
+ for img_path in image:
219
+ img = Image.open(img_path).convert('RGB')
220
+ transform = transforms.Compose([
221
+ transforms.Resize((224, 224)),
222
+ transforms.ToTensor(),
223
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
224
+ ])
225
+ processed_img = transform(img).to(device)
226
+ processed_images.append(processed_img)
227
+ image = torch.stack(processed_images, dim=0) # [B, 3, 224, 224]
228
+
229
+ # Handle single file path
230
+ elif isinstance(image, str):
231
+ img = Image.open(image).convert('RGB')
232
+ transform = transforms.Compose([
233
+ transforms.Resize((224, 224)),
234
+ transforms.ToTensor(),
235
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
236
+ ])
237
+ image = transform(img).to(device).unsqueeze(0) # Add batch dimension [1, 3, 224, 224]
238
+
239
+ # Handle tensor input (assume it's already processed but may need device transfer)
240
+ elif isinstance(image, torch.Tensor):
241
+ # If single image without batch dimension
242
+ if image.dim() == 3:
243
+ image = image.unsqueeze(0) # Add batch dimension
244
+ image = image.to(device)
245
+
246
  embeddings = {}
247
  if image is not None:
248
  embeddings['visual_feats'] = self.visual_embedder(image)
 
259
  embeddings['text_audio_sim_matrix'] = self.compute_similarity_matrix(embeddings['text_feats'], embeddings['audio_feats'])
260
  return embeddings
261
 
262
+