WUBIAO commited on
Commit
1ed22f6
·
verified ·
1 Parent(s): 99420be

Upload clip_encoder.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. clip_encoder.py +135 -0
clip_encoder.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
6
+
7
+
8
+ class CLIPVisionTower(nn.Module):
9
+ def __init__(self, vision_tower, args, delay_load=False):
10
+ super().__init__()
11
+
12
+ self.is_loaded = False
13
+
14
+ self.vision_tower_name = vision_tower
15
+ self.select_layer = args.mm_vision_select_layer
16
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
17
+
18
+ # ##########################################################################
19
+ # print(delay_load)
20
+ # print(getattr(args, 'unfreeze_mm_vision_tower', False))
21
+ # ##########################################################################
22
+
23
+ if not delay_load:
24
+ self.load_model()
25
+ elif getattr(args, 'unfreeze_mm_vision_tower', False):
26
+ self.load_model()
27
+ else:
28
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
29
+
30
+
31
+
32
+ def load_model(self, device_map=None):
33
+ if self.is_loaded:
34
+ print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
35
+ return
36
+
37
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
38
+ # self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
39
+
40
+
41
+ # ##########################################################################
42
+ self.vision_tower = CLIPVisionModel.from_pretrained('laion/CLIP-ViT-bigG-14-laion2B-39B-b160k', device_map=device_map)
43
+
44
+
45
+ SEEC = False
46
+ if SEEC == True:
47
+ print('++++++++++++++++++++++++++ SeeClick Used ++++++++++++++++++++++++++++++')
48
+ print()
49
+ new_state_dict = torch.load('vision_encoder.pth')
50
+ self.vision_tower.load_state_dict(new_state_dict, strict=False)
51
+ print('++++++++++++++++++++++++++ SeeClick Used ++++++++++++++++++++++++++++++')
52
+ else:
53
+ print('++++++++++++++++++++++++++ BigG Used ++++++++++++++++++++++++++++++')
54
+ print()
55
+ print('++++++++++++++++++++++++++ BigG Used ++++++++++++++++++++++++++++++')
56
+
57
+
58
+
59
+
60
+
61
+
62
+ # from transformers import AutoModelForCausalLM, AutoTokenizer
63
+ # model = AutoModelForCausalLM.from_pretrained("/home/kyr/BiaoWu/SeeClick/SeeClick", device_map="cuda", trust_remote_code=True, bf16=True).eval()
64
+ # model.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设置 device
65
+ # model.transformer.visual.hidden_size = 1664
66
+ # self.vision_tower = model.transformer.visual
67
+ # # self.vision_tower.hidden_size = 1664
68
+ # print(dir(self.vision_tower))
69
+
70
+ #import pdb; pdb.set_trace()
71
+ #print(self.vision_tower)
72
+ #print(self.vision_tower.hidden_size)
73
+ print('==========================================================')
74
+ ##########################################################################
75
+
76
+
77
+
78
+ self.vision_tower.requires_grad_(False)
79
+ self.is_loaded = True
80
+
81
+ def feature_select(self, image_forward_outs):
82
+ image_features = image_forward_outs.hidden_states[self.select_layer]
83
+ if self.select_feature == 'patch':
84
+ image_features = image_features[:, 1:]
85
+ elif self.select_feature == 'cls_patch':
86
+ image_features = image_features
87
+ else:
88
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
89
+ return image_features
90
+
91
+ @torch.no_grad()
92
+ def forward(self, images):
93
+ if type(images) is list:
94
+ image_features = []
95
+ for image in images:
96
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
97
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
98
+ image_features.append(image_feature)
99
+ else:
100
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
101
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
102
+
103
+ return image_features
104
+
105
+ @property
106
+ def dummy_feature(self):
107
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
108
+
109
+ @property
110
+ def dtype(self):
111
+ return self.vision_tower.dtype
112
+
113
+ @property
114
+ def device(self):
115
+ return self.vision_tower.device
116
+
117
+ @property
118
+ def config(self):
119
+ if self.is_loaded:
120
+ return self.vision_tower.config
121
+ else:
122
+ return self.cfg_only
123
+
124
+ @property
125
+ def hidden_size(self):
126
+ return self.config.hidden_size
127
+
128
+ @property
129
+ def num_patches_per_side(self):
130
+ return self.config.image_size // self.config.patch_size
131
+
132
+ @property
133
+ def num_patches(self):
134
+ return (self.config.image_size // self.config.patch_size) ** 2
135
+