File size: 7,297 Bytes
15063d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import torch
import torch.nn
import torch.nn.functional as F
from .multi_head_attention import MultiHeadAttention, AttentionMask
from typing import Optional, Callable, Dict
from dataclasses import dataclass
# This file is based on PyTorch's internal implementation
ActivationFunction = Callable[[torch.Tensor], torch.Tensor]
class TransformerEncoderLayer(torch.nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation: ActivationFunction = F.relu,
attention_dropout=0):
super(TransformerEncoderLayer, self).__init__()
self.self_attn = MultiHeadAttention(d_model, nhead, dropout=attention_dropout)
self.linear1 = torch.nn.Linear(d_model, dim_feedforward)
self.dropout = torch.nn.Dropout(dropout)
self.linear2 = torch.nn.Linear(dim_feedforward, d_model)
self.norm1 = torch.nn.LayerNorm(d_model)
self.norm2 = torch.nn.LayerNorm(d_model)
self.dropout1 = torch.nn.Dropout(dropout)
self.dropout2 = torch.nn.Dropout(dropout)
self.activation = activation
self.reset_parameters()
def forward(self, src: torch.Tensor, mask: Optional[AttentionMask] = None) -> torch.Tensor:
src2 = self.self_attn(src, src, mask)
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src
def reset_parameters(self):
torch.nn.init.xavier_uniform_(self.linear1.weight, gain=torch.nn.init.calculate_gain('relu') \
if self.activation is F.relu else 1.0)
torch.nn.init.xavier_uniform_(self.linear2.weight)
class TransformerDecoderLayer(torch.nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation: ActivationFunction = F.relu,
attention_dropout=0):
super(TransformerDecoderLayer, self).__init__()
self.self_attn = MultiHeadAttention(d_model, nhead, dropout=attention_dropout)
self.multihead_attn = MultiHeadAttention(d_model, nhead, dropout=attention_dropout)
# Implementation of Feedforward model
self.linear1 = torch.nn.Linear(d_model, dim_feedforward)
self.dropout = torch.nn.Dropout(dropout)
self.linear2 = torch.nn.Linear(dim_feedforward, d_model)
self.norm1 = torch.nn.LayerNorm(d_model)
self.norm2 = torch.nn.LayerNorm(d_model)
self.norm3 = torch.nn.LayerNorm(d_model)
self.dropout1 = torch.nn.Dropout(dropout)
self.dropout2 = torch.nn.Dropout(dropout)
self.dropout3 = torch.nn.Dropout(dropout)
self.activation = activation
self.reset_parameters()
def forward(self, tgt: torch.Tensor, memory: torch.Tensor, tgt_mask: Optional[torch.Tensor] = None,
memory_key_padding_mask: Optional[torch.Tensor] = None,
full_target: Optional[torch.Tensor] = None, pos_offset: int = 0) -> torch.Tensor:
assert pos_offset == 0 or tgt_mask is None
tgt2 = self.self_attn(tgt, tgt if full_target is None else full_target, mask=AttentionMask(None, tgt_mask))
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
tgt2 = self.multihead_attn(tgt, memory, mask=AttentionMask(memory_key_padding_mask, None))
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
def reset_parameters(self):
torch.nn.init.xavier_uniform_(self.linear1.weight, gain=torch.nn.init.calculate_gain('relu') \
if self.activation is F.relu else 1.0)
torch.nn.init.xavier_uniform_(self.linear2.weight)
class TransformerDecoderBase(torch.nn.Module):
@dataclass
class State:
step: int
state: Dict[int, torch.Tensor]
def __init__(self, d_model: int):
super().__init__()
self.d_model = d_model
def create_state(self, batch_size: int, max_length: int, device: torch.device) -> State:
return self.State(0, {i: torch.empty([batch_size, max_length, self.d_model], device=device)
for i in range(len(self.layers))})
def one_step_forward(self, state: State, data: torch.Tensor, *args, **kwargs):
assert data.shape[1] == 1, f"For one-step forward should have one timesteps, but shape is {data.shape}"
assert state.step < state.state[0].shape[1]
for i, l in enumerate(self.layers):
state.state[i][:, state.step:state.step + 1] = data
data = l(data, *args, **kwargs, full_target=state.state[i][:, :state.step + 1],
pos_offset=state.step)
state.step += 1
return data
class TransformerEncoder(torch.nn.Module):
def __init__(self, layer, n_layers: int, *args, **kwargs):
super().__init__()
self.layers = torch.nn.ModuleList([layer(*args, **kwargs) for _ in range(n_layers)])
def forward(self, data: torch.Tensor, *args, **kwargs):
for l in self.layers:
data = l(data, *args, **kwargs)
return data
class TransformerDecoder(TransformerDecoderBase):
def __init__(self, layer, n_layers: int, d_model: int, *args, **kwargs):
super().__init__(d_model)
self.layers = torch.nn.ModuleList([layer(d_model, *args, **kwargs) for _ in range(n_layers)])
def forward(self, data: torch.Tensor, *args, **kwargs):
for l in self.layers:
data = l(data, *args, **kwargs)
return data
def TransformerEncoderWithLayer(layer = TransformerEncoder):
return lambda *args, **kwargs: TransformerEncoder(layer, *args, **kwargs)
def TransformerDecoderWithLayer(layer = TransformerDecoder):
return lambda *args, **kwargs: TransformerDecoder(layer, *args, **kwargs)
class Transformer(torch.nn.Module):
def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6,
num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1,
activation: ActivationFunction = F.relu, encoder_layer=TransformerEncoderWithLayer(),
decoder_layer=TransformerDecoderWithLayer(), attention_dropout: float = 0):
super().__init__()
self.encoder = encoder_layer(num_encoder_layers, d_model, nhead, dim_feedforward,
dropout, activation, attention_dropout)
self.decoder = decoder_layer(num_decoder_layers, d_model, nhead, dim_feedforward,
dropout, activation, attention_dropout)
def forward(self, src: torch.Tensor, tgt: torch.Tensor, tgt_mask: Optional[torch.Tensor] = None,
src_mask: Optional[AttentionMask] = None):
memory = self.encoder(src, src_mask)
return self.decoder(tgt, memory, tgt_mask, src_mask.src_length_mask if src_mask is not None else None)
@staticmethod
def generate_square_subsequent_mask(sz: int, device: torch.device) -> torch.Tensor:
return torch.triu(torch.ones(sz, sz, dtype=torch.bool, device=device), diagonal=1)
|