| import torch | |
| import torch.nn as nn | |
| import math | |
| class DualStreamTransformer(nn.Module): | |
| def __init__( | |
| self, | |
| vocab_size: int, | |
| d_model: int = 768, | |
| n_head: int = 8, | |
| d_hid: int = 768, | |
| num_encoder_layers: int = 5, | |
| num_decoder_layers: int = 8, | |
| dino_dim: int = 768, | |
| dropout: float = 0.1, | |
| ): | |
| super().__init__() | |
| self.vocab_size = vocab_size | |
| self.d_model = d_model | |
| self.n_head = n_head | |
| self.d_hid = d_hid | |
| self.num_encoder_layers = num_encoder_layers | |
| self.num_decoder_layers = num_decoder_layers | |
| self.dino_dim = dino_dim | |
| self.dropout = dropout | |
| self.text_embedding = self.SimpleTextEmbedding(vocab_size, d_model) | |
| self.image_embedding = self.DinoImageEmbedding(dino_dim, d_model) | |
| self.image_encoder = self.Encoder( | |
| d_model, n_head, d_hid, num_encoder_layers, dropout | |
| ) | |
| self.decoder = self.MultimodalDecoder( | |
| d_model, n_head, d_hid, num_decoder_layers, dropout | |
| ) | |
| self.output_layer = nn.Linear(d_model, vocab_size) | |
| def forward( | |
| self, input_ids, dino_embedding=None, padding_mask=None, use_image: bool = False | |
| ): | |
| embedded = self.text_embedding(input_ids) | |
| if ( | |
| use_image | |
| and dino_embedding is not None | |
| and not torch.all(dino_embedding == 0) | |
| ): | |
| image_embedded = self.image_embedding(dino_embedding) | |
| image_encoded = self.image_encoder(image_embedded) | |
| else: | |
| image_encoded = None | |
| seq_len = embedded.size(1) | |
| tgt_mask = self.decoder.generate_square_subsequent_mask(seq_len).to( | |
| embedded.device | |
| ) | |
| decoder_output = self.decoder( | |
| tgt=embedded, | |
| image_memory=image_encoded, | |
| tgt_mask=tgt_mask, | |
| tgt_key_padding_mask=padding_mask, | |
| ) | |
| output = self.output_layer(decoder_output) | |
| return output | |
| class SimpleTextEmbedding(nn.Module): | |
| def __init__(self, vocab_size, d_model, max_len=128, dropout=0.1): | |
| super().__init__() | |
| self.token_embedding = nn.Embedding(vocab_size, d_model) | |
| self.position_embedding = nn.Embedding(max_len, d_model) | |
| self.layer_norm = nn.LayerNorm(d_model) | |
| self.dropout = nn.Dropout(p=dropout) | |
| self.d_model = d_model | |
| def forward(self, x): | |
| batch_size, seq_len = x.size() | |
| positions = ( | |
| torch.arange(seq_len, device=x.device) | |
| .unsqueeze(0) | |
| .expand(batch_size, seq_len) | |
| ) | |
| scale = math.sqrt(self.d_model) | |
| token_emb = self.token_embedding(x) * scale | |
| pos_emb = self.position_embedding(positions) | |
| embeddings = self.dropout(token_emb + pos_emb) | |
| return self.layer_norm(embeddings) | |
| class DinoImageEmbedding(nn.Module): | |
| def __init__(self, dino_dim, d_model): | |
| super().__init__() | |
| self.projection_layer = nn.Linear(dino_dim, d_model) | |
| def forward(self, x): | |
| return self.projection_layer(x.unsqueeze(1)) | |
| class Encoder(nn.Module): | |
| def __init__( | |
| self, | |
| d_model: int, | |
| n_head: int, | |
| d_hid: int, | |
| n_layers: int, | |
| dropout: float = 0.1, | |
| ): | |
| super().__init__() | |
| encoder_layer = nn.TransformerEncoderLayer( | |
| d_model, n_head, d_hid, dropout, activation="gelu", batch_first=True | |
| ) | |
| self.encoder = nn.TransformerEncoder(encoder_layer, n_layers) | |
| def forward(self, src, src_mask=None, src_key_padding_mask=None): | |
| return self.encoder(src, src_mask, src_key_padding_mask) | |
| class DynamicGating(nn.Module): | |
| def __init__(self, d_model: int, dropout: float = 0.1): | |
| super().__init__() | |
| self.gate_fc = nn.Linear(d_model * 2, d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| self.layer_norm = nn.LayerNorm(d_model) | |
| def forward(self, text_features, image_features): | |
| if image_features is None: | |
| return text_features | |
| combined = torch.cat([text_features, image_features], dim=-1) | |
| gate = torch.sigmoid(self.gate_fc(combined)) | |
| fused = gate * text_features + (1 - gate) * image_features | |
| fused = self.layer_norm(self.dropout(fused)) | |
| return fused | |
| class MultimodalDecoderLayer(nn.Module): | |
| def __init__(self, d_model: int, n_head: int, d_hid: int, dropout: float = 0.1): | |
| super().__init__() | |
| self.self_attn = nn.MultiheadAttention( | |
| d_model, n_head, dropout=dropout, batch_first=True | |
| ) | |
| self.cross_attn_txt_image = nn.MultiheadAttention( | |
| d_model, n_head, dropout=dropout, batch_first=True | |
| ) | |
| self.norm1 = nn.LayerNorm(d_model) | |
| self.norm2 = nn.LayerNorm(d_model) | |
| self.norm3 = nn.LayerNorm(d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| self.gate = DualStreamTransformer.DynamicGating(d_model, dropout) | |
| self.ff = nn.Sequential( | |
| nn.Linear(d_model, d_hid), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(d_hid, d_model), | |
| nn.Dropout(dropout), | |
| ) | |
| def forward(self, tgt, image_memory, tgt_mask=None, tgt_key_padding_mask=None): | |
| tgt_norm = self.norm1(tgt) | |
| self_attn_output, _ = self.self_attn( | |
| tgt_norm, | |
| tgt_norm, | |
| tgt_norm, | |
| key_padding_mask=tgt_key_padding_mask, | |
| attn_mask=tgt_mask, | |
| is_causal=True, | |
| ) | |
| tgt = tgt + self.dropout(self_attn_output) | |
| if image_memory is not None: | |
| tgt_norm = self.norm2(tgt) | |
| cross_attn_output, _ = self.cross_attn_txt_image( | |
| tgt_norm, image_memory, image_memory | |
| ) | |
| cross_attn_output = self.dropout(cross_attn_output) | |
| fused = self.gate(tgt_norm, cross_attn_output) | |
| tgt = tgt + fused | |
| tgt_norm = self.norm3(tgt) | |
| ff_output = self.ff(tgt_norm) | |
| tgt = tgt + self.dropout(ff_output) | |
| return tgt | |
| class MultimodalDecoder(nn.Module): | |
| def __init__( | |
| self, | |
| d_model: int, | |
| n_head: int, | |
| d_hid: int, | |
| n_layers: int, | |
| dropout: float = 0.1, | |
| ): | |
| super().__init__() | |
| self.layers = nn.ModuleList( | |
| [ | |
| DualStreamTransformer.MultimodalDecoderLayer( | |
| d_model, n_head, d_hid, dropout | |
| ) | |
| for _ in range(n_layers) | |
| ] | |
| ) | |
| def generate_square_subsequent_mask(self, size): | |
| mask = torch.triu(torch.ones(size, size), diagonal=1).bool() | |
| return mask | |
| def forward(self, tgt, image_memory, tgt_mask, tgt_key_padding_mask=None): | |
| output = tgt | |
| for layer in self.layers: | |
| output = layer( | |
| output, | |
| image_memory, | |
| tgt_mask=tgt_mask, | |
| tgt_key_padding_mask=tgt_key_padding_mask, | |
| ) | |
| return output | |