biancaganescu commited on
Commit
1894069
·
verified ·
1 Parent(s): a9a73ce

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +229 -0
model.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+
6
+ class DualStreamTransformer(nn.Module):
7
+ def __init__(
8
+ self,
9
+ vocab_size: int,
10
+ d_model: int = 768,
11
+ n_head: int = 8,
12
+ d_hid: int = 768,
13
+ num_encoder_layers: int = 5,
14
+ num_decoder_layers: int = 8,
15
+ dino_dim: int = 768,
16
+ dropout: float = 0.1,
17
+ ):
18
+ super().__init__()
19
+ self.vocab_size = vocab_size
20
+ self.d_model = d_model
21
+ self.n_head = n_head
22
+ self.d_hid = d_hid
23
+ self.num_encoder_layers = num_encoder_layers
24
+ self.num_decoder_layers = num_decoder_layers
25
+ self.dino_dim = dino_dim
26
+ self.dropout = dropout
27
+
28
+ self.text_embedding = self.SimpleTextEmbedding(vocab_size, d_model)
29
+ self.image_embedding = self.DinoImageEmbedding(dino_dim, d_model)
30
+
31
+ self.image_encoder = self.Encoder(
32
+ d_model, n_head, d_hid, num_encoder_layers, dropout
33
+ )
34
+
35
+ self.decoder = self.MultimodalDecoder(
36
+ d_model, n_head, d_hid, num_decoder_layers, dropout
37
+ )
38
+
39
+ self.output_layer = nn.Linear(d_model, vocab_size)
40
+
41
+ def forward(
42
+ self, input_ids, dino_embedding=None, padding_mask=None, use_image: bool = False
43
+ ):
44
+ embedded = self.text_embedding(input_ids)
45
+
46
+ if (
47
+ use_image
48
+ and dino_embedding is not None
49
+ and not torch.all(dino_embedding == 0)
50
+ ):
51
+ image_embedded = self.image_embedding(dino_embedding)
52
+ image_encoded = self.image_encoder(image_embedded)
53
+ else:
54
+ image_encoded = None
55
+
56
+ seq_len = embedded.size(1)
57
+
58
+ tgt_mask = self.decoder.generate_square_subsequent_mask(seq_len).to(
59
+ embedded.device
60
+ )
61
+
62
+ decoder_output = self.decoder(
63
+ tgt=embedded,
64
+ image_memory=image_encoded,
65
+ tgt_mask=tgt_mask,
66
+ tgt_key_padding_mask=padding_mask,
67
+ )
68
+
69
+ output = self.output_layer(decoder_output)
70
+
71
+ return output
72
+
73
+ class SimpleTextEmbedding(nn.Module):
74
+ def __init__(self, vocab_size, d_model, max_len=128, dropout=0.1):
75
+ super().__init__()
76
+ self.token_embedding = nn.Embedding(vocab_size, d_model)
77
+ self.position_embedding = nn.Embedding(max_len, d_model)
78
+ self.layer_norm = nn.LayerNorm(d_model)
79
+ self.dropout = nn.Dropout(p=dropout)
80
+ self.d_model = d_model
81
+
82
+ def forward(self, x):
83
+ batch_size, seq_len = x.size()
84
+
85
+ positions = (
86
+ torch.arange(seq_len, device=x.device)
87
+ .unsqueeze(0)
88
+ .expand(batch_size, seq_len)
89
+ )
90
+ scale = math.sqrt(self.d_model)
91
+
92
+ token_emb = self.token_embedding(x) * scale
93
+ pos_emb = self.position_embedding(positions)
94
+
95
+ embeddings = self.dropout(token_emb + pos_emb)
96
+
97
+ return self.layer_norm(embeddings)
98
+
99
+ class DinoImageEmbedding(nn.Module):
100
+ def __init__(self, dino_dim, d_model):
101
+ super().__init__()
102
+ self.projection_layer = nn.Linear(dino_dim, d_model)
103
+
104
+ def forward(self, x):
105
+ return self.projection_layer(x.unsqueeze(1))
106
+
107
+ class Encoder(nn.Module):
108
+ def __init__(
109
+ self,
110
+ d_model: int,
111
+ n_head: int,
112
+ d_hid: int,
113
+ n_layers: int,
114
+ dropout: float = 0.1,
115
+ ):
116
+ super().__init__()
117
+ encoder_layer = nn.TransformerEncoderLayer(
118
+ d_model, n_head, d_hid, dropout, activation="gelu", batch_first=True
119
+ )
120
+ self.encoder = nn.TransformerEncoder(encoder_layer, n_layers)
121
+
122
+ def forward(self, src, src_mask=None, src_key_padding_mask=None):
123
+ return self.encoder(src, src_mask, src_key_padding_mask)
124
+
125
+ class DynamicGating(nn.Module):
126
+ def __init__(self, d_model: int, dropout: float = 0.1):
127
+ super().__init__()
128
+ self.gate_fc = nn.Linear(d_model * 2, d_model)
129
+ self.dropout = nn.Dropout(dropout)
130
+ self.layer_norm = nn.LayerNorm(d_model)
131
+
132
+ def forward(self, text_features, image_features):
133
+ if image_features is None:
134
+ return text_features
135
+
136
+ combined = torch.cat([text_features, image_features], dim=-1)
137
+ gate = torch.sigmoid(self.gate_fc(combined))
138
+ fused = gate * text_features + (1 - gate) * image_features
139
+ fused = self.layer_norm(self.dropout(fused))
140
+ return fused
141
+
142
+ class MultimodalDecoderLayer(nn.Module):
143
+ def __init__(self, d_model: int, n_head: int, d_hid: int, dropout: float = 0.1):
144
+ super().__init__()
145
+ self.self_attn = nn.MultiheadAttention(
146
+ d_model, n_head, dropout=dropout, batch_first=True
147
+ )
148
+ self.cross_attn_txt_image = nn.MultiheadAttention(
149
+ d_model, n_head, dropout=dropout, batch_first=True
150
+ )
151
+
152
+ self.norm1 = nn.LayerNorm(d_model)
153
+ self.norm2 = nn.LayerNorm(d_model)
154
+ self.norm3 = nn.LayerNorm(d_model)
155
+
156
+ self.dropout = nn.Dropout(dropout)
157
+
158
+ self.gate = DualStreamTransformer.DynamicGating(d_model, dropout)
159
+
160
+ self.ff = nn.Sequential(
161
+ nn.Linear(d_model, d_hid),
162
+ nn.GELU(),
163
+ nn.Dropout(dropout),
164
+ nn.Linear(d_hid, d_model),
165
+ nn.Dropout(dropout),
166
+ )
167
+
168
+ def forward(self, tgt, image_memory, tgt_mask=None, tgt_key_padding_mask=None):
169
+ tgt_norm = self.norm1(tgt)
170
+ self_attn_output, _ = self.self_attn(
171
+ tgt_norm,
172
+ tgt_norm,
173
+ tgt_norm,
174
+ key_padding_mask=tgt_key_padding_mask,
175
+ attn_mask=tgt_mask,
176
+ is_causal=True,
177
+ )
178
+
179
+ tgt = tgt + self.dropout(self_attn_output)
180
+
181
+ if image_memory is not None:
182
+ tgt_norm = self.norm2(tgt)
183
+ cross_attn_output, _ = self.cross_attn_txt_image(
184
+ tgt_norm, image_memory, image_memory
185
+ )
186
+ cross_attn_output = self.dropout(cross_attn_output)
187
+
188
+ fused = self.gate(tgt_norm, cross_attn_output)
189
+ tgt = tgt + fused
190
+
191
+ tgt_norm = self.norm3(tgt)
192
+ ff_output = self.ff(tgt_norm)
193
+ tgt = tgt + self.dropout(ff_output)
194
+
195
+ return tgt
196
+
197
+ class MultimodalDecoder(nn.Module):
198
+ def __init__(
199
+ self,
200
+ d_model: int,
201
+ n_head: int,
202
+ d_hid: int,
203
+ n_layers: int,
204
+ dropout: float = 0.1,
205
+ ):
206
+ super().__init__()
207
+ self.layers = nn.ModuleList(
208
+ [
209
+ DualStreamTransformer.MultimodalDecoderLayer(
210
+ d_model, n_head, d_hid, dropout
211
+ )
212
+ for _ in range(n_layers)
213
+ ]
214
+ )
215
+
216
+ def generate_square_subsequent_mask(self, size):
217
+ mask = torch.triu(torch.ones(size, size), diagonal=1).bool()
218
+ return mask
219
+
220
+ def forward(self, tgt, image_memory, tgt_mask, tgt_key_padding_mask=None):
221
+ output = tgt
222
+ for layer in self.layers:
223
+ output = layer(
224
+ output,
225
+ image_memory,
226
+ tgt_mask=tgt_mask,
227
+ tgt_key_padding_mask=tgt_key_padding_mask,
228
+ )
229
+ return output