| import torch | |
| import torch.nn as nn | |
| class PuppetCaptionModel(nn.Module): | |
| def __init__(self, opt): | |
| super(PuppetCaptionModel, self).__init__() | |
| self.vocab_size = opt.vocab_size | |
| self.opt = opt | |
| self.puppet_layer= nn.Linear(1,1) | |
| def forward(self, event, clip, clip_mask, seq): | |
| N, L = seq.shape | |
| output = torch.zeros((N, L-1, self.vocab_size + 1), device=seq.device) | |
| return output | |
| def sample(self, event, clip, clip_mask, opt={}): | |
| N, _, C = clip.shape | |
| output = torch.zeros((N, 3), device=clip.device) | |
| prob = torch.zeros((N, 3), device=clip.device) | |
| return output, prob | |
| def build_loss(self, input, target, mask): | |
| one_hot = torch.nn.functional.one_hot(target, self.opt.vocab_size+1) | |
| output = - (one_hot * input * mask[..., None]).sum(2).sum(1) / (mask.sum(1) + 1e-6) | |
| return output |