GulbaharAI commited on
Commit
952da02
·
verified ·
1 Parent(s): 749cd52

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +116 -0
model.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torchvision.transforms as T
6
+ from transformers.models.clip.modeling_clip import (
7
+ CLIPTextTransformer,
8
+ CLIPPreTrainedModel,
9
+ CLIPModel,
10
+ )
11
+
12
+
13
+ class CLIPImageEncoder(CLIPPreTrainedModel):
14
+ @staticmethod
15
+ def from_pretrained(
16
+ global_model_name_or_path,
17
+ cache_dir
18
+ ):
19
+ model = CLIPModel.from_pretrained(
20
+ global_model_name_or_path,
21
+ subfolder="image_prompt_encoder",
22
+ cache_dir=cache_dir
23
+ )
24
+ vision_model = model.vision_model
25
+ visual_projection = model.visual_projection
26
+ vision_processor = T.Normalize(
27
+ (0.48145466, 0.4578275, 0.40821073),
28
+ (0.26862954, 0.26130258, 0.27577711),
29
+ )
30
+ return CLIPImageEncoder(
31
+ vision_model,
32
+ visual_projection,
33
+ vision_processor,
34
+ )
35
+
36
+ def __init__(
37
+ self,
38
+ vision_model,
39
+ visual_projection,
40
+ vision_processor,
41
+ ):
42
+ super().__init__(vision_model.config)
43
+ self.vision_model = vision_model
44
+ self.visual_projection = visual_projection
45
+ self.vision_processor = vision_processor
46
+
47
+ self.image_size = vision_model.config.image_size
48
+
49
+ def forward(self, object_pixel_values):
50
+ b, c, h, w = object_pixel_values.shape
51
+
52
+ if h != self.image_size or w != self.image_size:
53
+ h, w = self.image_size, self.image_size
54
+ object_pixel_values = F.interpolate(
55
+ object_pixel_values, (h, w), mode="bilinear", antialias=True
56
+ )
57
+
58
+ object_pixel_values = self.vision_processor(object_pixel_values)
59
+ object_embeds = self.vision_model(object_pixel_values)[1]
60
+ object_embeds = self.visual_projection(object_embeds)
61
+ object_embeds = object_embeds.view(b, 1, -1)
62
+ return object_embeds
63
+
64
+
65
+ class MLP(nn.Module):
66
+ def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True):
67
+ super().__init__()
68
+ if use_residual:
69
+ assert in_dim == out_dim
70
+ self.layernorm = nn.LayerNorm(in_dim)
71
+ self.fc1 = nn.Linear(in_dim, hidden_dim)
72
+ self.fc2 = nn.Linear(hidden_dim, out_dim)
73
+ self.use_residual = use_residual
74
+ self.act_fn = nn.GELU()
75
+
76
+ def forward(self, x):
77
+ residual = x
78
+ x = self.layernorm(x)
79
+ x = self.fc1(x)
80
+ x = self.act_fn(x)
81
+ x = self.fc2(x)
82
+ if self.use_residual:
83
+ x = x + residual
84
+ return x
85
+
86
+ class PostfuseModule(nn.Module):
87
+ def __init__(self, embed_dim, embed_dim_img):
88
+ super().__init__()
89
+ self.mlp1 = MLP(embed_dim_img, embed_dim, embed_dim, use_residual=False)
90
+ self.mlp2 = MLP(embed_dim, embed_dim, embed_dim, use_residual=True)
91
+ self.layer_norm = nn.LayerNorm(embed_dim)
92
+
93
+ @property
94
+ def dtype(self):
95
+ try:
96
+ return next(self.parameters()).dtype
97
+ except StopIteration:
98
+ return torch.float32
99
+
100
+ def fuse_fn(self, object_embeds):
101
+ text_object_embeds = self.mlp1(object_embeds)
102
+ text_object_embeds = self.mlp2(text_object_embeds)
103
+ text_object_embeds = self.layer_norm(text_object_embeds)
104
+ return text_object_embeds
105
+
106
+ def forward(
107
+ self,
108
+ text_embeds,
109
+ object_embeds,
110
+ fuse_index,
111
+ ) -> torch.Tensor:
112
+ text_object_embed = self.fuse_fn(object_embeds)
113
+ text_embeds_new = text_embeds.clone()
114
+ text_embeds_new[:, fuse_index, :] = text_object_embed.squeeze(1)
115
+
116
+ return text_embeds_new