| import torch | |
| from torch import nn | |
| class AdaptiveEmbedding(nn.Module): | |
| def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, sample_softmax=False): | |
| super().__init__() | |
| self.n_token = n_token | |
| self.d_embed = d_embed | |
| self.cutoffs = cutoffs + [n_token] | |
| self.div_val = div_val | |
| self.d_proj = d_proj | |
| self.emb_scale = d_proj**0.5 | |
| self.cutoff_ends = [0] + self.cutoffs | |
| self.emb_layers = nn.ModuleList() | |
| self.emb_projs = nn.ParameterList() | |
| if div_val == 1: | |
| self.emb_layers.append(nn.Embedding(n_token, d_embed, sparse=sample_softmax > 0)) | |
| if d_proj != d_embed: | |
| self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_embed))) | |
| else: | |
| for i in range(len(self.cutoffs)): | |
| l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] | |
| d_emb_i = d_embed // (div_val**i) | |
| self.emb_layers.append(nn.Embedding(r_idx - l_idx, d_emb_i)) | |
| self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_emb_i))) | |
| def forward(self, inp): | |
| if self.div_val == 1: | |
| embed = self.emb_layers[0](inp) | |
| if self.d_proj != self.d_embed: | |
| embed = nn.functional.linear(embed, self.emb_projs[0]) | |
| else: | |
| param = next(self.parameters()) | |
| inp_flat = inp.view(-1) | |
| emb_flat = torch.zeros([inp_flat.size(0), self.d_proj], dtype=param.dtype, device=param.device) | |
| for i in range(len(self.cutoffs)): | |
| l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] | |
| mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx) | |
| indices_i = mask_i.nonzero().squeeze() | |
| if indices_i.numel() == 0: | |
| continue | |
| inp_i = inp_flat.index_select(0, indices_i) - l_idx | |
| emb_i = self.emb_layers[i](inp_i) | |
| emb_i = nn.functional.linear(emb_i, self.emb_projs[i]) | |
| emb_flat.index_copy_(0, indices_i, emb_i) | |
| embed_shape = inp.size() + (self.d_proj,) | |
| embed = emb_flat.view(embed_shape) | |
| embed.mul_(self.emb_scale) | |
| return embed | |
| class PositionalEmbeddingAux(nn.Module): | |
| def __init__(self, demb): | |
| super().__init__() | |
| self.demb = demb | |
| inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb)) | |
| self.register_buffer("inv_freq", inv_freq) | |
| def forward(self, pos_seq, bsz=None): | |
| sinusoid_inp = torch.outer(pos_seq, self.inv_freq) | |
| pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) | |
| if bsz is not None: | |
| return pos_emb[:, None, :].expand(-1, bsz, -1) | |
| else: | |
| return pos_emb[:, None, :] | |
| class PositionalEmbedding(PositionalEmbeddingAux): | |
| def forward(self, pos_seq, bsz=None): | |
| return super().forward(pos_seq.squeeze(0), bsz=bsz).squeeze(1) | |