Fabrice-TIERCELIN commited on
Commit
7662e68
·
verified ·
1 Parent(s): 305c29a

Upload clip_encoder.py

Browse files
llava/model/multimodal_encoder/clip_encoder.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5
+
6
+
7
+ class CLIPVisionTower(nn.Module):
8
+ def __init__(self, vision_tower, args, delay_load=False):
9
+ super().__init__()
10
+
11
+ self.is_loaded = False
12
+
13
+ self.vision_tower_name = vision_tower
14
+ print(f'Loading vision tower: {self.vision_tower_name}')
15
+ self.select_layer = args.mm_vision_select_layer
16
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
17
+
18
+ if not delay_load:
19
+ self.load_model()
20
+ else:
21
+ # self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
22
+ self.cfg_only = CLIPVisionConfig.from_pretrained(
23
+ self.vision_tower_name)
24
+
25
+ def load_model(self):
26
+ self.image_processor = CLIPImageProcessor.from_pretrained(
27
+ self.vision_tower_name)
28
+ self.vision_tower = CLIPVisionModel.from_pretrained(
29
+ self.vision_tower_name)
30
+ self.vision_tower.requires_grad_(False)
31
+
32
+ self.is_loaded = True
33
+
34
+ def feature_select(self, image_forward_outs):
35
+ image_features = image_forward_outs.hidden_states[self.select_layer]
36
+ if self.select_feature == 'patch':
37
+ image_features = image_features[:, 1:]
38
+ elif self.select_feature == 'cls_patch':
39
+ image_features = image_features
40
+ else:
41
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
42
+ return image_features
43
+
44
+ @torch.no_grad()
45
+ def forward(self, images):
46
+ if type(images) is list:
47
+ image_features = []
48
+ for image in images:
49
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
50
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
51
+ image_features.append(image_feature)
52
+ else:
53
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
54
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
55
+
56
+ return image_features
57
+
58
+ @property
59
+ def dummy_feature(self):
60
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
61
+
62
+ @property
63
+ def dtype(self):
64
+ return self.vision_tower.dtype
65
+
66
+ @property
67
+ def device(self):
68
+ return self.vision_tower.device
69
+
70
+ @property
71
+ def config(self):
72
+ if self.is_loaded:
73
+ return self.vision_tower.config
74
+ else:
75
+ return self.cfg_only
76
+
77
+ @property
78
+ def hidden_size(self):
79
+ return self.config.hidden_size
80
+
81
+ @property
82
+ def num_patches(self):
83
+ return (self.config.image_size // self.config.patch_size) ** 2