Vidit2003 commited on
Commit
017543a
·
verified ·
1 Parent(s): 96db00c

Update vision_transformer.py

Browse files
Files changed (1) hide show
  1. vision_transformer.py +37 -0
vision_transformer.py CHANGED
@@ -1,3 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import math
2
  from functools import partial
3
  import torch
@@ -304,6 +323,24 @@ class VisionTransformer(PreTrainedModel):
304
  attentions=None,
305
  )
306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  def get_last_selfattention(self, x):
308
  """Get attention from last block"""
309
  x = self.prepare_tokens(x)
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Mostly copy-paste from timm library.
16
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
17
+
18
+ Modified to work with Hugging Face Transformers
19
+ """
20
  import math
21
  from functools import partial
22
  import torch
 
323
  attentions=None,
324
  )
325
 
326
+ def forward_features(self, x):
327
+ """
328
+ Feature extraction method - EXACT match to local vision_transformer.py
329
+ This ensures HuggingFace and local models give identical results
330
+ """
331
+ x = self.prepare_tokens(x) # Tokenize input
332
+
333
+ for blk in self.blocks:
334
+ x = blk(x)
335
+
336
+ x_norm = self.norm(x) # Normalize tokens
337
+
338
+ return {
339
+ "x_norm_clstoken": x_norm[:, 0], # CLS token
340
+ "x_norm_patchtokens": x_norm[:, 1:], # Patch tokens
341
+ "x_prenorm": x, # Before norm
342
+ }
343
+
344
  def get_last_selfattention(self, x):
345
  """Get attention from last block"""
346
  x = self.prepare_tokens(x)