| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from encoding import get_encoder |
|
|
| |
| class AudioAttNet(nn.Module): |
| def __init__(self, dim_aud=64, seq_len=8): |
| super(AudioAttNet, self).__init__() |
| self.seq_len = seq_len |
| self.dim_aud = dim_aud |
| self.attentionConvNet = nn.Sequential( |
| nn.Conv1d(self.dim_aud, 16, kernel_size=3, stride=1, padding=1, bias=True), |
| nn.LeakyReLU(0.02, True), |
| nn.Conv1d(16, 8, kernel_size=3, stride=1, padding=1, bias=True), |
| nn.LeakyReLU(0.02, True), |
| nn.Conv1d(8, 4, kernel_size=3, stride=1, padding=1, bias=True), |
| nn.LeakyReLU(0.02, True), |
| nn.Conv1d(4, 2, kernel_size=3, stride=1, padding=1, bias=True), |
| nn.LeakyReLU(0.02, True), |
| nn.Conv1d(2, 1, kernel_size=3, stride=1, padding=1, bias=True), |
| nn.LeakyReLU(0.02, True) |
| ) |
| self.attentionNet = nn.Sequential( |
| nn.Linear(in_features=self.seq_len, out_features=self.seq_len, bias=True), |
| nn.Softmax(dim=1) |
| ) |
|
|
| def forward(self, x): |
| |
| y = x.permute(0, 2, 1) |
| y = self.attentionConvNet(y) |
| y = self.attentionNet(y.view(1, self.seq_len)).view(1, self.seq_len, 1) |
| return torch.sum(y * x, dim=1) |
|
|
|
|
| |
| class AudioNet(nn.Module): |
| def __init__(self, dim_in=29, dim_aud=64, win_size=16): |
| super(AudioNet, self).__init__() |
| self.win_size = win_size |
| self.dim_aud = dim_aud |
| self.encoder_conv = nn.Sequential( |
| nn.Conv1d(dim_in, 32, kernel_size=3, stride=2, padding=1, bias=True), |
| nn.LeakyReLU(0.02, True), |
| nn.Conv1d(32, 32, kernel_size=3, stride=2, padding=1, bias=True), |
| nn.LeakyReLU(0.02, True), |
| nn.Conv1d(32, 64, kernel_size=3, stride=2, padding=1, bias=True), |
| nn.LeakyReLU(0.02, True), |
| nn.Conv1d(64, 64, kernel_size=3, stride=2, padding=1, bias=True), |
| nn.LeakyReLU(0.02, True), |
| ) |
| self.encoder_fc1 = nn.Sequential( |
| nn.Linear(64, 64), |
| nn.LeakyReLU(0.02, True), |
| nn.Linear(64, dim_aud), |
| ) |
|
|
| def forward(self, x): |
| half_w = int(self.win_size/2) |
| x = x[:, :, 8-half_w:8+half_w] |
| x = self.encoder_conv(x).squeeze(-1) |
| x = self.encoder_fc1(x) |
| return x |
|
|
|
|
| class MLP(nn.Module): |
| def __init__(self, dim_in, dim_out, dim_hidden, num_layers): |
| super().__init__() |
| self.dim_in = dim_in |
| self.dim_out = dim_out |
| self.dim_hidden = dim_hidden |
| self.num_layers = num_layers |
|
|
| net = [] |
| for l in range(num_layers): |
| net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=False)) |
|
|
| self.net = nn.ModuleList(net) |
| |
| def forward(self, x): |
| for l in range(self.num_layers): |
| x = self.net[l](x) |
| if l != self.num_layers - 1: |
| x = F.relu(x, inplace=True) |
| |
| |
| return x |
|
|
|
|
| class MotionNetwork(nn.Module): |
| def __init__(self, |
| audio_dim = 32, |
| ind_dim = 0, |
| args = None, |
| ): |
| super(MotionNetwork, self).__init__() |
|
|
| if 'esperanto' in args.audio_extractor: |
| self.audio_in_dim = 44 |
| elif 'deepspeech' in args.audio_extractor: |
| self.audio_in_dim = 29 |
| elif 'hubert' in args.audio_extractor: |
| self.audio_in_dim = 1024 |
| else: |
| raise NotImplementedError |
| |
| self.bound = 0.15 |
| self.exp_eye = True |
|
|
| |
| self.individual_dim = ind_dim |
| if self.individual_dim > 0: |
| self.individual_codes = nn.Parameter(torch.randn(10000, self.individual_dim) * 0.1) |
|
|
| |
| self.audio_dim = audio_dim |
| self.audio_net = AudioNet(self.audio_in_dim, self.audio_dim) |
|
|
| self.audio_att_net = AudioAttNet(self.audio_dim) |
|
|
| |
| self.num_levels = 12 |
| self.level_dim = 1 |
| self.encoder_xy, self.in_dim_xy = get_encoder('hashgrid', input_dim=2, num_levels=self.num_levels, level_dim=self.level_dim, base_resolution=16, log2_hashmap_size=17, desired_resolution=256 * self.bound) |
| self.encoder_yz, self.in_dim_yz = get_encoder('hashgrid', input_dim=2, num_levels=self.num_levels, level_dim=self.level_dim, base_resolution=16, log2_hashmap_size=17, desired_resolution=256 * self.bound) |
| self.encoder_xz, self.in_dim_xz = get_encoder('hashgrid', input_dim=2, num_levels=self.num_levels, level_dim=self.level_dim, base_resolution=16, log2_hashmap_size=17, desired_resolution=256 * self.bound) |
|
|
| self.in_dim = self.in_dim_xy + self.in_dim_yz + self.in_dim_xz |
|
|
|
|
| self.num_layers = 3 |
| self.hidden_dim = 64 |
|
|
| self.exp_in_dim = 6 - 1 |
| self.eye_dim = 6 if self.exp_eye else 0 |
| self.exp_encode_net = MLP(self.exp_in_dim, self.eye_dim - 1, 16, 2) |
|
|
| self.eye_att_net = MLP(self.in_dim, self.eye_dim, 16, 2) |
|
|
| |
| self.out_dim = 11 |
| self.sigma_net = MLP(self.in_dim + self.audio_dim + self.eye_dim + self.individual_dim, self.out_dim, self.hidden_dim, self.num_layers) |
| |
| self.aud_ch_att_net = MLP(self.in_dim, self.audio_dim, 32, 2) |
|
|
|
|
| @staticmethod |
| @torch.jit.script |
| def split_xyz(x): |
| xy, yz, xz = x[:, :-1], x[:, 1:], torch.cat([x[:,:1], x[:,-1:]], dim=-1) |
| return xy, yz, xz |
|
|
|
|
| def encode_x(self, xyz, bound): |
| |
| N, M = xyz.shape |
| xy, yz, xz = self.split_xyz(xyz) |
| feat_xy = self.encoder_xy(xy, bound=bound) |
| feat_yz = self.encoder_yz(yz, bound=bound) |
| feat_xz = self.encoder_xz(xz, bound=bound) |
| |
| return torch.cat([feat_xy, feat_yz, feat_xz], dim=-1) |
| |
|
|
| def encode_audio(self, a): |
| |
| |
|
|
| |
| if a is None: return None |
|
|
| enc_a = self.audio_net(a) |
| enc_a = self.audio_att_net(enc_a.unsqueeze(0)) |
| |
| return enc_a |
|
|
|
|
| def forward(self, x, a, e=None, c=None): |
| |
| enc_x = self.encode_x(x, bound=self.bound) |
|
|
| enc_a = self.encode_audio(a) |
| enc_a = enc_a.repeat(enc_x.shape[0], 1) |
| aud_ch_att = self.aud_ch_att_net(enc_x) |
| enc_w = enc_a * aud_ch_att |
| |
| eye_att = torch.relu(self.eye_att_net(enc_x)) |
| enc_e = self.exp_encode_net(e[:-1]) |
| enc_e = torch.cat([enc_e, e[-1:]], dim=-1) |
| enc_e = enc_e * eye_att |
| if c is not None: |
| c = c.repeat(enc_x.shape[0], 1) |
| h = torch.cat([enc_x, enc_w, enc_e, c], dim=-1) |
| else: |
| h = torch.cat([enc_x, enc_w, enc_e], dim=-1) |
|
|
| h = self.sigma_net(h) |
|
|
| d_xyz = h[..., :3] * 1e-2 |
| d_rot = h[..., 3:7] |
| d_opa = h[..., 7:8] |
| d_scale = h[..., 8:11] |
| return { |
| 'd_xyz': d_xyz, |
| 'd_rot': d_rot, |
| 'd_opa': d_opa, |
| 'd_scale': d_scale, |
| 'ambient_aud' : aud_ch_att.norm(dim=-1, keepdim=True), |
| 'ambient_eye' : eye_att.norm(dim=-1, keepdim=True), |
| } |
|
|
|
|
| |
| def get_params(self, lr, lr_net, wd=0): |
|
|
| params = [ |
| {'params': self.audio_net.parameters(), 'lr': lr_net, 'weight_decay': wd}, |
| {'params': self.encoder_xy.parameters(), 'lr': lr}, |
| {'params': self.encoder_yz.parameters(), 'lr': lr}, |
| {'params': self.encoder_xz.parameters(), 'lr': lr}, |
| {'params': self.sigma_net.parameters(), 'lr': lr_net, 'weight_decay': wd}, |
| ] |
| params.append({'params': self.audio_att_net.parameters(), 'lr': lr_net * 5, 'weight_decay': 0.0001}) |
| if self.individual_dim > 0: |
| params.append({'params': self.individual_codes, 'lr': lr_net, 'weight_decay': wd}) |
| |
| params.append({'params': self.aud_ch_att_net.parameters(), 'lr': lr_net, 'weight_decay': wd}) |
| params.append({'params': self.eye_att_net.parameters(), 'lr': lr_net, 'weight_decay': wd}) |
| params.append({'params': self.exp_encode_net.parameters(), 'lr': lr_net, 'weight_decay': wd}) |
|
|
| return params |
|
|
|
|
|
|
|
|
| class MouthMotionNetwork(nn.Module): |
| def __init__(self, |
| audio_dim = 32, |
| ind_dim = 0, |
| args = None, |
| ): |
| super(MouthMotionNetwork, self).__init__() |
|
|
| if 'esperanto' in args.audio_extractor: |
| self.audio_in_dim = 44 |
| elif 'deepspeech' in args.audio_extractor: |
| self.audio_in_dim = 29 |
| elif 'hubert' in args.audio_extractor: |
| self.audio_in_dim = 1024 |
| else: |
| raise NotImplementedError |
| |
| |
| self.bound = 0.15 |
|
|
| |
| self.individual_dim = ind_dim |
| if self.individual_dim > 0: |
| self.individual_codes = nn.Parameter(torch.randn(10000, self.individual_dim) * 0.1) |
|
|
| |
| self.audio_dim = audio_dim |
| self.audio_net = AudioNet(self.audio_in_dim, self.audio_dim) |
|
|
| self.audio_att_net = AudioAttNet(self.audio_dim) |
|
|
| |
| self.num_levels = 12 |
| self.level_dim = 1 |
| self.encoder_xy, self.in_dim_xy = get_encoder('hashgrid', input_dim=2, num_levels=self.num_levels, level_dim=self.level_dim, base_resolution=64, log2_hashmap_size=17, desired_resolution=384 * self.bound) |
| self.encoder_yz, self.in_dim_yz = get_encoder('hashgrid', input_dim=2, num_levels=self.num_levels, level_dim=self.level_dim, base_resolution=64, log2_hashmap_size=17, desired_resolution=384 * self.bound) |
| self.encoder_xz, self.in_dim_xz = get_encoder('hashgrid', input_dim=2, num_levels=self.num_levels, level_dim=self.level_dim, base_resolution=64, log2_hashmap_size=17, desired_resolution=384 * self.bound) |
|
|
| self.in_dim = self.in_dim_xy + self.in_dim_yz + self.in_dim_xz |
|
|
| |
| self.num_layers = 3 |
| self.hidden_dim = 32 |
|
|
| self.out_dim = 3 |
| self.sigma_net = MLP(self.in_dim + self.audio_dim + self.individual_dim, self.out_dim, self.hidden_dim, self.num_layers) |
| |
| self.aud_ch_att_net = MLP(self.in_dim, self.audio_dim, 32, 2) |
| |
|
|
| def encode_audio(self, a): |
| |
| |
|
|
| |
| if a is None: return None |
|
|
| enc_a = self.audio_net(a) |
| enc_a = self.audio_att_net(enc_a.unsqueeze(0)) |
| |
| return enc_a |
| |
|
|
| @staticmethod |
| @torch.jit.script |
| def split_xyz(x): |
| xy, yz, xz = x[:, :-1], x[:, 1:], torch.cat([x[:,:1], x[:,-1:]], dim=-1) |
| return xy, yz, xz |
|
|
|
|
| def encode_x(self, xyz, bound): |
| |
| N, M = xyz.shape |
| xy, yz, xz = self.split_xyz(xyz) |
| feat_xy = self.encoder_xy(xy, bound=bound) |
| feat_yz = self.encoder_yz(yz, bound=bound) |
| feat_xz = self.encoder_xz(xz, bound=bound) |
| |
| return torch.cat([feat_xy, feat_yz, feat_xz], dim=-1) |
|
|
|
|
| def forward(self, x, a): |
| |
| enc_x = self.encode_x(x, bound=self.bound) |
|
|
| enc_a = self.encode_audio(a) |
| enc_w = enc_a.repeat(enc_x.shape[0], 1) |
| |
| |
|
|
| h = torch.cat([enc_x, enc_w], dim=-1) |
|
|
| h = self.sigma_net(h) |
|
|
| d_xyz = h * 1e-2 |
| d_xyz[..., 0] = d_xyz[..., 0] / 5 |
| d_xyz[..., 2] = d_xyz[..., 2] / 5 |
| return { |
| 'd_xyz': d_xyz, |
| |
| } |
|
|
|
|
| |
| def get_params(self, lr, lr_net, wd=0): |
|
|
| params = [ |
| {'params': self.audio_net.parameters(), 'lr': lr_net, 'weight_decay': wd}, |
| {'params': self.encoder_xy.parameters(), 'lr': lr}, |
| {'params': self.encoder_yz.parameters(), 'lr': lr}, |
| {'params': self.encoder_xz.parameters(), 'lr': lr}, |
| {'params': self.sigma_net.parameters(), 'lr': lr_net, 'weight_decay': wd}, |
| ] |
| params.append({'params': self.audio_att_net.parameters(), 'lr': lr_net * 5, 'weight_decay': 0.0001}) |
| if self.individual_dim > 0: |
| params.append({'params': self.individual_codes, 'lr': lr_net, 'weight_decay': wd}) |
| |
| params.append({'params': self.aud_ch_att_net.parameters(), 'lr': lr_net, 'weight_decay': wd}) |
|
|
| return params |
|
|