Spaces:
Sleeping
Sleeping
| from models.image_encoder import ImageEncoder | |
| from models.image_decoder import ImageDecoder | |
| from models.modality_fusion import ModalityFusion | |
| from models.vgg_perceptual_loss import VGGPerceptualLoss | |
| from models.transformers import * | |
| from torch.autograd import Variable | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| class ModelMain(nn.Module): | |
| def __init__(self, opts, mode='train'): | |
| super().__init__() | |
| self.opts = opts | |
| self.img_encoder = ImageEncoder(img_size=opts.img_size, input_nc=opts.ref_nshot, ngf=opts.ngf, norm_layer=nn.LayerNorm) | |
| self.img_decoder = ImageDecoder(img_size=opts.img_size, input_nc=opts.bottleneck_bits + opts.char_num, output_nc=1, ngf=opts.ngf, norm_layer=nn.LayerNorm) | |
| self.vggptlossfunc = VGGPerceptualLoss() | |
| self.modality_fusion = ModalityFusion(img_size=opts.img_size, ref_nshot=opts.ref_nshot, bottleneck_bits=opts.bottleneck_bits, ngf=opts.ngf, mode=opts.mode) | |
| self.transformer_main = Transformer( | |
| input_channels = 1, | |
| input_axis = 2, # number of axis for input data (2 for images, 3 for video) | |
| num_freq_bands = 6, # number of freq bands, with original value (2 * K + 1) | |
| max_freq = 10., # maximum frequency, hyperparameter depending on how fine the data is | |
| depth = 6, # depth of net. The shape of the final attention mechanism will be: | |
| # depth * (cross attention -> self_per_cross_attn * self attention) | |
| num_latents = 256, # number of latents, or induced set points, or centroids. different papers giving it different names | |
| latent_dim = opts.dim_seq_latent, # latent dimension | |
| cross_heads = 1, # number of heads for cross attention. paper said 1 | |
| latent_heads = 8, # number of heads for latent self attention, 8 | |
| cross_dim_head = 64, # number of dimensions per cross attention head | |
| latent_dim_head = 64, # number of dimensions per latent self attention head | |
| num_classes = 1000, # output number of classes | |
| attn_dropout = 0., | |
| ff_dropout = 0., | |
| weight_tie_layers = False, # whether to weight tie layers (optional, as indicated in the diagram) | |
| fourier_encode_data = True, # whether to auto-fourier encode the data, using the input_axis given. defaults to True, but can be turned off if you are fourier encoding the data yourself | |
| self_per_cross_attn = 2 # number of self attention blocks per cross attention | |
| ) | |
| self.transformer_seqdec = Transformer_decoder() | |
| def forward(self, data, mode='train'): | |
| imgs, seqs, scalars = self.fetch_data(data, mode) | |
| ref_img, trg_img = imgs | |
| ref_seq, ref_seq_cat, ref_pad_mask, trg_seq, trg_seq_gt, trg_seq_shifted, trg_pts_aux = seqs | |
| trg_char_onehot, trg_cls, trg_seqlen = scalars | |
| # image encoding | |
| img_encoder_out = self.img_encoder(ref_img) | |
| img_feat = img_encoder_out['img_feat'] # bs, ngf * (2 ** 6) | |
| # seq encoding | |
| ref_img_ = ref_img.view(ref_img.size(0) * ref_img.size(1), ref_img.size(2), ref_img.size(3)).unsqueeze(-1) # [max_seq_len, n_bs * n_ref, 9] | |
| seq_feat, _ = self.transformer_main(ref_img_, ref_seq_cat, mask=ref_pad_mask) # [n_bs * n_ref, max_seq_len + 1, 9] | |
| # modality funsion | |
| mf_output, latent_feat_seq = self.modality_fusion(seq_feat, img_feat, ref_pad_mask=ref_pad_mask) | |
| latent_feat_seq = self.transformer_main.att_residual(latent_feat_seq) # [n_bs, max_seq_len + 1, bottleneck_bits] | |
| z = mf_output['latent'] | |
| kl_loss = mf_output['kl_loss'] | |
| # image decoding | |
| img_decoder_out = self.img_decoder(z, trg_char_onehot, trg_img) | |
| ret_dict = {} | |
| loss_dict = {} | |
| ret_dict['img'] = {} | |
| ret_dict['img']['out'] = img_decoder_out['gen_imgs'] | |
| ret_dict['img']['ref'] = ref_img | |
| ret_dict['img']['trg'] = trg_img | |
| if mode in {'train', 'val'}: | |
| # seq decoding (training or val mode) | |
| tgt_mask = Variable(subsequent_mask(self.opts.max_seq_len).type_as(ref_pad_mask.data)).unsqueeze(0).expand(z.size(0), -1, -1, -1).to(device).float() | |
| command_logits, args_logits, attn = self.transformer_seqdec(x=trg_seq_shifted, memory=latent_feat_seq, trg_char=trg_cls, tgt_mask=tgt_mask) | |
| command_logits_2, args_logits_2 = self.transformer_seqdec.parallel_decoder(command_logits, args_logits, memory=latent_feat_seq.detach(), trg_char=trg_cls) | |
| total_loss = self.transformer_main.loss(command_logits, args_logits,trg_seq, trg_seqlen, trg_pts_aux) | |
| total_loss_parallel = self.transformer_main.loss(command_logits_2, args_logits_2, trg_seq, trg_seqlen, trg_pts_aux) | |
| vggpt_loss = self.vggptlossfunc(img_decoder_out['gen_imgs'], trg_img) | |
| # loss and output | |
| loss_svg_items = ['total', 'cmd', 'args', 'smt', 'aux'] | |
| # for image | |
| loss_dict['img'] = {} | |
| loss_dict['img']['l1'] = img_decoder_out['img_l1loss'] | |
| loss_dict['img']['vggpt'] = vggpt_loss['pt_c_loss'] | |
| # for latent | |
| loss_dict['kl'] = kl_loss | |
| # for svg | |
| loss_dict['svg'] = {} | |
| loss_dict['svg_para'] = {} | |
| for item in loss_svg_items: | |
| loss_dict['svg'][item] = total_loss[f'loss_{item}'] | |
| loss_dict['svg_para'][item] = total_loss_parallel[f'loss_{item}'] | |
| else: # testing (inference) | |
| trg_len = trg_seq_shifted.size(0) | |
| sampled_svg = torch.zeros(1, trg_seq.size(1), self.opts.dim_seq_short).to(device) | |
| for t in range(0, trg_len): | |
| tgt_mask = Variable(subsequent_mask(sampled_svg.size(0)).type_as(ref_seq_cat.data)).unsqueeze(0).expand(sampled_svg.size(1), -1, -1, -1).to(device).float() | |
| command_logits, args_logits, attn = self.transformer_seqdec(x=sampled_svg, memory=latent_feat_seq, trg_char=trg_cls, tgt_mask=tgt_mask) | |
| prob_comand = F.softmax(command_logits[:, -1, :], -1) | |
| prob_args = F.softmax(args_logits[:, -1, :], -1) | |
| next_command = torch.argmax(prob_comand, -1).unsqueeze(-1) | |
| next_args = torch.argmax(prob_args, -1) | |
| predict_tmp = torch.cat((next_command, next_args),-1).unsqueeze(1).transpose(0,1) | |
| sampled_svg = torch.cat((sampled_svg, predict_tmp), dim=0) | |
| sampled_svg = sampled_svg[1:] | |
| cmd2 = sampled_svg[:,:,0].unsqueeze(-1) | |
| arg2 = sampled_svg[:,:,1:] | |
| command_logits_2, args_logits_2 = self.transformer_seqdec.parallel_decoder(cmd_logits=cmd2, args_logits=arg2, memory=latent_feat_seq, trg_char=trg_cls) | |
| prob_comand = F.softmax(command_logits_2,-1) | |
| prob_args = F.softmax(args_logits_2,-1) | |
| update_command = torch.argmax(prob_comand,-1).unsqueeze(-1) | |
| update_args = torch.argmax(prob_args,-1) | |
| sampled_svg_parralel = torch.cat((update_command, update_args),-1).transpose(0,1) | |
| commands1 = F.one_hot(sampled_svg[:,:,:1].long(), 4).squeeze().transpose(0, 1) | |
| args1 = denumericalize(sampled_svg[:,:,1:]).transpose(0,1) | |
| sampled_svg_1 = torch.cat([commands1.cpu().detach(),args1[:, :, 2:].cpu().detach()],dim =-1) | |
| commands2 = F.one_hot(sampled_svg_parralel[:, :, :1].long(), 4).squeeze().transpose(0, 1) | |
| args2 = denumericalize(sampled_svg_parralel[:, :, 1:]).transpose(0,1) | |
| sampled_svg_2 = torch.cat([commands2.cpu().detach(),args2[:, :, 2:].cpu().detach()], dim =-1) | |
| ret_dict['svg'] = {} | |
| ret_dict['svg']['sampled_1'] = sampled_svg_1 | |
| ret_dict['svg']['sampled_2'] = sampled_svg_2 | |
| ret_dict['svg']['trg'] = trg_seq_gt | |
| return ret_dict, loss_dict | |
| def fetch_data(self, data, mode): | |
| input_image = data['rendered'] # [bs, opts.char_num, opts.img_size, opts.img_size] | |
| input_sequence = data['sequence'] # [bs, opts.char_num, opts.max_seq_len] | |
| input_seqlen = data['seq_len'] | |
| input_seqlen = input_seqlen + 1 | |
| input_pts_aux = data['pts_aux'] | |
| arg_quant = numericalize(input_sequence[:, :, :, 4:]) | |
| cmd_cls = torch.argmax(input_sequence[:, :, :, :4], dim=-1).unsqueeze(-1) | |
| input_sequence = torch.cat([cmd_cls, arg_quant], dim=-1) # 1 + 8 = 9 dimension | |
| # choose reference classes and target classes | |
| if mode == 'train': | |
| ref_cls = torch.randint(0, self.opts.char_num, (input_image.size(0), self.opts.ref_nshot)).to(device) | |
| if opts.ref_nshot == 52: # For ENG to TH | |
| ref_cls_upper = torch.randint(0, 26, (input_image.size(0), self.opts.ref_nshot // 2)).to(device) | |
| ref_cls_lower = torch.randint(26, 52, (input_image.size(0), self.opts.ref_nshot // 2)).to(device) | |
| ref_cls = torch.cat((ref_cls_upper, ref_cls_lower), -1) | |
| elif mode == 'val': | |
| ref_cls = torch.arange(0, self.opts.ref_nshot, 1).to(device).unsqueeze(0).expand(input_image.size(0), -1) | |
| else: | |
| ref_ids = self.opts.ref_char_ids.split(',') | |
| ref_ids = list(map(int, ref_ids)) | |
| assert len(ref_ids) == self.opts.ref_nshot | |
| ref_cls = torch.tensor(ref_ids).to(device).unsqueeze(0).expand(self.opts.char_num, -1) | |
| if mode in {'train', 'val'}: | |
| trg_cls = torch.randint(0, self.opts.char_num, (input_image.size(0), 1)).to(device) | |
| if opts.ref_nshot == 52: | |
| trg_cls = torch.randint(52, opts.char_num, (input_image.size(0), 1)).to(device) | |
| else: | |
| trg_cls = torch.arange(0, self.opts.char_num).to(device) | |
| if opts.ref_nshot == 52: | |
| trg_cls = torch.randint(52, opts.char_num, (input_image.size(0), 1)).to(device) | |
| trg_cls = trg_cls.view(self.opts.char_num, 1) | |
| input_image = input_image.expand(self.opts.char_num, -1, -1, -1) | |
| input_sequence = input_sequence.expand(self.opts.char_num, -1, -1, -1) | |
| input_pts_aux = input_pts_aux.expand(self.opts.char_num, -1, -1, -1) | |
| input_seqlen = input_seqlen.expand(self.opts.char_num, -1, -1) | |
| ref_img = util_funcs.select_imgs(input_image, ref_cls, self.opts) | |
| # select a target glyph image | |
| trg_img = util_funcs.select_imgs(input_image, trg_cls, self.opts) | |
| # randomly select ref vector glyphs | |
| ref_seq = util_funcs.select_seqs(input_sequence, ref_cls, self.opts, self.opts.dim_seq_short) # [opts.batch_size, opts.ref_nshot, opts.max_seq_len, opts.dim_seq_nmr] | |
| # randomly select a target vector glyph | |
| trg_seq = util_funcs.select_seqs(input_sequence, trg_cls, self.opts, self.opts.dim_seq_short) | |
| trg_seq = trg_seq.squeeze(1) | |
| trg_pts_aux = util_funcs.select_seqs(input_pts_aux, trg_cls, self.opts, opts.n_aux_pts) | |
| trg_pts_aux = trg_pts_aux.squeeze(1) | |
| # the one-hot target char class | |
| trg_char_onehot = util_funcs.trgcls_to_onehot(trg_cls, self.opts) | |
| # shift target sequence | |
| trg_seq_gt = trg_seq.clone().detach() | |
| trg_seq_gt = torch.cat((trg_seq_gt[:, :, :1], trg_seq_gt[:, :, 3:]), -1) | |
| trg_seq = trg_seq.transpose(0, 1) | |
| trg_seq_shifted = util_funcs.shift_right(trg_seq) | |
| ref_seq_cat = ref_seq.view(ref_seq.size(0) * ref_seq.size(1), ref_seq.size(2), ref_seq.size(3)) | |
| ref_seq_cat = ref_seq_cat.transpose(0,1) | |
| ref_seqlen = util_funcs.select_seqlens(input_seqlen, ref_cls, self.opts) | |
| ref_seqlen_cat = ref_seqlen.view(ref_seqlen.size(0) * ref_seqlen.size(1), ref_seqlen.size(2)) | |
| ref_pad_mask = torch.zeros(ref_seqlen_cat.size(0), self.opts.max_seq_len) # value = 1 means pos to be masked | |
| for i in range(ref_seqlen_cat.size(0)): | |
| ref_pad_mask[i,:ref_seqlen_cat[i]] = 1 | |
| ref_pad_mask = ref_pad_mask.to(device).float().unsqueeze(1) | |
| trg_seqlen = util_funcs.select_seqlens(input_seqlen, trg_cls, self.opts) | |
| trg_seqlen = trg_seqlen.squeeze() | |
| return [ref_img, trg_img], [ref_seq, ref_seq_cat, ref_pad_mask, trg_seq, trg_seq_gt, trg_seq_shifted, trg_pts_aux], [trg_char_onehot, trg_cls, trg_seqlen] |