Spaces:
Runtime error
Runtime error
Update modeling_llava_qwen2.py
Browse files- modeling_llava_qwen2.py +5 -5
modeling_llava_qwen2.py
CHANGED
|
@@ -538,13 +538,13 @@ class SigLipVisionTower(nn.Module):
|
|
| 538 |
if type(images) is list:
|
| 539 |
image_features = []
|
| 540 |
for image in images:
|
| 541 |
-
image_forward_out = self.vision_tower(image.to(device=
|
| 542 |
output_hidden_states=True)
|
| 543 |
image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
|
| 544 |
assert image_features.shape[-2] == 729
|
| 545 |
image_features.append(image_feature)
|
| 546 |
else:
|
| 547 |
-
image_forward_outs = self.vision_tower(images.to(device=
|
| 548 |
output_hidden_states=True)
|
| 549 |
image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
|
| 550 |
assert image_features.shape[-2] == 729
|
|
@@ -553,7 +553,7 @@ class SigLipVisionTower(nn.Module):
|
|
| 553 |
|
| 554 |
@property
|
| 555 |
def dummy_feature(self):
|
| 556 |
-
return torch.zeros(1, self.hidden_size, device=
|
| 557 |
|
| 558 |
@property
|
| 559 |
def dtype(self):
|
|
@@ -685,9 +685,9 @@ class LlavaMetaForCausalLM(ABC):
|
|
| 685 |
image_features = self.encode_images(concat_images)
|
| 686 |
split_sizes = [image.shape[0] for image in images]
|
| 687 |
image_features = torch.split(image_features, split_sizes, dim=0)
|
| 688 |
-
image_features = [x.flatten(0, 1).to(
|
| 689 |
else:
|
| 690 |
-
image_features = self.encode_images(images).to(
|
| 691 |
|
| 692 |
# Let's just add dummy tensors if they do not exist,
|
| 693 |
# it is a headache to deal with None all the time.
|
|
|
|
| 538 |
if type(images) is list:
|
| 539 |
image_features = []
|
| 540 |
for image in images:
|
| 541 |
+
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
|
| 542 |
output_hidden_states=True)
|
| 543 |
image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
|
| 544 |
assert image_features.shape[-2] == 729
|
| 545 |
image_features.append(image_feature)
|
| 546 |
else:
|
| 547 |
+
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype),
|
| 548 |
output_hidden_states=True)
|
| 549 |
image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
|
| 550 |
assert image_features.shape[-2] == 729
|
|
|
|
| 553 |
|
| 554 |
@property
|
| 555 |
def dummy_feature(self):
|
| 556 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
| 557 |
|
| 558 |
@property
|
| 559 |
def dtype(self):
|
|
|
|
| 685 |
image_features = self.encode_images(concat_images)
|
| 686 |
split_sizes = [image.shape[0] for image in images]
|
| 687 |
image_features = torch.split(image_features, split_sizes, dim=0)
|
| 688 |
+
image_features = [x.flatten(0, 1).to(self.device) for x in image_features]
|
| 689 |
else:
|
| 690 |
+
image_features = self.encode_images(images).to(self.device)
|
| 691 |
|
| 692 |
# Let's just add dummy tensors if they do not exist,
|
| 693 |
# it is a headache to deal with None all the time.
|