Update hf_model.py
Browse files- hf_model.py +36 -9
hf_model.py
CHANGED
|
@@ -205,18 +205,44 @@ class Triad(nn.Module):
|
|
| 205 |
assert image is not None or audio is not None or text_list is not None, "At least one modality must be provided"
|
| 206 |
if image is not None: assert image is not str, "Frames should be a path to an image"
|
| 207 |
if audio is not None:
|
| 208 |
-
assert isinstance(audio, torch.Tensor) and
|
| 209 |
if text_list is not None:
|
| 210 |
assert isinstance(text_list, list) and len(text_list) == 1, "Text list must be a list of strings of length 1"
|
| 211 |
if image is not None:
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
embeddings = {}
|
| 221 |
if image is not None:
|
| 222 |
embeddings['visual_feats'] = self.visual_embedder(image)
|
|
@@ -233,3 +259,4 @@ class Triad(nn.Module):
|
|
| 233 |
embeddings['text_audio_sim_matrix'] = self.compute_similarity_matrix(embeddings['text_feats'], embeddings['audio_feats'])
|
| 234 |
return embeddings
|
| 235 |
|
|
|
|
|
|
| 205 |
assert image is not None or audio is not None or text_list is not None, "At least one modality must be provided"
|
| 206 |
if image is not None: assert image is not str, "Frames should be a path to an image"
|
| 207 |
if audio is not None:
|
| 208 |
+
assert isinstance(audio, torch.Tensor) and len(audio.shape) == 2, "Audio must be a PyTorch tensor of shape (B, T)"
|
| 209 |
if text_list is not None:
|
| 210 |
assert isinstance(text_list, list) and len(text_list) == 1, "Text list must be a list of strings of length 1"
|
| 211 |
if image is not None:
|
| 212 |
+
device = next(self.parameters()).device
|
| 213 |
+
|
| 214 |
+
# Handle batch of file paths
|
| 215 |
+
if isinstance(image, list):
|
| 216 |
+
# Process a list of image paths
|
| 217 |
+
processed_images = []
|
| 218 |
+
for img_path in image:
|
| 219 |
+
img = Image.open(img_path).convert('RGB')
|
| 220 |
+
transform = transforms.Compose([
|
| 221 |
+
transforms.Resize((224, 224)),
|
| 222 |
+
transforms.ToTensor(),
|
| 223 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 224 |
+
])
|
| 225 |
+
processed_img = transform(img).to(device)
|
| 226 |
+
processed_images.append(processed_img)
|
| 227 |
+
image = torch.stack(processed_images, dim=0) # [B, 3, 224, 224]
|
| 228 |
+
|
| 229 |
+
# Handle single file path
|
| 230 |
+
elif isinstance(image, str):
|
| 231 |
+
img = Image.open(image).convert('RGB')
|
| 232 |
+
transform = transforms.Compose([
|
| 233 |
+
transforms.Resize((224, 224)),
|
| 234 |
+
transforms.ToTensor(),
|
| 235 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 236 |
+
])
|
| 237 |
+
image = transform(img).to(device).unsqueeze(0) # Add batch dimension [1, 3, 224, 224]
|
| 238 |
+
|
| 239 |
+
# Handle tensor input (assume it's already processed but may need device transfer)
|
| 240 |
+
elif isinstance(image, torch.Tensor):
|
| 241 |
+
# If single image without batch dimension
|
| 242 |
+
if image.dim() == 3:
|
| 243 |
+
image = image.unsqueeze(0) # Add batch dimension
|
| 244 |
+
image = image.to(device)
|
| 245 |
+
|
| 246 |
embeddings = {}
|
| 247 |
if image is not None:
|
| 248 |
embeddings['visual_feats'] = self.visual_embedder(image)
|
|
|
|
| 259 |
embeddings['text_audio_sim_matrix'] = self.compute_similarity_matrix(embeddings['text_feats'], embeddings['audio_feats'])
|
| 260 |
return embeddings
|
| 261 |
|
| 262 |
+
|