import torch import torch.nn as nn import math class DualStreamTransformer(nn.Module): def __init__( self, vocab_size: int, d_model: int = 768, n_head: int = 8, d_hid: int = 768, num_encoder_layers: int = 5, num_decoder_layers: int = 8, dino_dim: int = 768, dropout: float = 0.1, ): super().__init__() self.vocab_size = vocab_size self.d_model = d_model self.n_head = n_head self.d_hid = d_hid self.num_encoder_layers = num_encoder_layers self.num_decoder_layers = num_decoder_layers self.dino_dim = dino_dim self.dropout = dropout self.text_embedding = self.SimpleTextEmbedding(vocab_size, d_model) self.image_embedding = self.DinoImageEmbedding(dino_dim, d_model) self.image_encoder = self.Encoder( d_model, n_head, d_hid, num_encoder_layers, dropout ) self.decoder = self.MultimodalDecoder( d_model, n_head, d_hid, num_decoder_layers, dropout ) self.output_layer = nn.Linear(d_model, vocab_size) def forward( self, input_ids, dino_embedding=None, padding_mask=None, use_image: bool = False ): embedded = self.text_embedding(input_ids) if ( use_image and dino_embedding is not None and not torch.all(dino_embedding == 0) ): image_embedded = self.image_embedding(dino_embedding) image_encoded = self.image_encoder(image_embedded) else: image_encoded = None seq_len = embedded.size(1) tgt_mask = self.decoder.generate_square_subsequent_mask(seq_len).to( embedded.device ) decoder_output = self.decoder( tgt=embedded, image_memory=image_encoded, tgt_mask=tgt_mask, tgt_key_padding_mask=padding_mask, ) output = self.output_layer(decoder_output) return output class SimpleTextEmbedding(nn.Module): def __init__(self, vocab_size, d_model, max_len=128, dropout=0.1): super().__init__() self.token_embedding = nn.Embedding(vocab_size, d_model) self.position_embedding = nn.Embedding(max_len, d_model) self.layer_norm = nn.LayerNorm(d_model) self.dropout = nn.Dropout(p=dropout) self.d_model = d_model def forward(self, x): batch_size, seq_len = x.size() positions = ( torch.arange(seq_len, device=x.device) .unsqueeze(0) .expand(batch_size, seq_len) ) scale = math.sqrt(self.d_model) token_emb = self.token_embedding(x) * scale pos_emb = self.position_embedding(positions) embeddings = self.dropout(token_emb + pos_emb) return self.layer_norm(embeddings) class DinoImageEmbedding(nn.Module): def __init__(self, dino_dim, d_model): super().__init__() self.projection_layer = nn.Linear(dino_dim, d_model) def forward(self, x): return self.projection_layer(x.unsqueeze(1)) class Encoder(nn.Module): def __init__( self, d_model: int, n_head: int, d_hid: int, n_layers: int, dropout: float = 0.1, ): super().__init__() encoder_layer = nn.TransformerEncoderLayer( d_model, n_head, d_hid, dropout, activation="gelu", batch_first=True ) self.encoder = nn.TransformerEncoder(encoder_layer, n_layers) def forward(self, src, src_mask=None, src_key_padding_mask=None): return self.encoder(src, src_mask, src_key_padding_mask) class DynamicGating(nn.Module): def __init__(self, d_model: int, dropout: float = 0.1): super().__init__() self.gate_fc = nn.Linear(d_model * 2, d_model) self.dropout = nn.Dropout(dropout) self.layer_norm = nn.LayerNorm(d_model) def forward(self, text_features, image_features): if image_features is None: return text_features combined = torch.cat([text_features, image_features], dim=-1) gate = torch.sigmoid(self.gate_fc(combined)) fused = gate * text_features + (1 - gate) * image_features fused = self.layer_norm(self.dropout(fused)) return fused class MultimodalDecoderLayer(nn.Module): def __init__(self, d_model: int, n_head: int, d_hid: int, dropout: float = 0.1): super().__init__() self.self_attn = nn.MultiheadAttention( d_model, n_head, dropout=dropout, batch_first=True ) self.cross_attn_txt_image = nn.MultiheadAttention( d_model, n_head, dropout=dropout, batch_first=True ) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) self.gate = DualStreamTransformer.DynamicGating(d_model, dropout) self.ff = nn.Sequential( nn.Linear(d_model, d_hid), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_hid, d_model), nn.Dropout(dropout), ) def forward(self, tgt, image_memory, tgt_mask=None, tgt_key_padding_mask=None): tgt_norm = self.norm1(tgt) self_attn_output, _ = self.self_attn( tgt_norm, tgt_norm, tgt_norm, key_padding_mask=tgt_key_padding_mask, attn_mask=tgt_mask, is_causal=True, ) tgt = tgt + self.dropout(self_attn_output) if image_memory is not None: tgt_norm = self.norm2(tgt) cross_attn_output, _ = self.cross_attn_txt_image( tgt_norm, image_memory, image_memory ) cross_attn_output = self.dropout(cross_attn_output) fused = self.gate(tgt_norm, cross_attn_output) tgt = tgt + fused tgt_norm = self.norm3(tgt) ff_output = self.ff(tgt_norm) tgt = tgt + self.dropout(ff_output) return tgt class MultimodalDecoder(nn.Module): def __init__( self, d_model: int, n_head: int, d_hid: int, n_layers: int, dropout: float = 0.1, ): super().__init__() self.layers = nn.ModuleList( [ DualStreamTransformer.MultimodalDecoderLayer( d_model, n_head, d_hid, dropout ) for _ in range(n_layers) ] ) def generate_square_subsequent_mask(self, size): mask = torch.triu(torch.ones(size, size), diagonal=1).bool() return mask def forward(self, tgt, image_memory, tgt_mask, tgt_key_padding_mask=None): output = tgt for layer in self.layers: output = layer( output, image_memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask, ) return output