MRaCL / ASDA /model /transformer.py
dianecy's picture
Upload folder using huggingface_hub
c187b4b verified
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
DETR Transformer class.
Copy-paste from torch.nn.Transformer with modifications:
* positional encodings are passed in MHattention
* extra LN at the end of encoder is removed
* decoder returns a stack of activations from all decoding layers
"""
import copy
from typing import Optional
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from .position_encoding import *
class lang_tf_enc(nn.Module):
def __init__(self, input_1, input_2, hidden_dim, head_num, dropout=0.1):
super(lang_tf_enc, self).__init__()
self.pos_embedding_1 = PositionEmbeddingSine(input_2, normalize=True)
self.pos_embedding_2 = PositionEmbeddingSine(input_1, normalize=True)
self.dense_q = nn.Linear(input_1, hidden_dim)
self.dense_k = nn.Linear(input_2, hidden_dim)
self.dense_v = nn.Linear(input_2, hidden_dim)
self.self_attn = nn.MultiheadAttention(hidden_dim, head_num, dropout=dropout)
self.forward_dim = 2048
self.norm1 = nn.LayerNorm(hidden_dim)
self.norm2 = nn.LayerNorm(hidden_dim)
self.linear1 = nn.Linear(hidden_dim, self.forward_dim)
self.linear2 = nn.Linear(self.forward_dim, hidden_dim)
self.activation = _get_activation("relu")
self.dropout = nn.Dropout(dropout)
# @get_local("weights")
def forward(self, vision_input, lang_input):
decoder_embed_lang = lang_input
decoder_embed_vis = vision_input
q_inp = F.relu(self.dense_q(decoder_embed_vis).permute(1, 0, 2))
k_inp = F.relu(self.dense_k(decoder_embed_lang).permute(1, 0, 2))
v_inp = F.relu(self.dense_v(decoder_embed_lang).permute(1, 0, 2))
lang_input = lang_input.permute(1, 0, 2)
decoded_layer, weights = self.self_attn(q_inp, k_inp, v_inp)
decoded_layer = decoded_layer.permute(1, 0, 2)
add_layer = decoded_layer + vision_input
add_layer = self.norm1(add_layer)
add_layer2 = self.linear2(self.dropout(self.activation(self.linear1(add_layer))))
add_layer = add_layer + self.dropout(add_layer2)
add_layer = self.norm2(add_layer)
return add_layer
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def _get_activation(activation):
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(F"activation shuld be relu/gelu, not {activation}.")
class TransformerEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers, norm=None):
super().__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(self, src, pos: Optional[Tensor] = None):
output = src
for layer in self.layers:
output = layer(output, pos=pos)
if self.norm is not None:
output = self.norm(output)
return output
class TransformerDecoder(nn.Module):
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
self.return_intermediate = return_intermediate
def forward(self, tgt, memory, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None):
output = tgt
intermediate = []
for layer in self.layers:
output = layer(output, memory, pos=pos, query_pos=query_pos)
if self.return_intermediate:
intermediate.append(self.norm(output))
if self.norm is not None:
output = self.norm(output)
if self.return_intermediate:
intermediate.pop()
intermediate.append(output)
if self.return_intermediate:
return torch.stack(intermediate)
return output
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
activation="relu", normalize_before=False):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
# @get_local("weights")
def forward_post(self, src, pos: Optional[Tensor] = None):
q = k = self.with_pos_embed(src, pos)
src2, weights = self.self_attn(q, k, value=src, need_weights=False)
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src
def forward_pre(self, src, pos: Optional[Tensor] = None):
src2 = self.norm1(src)
q = k = self.with_pos_embed(src2, pos)
src2, weights = self.self_attn(q, k, value=src2)
src = src + self.dropout1(src2)
src2 = self.norm2(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
src = src + self.dropout2(src2)
return src
def forward(self, src, pos: Optional[Tensor] = None):
if self.normalize_before:
return self.forward_pre(src, pos)
return self.forward_post(src, pos)
class TransformerDecoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
activation="relu", normalize_before=False):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self, tgt, memory, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None):
q = k = self.with_pos_embed(tgt, query_pos)
tgt2, weights = self.self_attn(q, k, value=tgt)
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
tgt2, weights = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory)
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
def forward_pre(self, tgt, memory, pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2, weights = self.self_attn(q, k, value=tgt2)
tgt = tgt + self.dropout1(tgt2)
tgt2 = self.norm2(tgt)
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory)
tgt = tgt + self.dropout2(tgt2)
tgt2 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout3(tgt2)
return tgt
def forward(self, tgt, memory, pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
if self.normalize_before:
return self.forward_pre(tgt, memory, pos, query_pos)
return self.forward_post(tgt, memory, pos, query_pos)
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")