Exclibur's picture
Upload folder using huggingface_hub
f1c1609 verified
import torch
import torch.nn as nn
class PuppetCaptionModel(nn.Module):
def __init__(self, opt):
super(PuppetCaptionModel, self).__init__()
self.vocab_size = opt.vocab_size
self.opt = opt
self.puppet_layer= nn.Linear(1,1)
def forward(self, event, clip, clip_mask, seq):
N, L = seq.shape
output = torch.zeros((N, L-1, self.vocab_size + 1), device=seq.device)
return output
def sample(self, event, clip, clip_mask, opt={}):
N, _, C = clip.shape
output = torch.zeros((N, 3), device=clip.device)
prob = torch.zeros((N, 3), device=clip.device)
return output, prob
def build_loss(self, input, target, mask):
one_hot = torch.nn.functional.one_hot(target, self.opt.vocab_size+1)
output = - (one_hot * input * mask[..., None]).sum(2).sum(1) / (mask.sum(1) + 1e-6)
return output