| import os |
| import sys |
| import glob |
| import h5py |
| import copy |
| import math |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| |
|
|
| def clones(module, N): |
| return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) |
|
|
| def attention(query, key, value, mask=None, dropout=None): |
| d_k = query.size(-1) |
| scores = torch.matmul(query, key.transpose(-2, -1).contiguous()) / math.sqrt(d_k) |
| if mask is not None: |
| scores = scores.masked_fill(mask == 0, -1e9) |
| p_attn = F.softmax(scores, dim=-1) |
| return torch.matmul(p_attn, value), p_attn |
|
|
| def nearest_neighbor(src, dst): |
| inner = -2 * torch.matmul(src.transpose(1, 0).contiguous(), dst) |
| distances = -torch.sum(src ** 2, dim=0, keepdim=True).transpose(1, 0).contiguous() - inner - torch.sum(dst ** 2, |
| dim=0, |
| keepdim=True) |
| distances, indices = distances.topk(k=1, dim=-1) |
| return distances, indices |
|
|
|
|
| class EncoderDecoder(nn.Module): |
| """ |
| A standard Encoder-Decoder architecture. Base for this and many |
| other models. |
| """ |
|
|
| def __init__(self, encoder, decoder, src_embed, tgt_embed, generator): |
| super(EncoderDecoder, self).__init__() |
| self.encoder = encoder |
| self.decoder = decoder |
| self.src_embed = src_embed |
| self.tgt_embed = tgt_embed |
| self.generator = generator |
|
|
| def forward(self, src, tgt, src_mask, tgt_mask): |
| "Take in and process masked src and target sequences." |
| return self.decode(self.encode(src, src_mask), src_mask, |
| tgt, tgt_mask) |
|
|
| def encode(self, src, src_mask): |
| return self.encoder(self.src_embed(src), src_mask) |
|
|
| def decode(self, memory, src_mask, tgt, tgt_mask): |
| return self.generator(self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)) |
|
|
|
|
| class Generator(nn.Module): |
| def __init__(self, emb_dims): |
| super(Generator, self).__init__() |
| self.nn = nn.Sequential(nn.Linear(emb_dims, emb_dims // 2), |
| nn.BatchNorm1d(emb_dims // 2), |
| nn.ReLU(), |
| nn.Linear(emb_dims // 2, emb_dims // 4), |
| nn.BatchNorm1d(emb_dims // 4), |
| nn.ReLU(), |
| nn.Linear(emb_dims // 4, emb_dims // 8), |
| nn.BatchNorm1d(emb_dims // 8), |
| nn.ReLU()) |
| self.proj_rot = nn.Linear(emb_dims // 8, 4) |
| self.proj_trans = nn.Linear(emb_dims // 8, 3) |
|
|
| def forward(self, x): |
| x = self.nn(x.max(dim=1)[0]) |
| rotation = self.proj_rot(x) |
| translation = self.proj_trans(x) |
| rotation = rotation / torch.norm(rotation, p=2, dim=1, keepdim=True) |
| return rotation, translation |
|
|
|
|
| class Encoder(nn.Module): |
| def __init__(self, layer, N): |
| super(Encoder, self).__init__() |
| self.layers = clones(layer, N) |
| self.norm = LayerNorm(layer.size) |
|
|
| def forward(self, x, mask): |
| for layer in self.layers: |
| x = layer(x, mask) |
| return self.norm(x) |
|
|
|
|
| class Decoder(nn.Module): |
| "Generic N layer decoder with masking." |
|
|
| def __init__(self, layer, N): |
| super(Decoder, self).__init__() |
| self.layers = clones(layer, N) |
| self.norm = LayerNorm(layer.size) |
|
|
| def forward(self, x, memory, src_mask, tgt_mask): |
| for layer in self.layers: |
| x = layer(x, memory, src_mask, tgt_mask) |
| return self.norm(x) |
|
|
|
|
| class LayerNorm(nn.Module): |
| def __init__(self, features, eps=1e-6): |
| super(LayerNorm, self).__init__() |
| self.a_2 = nn.Parameter(torch.ones(features)) |
| self.b_2 = nn.Parameter(torch.zeros(features)) |
| self.eps = eps |
|
|
| def forward(self, x): |
| mean = x.mean(-1, keepdim=True) |
| std = x.std(-1, keepdim=True) |
| return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 |
|
|
|
|
| class SublayerConnection(nn.Module): |
| def __init__(self, size, dropout=None): |
| super(SublayerConnection, self).__init__() |
| self.norm = LayerNorm(size) |
|
|
| def forward(self, x, sublayer): |
| return x + sublayer(self.norm(x)) |
|
|
|
|
| class EncoderLayer(nn.Module): |
| def __init__(self, size, self_attn, feed_forward, dropout): |
| super(EncoderLayer, self).__init__() |
| self.self_attn = self_attn |
| self.feed_forward = feed_forward |
| self.sublayer = clones(SublayerConnection(size, dropout), 2) |
| self.size = size |
|
|
| def forward(self, x, mask): |
| x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) |
| return self.sublayer[1](x, self.feed_forward) |
|
|
|
|
| class DecoderLayer(nn.Module): |
| "Decoder is made of self-attn, src-attn, and feed forward (defined below)" |
|
|
| def __init__(self, size, self_attn, src_attn, feed_forward, dropout): |
| super(DecoderLayer, self).__init__() |
| self.size = size |
| self.self_attn = self_attn |
| self.src_attn = src_attn |
| self.feed_forward = feed_forward |
| self.sublayer = clones(SublayerConnection(size, dropout), 3) |
|
|
| def forward(self, x, memory, src_mask, tgt_mask): |
| "Follow Figure 1 (right) for connections." |
| m = memory |
| x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) |
| x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask)) |
| return self.sublayer[2](x, self.feed_forward) |
|
|
|
|
| class MultiHeadedAttention(nn.Module): |
| def __init__(self, h, d_model, dropout=0.1): |
| "Take in model size and number of heads." |
| super(MultiHeadedAttention, self).__init__() |
| assert d_model % h == 0 |
| |
| self.d_k = d_model // h |
| self.h = h |
| self.linears = clones(nn.Linear(d_model, d_model), 4) |
| self.attn = None |
| self.dropout = None |
|
|
| def forward(self, query, key, value, mask=None): |
| "Implements Figure 2" |
| if mask is not None: |
| |
| mask = mask.unsqueeze(1) |
| nbatches = query.size(0) |
|
|
| |
| query, key, value = \ |
| [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2).contiguous() |
| for l, x in zip(self.linears, (query, key, value))] |
|
|
| |
| x, self.attn = attention(query, key, value, mask=mask, |
| dropout=self.dropout) |
|
|
| |
| x = x.transpose(1, 2).contiguous() \ |
| .view(nbatches, -1, self.h * self.d_k) |
| return self.linears[-1](x) |
|
|
|
|
| class PositionwiseFeedForward(nn.Module): |
| "Implements FFN equation." |
|
|
| def __init__(self, d_model, d_ff, dropout=0.1): |
| super(PositionwiseFeedForward, self).__init__() |
| self.w_1 = nn.Linear(d_model, d_ff) |
| self.norm = nn.Sequential() |
| self.w_2 = nn.Linear(d_ff, d_model) |
| self.dropout = None |
|
|
| def forward(self, x): |
| return self.w_2(self.norm(F.relu(self.w_1(x)).transpose(2, 1).contiguous()).transpose(2, 1).contiguous()) |
|
|
|
|
| class Identity(nn.Module): |
| def __init__(self): |
| super(Identity, self).__init__() |
|
|
| def forward(self, *input): |
| return input |
|
|
|
|
| class Transformer(nn.Module): |
| def __init__(self, emb_dims, n_blocks, dropout, ff_dims, n_heads): |
| super(Transformer, self).__init__() |
| self.emb_dims = emb_dims |
| self.N = n_blocks |
| self.dropout = dropout |
| self.ff_dims = ff_dims |
| self.n_heads = n_heads |
| c = copy.deepcopy |
| attn = MultiHeadedAttention(self.n_heads, self.emb_dims) |
| ff = PositionwiseFeedForward(self.emb_dims, self.ff_dims, self.dropout) |
| self.model = EncoderDecoder(Encoder(EncoderLayer(self.emb_dims, c(attn), c(ff), self.dropout), self.N), |
| Decoder(DecoderLayer(self.emb_dims, c(attn), c(attn), c(ff), self.dropout), self.N), |
| nn.Sequential(), |
| nn.Sequential(), |
| nn.Sequential()) |
|
|
| def forward(self, *input): |
| src = input[0] |
| tgt = input[1] |
| src = src.transpose(2, 1).contiguous() |
| tgt = tgt.transpose(2, 1).contiguous() |
| tgt_embedding = self.model(src, tgt, None, None).transpose(2, 1).contiguous() |
| src_embedding = self.model(tgt, src, None, None).transpose(2, 1).contiguous() |
| return src_embedding, tgt_embedding |