Spaces:
Sleeping
Sleeping
| from math import pi, log | |
| from functools import wraps | |
| from multiprocessing import context | |
| from textwrap import indent | |
| import models.util_funcs as util_funcs | |
| import math, copy | |
| import numpy as np | |
| import torch | |
| from torch import nn, einsum | |
| import torch.nn.functional as F | |
| from einops import rearrange, repeat | |
| from einops.layers.torch import Reduce | |
| import pdb | |
| from einops.layers.torch import Rearrange | |
| from options import get_parser_main_model | |
| opts = get_parser_main_model().parse_args() | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, d_model, dropout=0.1, max_len=5000): | |
| super(PositionalEncoding, self).__init__() | |
| self.dropout = nn.Dropout(p=dropout) | |
| pe = torch.zeros(max_len, d_model) | |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| pe = pe.unsqueeze(0).transpose(0, 1) | |
| self.register_buffer('pe', pe) | |
| def forward(self, x): | |
| """ | |
| :param x: [x_len, batch_size, emb_size] | |
| :return: [x_len, batch_size, emb_size] | |
| """ | |
| x = x + self.pe[:x.size(0), :].to(x.device) | |
| return self.dropout(x) | |
| def exists(val): | |
| return val is not None | |
| def default(val, d): | |
| return val if exists(val) else d | |
| def cache_fn(f): | |
| cache = dict() | |
| def cached_fn(*args, _cache = True, key = None, **kwargs): | |
| if not _cache: | |
| return f(*args, **kwargs) | |
| nonlocal cache | |
| if key in cache: | |
| return cache[key] | |
| result = f(*args, **kwargs) | |
| cache[key] = result | |
| return result | |
| return cached_fn | |
| def fourier_encode(x, max_freq, num_bands = 4): | |
| ''' | |
| x: ([64, 64, 2, 1]) is between [-1,1] | |
| max_feq is 10 | |
| num_bands is 6 | |
| ''' | |
| x = x.unsqueeze(-1) | |
| device, dtype, orig_x = x.device, x.dtype, x | |
| scales = torch.linspace(1., max_freq / 2, num_bands, device = device, dtype = dtype) # tensor([1.0000, 1.8000, 2.6000, 3.4000, 4.2000, 5.0000] | |
| scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)] # r([[[[1.0000, 1.8000, 2.6000, 3.4000, 4.2000, 5.0000]]]], | |
| x = x * scales * pi | |
| x = torch.cat([x.sin(), x.cos()], dim = -1) | |
| x = torch.cat((x, orig_x), dim = -1) | |
| return x | |
| class PreNorm(nn.Module): | |
| def __init__(self, dim, fn, context_dim = None): | |
| super().__init__() | |
| self.fn = fn | |
| self.norm = nn.LayerNorm(dim) | |
| self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None | |
| def forward(self, x, **kwargs): | |
| x = self.norm(x) | |
| if exists(self.norm_context): | |
| context = kwargs['context'] | |
| normed_context = self.norm_context(context) | |
| kwargs.update(context = normed_context) | |
| return self.fn(x, **kwargs) | |
| class GEGLU(nn.Module): | |
| def forward(self, x): | |
| x, gates = x.chunk(2, dim = -1) | |
| return x * F.gelu(gates) | |
| class FeedForward(nn.Module): | |
| def __init__(self, dim, mult = 4, dropout = 0.): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(dim, dim * mult * 2), | |
| GEGLU(), | |
| nn.Linear(dim * mult, dim), | |
| nn.Dropout(dropout) | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| class Attention(nn.Module): | |
| def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64, dropout = 0.,cls_conv_dim=None): | |
| super().__init__() | |
| inner_dim = dim_head * heads | |
| context_dim = default(context_dim, query_dim) | |
| self.scale = dim_head ** -0.5 | |
| self.heads = heads | |
| self.to_q = nn.Linear(query_dim, inner_dim, bias = False) | |
| self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False) # 27 to 5012*2 = 1024 | |
| self.dropout = nn.Dropout(dropout) | |
| self.to_out = nn.Linear(inner_dim, query_dim) | |
| #self.cls_dim_adjust = nn.Linear(context_dim,cls_conv_dim) | |
| def forward(self, x, context = None, mask = None, ref_cls_onehot=None): | |
| h = self.heads | |
| q = self.to_q(x) | |
| context = default(context, x) | |
| k, v = self.to_kv(context).chunk(2, dim = -1) | |
| q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v)) | |
| sim = einsum('b i d, b j d -> b i j', q, k) * self.scale | |
| if exists(mask): | |
| mask = repeat(mask, 'b j k -> (b h) k j', h = h) | |
| sim.masked_fill(mask == 0, -1e9) | |
| # attention, what we cannot get enough of | |
| attn = sim.softmax(dim = -1) | |
| attn = self.dropout(attn) | |
| out = einsum('b i j, b j d -> b i d', attn, v) | |
| out = rearrange(out, '(b h) n d -> b n (h d)', h = h) | |
| return self.to_out(out), attn | |
| class SVGEmbedding(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.command_embed = nn.Embedding(4, 512) | |
| self.arg_embed = nn.Embedding(128, 128,padding_idx=0) | |
| self.embed_fcn = nn.Linear(128 * 8, 512) | |
| self.pos_encoding = PositionalEncoding(d_model=opts.hidden_size, max_len=opts.max_seq_len + 1) | |
| self._init_embeddings() | |
| def _init_embeddings(self): | |
| nn.init.kaiming_normal_(self.command_embed.weight, mode="fan_in") | |
| nn.init.kaiming_normal_(self.arg_embed.weight, mode="fan_in") | |
| nn.init.kaiming_normal_(self.embed_fcn.weight, mode="fan_in") | |
| def forward(self, commands, args, groups=None): | |
| S, GN,_ = commands.shape | |
| src = self.command_embed(commands.long()).squeeze() + \ | |
| self.embed_fcn(self.arg_embed((args).long()).view(S, GN, -1)) # shift due to -1 PAD_VAL | |
| src = self.pos_encoding(src) | |
| return src | |
| class PositionwiseFeedForward(nn.Module): | |
| "Implements FFN equation." | |
| def __init__(self, d_model, d_ff, dropout): | |
| super(PositionwiseFeedForward, self).__init__() | |
| self.w_1 = nn.Linear(d_model, d_ff) | |
| self.w_2 = nn.Linear(d_ff, d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x): | |
| return self.w_2(F.relu(self.dropout(self.w_1(x)))) | |
| class Transformer_decoder(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.SVG_embedding = SVGEmbedding() | |
| self.command_fcn = nn.Linear(512, 4) | |
| self.args_fcn = nn.Linear(512, 8 * 128) | |
| c = copy.deepcopy | |
| attn = MultiHeadedAttention(h=8, d_model=512, dropout=0.0) | |
| ff = PositionwiseFeedForward(d_model=512, d_ff=1024, dropout=0.0) | |
| self.decoder_layers = clones(DecoderLayer(512, c(attn), c(attn),c(ff), dropout=0.0), 6) | |
| self.decoder_norm = nn.LayerNorm(512) | |
| self.decoder_layers_parallel = clones(DecoderLayer(512, c(attn), c(attn), c(ff), dropout=0.0), 1) | |
| self.decoder_norm_parallel = nn.LayerNorm(512) | |
| if opts.ref_nshot == 52: | |
| self.cls_embedding = nn.Embedding(96,512) | |
| else: | |
| self.cls_embedding = nn.Embedding(52,512) | |
| self.cls_token = nn.Parameter(torch.zeros(1, 1, 512)) | |
| def forward(self, x, memory, trg_char, src_mask=None, tgt_mask=None): | |
| memory = memory.unsqueeze(1) | |
| commands = x[:, :, :1] | |
| args = x[:, :, 1:] | |
| x = self.SVG_embedding(commands, args).transpose(0,1) | |
| trg_char = trg_char.long() | |
| trg_char = self.cls_embedding(trg_char) | |
| x[:, 0:1, :] = trg_char | |
| tgt_mask = tgt_mask.squeeze() | |
| for layer in self.decoder_layers: | |
| x,attn = layer(x, memory, src_mask, tgt_mask) | |
| out = self.decoder_norm(x) | |
| N, S, _ = out.shape | |
| cmd_logits = self.command_fcn(out) | |
| args_logits = self.args_fcn(out) # shape: bs, max_len, 8, 256 | |
| args_logits = args_logits.reshape(N, S, 8, 128) | |
| return cmd_logits,args_logits,attn | |
| def parallel_decoder(self, cmd_logits, args_logits, memory, trg_char): | |
| memory = memory.unsqueeze(1) | |
| cmd_args_mask = torch.Tensor([[0, 0, 0., 0., 0., 0., 0., 0.], | |
| [1, 1, 0., 0., 0., 0., 1., 1.], | |
| [1, 1, 0., 0., 0., 0., 1., 1.], | |
| [1, 1, 1., 1., 1., 1., 1., 1.]]).to(cmd_logits.device) | |
| if opts.mode == 'train': | |
| cmd2 = torch.argmax(cmd_logits, -1).unsqueeze(-1).transpose(0, 1) | |
| arg2 = torch.argmax(args_logits, -1).transpose(0, 1) | |
| cmd2paddingmask = _get_key_padding_mask(cmd2).transpose(0,1).unsqueeze(-1).to(cmd2.device) | |
| cmd2 = cmd2 * cmd2paddingmask | |
| args_mask = torch.matmul(F.one_hot(cmd2.long(),4).float(), cmd_args_mask).transpose(-1,-2).squeeze(-1) | |
| arg2 = arg2 * args_mask | |
| x = self.SVG_embedding(cmd2, arg2).transpose(0, 1) | |
| else: | |
| cmd2 = cmd_logits | |
| arg2 = args_logits | |
| cmd2paddingmask = _get_key_padding_mask(cmd2).transpose(0, 1).unsqueeze(-1).to(cmd2.device) | |
| cmd2 = cmd2 * cmd2paddingmask | |
| args_mask = torch.matmul(F.one_hot(cmd2.long(),4).float(), cmd_args_mask).transpose(-1, -2).squeeze(-1) | |
| arg2 = arg2 * args_mask | |
| x = self.SVG_embedding(cmd2, arg2).transpose(0,1) | |
| S = x.size(1) | |
| B = x.size(0) | |
| tgt_mask = torch.ones(S,S).to(x.device).unsqueeze(0).repeat(B, 1, 1) | |
| cmd2paddingmask = cmd2paddingmask.transpose(0, 1).transpose(-1, -2) | |
| tgt_mask = tgt_mask * cmd2paddingmask | |
| trg_char = trg_char.long() | |
| trg_char = self.cls_embedding(trg_char) | |
| x = torch.cat([trg_char, x],1) | |
| x[:, 0:1, :] = trg_char | |
| x = x[:,:opts.max_seq_len,:] | |
| tgt_mask = tgt_mask #*tri | |
| for layer in self.decoder_layers_parallel: | |
| x, attn = layer(x, memory, src_mask=None, tgt_mask=tgt_mask) | |
| out = self.decoder_norm_parallel(x) | |
| N, S, _ = out.shape | |
| cmd_logits = self.command_fcn(out) | |
| args_logits = self.args_fcn(out) | |
| args_logits = args_logits.reshape(N, S, 8, 128) | |
| return cmd_logits, args_logits | |
| def _get_key_padding_mask(commands, seq_dim=0): | |
| """ | |
| Args: | |
| commands: Shape [S, ...] | |
| """ | |
| lens =[] | |
| with torch.no_grad(): | |
| key_padding_mask = (commands == 0).cumsum(dim=seq_dim) > 0 | |
| commands=commands.transpose(0,1).squeeze(-1) #bs, opts.max_seq_len | |
| for i in range(commands.size(0)): | |
| try: | |
| seqi = commands[i]#blue opts.max_seq_len | |
| index = torch.where(seqi==0)[0][0] | |
| except: | |
| index=opts.max_seq_len | |
| lens.append(index) | |
| lens = torch.tensor(lens)+1#blue b | |
| seqlen_mask = util_funcs.sequence_mask(lens, opts.max_seq_len)#blue b,opts.max_seq_len | |
| return seqlen_mask | |
| class Transformer(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| num_freq_bands, | |
| depth, | |
| max_freq, | |
| input_channels = 1, | |
| input_axis = 2, | |
| num_latents = 512, | |
| latent_dim = 512, | |
| cross_heads = 1, | |
| latent_heads = 8, | |
| cross_dim_head = 64, | |
| latent_dim_head = 64, | |
| num_classes = 1000, | |
| attn_dropout = 0., | |
| ff_dropout = 0., | |
| weight_tie_layers = False, | |
| fourier_encode_data = True, | |
| self_per_cross_attn = 2, | |
| final_classifier_head = True | |
| ): | |
| """The shape of the final attention mechanism will be: | |
| depth * (cross attention -> self_per_cross_attn * self attention) | |
| Args: | |
| num_freq_bands: Number of freq bands, with original value (2 * K + 1) | |
| depth: Depth of net. | |
| max_freq: Maximum frequency, hyperparameter depending on how | |
| fine the data is. | |
| freq_base: Base for the frequency | |
| input_channels: Number of channels for each token of the input. | |
| input_axis: Number of axes for input data (2 for images, 3 for video) | |
| num_latents: Number of latents, or induced set points, or centroids. | |
| Different papers giving it different names. | |
| latent_dim: Latent dimension. | |
| cross_heads: Number of heads for cross attention. Paper said 1. | |
| latent_heads: Number of heads for latent self attention, 8. | |
| cross_dim_head: Number of dimensions per cross attention head. | |
| latent_dim_head: Number of dimensions per latent self attention head. | |
| num_classes: Output number of classes. | |
| attn_dropout: Attention dropout | |
| ff_dropout: Feedforward dropout | |
| weight_tie_layers: Whether to weight tie layers (optional). | |
| fourier_encode_data: Whether to auto-fourier encode the data, using | |
| the input_axis given. defaults to True, but can be turned off | |
| if you are fourier encoding the data yourself. | |
| self_per_cross_attn: Number of self attention blocks per cross attn. | |
| final_classifier_head: mean pool and project embeddings to number of classes (num_classes) at the end | |
| """ | |
| super().__init__() | |
| self.input_axis = input_axis | |
| self.max_freq = max_freq | |
| self.num_freq_bands = num_freq_bands | |
| self.fourier_encode_data = fourier_encode_data | |
| fourier_channels = (input_axis * ((num_freq_bands * 2) + 1)) if fourier_encode_data else 0 # 26 | |
| input_dim = fourier_channels + input_channels | |
| self.latents = nn.Parameter(torch.randn(num_latents, latent_dim)) | |
| get_cross_attn = lambda: PreNorm(latent_dim, Attention(latent_dim, input_dim, heads=cross_heads, dim_head=cross_dim_head, dropout=attn_dropout), context_dim=input_dim) | |
| get_cross_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout=ff_dropout)) | |
| get_latent_attn = lambda: PreNorm(latent_dim, Attention(latent_dim, heads=latent_heads, dim_head=latent_dim_head, dropout=attn_dropout)) | |
| get_latent_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout=ff_dropout)) | |
| get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff = map(cache_fn, (get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff)) | |
| #self_per_cross_attn=1 | |
| self.layers = nn.ModuleList([]) | |
| for i in range(depth): | |
| should_cache = i > 0 and weight_tie_layers | |
| cache_args = {'_cache': should_cache} | |
| self_attns = nn.ModuleList([]) | |
| for block_ind in range(self_per_cross_attn): #BUG 之前是2 self_per_cross_attn | |
| self_attns.append(nn.ModuleList([ | |
| get_latent_attn(**cache_args, key = block_ind), | |
| get_latent_ff(**cache_args, key = block_ind) | |
| ])) | |
| self.layers.append(nn.ModuleList([ | |
| get_cross_attn(**cache_args), | |
| get_cross_ff(**cache_args), | |
| self_attns | |
| ])) | |
| get_cross_attn2 = lambda: PreNorm(latent_dim, Attention(latent_dim, input_dim, heads = cross_heads, dim_head = cross_dim_head, dropout = attn_dropout), context_dim = input_dim) | |
| get_cross_ff2 = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout)) | |
| get_latent_attn2 = lambda: PreNorm(latent_dim, Attention(latent_dim, heads = latent_heads, dim_head = latent_dim_head, dropout = attn_dropout)) | |
| get_latent_ff2 = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout)) | |
| get_cross_attn2, get_cross_ff2, get_latent_attn2, get_latent_ff2 = map(cache_fn, (get_cross_attn2, get_cross_ff2, get_latent_attn2, get_latent_ff2)) | |
| self.layers_cnnsvg = nn.ModuleList([]) | |
| for i in range(1): | |
| should_cache = i > 0 and weight_tie_layers | |
| cache_args = {'_cache': should_cache} | |
| self_attns2 = nn.ModuleList([]) | |
| for block_ind in range(self_per_cross_attn): | |
| self_attns2.append(nn.ModuleList([ | |
| get_latent_attn2(**cache_args, key = block_ind), | |
| get_latent_ff2(**cache_args, key = block_ind) | |
| ])) | |
| self.layers_cnnsvg.append(nn.ModuleList([ | |
| get_cross_attn2(**cache_args), | |
| get_cross_ff2(**cache_args), | |
| self_attns2 | |
| ])) | |
| self.to_logits = nn.Sequential( | |
| Reduce('b n d -> b d', 'mean'), | |
| nn.LayerNorm(latent_dim), | |
| nn.Linear(latent_dim, num_classes) | |
| ) if final_classifier_head else nn.Identity() | |
| self.pre_lstm_fc = nn.Linear(10,opts.hidden_size) | |
| self.posr = PositionalEncoding(d_model=opts.hidden_size,max_len=opts.max_seq_len) | |
| patch_height = 2 | |
| patch_width = 2 | |
| patch_dim = 1 * patch_height * patch_width | |
| self.to_patch_embedding = nn.Sequential( | |
| Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), | |
| nn.Linear(patch_dim, 16), | |
| ) | |
| self.SVG_embedding = SVGEmbedding() | |
| self.cls_token = nn.Parameter(torch.zeros(1, 1, 512)) | |
| def forward(self, data, seq, ref_cls_onehot=None, mask=None, return_embeddings=True): | |
| b, *axis, _, device, dtype = *data.shape, data.device, data.dtype | |
| assert len(axis) == self.input_axis, 'input data must have the right number of axis' # img is 2 | |
| x = seq | |
| commands=x[:, :, :1] | |
| args=x[:, :, 1:] | |
| x = self.SVG_embedding(commands, args).transpose(0,1) | |
| cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = x.size(0)) | |
| x = torch.cat([cls_tokens,x],dim = 1) | |
| cls_one_pad = torch.ones((1,1,1)).to(x.device).repeat(x.size(0),1,1) | |
| mask = torch.cat([cls_one_pad,mask],dim=-1) | |
| self_atten = [] | |
| for cross_attn, cross_ff, self_attns in self.layers: | |
| for self_attn, self_ff in self_attns: | |
| x_,atten = self_attn(x,mask=mask) | |
| x = x_ + x | |
| self_atten.append(atten) | |
| x = self_ff(x) + x | |
| x = x + torch.randn_like(x) # add a perturbation | |
| return x, self_atten | |
| def att_residual(self, x, mask=None): | |
| for cross_attn, cross_ff, self_attns in self.layers_cnnsvg: | |
| for self_attn, self_ff in self_attns: | |
| x_, atten = self_attn(x) | |
| x = x_ + x | |
| x = self_ff(x) + x | |
| return x | |
| def loss(self, cmd_logits, args_logits, trg_seq, trg_seqlen, trg_pts_aux): | |
| ''' | |
| Inputs: | |
| cmd_logits: [b, 51, 4] | |
| args_logits: [b, 51, 6] | |
| ''' | |
| cmd_args_mask = torch.Tensor([[0, 0, 0., 0., 0., 0., 0., 0.], | |
| [1, 1, 0., 0., 0., 0., 1., 1.], | |
| [1, 1, 0., 0., 0., 0., 1., 1.], | |
| [1, 1, 1., 1., 1., 1., 1., 1.]]).to(cmd_logits.device) | |
| tgt_commands = trg_seq[:,:,:1].transpose(0,1) | |
| tgt_args = trg_seq[:,:,1:].transpose(0,1) | |
| seqlen_mask = util_funcs.sequence_mask(trg_seqlen, opts.max_seq_len).unsqueeze(-1) | |
| seqlen_mask2 = seqlen_mask.repeat(1,1,4)# NOTE b,501,4 | |
| seqlen_mask4 = seqlen_mask.repeat(1,1,8) | |
| seqlen_mask3 = seqlen_mask.unsqueeze(-1).repeat(1,1,8,128) | |
| tgt_commands_onehot = F.one_hot(tgt_commands, 4) | |
| tgt_args_onehot = F.one_hot(tgt_args, 128) | |
| args_mask = torch.matmul(tgt_commands_onehot.float(),cmd_args_mask).squeeze() | |
| loss_cmd = torch.sum(- tgt_commands_onehot.squeeze() * F.log_softmax(cmd_logits, -1), -1) | |
| loss_cmd = torch.mul(loss_cmd, seqlen_mask.squeeze()) | |
| loss_cmd = torch.mean(torch.sum(loss_cmd/trg_seqlen.unsqueeze(-1),-1)) | |
| loss_args = (torch.sum(-tgt_args_onehot*F.log_softmax(args_logits,-1),-1)*seqlen_mask4*args_mask) | |
| loss_args = torch.mean(loss_args,dim=-1,keepdim=False) | |
| loss_args = torch.mean(torch.sum(loss_args/trg_seqlen.unsqueeze(-1),-1)) | |
| SE_mask = torch.Tensor([[1, 1], | |
| [0, 0], | |
| [1, 1], | |
| [1, 1]]).to(cmd_logits.device) | |
| SE_args_mask = torch.matmul(tgt_commands_onehot.float(),SE_mask).squeeze().unsqueeze(-1) | |
| args_prob = F.softmax(args_logits, -1) | |
| args_end = args_prob[:,:,6:] | |
| args_end_shifted = torch.cat((torch.zeros(args_end.size(0),1,args_end.size(2),args_end.size(3)).to(args_end.device),args_end),1) | |
| args_end_shifted = args_end_shifted[:,:opts.max_seq_len,:,:] | |
| args_end_shifted = args_end_shifted*SE_args_mask + args_end*(1-SE_args_mask) | |
| args_start = args_prob[:,:,:2] | |
| seqlen_mask5 = util_funcs.sequence_mask(trg_seqlen-1, opts.max_seq_len).unsqueeze(-1) | |
| seqlen_mask5 = seqlen_mask5.repeat(1,1,2) | |
| smooth_constrained = torch.sum(torch.pow((args_end_shifted - args_start), 2), -1) * seqlen_mask5 | |
| smooth_constrained = torch.mean(smooth_constrained, dim=-1, keepdim=False) | |
| smooth_constrained = torch.mean(torch.sum(smooth_constrained / (trg_seqlen - 1).unsqueeze(-1), -1)) | |
| args_prob2 = F.softmax(args_logits / 0.1, -1) | |
| c = torch.argmax(args_prob2,-1).unsqueeze(-1).float() - args_prob2.detach() | |
| p_argmax = args_prob2 + c | |
| p_argmax = torch.mean(p_argmax,-1) | |
| control_pts = denumericalize(p_argmax) | |
| p0 = control_pts[:,:,:2] | |
| p1 = control_pts[:,:,2:4] | |
| p2 = control_pts[:,:,4:6] | |
| p3 = control_pts[:,:,6:8] | |
| line_mask = (tgt_commands==2).float() + (tgt_commands==1).float() | |
| curve_mask = (tgt_commands==3).float() | |
| t=0.25 | |
| aux_pts_line = p0 + t*(p3-p0) | |
| for t in [0.5,0.75]: | |
| coord_t = p0 + t*(p3-p0) | |
| aux_pts_line = torch.cat((aux_pts_line,coord_t),-1) | |
| aux_pts_line = aux_pts_line*line_mask | |
| t=0.25 | |
| aux_pts_curve = (1-t)*(1-t)*(1-t)*p0 + 3*t*(1-t)*(1-t)*p1 + 3*t*t*(1-t)*p2 + t*t*t*p3 | |
| for t in [0.5, 0.75]: | |
| coord_t = (1-t)*(1-t)*(1-t)*p0 + 3*t*(1-t)*(1-t)*p1 + 3*t*t*(1-t)*p2 + t*t*t*p3 | |
| aux_pts_curve = torch.cat((aux_pts_curve,coord_t),-1) | |
| aux_pts_curve = aux_pts_curve * curve_mask | |
| aux_pts_predict = aux_pts_curve + aux_pts_line | |
| seqlen_mask_aux = util_funcs.sequence_mask(trg_seqlen - 1, opts.max_seq_len).unsqueeze(-1) | |
| aux_pts_loss = torch.pow((aux_pts_predict - trg_pts_aux), 2) * seqlen_mask_aux | |
| loss_aux = torch.mean(aux_pts_loss, dim=-1, keepdim=False) | |
| loss_aux = torch.mean(torch.sum(loss_aux / trg_seqlen.unsqueeze(-1), -1)) | |
| loss = opts.loss_w_cmd * loss_cmd + opts.loss_w_args * loss_args + opts.loss_w_aux * loss_aux + opts.loss_w_smt * smooth_constrained | |
| svg_losses = {} | |
| svg_losses['loss_total'] = loss | |
| svg_losses["loss_cmd"] = loss_cmd | |
| svg_losses["loss_args"] = loss_args | |
| svg_losses["loss_smt"] = smooth_constrained | |
| svg_losses["loss_aux"] = loss_aux | |
| return svg_losses | |
| class DecoderLayer(nn.Module): | |
| "Decoder is made of self-attn, src-attn, and feed forward (defined below)" | |
| def __init__(self, size, self_attn, src_attn, feed_forward, dropout): | |
| super(DecoderLayer, self).__init__() | |
| self.size = size | |
| self.self_attn = self_attn | |
| self.src_attn = src_attn | |
| self.feed_forward = feed_forward | |
| self.sublayer = clones(SublayerConnection(size, dropout), 3) | |
| def forward(self, x, memory, src_mask, tgt_mask): | |
| "Follow Figure 1 (right) for connections." | |
| m = memory | |
| x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) | |
| x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask)) | |
| attn = self.self_attn.attn | |
| return self.sublayer[2](x, self.feed_forward),attn | |
| def subsequent_mask(size): | |
| "Mask out subsequent positions." | |
| attn_shape = (1, size, size) | |
| subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') | |
| return torch.from_numpy(subsequent_mask) == 0 | |
| def numericalize(cmd, n=128): | |
| """NOTE: shall only be called after normalization""" | |
| # assert np.max(cmd.origin) <= 1.0 and np.min(cmd.origin) >= -1.0 | |
| cmd = (cmd / 30 * n).round().clip(min=0, max=n-1).int() | |
| return cmd | |
| def denumericalize(cmd, n=128): | |
| cmd = cmd / n * 30 | |
| return cmd | |
| def attention(query, key, value, mask=None, trg_tri_mask=None,dropout=None, posr=None): | |
| "Compute 'Scaled Dot Product Attention'" | |
| d_k = query.size(-1) | |
| scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) | |
| if posr is not None: | |
| posr = posr.unsqueeze(1) | |
| scores = scores + posr | |
| if mask is not None: | |
| try: | |
| scores = scores.masked_fill(mask == 0, -1e9) # note mask: b,1,501,501 scores: b, head, 501,501 | |
| except Exception as e: | |
| print("Shape: ",scores.shape) | |
| print("Error: ",e) | |
| import pdb; pdb.set_trace() | |
| if trg_tri_mask is not None: | |
| scores = scores.masked_fill(trg_tri_mask == 0, -1e9) | |
| p_attn = F.softmax(scores, dim=-1) | |
| if dropout is not None: | |
| p_attn = dropout(p_attn) | |
| return torch.matmul(p_attn, value), p_attn | |
| class MultiHeadedAttention(nn.Module): | |
| def __init__(self, h, d_model, dropout): | |
| "Take in model size and number of heads." | |
| super(MultiHeadedAttention, self).__init__() | |
| assert d_model % h == 0 | |
| # We assume d_v always equals d_k | |
| self.d_k = d_model // h #32 | |
| self.h = h #8 | |
| self.linears = clones(nn.Linear(d_model, d_model), 4) | |
| self.attn = None | |
| self.dropout = nn.Dropout(p=dropout) | |
| def forward(self, query, key, value, mask=None,trg_tri_mask=None, posr=None): | |
| "Implements Figure 2" | |
| if mask is not None: | |
| # Same mask applied to all h heads. | |
| mask = mask.unsqueeze(1) | |
| nbatches = query.size(0) #16 | |
| query, key, value = \ | |
| [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) | |
| for l, x in zip(self.linears, (query, key, value))] | |
| x, self.attn = attention(query, key, value, mask=mask,trg_tri_mask=trg_tri_mask, | |
| dropout=self.dropout, posr=posr) | |
| x = x.transpose(1, 2).contiguous() \ | |
| .view(nbatches, -1, self.h * self.d_k) | |
| return self.linears[-1](x) | |
| def clones(module, N): | |
| "Produce N identical layers." | |
| return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) | |
| class SublayerConnection(nn.Module): | |
| """ | |
| A residual connection followed by a layer norm. | |
| Note for code simplicity the norm is first as opposed to last. | |
| """ | |
| def __init__(self, size, dropout): | |
| super(SublayerConnection, self).__init__() | |
| self.norm = nn.LayerNorm(size) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x, sublayer): | |
| "Apply residual connection to any sublayer with the same size." | |
| x_norm=self.norm(x) | |
| return x + self.dropout(sublayer(x_norm))#+ self.augs(x_norm) | |
| if __name__ == '__main__': | |
| model = Transformer( | |
| input_channels = 1, # number of channels for each token of the input | |
| input_axis = 2, # number of axis for input data (2 for images, 3 for video) | |
| num_freq_bands = 6, # number of freq bands, with original value (2 * K + 1) | |
| max_freq = 10., # maximum frequency, hyperparameter depending on how fine the data is | |
| depth = 6, # depth of net. The shape of the final attention mechanism will be: | |
| # depth * (cross attention -> self_per_cross_attn * self attention) | |
| num_latents = 256, # number of latents, or induced set points, or centroids. different papers giving it different names | |
| latent_dim = 512, # latent dimension | |
| cross_heads = 1, # number of heads for cross attention. paper said 1 | |
| latent_heads = 8, # number of heads for latent self attention, 8 | |
| cross_dim_head = 64, # number of dimensions per cross attention head | |
| latent_dim_head = 64, # number of dimensions per latent self attention head | |
| num_classes = 1000, # output number of classes | |
| attn_dropout = 0., | |
| ff_dropout = 0., | |
| weight_tie_layers = False, # whether to weight tie layers (optional, as indicated in the diagram) | |
| fourier_encode_data = True, # whether to auto-fourier encode the data, using the input_axis given. defaults to True, but can be turned off if you are fourier encoding the data yourself | |
| self_per_cross_attn = 2 # number of self attention blocks per cross attention | |
| ) | |
| img = torch.randn(1, 224, 224, 3) # 1 imagenet image, pixelized | |
| model(img) # (1, 1000) |