biancaganescu commited on
Commit
2f9bc55
·
verified ·
1 Parent(s): 542f9dd

Upload model_channel_attention.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model_channel_attention.py +236 -0
model_channel_attention.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from torchvision.ops import SqueezeExcitation
6
+
7
+
8
+ class DualStreamTransformer(nn.Module):
9
+ def __init__(
10
+ self,
11
+ vocab_size: int,
12
+ d_model: int = 768,
13
+ n_head: int = 6,
14
+ d_hid: int = 768,
15
+ num_encoder_layers: int = 4,
16
+ num_decoder_layers: int = 4,
17
+ dino_dim: int = 768,
18
+ dropout: float = 0.1,
19
+ ):
20
+ super().__init__()
21
+ self.vocab_size = vocab_size
22
+ self.d_model = d_model
23
+ self.n_head = n_head
24
+ self.d_hid = d_hid
25
+ self.num_encoder_layers = num_encoder_layers
26
+ self.num_decoder_layers = num_decoder_layers
27
+ self.dino_dim = dino_dim
28
+ self.dropout = dropout
29
+
30
+ self.text_embedding = SimpleTextEmbedding(vocab_size, d_model)
31
+ self.image_embedding = DinoImageEmbedding(dino_dim, d_model)
32
+
33
+ self.image_encoder = Encoder(
34
+ d_model, n_head, d_hid, num_encoder_layers, dropout
35
+ )
36
+
37
+ self.decoder = MultimodalDecoder(
38
+ d_model, n_head, d_hid, num_decoder_layers, dropout
39
+ )
40
+
41
+ self.output_layer = nn.Linear(d_model, vocab_size)
42
+
43
+ def forward(
44
+ self, input_ids, dino_embedding=None, padding_mask=None, use_image: bool = False
45
+ ):
46
+ embedded = self.text_embedding(input_ids)
47
+
48
+ if (
49
+ use_image
50
+ and dino_embedding is not None
51
+ and not torch.all(dino_embedding == 0)
52
+ ):
53
+ image_embedded = self.image_embedding(dino_embedding)
54
+ image_encoded = self.image_encoder(image_embedded)
55
+ else:
56
+ image_encoded = None
57
+
58
+ seq_len = embedded.size(1)
59
+
60
+ tgt_mask = self.decoder.generate_square_subsequent_mask(seq_len).to(
61
+ embedded.device
62
+ )
63
+
64
+ decoder_output = self.decoder(
65
+ tgt=embedded,
66
+ image_memory=image_encoded,
67
+ tgt_mask=tgt_mask,
68
+ tgt_key_padding_mask=padding_mask,
69
+ )
70
+
71
+ output = self.output_layer(decoder_output)
72
+
73
+ return output
74
+
75
+
76
+ class SimpleTextEmbedding(nn.Module):
77
+ def __init__(self, vocab_size, d_model, max_len=128, dropout=0.1):
78
+ super().__init__()
79
+ self.token_embedding = nn.Embedding(vocab_size, d_model)
80
+ self.position_embedding = nn.Embedding(max_len, d_model)
81
+ self.layer_norm = nn.LayerNorm(d_model)
82
+ self.dropout = nn.Dropout(p=dropout)
83
+ self.d_model = d_model
84
+
85
+ def forward(self, x):
86
+ batch_size, seq_len = x.size()
87
+
88
+ positions = (
89
+ torch.arange(seq_len, device=x.device)
90
+ .unsqueeze(0)
91
+ .expand(batch_size, seq_len)
92
+ )
93
+ scale = math.sqrt(self.d_model)
94
+
95
+ token_emb = self.token_embedding(x) * scale
96
+ pos_emb = self.position_embedding(positions)
97
+
98
+ embeddings = self.dropout(token_emb + pos_emb)
99
+
100
+ return self.layer_norm(embeddings)
101
+
102
+
103
+ class DinoImageEmbedding(nn.Module):
104
+ def __init__(self, dino_dim, d_model):
105
+ super().__init__()
106
+ self.projection_layer = nn.Linear(dino_dim, d_model)
107
+
108
+ def forward(self, x):
109
+ return self.projection_layer(x.unsqueeze(1))
110
+
111
+
112
+ class Encoder(nn.Module):
113
+ def __init__(
114
+ self, d_model: int, n_head: int, d_hid: int, n_layers: int, dropout: float = 0.1
115
+ ):
116
+ super().__init__()
117
+ encoder_layer = nn.TransformerEncoderLayer(
118
+ d_model, n_head, d_hid, dropout, activation="gelu", batch_first=True
119
+ )
120
+ self.encoder = nn.TransformerEncoder(encoder_layer, n_layers)
121
+ self.se = SqueezeExcitation(d_model, d_model // 16) # reduction factor is 16
122
+
123
+ def forward(self, src, src_mask=None, src_key_padding_mask=None):
124
+ output = self.encoder(src, src_mask, src_key_padding_mask)
125
+ output = output.permute(0, 2, 1).unsqueeze(2)
126
+ output = self.se(output)
127
+ output = output.squeeze(2).permute(0, 2, 1)
128
+ return output
129
+
130
+
131
+ class DynamicGating(nn.Module):
132
+ def __init__(self, d_model: int, dropout: float = 0.1):
133
+ super().__init__()
134
+ self.gate = nn.Linear(d_model * 2, d_model)
135
+ self.dropout = nn.Dropout(dropout)
136
+ self.layer_norm = nn.LayerNorm(d_model)
137
+
138
+ def forward(self, text_features, image_features):
139
+ if image_features is None:
140
+ return text_features
141
+
142
+ combined = torch.cat([text_features, image_features], dim=-1)
143
+
144
+ gate = torch.sigmoid(self.gate(combined))
145
+
146
+ fused = gate * text_features + (1 - gate) * image_features
147
+
148
+ fused = self.layer_norm(self.dropout(fused))
149
+
150
+ return fused
151
+
152
+
153
+ class MultimodalDecoderLayer(nn.Module):
154
+ def __init__(self, d_model: int, n_head: int, d_hid: int, dropout: float = 0.1):
155
+ super().__init__()
156
+
157
+ self.self_attn = nn.MultiheadAttention(
158
+ d_model, n_head, dropout=dropout, batch_first=True
159
+ )
160
+
161
+ self.cross_attn_txt_image = nn.MultiheadAttention(
162
+ d_model, n_head, dropout=dropout, batch_first=True
163
+ )
164
+
165
+ self.norm1 = nn.LayerNorm(d_model)
166
+ self.norm2 = nn.LayerNorm(d_model)
167
+ self.norm3 = nn.LayerNorm(d_model)
168
+
169
+ self.dropout = nn.Dropout(dropout)
170
+
171
+ self.gate = DynamicGating(d_model, dropout)
172
+
173
+ self.ff = nn.Sequential(
174
+ nn.Linear(d_model, d_hid),
175
+ nn.GELU(),
176
+ nn.Dropout(dropout),
177
+ nn.Linear(d_hid, d_model),
178
+ nn.Dropout(dropout),
179
+ )
180
+
181
+ def forward(self, tgt, image_memory, tgt_mask=None, tgt_key_padding_mask=None):
182
+ tgt_norm = self.norm1(tgt)
183
+ self_attn_output, _ = self.self_attn(
184
+ tgt_norm,
185
+ tgt_norm,
186
+ tgt_norm,
187
+ key_padding_mask=tgt_key_padding_mask,
188
+ attn_mask=tgt_mask,
189
+ is_causal=True,
190
+ )
191
+
192
+ tgt = tgt + self.dropout(self_attn_output)
193
+
194
+ if image_memory is not None:
195
+ tgt_norm = self.norm2(tgt)
196
+ cross_attn_output, _ = self.cross_attn_txt_image(
197
+ tgt_norm, image_memory, image_memory
198
+ )
199
+ cross_attn_output = self.dropout(cross_attn_output)
200
+
201
+ fused = self.gate(tgt_norm, cross_attn_output)
202
+ tgt = tgt + fused
203
+
204
+ tgt_norm = self.norm3(tgt)
205
+ ff_output = self.ff(tgt_norm)
206
+ tgt = tgt + self.dropout(ff_output)
207
+
208
+ return tgt
209
+
210
+
211
+ class MultimodalDecoder(nn.Module):
212
+ def __init__(
213
+ self, d_model: int, n_head: int, d_hid: int, n_layers: int, dropout: float = 0.1
214
+ ):
215
+ super().__init__()
216
+ self.layers = nn.ModuleList(
217
+ [
218
+ MultimodalDecoderLayer(d_model, n_head, d_hid, dropout)
219
+ for _ in range(n_layers)
220
+ ]
221
+ )
222
+
223
+ def generate_square_subsequent_mask(self, size):
224
+ mask = torch.triu(torch.ones(size, size), diagonal=1).bool()
225
+ return mask
226
+
227
+ def forward(self, tgt, image_memory, tgt_mask, tgt_key_padding_mask=None):
228
+ output = tgt
229
+ for layer in self.layers:
230
+ output = layer(
231
+ output,
232
+ image_memory,
233
+ tgt_mask=tgt_mask,
234
+ tgt_key_padding_mask=tgt_key_padding_mask,
235
+ )
236
+ return output