#NOTE: Hanning, Transformer Decoder import torch import numpy as np import copy import torch.nn.functional as F def _get_clones(module, N): return torch.nn.ModuleList([copy.deepcopy(module) for i in range(N)]) def _get_activation_fn(activation): if activation == "relu": return F.relu if activation == "gelu": return F.gelu if activation == "glu": return F.glu raise RuntimeError(f"activation should be relu/gelu, not {activation}.") class TransformerDecoder(torch.nn.Module): def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): super().__init__() self.layers = _get_clones(decoder_layer, num_layers) self.num_layers = num_layers self.norm = norm self.return_intermediate = return_intermediate def forward( self, tgt, memory, text_memory, tgt_mask = None, memory_mask = None, text_memory_key_padding_mask = None, tgt_key_padding_mask = None, memory_key_padding_mask = None, pos = None, query_pos = None, ): output = tgt intermediate = [] for layer in self.layers: output, memory = layer( output, memory, text_memory=text_memory, tgt_mask=tgt_mask, memory_mask=memory_mask, text_memory_key_padding_mask=text_memory_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask, pos=pos, query_pos=query_pos, ) if self.return_intermediate: intermediate.append(self.norm(output)) if self.norm is not None: output = self.norm(output) memory = self.norm(memory) if self.return_intermediate: intermediate.pop() intermediate.append(output) if self.return_intermediate: return torch.stack(intermediate) return output, memory class TransformerDecoderLayer(torch.nn.Module): def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False): super().__init__() self.self_attn_text = torch.nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.cross_attn_text = torch.nn.MultiheadAttention(d_model, nhead, dropout=dropout) # Implementation of Feedforward model self.linear1 = torch.nn.Linear(d_model, dim_feedforward) self.dropout = torch.nn.Dropout(dropout) self.linear2 = torch.nn.Linear(dim_feedforward, d_model) self.norm1 = torch.nn.LayerNorm(d_model) # self.norm2 = nn.LayerNorm(d_model) self.norm3 = torch.nn.LayerNorm(d_model) self.norm4 = torch.nn.LayerNorm(d_model) self.dropout1 = torch.nn.Dropout(dropout) # self.dropout2 = nn.Dropout(dropout) self.dropout3 = torch.nn.Dropout(dropout) self.dropout4 = torch.nn.Dropout(dropout) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before def with_pos_embed(self, tensor, pos): return tensor if pos is None else tensor + pos # For now, trying one version where its self attn -> cross attn text -> cross attn image -> FFN def forward_post( self, tgt, memory, text_memory, tgt_mask = None, memory_mask = None, text_memory_key_padding_mask = None, tgt_key_padding_mask = None, memory_key_padding_mask = None, pos = None, query_pos = None, ): #NOTE: memory 2 is None, need to figure out q_text = self.with_pos_embed(memory,query_pos) k_text = self.with_pos_embed(memory,query_pos) memory2 = self.self_attn_text(q_text,k_text,value=memory,attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] memory = memory + self.dropout1(memory2) memory = self.norm1(memory) # Cross attention to image memory2 = self.cross_attn_text( query=self.with_pos_embed(memory, query_pos), key=self.with_pos_embed(tgt, pos), value=tgt, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask, )[0] memory = memory + self.dropout3(memory2) memory = self.norm3(memory) # FFN memory2 = self.linear2(self.dropout(self.activation(self.linear1(memory)))) memory = memory + self.dropout4(memory2) memory = self.norm4(memory) return tgt, memory def forward( self, tgt, memory, text_memory, tgt_mask = None, memory_mask = None, text_memory_key_padding_mask = None, tgt_key_padding_mask = None, memory_key_padding_mask = None, pos = None, query_pos = None, ): return self.forward_post( tgt, memory, text_memory, tgt_mask, memory_mask, text_memory_key_padding_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos, )