tasin
init
f075308
from typing import Dict
import numpy as np
import torch
import torch.nn as nn
from omegaconf import DictConfig
from src.torch_utils import misc
from src.torch_utils import persistence
from src.training.layers import (
MappingNetwork,
EqLRConv1d,
FullyConnectedLayer,
)
#----------------------------------------------------------------------------
@persistence.persistent_class
class MotionMappingNetwork(torch.nn.Module):
def __init__(self, cfg: DictConfig):
super().__init__()
self.cfg = cfg
assert self.cfg.motion.gen_strategy in ["autoregressive", "conv"], f"Unknown generation strategy: {self.cfg.motion.gen_strategy}"
if self.cfg.motion.fourier:
self.time_encoder = AlignedTimeEncoder(
cfg=self.cfg,
latent_dim=self.cfg.motion.v_dim
)
else:
self.mapping = MappingNetwork(
z_dim=self.cfg.motion.z_dim,
c_dim=self.cfg.c_dim,
w_dim=self.cfg.motion.v_dim,
num_ws=None,
num_layers=2,
activation='lrelu',
w_avg_beta=None,
cfg=self.cfg,
)
if self.cfg.motion.gen_strategy == 'autoregressive':
self.rnn = nn.LSTM(
input_size=self.cfg.motion.z_dim + self.cfg.c_dim,
hidden_size=self.cfg.motion.z_dim,
bidirectional=False,
batch_first=True)
self._parameters_flattened = False
self.num_additional_codes = 0
elif self.cfg.motion.gen_strategy == 'conv':
# Using Conv1d without paddings instead of LSTM makes the generations good for any time in t \in (0, +\infty),
# while LSTM would diverge for large `t`
# Also, this allows us to use equalized learning rates
self.conv = nn.Sequential(
EqLRConv1d(self.cfg.motion.z_dim + self.cfg.c_dim, self.cfg.motion.z_dim, self.cfg.motion.kernel_size, padding=0, activation='lrelu', lr_multiplier=0.01),
EqLRConv1d(self.cfg.motion.z_dim, self.cfg.motion.v_dim, self.cfg.motion.kernel_size, padding=0, activation='lrelu', lr_multiplier=0.01),
)
self.num_additional_codes = (self.cfg.motion.kernel_size - 1) * 2
else:
raise NotImplementedError(f'Unknown generation strategy: {self.cfg.motion.gen_strategy}')
def get_max_traj_len(self, t: torch.Tensor) -> int:
max_t = max(self.cfg.sampling.max_num_frames - 1, t.max().item()) # [1]
max_traj_len = np.ceil(max_t / self.cfg.motion.motion_z_distance).astype(int).item() + 2 # [1]
return max_traj_len
def generate_motion_u_codes(self, c: torch.Tensor, t: torch.Tensor, motion_z: torch.Tensor=None) -> Dict:
"""
Arguments:
- c of shape [batch_size, c_dim]
- t of shape [batch_size, num_frames]
- w of shape [batch_size, w_dim]
- motion_z of shape [batch_size, max_traj_len, motion_z_dim] --- in case we want to reuse some existing motion noise
"""
out = {}
batch_size, num_frames = t.shape
# Consutruct trajectories (from code idx for now)
max_traj_len = self.get_max_traj_len(t) + self.num_additional_codes # [1]
if motion_z is None:
motion_z = torch.randn(batch_size, max_traj_len, self.cfg.motion.z_dim, device=c.device) # [batch_size, max_traj_len, motion.z_dim]
# Input motion trajectories are just random noise
input_trajs = motion_z[:batch_size, :max_traj_len, :self.cfg.motion.z_dim].to(c.device) # [batch_size, max_traj_len, motion.z_dim]
if self.cfg.c_dim > 0:
# Different classes might have different motions, so it should be useful to condition on c
misc.assert_shape(c, [batch_size, None])
input_trajs = torch.cat([input_trajs, c.unsqueeze(1).repeat(1, max_traj_len, 1)], dim=2) # [batch_size, max_traj_len, motion.z_dim + cond_dim]
if self.cfg.motion.gen_strategy == 'autoregressive':
# Somehow, RNN parameters do not get flattened on their own and we get a lot of warnings...
if not self._parameters_flattened:
self.rnn.flatten_parameters()
self._parameters_flattened = True
trajs, _ = self.rnn(input_trajs) # [batch_size, max_traj_len, motion.z_dim]
elif self.cfg.motion.gen_strategy == 'conv':
trajs = self.conv(input_trajs.permute(0, 2, 1)).permute(0, 2, 1) # [batch_size, max_traj_len, motion.v_dim]
else:
raise NotImplementedError(f'Unknown generation strategy: {self.cfg.motion.gen_strategy}')
# Now, we should select neighbouring codes for each frame
left_idx = (t / self.cfg.motion.motion_z_distance).floor().long() # [batch_size, num_frames]
batch_idx = torch.arange(batch_size, device=c.device).unsqueeze(1).repeat(1, num_frames) # [batch_size, num_frames]
motion_u_left = trajs[batch_idx, left_idx] # [batch_size, num_frames, motion.z_dim]
motion_u_right = trajs[batch_idx, left_idx + 1] # [batch_size, num_frames, motion.z_dim]
# Compute `u` codes as the interpolation between `u_left` and `u_right`
t_left = t - t % self.cfg.motion.motion_z_distance # [batch_size, num_frames]
t_right = t_left + self.cfg.motion.motion_z_distance # [batch_size, num_frames]
# Compute interpolation weights `alpha` (we'll use them later)
interp_weights = ((t % self.cfg.motion.motion_z_distance) / self.cfg.motion.motion_z_distance).unsqueeze(2).to(torch.float32) # [batch_size, num_frames, 1]
motion_u = motion_u_left * (1 - interp_weights) + motion_u_right * interp_weights # [batch_size, num_frames, motion.z_dim]
motion_u = motion_u.view(batch_size * num_frames, motion_u.shape[2]).to(torch.float32) # [batch_size * num_frames, motion.z_dim]
# Save the results into the output dict
out['motion_u_left'] = motion_u_left # [batch_size, num_frames, motion.z_dim]
out['motion_u_right'] = motion_u_right # [batch_size, num_frames, motion.z_dim]
out['t_left'] = t_left # [batch_size, num_frames]
out['t_right'] = t_right # [batch_size, num_frames]
out['interp_weights'] = interp_weights # [batch_size, num_frames, 1]
out['motion_u'] = motion_u # [batch_size * num_frames, motion.z_dim]
out['motion_z'] = motion_z # [batch_size+, max_traj_len+, motion.z_dim+]
return out
def get_dim(self) -> int:
return self.cfg.motion.v_dim if self.time_encoder is None else self.time_encoder.get_dim()
def forward(self, c: torch.Tensor, t: torch.Tensor, motion_z: Dict=None) -> Dict:
assert len(c) == len(t), f"Wrong shape: {c.shape}, {t.shape}"
assert t.ndim == 2, f"Wrong shape: {t.shape}"
out = {} # We'll be aggregating the result here
motion_u_info: Dict = self.generate_motion_u_codes(c, t, motion_z=motion_z) # Dict of tensors
motion_u = motion_u_info['motion_u'].view(t.shape[0] * t.shape[1], -1) # [batch_size * num_frames, motion.z_dim]
# Compute the `v` motion codes
if self.cfg.motion.fourier:
motion_v = self.time_encoder(
t=t,
motion_u_left=motion_u_info['motion_u_left'],
motion_u_right=motion_u_info['motion_u_right'],
t_left=motion_u_info['t_left'],
t_right=motion_u_info['t_right'],
interp_weights=motion_u_info['interp_weights'],
) # [batch_size * num_frames, motion_v_dim]
else:
motion_v = self.mapping(z=motion_u, c=c.repeat_interleave(t.shape[1], dim=0)) # [batch_size * num_frames, motion.v_dim]
out['motion_v'] = motion_v # [batch_size * num_frames, motion.v_dim]
out['motion_z'] = motion_u_info['motion_z'] # (Any shape)
return out
#----------------------------------------------------------------------------
@persistence.persistent_class
class AlignedTimeEncoder(nn.Module):
def __init__(self,
latent_dim: int=512,
cfg: DictConfig = {},
):
super().__init__()
self.cfg = cfg
self.latent_dim = latent_dim
freqs = construct_linspaced_frequencies(self.cfg.time_enc.dim, self.cfg.time_enc.min_period_len, self.cfg.time_enc.max_period_len)
self.register_buffer('freqs', freqs) # [1, num_fourier_feats]
# Creating the affine without bias to prevent motion mode collapse
self.periods_predictor = FullyConnectedLayer(latent_dim, freqs.shape[1], activation='linear', bias=False)
self.phase_predictor = FullyConnectedLayer(latent_dim, freqs.shape[1], activation='linear', bias=False)
period_lens = 2 * np.pi / self.freqs # [1, num_fourier_feats]
phase_scales = self.cfg.time_enc.max_period_len / period_lens # [1, num_fourier_feats]
self.register_buffer('phase_scales', phase_scales)
self.aligners_predictor = FullyConnectedLayer(latent_dim, self.freqs.shape[1] * 2, activation='linear', bias=False)
def get_dim(self) -> int:
return self.freqs.shape[1] * 2
def forward(self, t: torch.Tensor, motion_u_left: torch.Tensor, motion_u_right: torch.Tensor, interp_weights: torch.Tensor, t_left: torch.Tensor, t_right: torch.Tensor):
batch_size, num_frames, motion_u_dim = motion_u_left.shape # [1], [1], [1]
misc.assert_shape(t, [batch_size, num_frames])
misc.assert_shape(motion_u_left, [batch_size, num_frames, None])
misc.assert_shape(motion_u_right, [batch_size, num_frames, None])
misc.assert_shape(interp_weights, [batch_size, num_frames, 1])
assert t.shape == t_left.shape == t_right.shape, f"Wrong shape: {t.shape} vs {t_left.shape} vs {t_right.shape}"
motion_u_left = motion_u_left.view(batch_size * num_frames, motion_u_dim) # [batch_size * num_frames, motion_u_dim]
motion_u_right = motion_u_right.view(batch_size * num_frames, motion_u_dim) # [batch_size * num_frames, motion_u_dim]
periods = self.periods_predictor(motion_u_left).tanh() + 1 # [batch_size * num_frames, feat_dim]
phases = self.phase_predictor(motion_u_left) # [batch_size * num_frames, feat_dim]
aligners_left = self.aligners_predictor(motion_u_left) # [batch_size * num_frames, out_dim]
aligners_right = self.aligners_predictor(motion_u_right) # [batch_size * num_frames, out_dim]
raw_pos_embs = self.freqs * periods * t.view(-1).float().unsqueeze(1) + phases * self.phase_scales # [bf, feat_dim]
raw_pos_embs_left = self.freqs * periods * t_left.view(-1).float().unsqueeze(1) + phases * self.phase_scales # [bf, feat_dim]
raw_pos_embs_right = self.freqs * periods * t_right.view(-1).float().unsqueeze(1) + phases * self.phase_scales # [bf, feat_dim]
pos_embs = torch.cat([raw_pos_embs.sin(), raw_pos_embs.cos()], dim=1) # [bf, out_dim]
pos_embs_left = torch.cat([raw_pos_embs_left.sin(), raw_pos_embs_left.cos()], dim=1) # [bf, out_dim]
pos_embs_right = torch.cat([raw_pos_embs_right.sin(), raw_pos_embs_right.cos()], dim=1) # [bf, out_dim]
interp_weights = interp_weights.view(-1, 1) # [bf, 1]
aligners_remove = pos_embs_left * (1 - interp_weights) + pos_embs_right * interp_weights # [bf, out_dim]
aligners_add = aligners_left * (1 - interp_weights) + aligners_right * interp_weights # [bf, out_dim]
time_embs = pos_embs - aligners_remove + aligners_add # [bf, out_dim]
return time_embs
#----------------------------------------------------------------------------
def construct_linspaced_frequencies(num_freqs: int, min_period_len: int, max_period_len: int) -> torch.Tensor:
freqs = 2 * np.pi / (2 ** np.linspace(np.log2(min_period_len), np.log2(max_period_len), num_freqs)) # [num_freqs]
freqs = torch.from_numpy(freqs[::-1].copy().astype(np.float32)).unsqueeze(0) # [1, num_freqs]
return freqs
#----------------------------------------------------------------------------