Added batch support for images
Browse files- hf_model.py +33 -9
hf_model.py
CHANGED
|
@@ -186,16 +186,40 @@ class Triad(nn.Module):
|
|
| 186 |
if text_list is not None:
|
| 187 |
assert isinstance(text_list, list) and len(text_list) == 1, "Text list must be a list of strings of length 1"
|
| 188 |
if image is not None:
|
| 189 |
-
image = Image.open(image).convert('RGB')
|
| 190 |
-
transform = transforms.Compose([
|
| 191 |
-
transforms.Resize((224, 224)),
|
| 192 |
-
transforms.ToTensor(),
|
| 193 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 194 |
-
std=[0.229, 0.224, 0.225])
|
| 195 |
-
])
|
| 196 |
-
image = transform(image)
|
| 197 |
device = next(self.parameters()).device
|
| 198 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
embeddings = {}
|
| 200 |
if image is not None:
|
| 201 |
embeddings['visual_feats'] = self.visual_embedder(image)
|
|
|
|
| 186 |
if text_list is not None:
|
| 187 |
assert isinstance(text_list, list) and len(text_list) == 1, "Text list must be a list of strings of length 1"
|
| 188 |
if image is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
device = next(self.parameters()).device
|
| 190 |
+
|
| 191 |
+
# Handle batch of file paths
|
| 192 |
+
if isinstance(image, list):
|
| 193 |
+
# Process a list of image paths
|
| 194 |
+
processed_images = []
|
| 195 |
+
for img_path in image:
|
| 196 |
+
img = Image.open(img_path).convert('RGB')
|
| 197 |
+
transform = transforms.Compose([
|
| 198 |
+
transforms.Resize((224, 224)),
|
| 199 |
+
transforms.ToTensor(),
|
| 200 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 201 |
+
])
|
| 202 |
+
processed_img = transform(img).to(device)
|
| 203 |
+
processed_images.append(processed_img)
|
| 204 |
+
image = torch.stack(processed_images, dim=0) # [B, 3, 224, 224]
|
| 205 |
+
|
| 206 |
+
# Handle single file path
|
| 207 |
+
elif isinstance(image, str):
|
| 208 |
+
img = Image.open(image).convert('RGB')
|
| 209 |
+
transform = transforms.Compose([
|
| 210 |
+
transforms.Resize((224, 224)),
|
| 211 |
+
transforms.ToTensor(),
|
| 212 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 213 |
+
])
|
| 214 |
+
image = transform(img).to(device).unsqueeze(0) # Add batch dimension [1, 3, 224, 224]
|
| 215 |
+
|
| 216 |
+
# Handle tensor input (assume it's already processed but may need device transfer)
|
| 217 |
+
elif isinstance(image, torch.Tensor):
|
| 218 |
+
# If single image without batch dimension
|
| 219 |
+
if image.dim() == 3:
|
| 220 |
+
image = image.unsqueeze(0) # Add batch dimension
|
| 221 |
+
image = image.to(device)
|
| 222 |
+
|
| 223 |
embeddings = {}
|
| 224 |
if image is not None:
|
| 225 |
embeddings['visual_feats'] = self.visual_embedder(image)
|