| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from .encoding import get_encoder |
| | from .renderer import NeRFRenderer |
| |
|
| |
|
| | class Conv2d(nn.Module): |
| | def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, leakyReLU=False, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| | self.conv_block = nn.Sequential( |
| | nn.Conv2d(cin, cout, kernel_size, stride, padding), |
| | nn.BatchNorm2d(cout) |
| | ) |
| | if leakyReLU: |
| | self.act = nn.LeakyReLU(0.02) |
| | else: |
| | self.act = nn.ReLU() |
| | self.residual = residual |
| |
|
| | def forward(self, x): |
| | out = self.conv_block(x) |
| | if self.residual: |
| | out += x |
| | return self.act(out) |
| |
|
| |
|
| | |
| | class AudioAttNet(nn.Module): |
| | def __init__(self, dim_aud=64, seq_len=8): |
| | super(AudioAttNet, self).__init__() |
| | self.seq_len = seq_len |
| | self.dim_aud = dim_aud |
| | self.attentionConvNet = nn.Sequential( |
| | nn.Conv1d(self.dim_aud, 16, kernel_size=3, stride=1, padding=1, bias=True), |
| | nn.LeakyReLU(0.02, True), |
| | nn.Conv1d(16, 8, kernel_size=3, stride=1, padding=1, bias=True), |
| | nn.LeakyReLU(0.02, True), |
| | nn.Conv1d(8, 4, kernel_size=3, stride=1, padding=1, bias=True), |
| | nn.LeakyReLU(0.02, True), |
| | nn.Conv1d(4, 2, kernel_size=3, stride=1, padding=1, bias=True), |
| | nn.LeakyReLU(0.02, True), |
| | nn.Conv1d(2, 1, kernel_size=3, stride=1, padding=1, bias=True), |
| | nn.LeakyReLU(0.02, True) |
| | ) |
| | self.attentionNet = nn.Sequential( |
| | nn.Linear(in_features=self.seq_len, out_features=self.seq_len, bias=True), |
| | nn.Softmax(dim=1) |
| | ) |
| |
|
| | def forward(self, x): |
| | |
| | y = x.permute(0, 2, 1) |
| | y = self.attentionConvNet(y) |
| | y = self.attentionNet(y.view(1, self.seq_len)).view(1, self.seq_len, 1) |
| | return torch.sum(y * x, dim=1) |
| |
|
| |
|
| | class AudioEncoder(nn.Module): |
| | def __init__(self): |
| | super(AudioEncoder, self).__init__() |
| |
|
| | self.audio_encoder = nn.Sequential( |
| | Conv2d(1, 32, kernel_size=3, stride=1, padding=1), |
| | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), |
| | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), |
| |
|
| | Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), |
| | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), |
| | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), |
| |
|
| | Conv2d(64, 128, kernel_size=3, stride=3, padding=1), |
| | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), |
| | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), |
| |
|
| | Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), |
| | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), |
| |
|
| | Conv2d(256, 512, kernel_size=3, stride=1, padding=0), |
| | Conv2d(512, 512, kernel_size=1, stride=1, padding=0), ) |
| |
|
| | def forward(self, x): |
| | out = self.audio_encoder(x) |
| | out = out.squeeze(2).squeeze(2) |
| |
|
| | return out |
| |
|
| | |
| | class AudioNet(nn.Module): |
| | def __init__(self, dim_in=29, dim_aud=64, win_size=16): |
| | super(AudioNet, self).__init__() |
| | self.win_size = win_size |
| | self.dim_aud = dim_aud |
| | self.encoder_conv = nn.Sequential( |
| | nn.Conv1d(dim_in, 32, kernel_size=3, stride=2, padding=1, bias=True), |
| | nn.LeakyReLU(0.02, True), |
| | nn.Conv1d(32, 32, kernel_size=3, stride=2, padding=1, bias=True), |
| | nn.LeakyReLU(0.02, True), |
| | nn.Conv1d(32, 64, kernel_size=3, stride=2, padding=1, bias=True), |
| | nn.LeakyReLU(0.02, True), |
| | nn.Conv1d(64, 64, kernel_size=3, stride=2, padding=1, bias=True), |
| | nn.LeakyReLU(0.02, True), |
| | ) |
| | self.encoder_fc1 = nn.Sequential( |
| | nn.Linear(64, 64), |
| | nn.LeakyReLU(0.02, True), |
| | nn.Linear(64, dim_aud), |
| | ) |
| |
|
| | def forward(self, x): |
| | half_w = int(self.win_size/2) |
| | x = x[:, :, 8-half_w:8+half_w] |
| | x = self.encoder_conv(x).squeeze(-1) |
| | x = self.encoder_fc1(x) |
| | return x |
| |
|
| |
|
| | |
| | class AudioNet_ave(nn.Module): |
| | def __init__(self, dim_in=29, dim_aud=64, win_size=16): |
| | super(AudioNet_ave, self).__init__() |
| | self.win_size = win_size |
| | self.dim_aud = dim_aud |
| | self.encoder_fc1 = nn.Sequential( |
| | nn.Linear(512, 256), |
| | nn.LeakyReLU(0.02, True), |
| | nn.Linear(256, 128), |
| | nn.LeakyReLU(0.02, True), |
| | nn.Linear(128, dim_aud), |
| | ) |
| | def forward(self, x): |
| | |
| | |
| | |
| | x = self.encoder_fc1(x).permute(1,0,2).squeeze(0) |
| | return x |
| |
|
| | class MLP(nn.Module): |
| | def __init__(self, dim_in, dim_out, dim_hidden, num_layers): |
| | super().__init__() |
| | self.dim_in = dim_in |
| | self.dim_out = dim_out |
| | self.dim_hidden = dim_hidden |
| | self.num_layers = num_layers |
| |
|
| | net = [] |
| | for l in range(num_layers): |
| | net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=False)) |
| |
|
| | self.net = nn.ModuleList(net) |
| | |
| | def forward(self, x): |
| | for l in range(self.num_layers): |
| | x = self.net[l](x) |
| | if l != self.num_layers - 1: |
| | x = F.relu(x, inplace=True) |
| | |
| | |
| | return x |
| |
|
| |
|
| | class NeRFNetwork(NeRFRenderer): |
| | def __init__(self, |
| | opt, |
| | audio_dim = 32, |
| | |
| | ): |
| | super().__init__(opt) |
| |
|
| | |
| | self.emb = self.opt.emb |
| |
|
| | if 'esperanto' in self.opt.asr_model: |
| | self.audio_in_dim = 44 |
| | elif 'deepspeech' in self.opt.asr_model: |
| | self.audio_in_dim = 29 |
| | elif 'hubert' in self.opt.asr_model: |
| | self.audio_in_dim = 1024 |
| | else: |
| | self.audio_in_dim = 32 |
| | |
| | if self.emb: |
| | self.embedding = nn.Embedding(self.audio_in_dim, self.audio_in_dim) |
| |
|
| | |
| | self.audio_dim = audio_dim |
| | if self.opt.asr_model == 'ave': |
| | self.audio_net = AudioNet_ave(self.audio_in_dim, self.audio_dim) |
| | else: |
| | self.audio_net = AudioNet(self.audio_in_dim, self.audio_dim) |
| |
|
| | self.att = self.opt.att |
| | if self.att > 0: |
| | self.audio_att_net = AudioAttNet(self.audio_dim) |
| |
|
| | |
| | self.num_levels = 12 |
| | self.level_dim = 1 |
| | self.encoder_xy, self.in_dim_xy = get_encoder('hashgrid', input_dim=2, num_levels=self.num_levels, level_dim=self.level_dim, base_resolution=64, log2_hashmap_size=14, desired_resolution=512 * self.bound) |
| | self.encoder_yz, self.in_dim_yz = get_encoder('hashgrid', input_dim=2, num_levels=self.num_levels, level_dim=self.level_dim, base_resolution=64, log2_hashmap_size=14, desired_resolution=512 * self.bound) |
| | self.encoder_xz, self.in_dim_xz = get_encoder('hashgrid', input_dim=2, num_levels=self.num_levels, level_dim=self.level_dim, base_resolution=64, log2_hashmap_size=14, desired_resolution=512 * self.bound) |
| |
|
| | self.in_dim = self.in_dim_xy + self.in_dim_yz + self.in_dim_xz |
| |
|
| | |
| | self.num_layers = 3 |
| | self.hidden_dim = 64 |
| | self.geo_feat_dim = 64 |
| | if self.opt.au45: |
| | self.eye_att_net = MLP(self.in_dim, 1, 16, 2) |
| | self.eye_dim = 1 if self.exp_eye else 0 |
| | else: |
| | if self.opt.bs_area == "upper": |
| | self.eye_att_net = MLP(self.in_dim, 7, 64, 2) |
| | self.eye_dim = 7 if self.exp_eye else 0 |
| | elif self.opt.bs_area == "single": |
| | self.eye_att_net = MLP(self.in_dim, 4, 64, 2) |
| | self.eye_dim = 4 if self.exp_eye else 0 |
| | elif self.opt.bs_area == "eye": |
| | self.eye_att_net = MLP(self.in_dim, 2, 64, 2) |
| | self.eye_dim = 2 if self.exp_eye else 0 |
| | self.sigma_net = MLP(self.in_dim + self.audio_dim + self.eye_dim, 1 + self.geo_feat_dim, self.hidden_dim, self.num_layers) |
| | |
| | self.num_layers_color = 2 |
| | self.hidden_dim_color = 64 |
| | self.encoder_dir, self.in_dim_dir = get_encoder('spherical_harmonics') |
| | self.color_net = MLP(self.in_dim_dir + self.geo_feat_dim + self.individual_dim, 3, self.hidden_dim_color, self.num_layers_color) |
| |
|
| | self.unc_net = MLP(self.in_dim, 1, 32, 2) |
| |
|
| | self.aud_ch_att_net = MLP(self.in_dim, self.audio_dim, 64, 2) |
| |
|
| | self.testing = False |
| |
|
| | if self.torso: |
| | |
| | self.register_parameter('anchor_points', |
| | nn.Parameter(torch.tensor([[0.01, 0.01, 0.1, 1], [-0.1, -0.1, 0.1, 1], [0.1, -0.1, 0.1, 1]]))) |
| | self.torso_deform_encoder, self.torso_deform_in_dim = get_encoder('frequency', input_dim=2, multires=8) |
| | |
| | self.anchor_encoder, self.anchor_in_dim = get_encoder('frequency', input_dim=6, multires=3) |
| | self.torso_deform_net = MLP(self.torso_deform_in_dim + self.anchor_in_dim + self.individual_dim_torso, 2, 32, 3) |
| |
|
| | |
| | self.torso_encoder, self.torso_in_dim = get_encoder('tiledgrid', input_dim=2, num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=16, desired_resolution=2048) |
| | self.torso_net = MLP(self.torso_in_dim + self.torso_deform_in_dim + self.anchor_in_dim + self.individual_dim_torso, 4, 32, 3) |
| |
|
| |
|
| | def forward_torso(self, x, poses, c=None): |
| | |
| | |
| | |
| |
|
| | |
| | x = x * self.opt.torso_shrink |
| |
|
| | |
| | wrapped_anchor = self.anchor_points[None, ...] @ poses.permute(0, 2, 1).inverse() |
| | wrapped_anchor = (wrapped_anchor[:, :, :2] / wrapped_anchor[:, :, 3, None] / wrapped_anchor[:, :, 2, None]).view(1, -1) |
| | |
| | |
| | enc_anchor = self.anchor_encoder(wrapped_anchor) |
| | enc_x = self.torso_deform_encoder(x) |
| |
|
| | if c is not None: |
| | h = torch.cat([enc_x, enc_anchor.repeat(x.shape[0], 1), c.repeat(x.shape[0], 1)], dim=-1) |
| | else: |
| | h = torch.cat([enc_x, enc_anchor.repeat(x.shape[0], 1)], dim=-1) |
| |
|
| | dx = self.torso_deform_net(h) |
| | |
| | x = (x + dx).clamp(-1, 1) |
| |
|
| | x = self.torso_encoder(x, bound=1) |
| |
|
| | |
| | h = torch.cat([x, h], dim=-1) |
| |
|
| | h = self.torso_net(h) |
| |
|
| | alpha = torch.sigmoid(h[..., :1])*(1 + 2*0.001) - 0.001 |
| | color = torch.sigmoid(h[..., 1:])*(1 + 2*0.001) - 0.001 |
| |
|
| | return alpha, color, dx |
| |
|
| |
|
| | @staticmethod |
| | @torch.jit.script |
| | def split_xyz(x): |
| | xy, yz, xz = x[:, :-1], x[:, 1:], torch.cat([x[:,:1], x[:,-1:]], dim=-1) |
| | return xy, yz, xz |
| |
|
| |
|
| | def encode_x(self, xyz, bound): |
| | |
| | N, M = xyz.shape |
| | xy, yz, xz = self.split_xyz(xyz) |
| | feat_xy = self.encoder_xy(xy, bound=bound) |
| | feat_yz = self.encoder_yz(yz, bound=bound) |
| | feat_xz = self.encoder_xz(xz, bound=bound) |
| | |
| | return torch.cat([feat_xy, feat_yz, feat_xz], dim=-1) |
| | |
| |
|
| | def encode_audio(self, a): |
| | |
| | |
| |
|
| | |
| | if a is None: return None |
| |
|
| | if self.emb: |
| | a = self.embedding(a).transpose(-1, -2).contiguous() |
| |
|
| | enc_a = self.audio_net(a) |
| |
|
| | if self.att > 0: |
| | enc_a = self.audio_att_net(enc_a.unsqueeze(0)) |
| | |
| | return enc_a |
| |
|
| | |
| | def predict_uncertainty(self, unc_inp): |
| | if self.testing or not self.opt.unc_loss: |
| | unc = torch.zeros_like(unc_inp) |
| | else: |
| | unc = self.unc_net(unc_inp.detach()) |
| |
|
| | return unc |
| |
|
| |
|
| | def forward(self, x, d, enc_a, c, e=None): |
| | |
| | |
| | |
| | |
| | |
| | enc_x = self.encode_x(x, bound=self.bound) |
| |
|
| | sigma_result = self.density(x, enc_a, e, enc_x) |
| | sigma = sigma_result['sigma'] |
| | geo_feat = sigma_result['geo_feat'] |
| | aud_ch_att = sigma_result['ambient_aud'] |
| | eye_att = sigma_result['ambient_eye'] |
| |
|
| | |
| | enc_d = self.encoder_dir(d) |
| |
|
| | if c is not None: |
| | h = torch.cat([enc_d, geo_feat, c.repeat(x.shape[0], 1)], dim=-1) |
| | else: |
| | h = torch.cat([enc_d, geo_feat], dim=-1) |
| | |
| | h_color = self.color_net(h) |
| | color = torch.sigmoid(h_color)*(1 + 2*0.001) - 0.001 |
| | |
| | uncertainty = self.predict_uncertainty(enc_x) |
| | uncertainty = torch.log(1 + torch.exp(uncertainty)) |
| |
|
| | return sigma, color, aud_ch_att, eye_att, uncertainty[..., None] |
| |
|
| |
|
| | def density(self, x, enc_a, e=None, enc_x=None): |
| | |
| | if enc_x is None: |
| | enc_x = self.encode_x(x, bound=self.bound) |
| |
|
| | enc_a = enc_a.repeat(enc_x.shape[0], 1) |
| | aud_ch_att = self.aud_ch_att_net(enc_x) |
| | enc_w = enc_a * aud_ch_att |
| |
|
| | if e is not None: |
| | |
| | |
| | e = e.repeat(enc_x.shape[0], 1) |
| | eye_att = self.eye_att_net(enc_x) |
| | e = e * eye_att |
| | |
| | h = torch.cat([enc_x, enc_w, e], dim=-1) |
| | else: |
| | h = torch.cat([enc_x, enc_w], dim=-1) |
| |
|
| | h = self.sigma_net(h) |
| |
|
| | sigma = torch.exp(h[..., 0]) |
| | geo_feat = h[..., 1:] |
| |
|
| | return { |
| | 'sigma': sigma, |
| | 'geo_feat': geo_feat, |
| | 'ambient_aud' : aud_ch_att.norm(dim=-1, keepdim=True), |
| | 'ambient_eye' : eye_att.norm(dim=-1, keepdim=True), |
| | } |
| |
|
| |
|
| | |
| | def get_params(self, lr, lr_net, wd=0): |
| |
|
| | |
| | if self.torso: |
| | params = [ |
| | {'params': self.torso_encoder.parameters(), 'lr': lr}, |
| | {'params': self.torso_deform_encoder.parameters(), 'lr': lr, 'weight_decay': wd}, |
| | {'params': self.torso_net.parameters(), 'lr': lr_net, 'weight_decay': wd}, |
| | {'params': self.torso_deform_net.parameters(), 'lr': lr_net, 'weight_decay': wd}, |
| | {'params': self.anchor_points, 'lr': lr_net, 'weight_decay': wd} |
| | ] |
| |
|
| | if self.individual_dim_torso > 0: |
| | params.append({'params': self.individual_codes_torso, 'lr': lr_net, 'weight_decay': wd}) |
| |
|
| | return params |
| |
|
| | params = [ |
| | {'params': self.audio_net.parameters(), 'lr': lr_net, 'weight_decay': wd}, |
| |
|
| | {'params': self.encoder_xy.parameters(), 'lr': lr}, |
| | {'params': self.encoder_yz.parameters(), 'lr': lr}, |
| | {'params': self.encoder_xz.parameters(), 'lr': lr}, |
| | |
| |
|
| | {'params': self.sigma_net.parameters(), 'lr': lr_net, 'weight_decay': wd}, |
| | {'params': self.color_net.parameters(), 'lr': lr_net, 'weight_decay': wd}, |
| | ] |
| | if self.att > 0: |
| | params.append({'params': self.audio_att_net.parameters(), 'lr': lr_net * 5, 'weight_decay': 0.0001}) |
| | if self.emb: |
| | params.append({'params': self.embedding.parameters(), 'lr': lr}) |
| | if self.individual_dim > 0: |
| | params.append({'params': self.individual_codes, 'lr': lr_net, 'weight_decay': wd}) |
| | if self.train_camera: |
| | params.append({'params': self.camera_dT, 'lr': 1e-5, 'weight_decay': 0}) |
| | params.append({'params': self.camera_dR, 'lr': 1e-5, 'weight_decay': 0}) |
| |
|
| | params.append({'params': self.aud_ch_att_net.parameters(), 'lr': lr_net, 'weight_decay': wd}) |
| | params.append({'params': self.unc_net.parameters(), 'lr': lr_net, 'weight_decay': wd}) |
| | params.append({'params': self.eye_att_net.parameters(), 'lr': lr_net, 'weight_decay': wd}) |
| |
|
| | return params |
| |
|