pgps-demo / model /encoder /transformer.py
asdfasdfdsafdsa's picture
Initial upload of PGPS demo with all dependencies
383bfb8 verified
import torch
import torch.nn as nn
from utils.utils import sequence_mask
import math
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000, dropout=0.1):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
def forward(self, x):
"""
x: [B, max_len, d_model]
pe: [1, max_len, d_model]
"""
x = x + self.pe[:, : x.size(1)].requires_grad_(False)
return self.dropout(x)
class LearnedPositionEncoding(nn.Module):
def __init__(self, d_model, max_len = 20):
super(LearnedPositionEncoding, self).__init__()
self.embedding = nn.Embedding(max_len, d_model)
def forward(self, x, var_pos):
"""
x: [B, max_len, d_model]
var_pos: [B, var_len]
"""
loc_mat = torch.zeros(x.size(0), x.size(1), dtype=torch.int64).cuda()
pos_id = torch.arange(1, var_pos.size(1)+1).repeat(var_pos.size(0), 1).cuda()
pos_id[var_pos==var_pos.min()] = 0
loc_mat.scatter_(1, var_pos, pos_id)
x = x + self.embedding(loc_mat)
return x
class TransformerEncoder(nn.Module):
def __init__(self, d_model=256, nhead=8, num_encoder_layers=6, dim_feedforward=1024, dropout=0.2):
super(TransformerEncoder,self).__init__()
encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout)
encoder_norm = nn.LayerNorm(d_model)
self.encoder = nn.TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
self.position = PositionalEncoding(d_model=d_model)
self._reset_parameters()
self.d_model = d_model
self.nhead = nhead
def _reset_parameters(self):
"""
Initiate parameters in the transformer model.
"""
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, len_src, emb_src):
# mask
src_key_padding_mask = ~sequence_mask(len_src)
# position encoding
emb_src = self.position(emb_src)
# encoder
memory = self.encoder(emb_src.permute(1,0,2), src_key_padding_mask=src_key_padding_mask)
return memory.permute(1,0,2)