Update vision_encoder.py
Browse files- vision_encoder.py +4 -4
vision_encoder.py
CHANGED
|
@@ -96,7 +96,7 @@ class VisionTransformer(nn.Module):
|
|
| 96 |
super().__init__()
|
| 97 |
|
| 98 |
embed_len = 729
|
| 99 |
-
embed_dim =
|
| 100 |
|
| 101 |
self.patch_embed = LinearPatchEmbedding()
|
| 102 |
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
|
|
@@ -288,7 +288,7 @@ class VisionEncoder(nn.Module):
|
|
| 288 |
|
| 289 |
full_img_features = combined_features[: len(im_list)]
|
| 290 |
patch_features = (
|
| 291 |
-
combined_features[len(im_list) :].transpose(1, 2).view(-1,
|
| 292 |
)
|
| 293 |
|
| 294 |
# Reshape patch features back to their original structure
|
|
@@ -297,7 +297,7 @@ class VisionEncoder(nn.Module):
|
|
| 297 |
for i, patch_set in enumerate(patches):
|
| 298 |
if len(patch_set) == 0:
|
| 299 |
reshaped_patch_features.append(
|
| 300 |
-
full_img_features[i].transpose(0, 1).view(
|
| 301 |
)
|
| 302 |
else:
|
| 303 |
sample_features = []
|
|
@@ -317,7 +317,7 @@ class VisionEncoder(nn.Module):
|
|
| 317 |
).squeeze(0)
|
| 318 |
reshaped_patch_features.append(sample_features)
|
| 319 |
reshaped_patch_features = (
|
| 320 |
-
torch.stack(reshaped_patch_features).view(-1,
|
| 321 |
)
|
| 322 |
|
| 323 |
final_features = torch.cat([full_img_features, reshaped_patch_features], dim=2)
|
|
|
|
| 96 |
super().__init__()
|
| 97 |
|
| 98 |
embed_len = 729
|
| 99 |
+
embed_dim = 4608
|
| 100 |
|
| 101 |
self.patch_embed = LinearPatchEmbedding()
|
| 102 |
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
|
|
|
|
| 288 |
|
| 289 |
full_img_features = combined_features[: len(im_list)]
|
| 290 |
patch_features = (
|
| 291 |
+
combined_features[len(im_list) :].transpose(1, 2).view(-1, 4608, 27, 27)
|
| 292 |
)
|
| 293 |
|
| 294 |
# Reshape patch features back to their original structure
|
|
|
|
| 297 |
for i, patch_set in enumerate(patches):
|
| 298 |
if len(patch_set) == 0:
|
| 299 |
reshaped_patch_features.append(
|
| 300 |
+
full_img_features[i].transpose(0, 1).view(4608, 27, 27)
|
| 301 |
)
|
| 302 |
else:
|
| 303 |
sample_features = []
|
|
|
|
| 317 |
).squeeze(0)
|
| 318 |
reshaped_patch_features.append(sample_features)
|
| 319 |
reshaped_patch_features = (
|
| 320 |
+
torch.stack(reshaped_patch_features).view(-1, 4608, 729).transpose(1, 2)
|
| 321 |
)
|
| 322 |
|
| 323 |
final_features = torch.cat([full_img_features, reshaped_patch_features], dim=2)
|