ZephyrCode commited on
Commit
b170a99
·
verified ·
1 Parent(s): 95e08a8

Update vision_encoder.py

Browse files
Files changed (1) hide show
  1. vision_encoder.py +5 -205
vision_encoder.py CHANGED
@@ -26,7 +26,6 @@ except ImportError:
26
 
27
 
28
  class Attention(nn.Module):
29
-
30
  def __init__(self, dim, num_heads=16, use_flash_attn=False):
31
  super().__init__()
32
  assert dim % num_heads == 0, "dim should be divisible by num_heads"
@@ -76,11 +75,10 @@ class Attention(nn.Module):
76
 
77
 
78
  class VitBlock(nn.Module):
79
-
80
  def __init__(self, embed_dim, use_flash_attn=False):
81
  super().__init__()
82
  self.attn = Attention(embed_dim, use_flash_attn=use_flash_attn)
83
- self.mlp = MLP(embed_dim, 4304)
84
  self.norm1 = nn.LayerNorm(embed_dim)
85
  self.norm2 = nn.LayerNorm(embed_dim)
86
 
@@ -91,14 +89,13 @@ class VitBlock(nn.Module):
91
 
92
 
93
  class VisionTransformer(nn.Module):
94
-
95
  def __init__(self, use_flash_attn=False):
96
  super().__init__()
97
 
98
  embed_len = 729
99
- embed_dim = 4608
100
 
101
- self.patch_embed = LinearPatchEmbedding()
102
  self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
103
  self.blocks = nn.Sequential(
104
  *[VitBlock(embed_dim, use_flash_attn=use_flash_attn) for _ in range(27)]
@@ -113,21 +110,10 @@ class VisionTransformer(nn.Module):
113
  return self.norm(x)
114
 
115
 
116
- class EncoderWrapper(nn.Module):
117
-
118
- def __init__(self, use_flash_attn=False):
119
- super().__init__()
120
- self.model = nn.ModuleDict({"visual": VisionTransformer(use_flash_attn)})
121
-
122
- def forward(self, x):
123
- return self.model["visual"](x)
124
-
125
-
126
  class LinearPatchEmbedding(nn.Module):
127
-
128
- def __init__(self):
129
  super().__init__()
130
- self.linear = nn.Linear(588, 4608)
131
 
132
  def forward(self, x):
133
  b, c, hp1, wp2 = x.shape
@@ -136,190 +122,4 @@ class LinearPatchEmbedding(nn.Module):
136
  x = x.reshape(b, c, h, p1, w, p2)
137
  x = x.permute(0, 2, 4, 1, 3, 5)
138
  x = x.reshape(b, h * w, c * p1 * p2)
139
-
140
  return self.linear(x)
141
-
142
-
143
- class MLP(nn.Module):
144
- def __init__(
145
- self,
146
- in_features: int,
147
- hidden_features: int = None,
148
- out_features: int = None,
149
- ) -> None:
150
- super().__init__()
151
- out_features = out_features or in_features
152
- hidden_features = hidden_features or in_features
153
- self.fc1 = nn.Linear(in_features, hidden_features)
154
- self.act = nn.GELU(approximate="tanh")
155
- self.fc2 = nn.Linear(hidden_features, out_features)
156
-
157
- torch.nn.init.kaiming_normal_(
158
- self.fc1.weight, mode="fan_in", nonlinearity="relu"
159
- )
160
- torch.nn.init.kaiming_normal_(
161
- self.fc2.weight, mode="fan_in", nonlinearity="relu"
162
- )
163
-
164
- def forward(self, x: torch.Tensor) -> torch.Tensor:
165
- x = self.fc1(x)
166
- x = self.act(x)
167
- x = self.fc2(x)
168
- return x
169
-
170
-
171
- class VisionProjection(nn.Module):
172
- def __init__(self):
173
- super().__init__()
174
-
175
- image_embedding_dim = 2304
176
- model_dim = 2048
177
- hidden_dim = model_dim * 4
178
-
179
- self.mlp = MLP(image_embedding_dim * 2, hidden_dim, model_dim)
180
-
181
- @property
182
- def device(self):
183
- return self.mlp.fc1.weight.device
184
-
185
- def forward(self, x):
186
- return self.mlp(x)
187
-
188
-
189
- def create_patches(image, patch_size=(378, 378)):
190
- assert image.dim() == 3, "Image must be in CHW format"
191
-
192
- _, height, width = image.shape # Channels, Height, Width
193
- patch_height, patch_width = patch_size
194
-
195
- if height == patch_height and width == patch_width:
196
- return []
197
-
198
- # Iterate over the image and create patches
199
- patches = []
200
- for i in range(0, height, patch_height):
201
- row_patches = []
202
- for j in range(0, width, patch_width):
203
- patch = image[:, i : i + patch_height, j : j + patch_width]
204
- row_patches.append(patch)
205
- patches.append(torch.stack(row_patches))
206
- return patches
207
-
208
-
209
- class VisionEncoder(nn.Module):
210
-
211
- def __init__(self, use_flash_attn=False):
212
- super().__init__()
213
-
214
- self.encoder = EncoderWrapper(use_flash_attn)
215
- self.projection = VisionProjection()
216
- self.supported_sizes = [(378, 378), (378, 756), (756, 378), (756, 756)]
217
-
218
- @property
219
- def device(self):
220
- return self.projection.mlp.fc1.weight.device
221
-
222
- @property
223
- def dtype(self):
224
- return self.projection.mlp.fc1.weight.dtype
225
-
226
- def preprocess(self, image: PIL.Image.Image):
227
- width, height = image.size
228
- max_dim = max(width, height)
229
- if max_dim < 512:
230
- im_size = (378, 378)
231
- else:
232
- aspect_ratio = width / height
233
- im_size = min(
234
- self.supported_sizes,
235
- key=lambda size: (
236
- abs((size[1] / size[0]) - aspect_ratio),
237
- abs(size[0] - width) + abs(size[1] - height),
238
- ),
239
- )
240
-
241
- return Compose(
242
- [
243
- Resize(size=im_size, interpolation=InterpolationMode.BICUBIC),
244
- ToImage(),
245
- ToDtype(torch.float32, scale=True),
246
- Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
247
- ]
248
- )(image)
249
-
250
- def forward(
251
- self, images: Union[PIL.Image.Image, list[PIL.Image.Image], torch.Tensor]
252
- ) -> torch.Tensor:
253
- im_list = None
254
- if isinstance(images, torch.Tensor):
255
- # Input must have dimensions (B, C, H, W)
256
- assert (
257
- len(images.shape) == 4
258
- ), "Tensor input must have dimensions (B, C, H, W)"
259
- im_list = list(images)
260
- elif isinstance(images, PIL.Image.Image):
261
- im_list = [images]
262
- elif isinstance(images, list):
263
- im_list = images
264
- else:
265
- raise ValueError(
266
- "Input must be a PIL image, list of PIL images, or a tensor"
267
- )
268
-
269
- # Preprocess unless the images are already tensors (indicating that
270
- # they have already been preprocessed)
271
- if not isinstance(im_list[0], torch.Tensor):
272
- im_list = [self.preprocess(im.convert("RGB")) for im in im_list]
273
-
274
- patches = [create_patches(im) for im in im_list]
275
- flat_patches = [patch for image_patches in patches for patch in image_patches]
276
-
277
- # Images may be variable size, and need to be resized to a common size after
278
- # creating patches.
279
- resized_images = [
280
- F.interpolate(im.unsqueeze(0), size=(378, 378), mode="bilinear")
281
- for im in im_list
282
- ]
283
-
284
- combined_images = torch.cat([*resized_images, *flat_patches], dim=0)
285
- combined_images = combined_images.to(self.device, dtype=self.dtype)
286
-
287
- combined_features = self.encoder(combined_images)
288
-
289
- full_img_features = combined_features[: len(im_list)]
290
- patch_features = (
291
- combined_features[len(im_list) :].transpose(1, 2).view(-1, 4608, 27, 27)
292
- )
293
-
294
- # Reshape patch features back to their original structure
295
- reshaped_patch_features = []
296
- patch_idx = 0
297
- for i, patch_set in enumerate(patches):
298
- if len(patch_set) == 0:
299
- reshaped_patch_features.append(
300
- full_img_features[i].transpose(0, 1).view(4608, 27, 27)
301
- )
302
- else:
303
- sample_features = []
304
- for row_patches in patch_set:
305
- row_len = len(row_patches)
306
- row_features = patch_features[
307
- patch_idx : patch_idx + row_len
308
- ] # row_len, T, C
309
- row_features = torch.cat(
310
- list(row_features), dim=2
311
- ) # T, C * row_len
312
- patch_idx += row_len
313
- sample_features.append(row_features)
314
- sample_features = torch.cat(sample_features, dim=1)
315
- sample_features = F.interpolate(
316
- sample_features.unsqueeze(0), size=(27, 27), mode="bilinear"
317
- ).squeeze(0)
318
- reshaped_patch_features.append(sample_features)
319
- reshaped_patch_features = (
320
- torch.stack(reshaped_patch_features).view(-1, 4608, 729).transpose(1, 2)
321
- )
322
-
323
- final_features = torch.cat([full_img_features, reshaped_patch_features], dim=2)
324
-
325
- return self.projection(final_features)
 
26
 
27
 
28
  class Attention(nn.Module):
 
29
  def __init__(self, dim, num_heads=16, use_flash_attn=False):
30
  super().__init__()
31
  assert dim % num_heads == 0, "dim should be divisible by num_heads"
 
75
 
76
 
77
  class VitBlock(nn.Module):
 
78
  def __init__(self, embed_dim, use_flash_attn=False):
79
  super().__init__()
80
  self.attn = Attention(embed_dim, use_flash_attn=use_flash_attn)
81
+ self.mlp = MLP(embed_dim, 2304)
82
  self.norm1 = nn.LayerNorm(embed_dim)
83
  self.norm2 = nn.LayerNorm(embed_dim)
84
 
 
89
 
90
 
91
  class VisionTransformer(nn.Module):
 
92
  def __init__(self, use_flash_attn=False):
93
  super().__init__()
94
 
95
  embed_len = 729
96
+ embed_dim = 1152 # Updated to match checkpoint
97
 
98
+ self.patch_embed = LinearPatchEmbedding(embed_dim)
99
  self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
100
  self.blocks = nn.Sequential(
101
  *[VitBlock(embed_dim, use_flash_attn=use_flash_attn) for _ in range(27)]
 
110
  return self.norm(x)
111
 
112
 
 
 
 
 
 
 
 
 
 
 
113
  class LinearPatchEmbedding(nn.Module):
114
+ def __init__(self, embed_dim=1152): # Updated default to match checkpoint
 
115
  super().__init__()
116
+ self.linear = nn.Linear(588, embed_dim) # Match saved model
117
 
118
  def forward(self, x):
119
  b, c, hp1, wp2 = x.shape
 
122
  x = x.reshape(b, c, h, p1, w, p2)
123
  x = x.permute(0, 2, 4, 1, 3, 5)
124
  x = x.reshape(b, h * w, c * p1 * p2)
 
125
  return self.linear(x)