ZephyrCode commited on
Commit
76d74d2
·
verified ·
1 Parent(s): 5a8c97f

Update vision_encoder.py

Browse files
Files changed (1) hide show
  1. vision_encoder.py +4 -4
vision_encoder.py CHANGED
@@ -96,7 +96,7 @@ class VisionTransformer(nn.Module):
96
  super().__init__()
97
 
98
  embed_len = 729
99
- embed_dim = 1152
100
 
101
  self.patch_embed = LinearPatchEmbedding()
102
  self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
@@ -288,7 +288,7 @@ class VisionEncoder(nn.Module):
288
 
289
  full_img_features = combined_features[: len(im_list)]
290
  patch_features = (
291
- combined_features[len(im_list) :].transpose(1, 2).view(-1, 1152, 27, 27)
292
  )
293
 
294
  # Reshape patch features back to their original structure
@@ -297,7 +297,7 @@ class VisionEncoder(nn.Module):
297
  for i, patch_set in enumerate(patches):
298
  if len(patch_set) == 0:
299
  reshaped_patch_features.append(
300
- full_img_features[i].transpose(0, 1).view(1152, 27, 27)
301
  )
302
  else:
303
  sample_features = []
@@ -317,7 +317,7 @@ class VisionEncoder(nn.Module):
317
  ).squeeze(0)
318
  reshaped_patch_features.append(sample_features)
319
  reshaped_patch_features = (
320
- torch.stack(reshaped_patch_features).view(-1, 1152, 729).transpose(1, 2)
321
  )
322
 
323
  final_features = torch.cat([full_img_features, reshaped_patch_features], dim=2)
 
96
  super().__init__()
97
 
98
  embed_len = 729
99
+ embed_dim = 4608
100
 
101
  self.patch_embed = LinearPatchEmbedding()
102
  self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
 
288
 
289
  full_img_features = combined_features[: len(im_list)]
290
  patch_features = (
291
+ combined_features[len(im_list) :].transpose(1, 2).view(-1, 4608, 27, 27)
292
  )
293
 
294
  # Reshape patch features back to their original structure
 
297
  for i, patch_set in enumerate(patches):
298
  if len(patch_set) == 0:
299
  reshaped_patch_features.append(
300
+ full_img_features[i].transpose(0, 1).view(4608, 27, 27)
301
  )
302
  else:
303
  sample_features = []
 
317
  ).squeeze(0)
318
  reshaped_patch_features.append(sample_features)
319
  reshaped_patch_features = (
320
+ torch.stack(reshaped_patch_features).view(-1, 4608, 729).transpose(1, 2)
321
  )
322
 
323
  final_features = torch.cat([full_img_features, reshaped_patch_features], dim=2)