diff --git a/tools/visualization_0416/utils/__pycache__/__init__.cpython-310.pyc b/tools/visualization_0416/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f72e1ee732c168ca358470ac6d2e0a0b6d66fe27 Binary files /dev/null and b/tools/visualization_0416/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/__pycache__/__init__.cpython-311.pyc b/tools/visualization_0416/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d49db48a2d92b3c7904804266fef6d29e8a58fe0 Binary files /dev/null and b/tools/visualization_0416/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/tools/visualization_0416/utils/__pycache__/face_detector.cpython-310.pyc b/tools/visualization_0416/utils/__pycache__/face_detector.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46c1f800957a83500651e3fa18a6a960fc229527 Binary files /dev/null and b/tools/visualization_0416/utils/__pycache__/face_detector.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/__pycache__/face_detector.cpython-311.pyc b/tools/visualization_0416/utils/__pycache__/face_detector.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d08967e4d4b6fbc1cb0e3fd1118778f490f2fe8 Binary files /dev/null and b/tools/visualization_0416/utils/__pycache__/face_detector.cpython-311.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/__init__.py b/tools/visualization_0416/utils/model_0506/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tools/visualization_0416/utils/model_0506/__pycache__/utils.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ea7dd73ed4339a90ab781d3822e6f3ae0145dfd Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/__pycache__/utils.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/__pycache__/utils.cpython-311.pyc b/tools/visualization_0416/utils/model_0506/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ccb9db33d3307dfb72583c916dc5611f6d565f4 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/__pycache__/utils.cpython-311.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/__pycache__/utils.cpython-313.pyc b/tools/visualization_0416/utils/model_0506/__pycache__/utils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ed30bb5c8b6aba5195346478301936da4f20acc Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/__pycache__/utils.cpython-313.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/__init__.py b/tools/visualization_0416/utils/model_0506/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tools/visualization_0416/utils/model_0506/model/__pycache__/__init__.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c76d1b8c25590cbfc35a6ac817279e9a025b40e Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/__pycache__/__init__.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/__pycache__/__init__.cpython-311.pyc b/tools/visualization_0416/utils/model_0506/model/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4e7b11293e5cfa2a7ab2003f001c1382bd2fc0b Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/__pycache__/__init__.cpython-311.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/__pycache__/__init__.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1150f35a919401f63bee5c0177dfcc773bedac0c Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/__pycache__/__init__.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/__pycache__/__init__.cpython-313.pyc b/tools/visualization_0416/utils/model_0506/model/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56eda2414018a338e1f039ce721113eaa1cbe744 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/__pycache__/__init__.cpython-313.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/basic_model/__init__.py b/tools/visualization_0416/utils/model_0506/model/basic_model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tools/visualization_0416/utils/model_0506/model/basic_model/__pycache__/basic_block.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/basic_model/__pycache__/basic_block.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a25509e72283dd3fbe0fbe0f727314cbb7f501d6 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/basic_model/__pycache__/basic_block.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/basic_model/__pycache__/resnet.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/basic_model/__pycache__/resnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e363eaec6f32eb2fa4ec52456d4ea357d2117df7 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/basic_model/__pycache__/resnet.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/basic_model/__pycache__/resnet.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/basic_model/__pycache__/resnet.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d54e366d29ccce16a719859c0784a9dca468053 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/basic_model/__pycache__/resnet.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/basic_model/basic_block.py b/tools/visualization_0416/utils/model_0506/model/basic_model/basic_block.py new file mode 100644 index 0000000000000000000000000000000000000000..7ace88863fa7e41c7c0096c2809c47e0c6ba7910 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/basic_model/basic_block.py @@ -0,0 +1,163 @@ +""" +Reference: https://github.com/hedra-labs/one-shot-face/blob/mega-portraits/models/building_blocks.py +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils.parametrizations import spectral_norm + +USE_BIAS = False + +# https://github.com/joe-siyuan-qiao/WeightStandardization?tab=readme-ov-file#pytorch +class WSConv2d(nn.Conv2d): + def __init__(self, *args, **kwargs): + super(WSConv2d, self).__init__(*args, **kwargs) + + def forward(self, inp): + weight = self.weight + weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True) + weight = weight - weight_mean + std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5 + weight = weight / std.expand_as(weight) + return F.conv2d(inp, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +class WSConv3d(nn.Conv3d): + def __init__(self, *args, **kwargs): + super(WSConv3d, self).__init__(*args, **kwargs) + + def forward(self, inp): + weight = self.weight + weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True).mean(dim=4, keepdim=True) + weight = weight - weight_mean + std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1, 1) + 1e-5 + weight = weight / std.expand_as(weight) + return F.conv3d(inp, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +class ResBlock2d(nn.Module): + def __init__(self, in_channels: int, out_channels: int, num_channels_per_group: int, use_spectral_norm: bool = False): + super().__init__() + + norm_func = lambda x: x + if use_spectral_norm: + norm_func = spectral_norm + + if in_channels != out_channels: + self.skip_layer = norm_func(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=USE_BIAS)) + else: + self.skip_layer = lambda x: x + + self.layers = nn.Sequential( + nn.GroupNorm(in_channels // num_channels_per_group, in_channels, affine=not USE_BIAS), + nn.ReLU(inplace=True), + norm_func(WSConv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=USE_BIAS)), + nn.GroupNorm(out_channels // num_channels_per_group, out_channels, affine=not USE_BIAS), + nn.ReLU(inplace=True), + norm_func(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=USE_BIAS)), + ) + + def forward(self, inp: torch.Tensor): + return self.skip_layer(inp) + self.layers(inp) + + +class ResBlock3d(nn.Module): + def __init__(self, in_channels: int, out_channels: int, num_channels_per_group: int): + super().__init__() + + if in_channels != out_channels: + self.skip_layer = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=USE_BIAS) + else: + self.skip_layer = lambda x: x + + self.layers = nn.Sequential( + nn.GroupNorm(in_channels // num_channels_per_group, in_channels, affine=not USE_BIAS), + nn.ReLU(inplace=True), + WSConv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=USE_BIAS), + nn.GroupNorm(out_channels // num_channels_per_group, out_channels, affine=not USE_BIAS), + nn.ReLU(inplace=True), + nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=USE_BIAS), + ) + + def forward(self, inp: torch.Tensor): + return self.skip_layer(inp) + self.layers(inp) + + +class ResBasic(nn.Module): + def __init__(self, in_channels: int, out_channels: int, stride: int, num_channels_per_group: int): + super().__init__() + + if stride != 1 and stride != 2: + raise NotImplementedError(f"Stride can be only 1 or 2 but '{stride}' is passed.") + + if in_channels != out_channels or stride != 1: + self.skip_layer = nn.Sequential( + WSConv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=USE_BIAS), + nn.GroupNorm(out_channels // num_channels_per_group, out_channels, affine=not USE_BIAS), + ) + else: + self.skip_layer = lambda x: x + + self.layers = nn.Sequential( + WSConv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=USE_BIAS), + nn.GroupNorm(out_channels // num_channels_per_group, out_channels, affine=not USE_BIAS), + nn.ReLU(inplace=True), + WSConv2d(out_channels, out_channels, kernel_size=1, bias=USE_BIAS), + nn.GroupNorm(out_channels // num_channels_per_group, out_channels, affine=not USE_BIAS), + ) + + + def forward(self, inp: torch.Tensor): + return F.relu(self.skip_layer(inp) + self.layers(inp)) + + +class ResBottleneck(nn.Module): + def __init__(self, in_channels: int, out_channels: int, stride: int, num_channels_per_group: int): + super().__init__() + + if stride != 1 and stride != 2: + raise NotImplementedError(f"Stride can be only 1 or 2 but '{stride}' is passed.") + + if in_channels != out_channels or stride != 1: + self.skip_layer = nn.Sequential( + WSConv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=USE_BIAS), + nn.GroupNorm(out_channels // num_channels_per_group, out_channels, affine=not USE_BIAS), + ) + else: + self.skip_layer = lambda x: x + + temp_out_channels = out_channels // 4 + self.layers = nn.Sequential( + WSConv2d(in_channels, temp_out_channels, kernel_size=1, bias=USE_BIAS), + nn.GroupNorm(temp_out_channels // num_channels_per_group, temp_out_channels, affine=not USE_BIAS), + nn.ReLU(inplace=True), + WSConv2d(temp_out_channels, temp_out_channels, kernel_size=3, stride=stride, padding=1, bias=USE_BIAS), + nn.GroupNorm(temp_out_channels // num_channels_per_group, temp_out_channels, affine=not USE_BIAS), + nn.ReLU(inplace=True), + WSConv2d(temp_out_channels, out_channels, kernel_size=1, bias=USE_BIAS), + nn.GroupNorm(out_channels // num_channels_per_group, out_channels, affine=not USE_BIAS), + ) + + + def forward(self, inp: torch.Tensor): + return F.relu(self.skip_layer(inp) + self.layers(inp)) + + +class ReshapeTo3DLayer(nn.Module): + def __init__(self, out_depth: int): + super().__init__() + + self.out_depth = out_depth + + def forward(self, inp: torch.Tensor): + batch_size, channels, height, width = inp.shape + return inp.view(batch_size, channels // self.out_depth, self.out_depth, height, width) + + +class ReshapeTo2DLayer(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, inp: torch.Tensor): + batch_size, channels, depth, height, width = inp.shape + return inp.view(batch_size, channels * depth, height, width) \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/basic_model/frnet.py b/tools/visualization_0416/utils/model_0506/model/basic_model/frnet.py new file mode 100644 index 0000000000000000000000000000000000000000..b2052bdec31571a4a03489e39370da219ccb72b8 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/basic_model/frnet.py @@ -0,0 +1,177 @@ +""" +Reference: https://github.com/yfeng95/DECA/blob/a11554ae2a2b0f3998cf1fa94dd4db03babb34a2/decalib/models/frnet.py +""" +import torch.nn as nn +import numpy as np +import torch +import torch.nn.functional as F +from torch.autograd import Variable +import math + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000, include_top=True): + self.inplanes = 64 + super(ResNet, self).__init__() + self.include_top = include_top + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) + + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AvgPool2d(7, stride=1) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + + if not self.include_top: + return x + + x = x.view(x.size(0), -1) + x = self.fc(x) + return x + +def resnet50(**kwargs): + """Constructs a ResNet-50 model. + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + return model + +import pickle +def load_state_dict(model, fname): + """ + Set parameters converted from Caffe models authors of VGGFace2 provide. + See https://www.robots.ox.ac.uk/~vgg/data/vgg_face2/. + Arguments: + model: model + fname: file name of parameters converted from a Caffe model, assuming the file format is Pickle. + """ + with open(fname, 'rb') as f: + weights = pickle.load(f, encoding='latin1') + + own_state = model.state_dict() + for name, param in weights.items(): + if name in own_state: + try: + own_state[name].copy_(torch.from_numpy(param)) + except Exception: + raise RuntimeError('While copying the parameter named {}, whose dimensions in the model are {} and whose '\ + 'dimensions in the checkpoint are {}.'.format(name, own_state[name].size(), param.size())) + else: + raise KeyError('unexpected key "{}" in state_dict'.format(name)) diff --git a/tools/visualization_0416/utils/model_0506/model/basic_model/resnet.py b/tools/visualization_0416/utils/model_0506/model/basic_model/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c473a9053ac1a2490d5899ddc5e5d07c352d46 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/basic_model/resnet.py @@ -0,0 +1,182 @@ +'''ResNet in PyTorch. + +For Pre-activation ResNet, see 'preact_resnet.py'. + +Reference: +[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun + Deep Residual Learning for Image Recognition. arXiv:1512.03385 +''' +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, + stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion*planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion*planes, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion*planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, + stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, self.expansion * + planes, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(self.expansion*planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion*planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion*planes, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion*planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class ResNet(nn.Module): + def __init__(self, block, num_blocks, num_classes=7): + super(ResNet, self).__init__() + self.in_planes = 64 + + # self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.conv1 = nn.Conv2d(1, 64, kernel_size = 3, stride = 1, padding = 1, bias = False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU() + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) + self.linear = nn.Linear(512*block.expansion, num_classes) + self.fc = nn.Linear(512*block.expansion, num_classes) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1]*(num_blocks-1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = F.avg_pool2d(out, 4) + out = out.view(out.size(0), -1) + out = self.linear(out) + return out + + +class ResNet_AE(nn.Module): + def __init__(self, block, num_blocks, num_classes=7): + super(ResNet_AE, self).__init__() + self.in_planes = 64 + + self.conv1 = nn.Conv2d(1, 64, kernel_size=3, + stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) + self.linear = nn.Linear(512*block.expansion, num_classes) + + self.decoder = nn.Sequential( + nn.ConvTranspose2d(64, 1, kernel_size=3,stride=1, padding=1, bias=False), + nn.Sigmoid() + ) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1]*(num_blocks-1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + + def forward(self, x): + out1 = self.conv1(x) + + decoded = self.decoder(out1) + + out = F.relu(self.bn1(out1)) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = F.avg_pool2d(out, 4) + out = out.view(out.size(0), -1) + out = self.linear(out) + return out, x, decoded + + +def ResNet18_AE(): + return ResNet_AE(BasicBlock, [2, 2, 2, 2]) + + +def ResNet18(): + return ResNet(BasicBlock, [2, 2, 2, 2]) + + +def ResNet34(): + return ResNet(BasicBlock, [3, 4, 6, 3]) + + +def ResNet50(): + return ResNet(Bottleneck, [3, 4, 6, 3]) + + +def ResNet101(): + return ResNet(Bottleneck, [3, 4, 23, 3]) + + +def ResNet152(): + return ResNet(Bottleneck, [3, 8, 36, 3]) + + +def test(): + net = ResNet18() + y = net(torch.randn(1, 3, 32, 32)) + print(y.size()) + +# test() \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/appearance_encoder.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/appearance_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a374eede4fc9af50ed5cf91965636790decccfa6 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/appearance_encoder.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/appearance_encoder.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/appearance_encoder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c4060db47f1f24d845c525506eacbcab0a4b812 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/appearance_encoder.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/args.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/args.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d88457f49f21dcf7346366da85043dedb4e787fc Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/args.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/args.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/args.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dcd8950ba5f0bfa81534e76c60baa1eea1344e94 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/args.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/expression_embedder.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/expression_embedder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..361a42a1ed9e87b1cf5023da592c2a7dee6cf38a Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/expression_embedder.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/expression_embedder.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/expression_embedder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..571b2d6d74ce077422ab3fb03e7a64a04d6594a2 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/expression_embedder.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/face_decoder.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/face_decoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11efb041e7531e8fe594af2f5691ca1288b81027 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/face_decoder.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/face_decoder.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/face_decoder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90c254204ddf48467f572901676fa62448b44f83 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/face_decoder.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/face_decoder_spade.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/face_decoder_spade.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52b06771a3925de8138666d87d0c7fc7c1a4f12b Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/face_decoder_spade.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/face_decoder_spade.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/face_decoder_spade.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02bb2ff7f84cf2590511396ec1c3b37567bcd5bc Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/face_decoder_spade.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/face_encoder.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/face_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8dba86a5c38c9ab06aa16819453499255b2c7a5d Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/face_encoder.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/face_encoder.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/face_encoder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e5d3ae8d7572b407a5a9c52f330b0388631438a Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/face_encoder.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/face_generator.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/face_generator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b6e3dead5de372de05d3b1fc791d6303c5959a0 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/face_generator.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/face_generator.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/face_generator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d649ea81459efdc02619f3fc1ac8852d441ccf03 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/face_generator.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/flow_estimator.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/flow_estimator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce9736fadbf2d8ad5b52a34001d85fa2180fa276 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/flow_estimator.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/flow_estimator.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/flow_estimator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd17537e99ed40e0b73d7f72ffc50406b31a9c05 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/flow_estimator.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/head_pose_regressor.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/head_pose_regressor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a46b348f2af75dc25e9da4afc2344386dee47fc Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/head_pose_regressor.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/head_pose_regressor.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/head_pose_regressor.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f416f7eea2c13dd60af48e89916d7e02372efce4 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/head_pose_regressor.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/identity_embedder.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/identity_embedder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47b9ea20fc01e66b3338855a6842b6b20cff2f43 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/identity_embedder.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/identity_embedder.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/identity_embedder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7c14ab7816c8cdc16bd44e8b97f3e4e5475f7c6 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/identity_embedder.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/motion_encoder.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/motion_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5c67d0dafd192c479602b3d7fa0d4852d5c4247 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/motion_encoder.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/motion_encoder.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/motion_encoder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1fd8280a0743cf8f1586f87a73e5cbd627e465f Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/motion_encoder.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/point_transforms.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/point_transforms.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89aa14ddd5c34e21966edd38b0bb09657affc0f0 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/point_transforms.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/point_transforms.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/point_transforms.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b232e447ec9404517890a19a0edf0e94fc6036a5 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/point_transforms.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/resblocks_3d.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/resblocks_3d.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..917598f1d29684df12d0cca419a7397d7661ee70 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/resblocks_3d.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/resblocks_3d.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/resblocks_3d.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..afabba5ab214b7fe8c2e651b2e7e1e016262cf9d Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/resblocks_3d.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/spectral_norm.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/spectral_norm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b9dcc6eeb05a329998d4a6fe71d6fe04255272e Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/spectral_norm.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/spectral_norm.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/spectral_norm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa7f328959f85af272a88c668c20932e9a91fad0 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/spectral_norm.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/unet_3d.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/unet_3d.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0dd3ccd709751577078516540382725bda6a66ac Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/unet_3d.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/unet_3d.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/unet_3d.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..058c88012bef32a30895605ffedc7144b6357b24 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/unet_3d.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/utils.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df09c6c6e4f6d39e61748a07d50ded89a876a1b8 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/utils.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/utils.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc192fb5fcbbab8d434ba5759dd068db1c41e0ab Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/utils.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/vpn_resblocks.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/vpn_resblocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bf7eda821f57b709a87a886e5ed6cd27dd58ca6 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/vpn_resblocks.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/vpn_resblocks.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/vpn_resblocks.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68fcd6cb862a5472670dbe3c4a2aef5e5a02e323 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/vpn_resblocks.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/warp_generator_resnet.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/warp_generator_resnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33fb9fa9355ec707bdd4c62112332790960bbb61 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/warp_generator_resnet.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/warp_generator_resnet.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/warp_generator_resnet.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c357a4f165fc96fa1f8f640a30d2305a67fd13ca Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/__pycache__/warp_generator_resnet.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/appearance_encoder.py b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/appearance_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..17c0c231223ab0cd612d89e7361aaad376391c16 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/appearance_encoder.py @@ -0,0 +1,110 @@ +import torch +from torch import nn +import torch.nn.functional as F +import torch.distributed as dist +from torchvision import models +from torch.cuda import amp +import math + +from .utils import blocks, norm_layers, activations +from dataclasses import dataclass + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x): + return self.fn(x) + x + +class AppearanceEncoder(nn.Module): + + @dataclass + class Config: + gen_upsampling_type: str + gen_downsampling_type: str + gen_input_image_size: int + gen_latent_texture_size: int + gen_latent_texture_depth: int + gen_latent_texture_channels: int + gen_num_channels: int + enc_channel_mult: float + norm_layer_type: str + gen_max_channels: int + enc_block_type: str + gen_activation_type: str + warp_norm_grad: bool + in_channels: int = 3 + + def __init__(self, cfg: Config): + super(AppearanceEncoder, self).__init__() + + self.cfg = cfg + self.upsample_type = self.cfg.gen_upsampling_type + self.downsample_type = self.cfg.gen_downsampling_type + self.ratio = self.cfg.gen_input_image_size // self.cfg.gen_latent_texture_size + self.num_2d_blocks = int(math.log(self.ratio, 2)) + self.init_depth = self.cfg.gen_latent_texture_depth + spatial_size = self.cfg.gen_input_image_size + + out_channels = int(self.cfg.gen_num_channels * self.cfg.enc_channel_mult) + + setattr( + self, + f'from_rgb_{spatial_size}px', + nn.Conv2d( + in_channels=self.cfg.in_channels, + out_channels=out_channels, + kernel_size=7, + padding=3, + )) + + norm = self.cfg.norm_layer_type + + for i in range(self.num_2d_blocks): + in_channels = out_channels + out_channels = min(out_channels * 2, self.cfg.gen_max_channels) + setattr( + self, + f'enc_{i}_block={spatial_size}px', + blocks[self.cfg.enc_block_type]( + in_channels=in_channels, + out_channels=out_channels, + stride=2, + norm_layer_type=norm, + activation_type=self.cfg.gen_activation_type, + resize_layer_type=self.cfg.gen_downsampling_type)) + spatial_size //= 2 + + in_channels = out_channels + out_channels = self.cfg.gen_latent_texture_channels + finale_layers = [] + if self.cfg.enc_block_type == 'res': + finale_layers += [ + norm_layers[norm](in_channels), + activations[self.cfg.gen_activation_type](inplace=True)] + + finale_layers += [ + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels * self.init_depth, + kernel_size=1)] + + + self.finale_layers = nn.Sequential(*finale_layers) + + def forward(self, source_img): + + s = source_img.shape[2] + + x = getattr(self, f'from_rgb_{s}px')(source_img) + + for i in range(self.num_2d_blocks): + x = getattr(self, f'enc_{i}_block={s}px')(x) + s //= 2 + + x = self.finale_layers(x) + + return x + + diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/args.py b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/args.py new file mode 100644 index 0000000000000000000000000000000000000000..bc93600490072e1b224bea84dda88cd1c4ca69ee --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/args.py @@ -0,0 +1,66 @@ +import argparse + + + +def str2bool(string): + if string == 'True': + return True + elif string == 'False': + return False + else: + raise + +def parse_str_to_list(string, value_type=str, sep=','): + if string: + outputs = string.replace(' ', '').split(sep) + else: + outputs = [] + + outputs = [value_type(output) for output in outputs] + + return outputs + +def parse_str_to_dict(string, value_type=str, sep=','): + items = [s.split(':') for s in string.replace(' ', '').split(sep)] + return {k: value_type(v) for k, v in items} + +def isfloat(value): + try: + float(value) + return True + except ValueError: + return False + +def parse_args_line(line): + # Parse a value from string + parts = line[:-1].split(': ') + if len(parts) > 2: + parts = [parts[0], ': '.join(parts[1:])] + k, v = parts + v_type = str + if v.isdigit(): + v = int(v) + v_type = int + elif isfloat(v): + v_type = float + v = float(v) + elif v == 'True': + v = True + elif v == 'False': + v = False + + return k, v, v_type + +def parse_args(args_path): + parser = argparse.ArgumentParser(conflict_handler='resolve') + parser.add = parser.add_argument + + with open(args_path, 'rt') as args_file: + lines = args_file.readlines() + for line in lines: + k, v, v_type = parse_args_line(line) + parser.add('--%s' % k, type=v_type, default=v) + + args, _ = parser.parse_known_args() + + return args \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/decoder.py b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..df11ede60ef3a1d25fc93917b411b3d56a6bc154 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/decoder.py @@ -0,0 +1,364 @@ +import torch +from torch import nn +from torch import optim +from torch.cuda import amp +import torch.nn.functional as F +import math +import numpy as np +import itertools +import copy +from torch.cuda import amp +from scipy import linalg +from . import utils +from .utils import ProjectorConv, ProjectorNorm, ProjectorNormLinear, assign_adaptive_conv_params, \ + assign_adaptive_norm_params, Upsample_sg2 +from dataclasses import dataclass +# from .sg3_generator import Generator + + +class Decoder(nn.Module): + + @dataclass + class Config: + eps : float + image_size : int + gen_embed_size : int + gen_adaptive_kernel : bool + gen_adaptive_conv_type: str + gen_latent_texture_size: int + in_channels: int + gen_num_channels: int + dec_max_channels: int + gen_use_adanorm: bool + gen_activation_type: str + gen_use_adaconv: bool + dec_channel_mult: float + dec_num_blocks: int + dec_up_block_type: str + dec_pred_seg: bool + dec_seg_channel_mult: float + num_gpus: int + norm_layer_type: str + bigger: bool = False + vol_render: bool = False + im_dec_num_lrs_per_resolution: int = 1 + im_dec_ch_div_factor: float = 2.0 + emb_v_exp: bool = False + dec_use_sg3_img_dec: bool = False + no_detach_frec: int = 10 + dec_key_emb: str = 'orig' + + def __init__(self, cfg:Config): + super(Decoder, self).__init__() + self.cfg = cfg + self.adaptive_conv_type = self.cfg.gen_adaptive_conv_type + num_blocks = self.cfg.dec_num_blocks + num_up_blocks = int(math.log(self.cfg.image_size // self.cfg.gen_latent_texture_size, 2)) + self.in_channels = self.cfg.in_channels + out_channels = min(int(self.cfg.gen_num_channels * self.cfg.dec_channel_mult * 2**num_up_blocks), self.cfg.dec_max_channels) + # print(num_up_blocks, out_channels) + self.gen_max_channels = self.cfg.dec_max_channels + # self.num_gpus = self.cfg.num_gpus + self.norm_layer_type = self.cfg.norm_layer_type + norm_layer_type = self.cfg.norm_layer_type + + # if norm_layer_type == 'bn': + # if self.num_gpus > 1: + # norm_layer_type = 'sync_' + norm_layer_type + # if self.cfg.gen_use_adanorm: + # norm_layer_type = 'ada_' + norm_layer_type + + # print(norm_layer_type) + if self.cfg.vol_render: + layers = [] + else: + layers = [ + nn.Conv2d( + in_channels=self.in_channels, + out_channels=out_channels, + kernel_size=(1, 1), + bias=False) + ] + + for i in range(num_blocks): + layers += [ + utils.blocks['res']( + in_channels=out_channels, + out_channels=out_channels, + norm_layer_type=norm_layer_type, + activation_type=self.cfg.gen_activation_type, + conv_layer_type=('ada_' if self.cfg.gen_use_adaconv else '') + 'conv')] + + + self.res_decoder = nn.Sequential(*layers) + + if self.cfg.dec_use_sg3_img_dec: + self.img_decoder = Generator( + 512, + 512, + 512, + 3, + num_layers=14, # Total number of layers, excluding Fourier features and ToRGB. # NOTE Original 6 + num_critical=2, # Number of critically sampled layers at the end. + first_cutoff=10.079, # Cutoff frequency of the first layer (f_{c,0}). # NOTE Original 2.0 + first_stopband=17.959, # Minimum stopband of the first layer (f_{t,0}). # NOTE Original 2**3.1 + last_stopband_rel=2**0.3, # Minimum stopband of the last layer, expressed relative to the cutoff. + margin_size=16, # Number of additional pixels outside the image. + num_fp16_res=False, # Use FP16 for the N highest resolutions. + ) + else: + self.img_decoder = ImageDecoder( + self.cfg.image_size, + self.cfg.gen_latent_texture_size, + self.cfg.gen_use_adanorm, + self.cfg.gen_num_channels, + self.cfg.dec_up_block_type, + self.cfg.gen_activation_type, + self.cfg.gen_use_adaconv, + self.cfg.dec_pred_seg, + self.cfg.dec_seg_channel_mult, + out_channels, + # self.num_gpus, + norm_layer_type=norm_layer_type, + bigger=self.cfg.bigger, + im_dec_num_lrs_per_resolution = self.cfg.im_dec_num_lrs_per_resolution, + im_dec_ch_div_factor = self.cfg.im_dec_ch_div_factor + ) + + + self.gen_use_adanorm = self.cfg.gen_use_adanorm + if self.cfg.gen_use_adanorm: + # self.projector = ProjectorNormLinear(net_or_nets=[self.res_decoder, self.img_decoder], eps=self.cfg.eps, + # gen_embed_size=32, + # gen_max_channels=64) + + # self.projector = ProjectorNorm(net_or_nets=[self.res_decoder, self.img_decoder], eps=self.cfg.eps, + # gen_embed_size=self.cfg.gen_embed_size, gen_max_channels=self.gen_max_channels,) + self.projector = ProjectorNormLinear(net_or_nets=[self.res_decoder, self.img_decoder], eps=self.cfg.eps, + gen_embed_size=self.cfg.gen_embed_size, + gen_max_channels=self.gen_max_channels, emb_v_exp = self.cfg.emb_v_exp, no_detach_frec=self.cfg.no_detach_frec, key_emb=self.cfg.dec_key_emb) + else: + self.projector = ProjectorNorm(net_or_nets=[self.res_decoder, self.img_decoder], eps=self.cfg.eps, + gen_embed_size=self.cfg.gen_embed_size, + gen_max_channels=self.gen_max_channels) + + # print(sum(p.numel() for p in self.res_decoder.parameters() if p.requires_grad), sum(p.numel() for p in self.img_decoder.parameters() if p.requires_grad)) + if self.cfg.gen_use_adaconv: + self.projector_conv = ProjectorConv(net_or_nets=[self.res_decoder, self.img_decoder], eps=self.cfg.eps, + gen_adaptive_kernel=self.cfg.gen_adaptive_kernel, + gen_max_channels=self.gen_max_channels) + + def forward(self, data_dict, embed_dict, feat_2d, input_flip_feat=False, annealing_alpha=0.0, embed=None, stage_two=False, iteration=0): + if self.gen_use_adanorm: + # b, c, es, _ = data_dict['ada_v'].shape + # params_norm = self.projector(data_dict['ada_v'].view(b, c, es ** 2)) + + params_norm = self.projector(embed_dict, iter=iteration) + annealing_alpha = 1 + # print('aaaa') + else: + params_norm = self.projector(embed_dict, iter=iteration) + + if input_flip_feat: + # Repeat params for flipped feat + params_norm_ = [] + for param in params_norm: + if isinstance(param, tuple): + params_norm_.append((torch.cat([p] * 2) for p in param)) + else: + params_norm_.append(torch.cat([param] * 2)) + else: + params_norm_ = params_norm + + assign_adaptive_norm_params([self.res_decoder, self.img_decoder], params_norm_, annealing_alpha) + + if hasattr(self, 'projector_conv'): + params_conv = self.projector_conv(embed_dict) + + if input_flip_feat: + # Repeat params for flipped feat + params_conv_ = [] + for param in params_conv: + if isinstance(param, tuple): + params_conv_.append((torch.cat([p] * 2) for p in param)) + else: + params_conv_.append(torch.cat([param] * 2)) + else: + params_conv_ = params_conv + + assign_adaptive_conv_params([self.res_decoder, self.img_decoder], params_conv, self.adaptive_conv_type, annealing_alpha) + + feat_2d = self.res_decoder(feat_2d) + img, seg, img_f = self.img_decoder(feat_2d, stage_two=stage_two) + + # Predict conf + if hasattr(self, 'conf_decoder') and self.training and input_flip_feat: + feat, feat_flip = feat_2d.split(feat_2d.shape[0] // 2) + + conf_ms, conf_ms_flip, conf, conf_flip = self.conf_decoder(feat, feat_flip) + + for conf_ms_k, conf_ms_flip_k, conf_name in zip(conf_ms, conf_ms_flip, self.conf_ms_names): + data_dict[f'{conf_name}_ms'] = conf_ms_k + data_dict[f'{conf_name}_flip_ms'] = conf_ms_flip_k + + data_dict[conf_name] = conf_ms_k[0] + data_dict[f'{conf_name}_flip'] = conf_ms_flip_k[0] + + for conf_k, conf_flip_k, conf_name in zip(conf, conf_flip, self.conf_names): + data_dict[f'{conf_name}'] = conf_k + data_dict[f'{conf_name}_flip'] = conf_flip_k + + if stage_two: + return img, seg, feat_2d, img_f + else: + return img, seg, None, None + +class ImageDecoder(nn.Module): + def __init__(self, + image_size, + gen_latent_texture_size, + gen_use_adanorm, + gen_num_channels, + dec_up_block_type, + gen_activation_type, + gen_use_adaconv, + dec_pred_seg, + dec_seg_channel_mult, + shared_in_channels, + # num_gpus, + norm_layer_type, + bigger=False, + im_dec_num_lrs_per_resolution=1, + im_dec_ch_div_factor = 2 + ): + super(ImageDecoder, self).__init__() + num_up_blocks = int(math.log(image_size // gen_latent_texture_size, 2)) + out_channels = shared_in_channels + self.bigger = bigger + self.im_dec_num_lrs_per_resolution = im_dec_num_lrs_per_resolution + + layers = [] + + if self.bigger: + num_up_blocks = num_up_blocks - 1 + + for i in range(num_up_blocks): + in_channels = out_channels + # out_channels = max(out_channels // 2, gen_num_channels) + out_channels = max(int(out_channels / im_dec_ch_div_factor/32)*32, gen_num_channels) + + + if self.bigger: + out_channels = max(out_channels, 256) + # out_channels = max(out_channels, gen_num_channels) + + # if out_channels%32!=0: + # c_norm_layer_type = 'gn_24' + # else: + # c_norm_layer_type = norm_layer_type + k=0 + for _ in range(self.im_dec_num_lrs_per_resolution): + layers += [ + utils.blocks[dec_up_block_type]( + in_channels=in_channels, + out_channels=out_channels, + stride=2 if k==0 else 1, + norm_layer_type=norm_layer_type, + activation_type=gen_activation_type, + conv_layer_type=('ada_' if gen_use_adaconv else '') + 'conv', + resize_layer_type='nearest' if k==0 else 'none'), + ] + in_channels = out_channels + k+=1 + + if self.bigger: + layers += [ + utils.blocks[dec_up_block_type]( + in_channels=out_channels, + out_channels=out_channels//2, + norm_layer_type=norm_layer_type, + activation_type=gen_activation_type, + conv_layer_type=('ada_' if gen_use_adaconv else '') + 'conv'), + + utils.blocks[dec_up_block_type]( + in_channels=out_channels//2, + out_channels=out_channels//2, + stride=2, + norm_layer_type=norm_layer_type, + activation_type=gen_activation_type, + conv_layer_type=('ada_' if gen_use_adaconv else '') + 'conv', + + resize_layer_type='nearest'), + utils.blocks[dec_up_block_type]( + in_channels=out_channels // 2, + out_channels=out_channels // 4, + norm_layer_type=norm_layer_type, + activation_type=gen_activation_type, + conv_layer_type=('ada_' if gen_use_adaconv else '') + 'conv'), + + ] + out_channels = out_channels // 4 + + self.dec_img_blocks = nn.Sequential(*layers) + + layers = [ + utils.norm_layers[norm_layer_type](out_channels), + utils.activations[gen_activation_type](inplace=True), + nn.Conv2d( + in_channels=out_channels, + out_channels=3, + kernel_size=1), + nn.Sigmoid()] + + self.dec_img_head = nn.Sequential(*layers) + + if dec_pred_seg: + in_channels = shared_in_channels + out_channels = int(gen_num_channels * dec_seg_channel_mult * 2**num_up_blocks) + + layers = [ + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + bias=False)] + + for i in range(num_up_blocks): + in_channels = out_channels + out_channels = max(out_channels // 2, int(gen_num_channels * dec_seg_channel_mult)) + layers += [ + utils.blocks[dec_up_block_type]( + in_channels=in_channels, + out_channels=out_channels, + stride=2, + norm_layer_type=norm_layer_type, + activation_type=gen_activation_type, + conv_layer_type=('ada_' if gen_use_adaconv else '') + 'conv', + resize_layer_type='nearest')] + + def forward(self, feat, stage_two=False): + img_feat = self.dec_img_blocks(feat) + img = self.dec_img_head(img_feat.float()) + + seg = None + if hasattr(self, 'dec_seg_blocks'): + seg_feat = self.dec_seg_blocks(feat) + seg = self.dec_seg_head(seg_feat.float()) + + if stage_two: + return img, None, img_feat + else: + return img, None, None + +def norm_ip(img, low, high): + img.clamp_(min=low, max=high) + img.sub_(low).div_(max(high - low, 1e-5)) + return img + + + + + + + + diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/expression_embedder.py b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/expression_embedder.py new file mode 100644 index 0000000000000000000000000000000000000000..b32313e6ef7b09c331b720592dab193309431db0 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/expression_embedder.py @@ -0,0 +1,491 @@ +import torch +from torch import nn +from torch import optim +import torch.nn.functional as F +from torch.cuda import amp +from torchvision import models +import torchvision.transforms.functional as FF +import math +import numpy as np +import os +from . import utils +# from utils import args as args_utils +# import apex +from dataclasses import dataclass +import copy + +from model.basic_model.resnet import ResNet18 + +# import torch +# torch.manual_seed(0) + +# import random +# random.seed(0) + +# import numpy as np +# np.random.seed(0) + +# os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:2" + +# import torch +# torch.use_deterministic_algorithms(True) + +class ExpressionEmbed(nn.Module): + + + @dataclass + class Config: + lpe_head_backbone: str + lpe_face_backbone: str + image_size: int + project_dir: str + # num_gpus: int + lpe_output_channels: int + lpe_output_channels_expression: int + lpe_final_pooling_type: str + lpe_output_size: int + lpe_head_transform_sep_scales: bool + norm_layer_type: str + use_head_net: bool = False + use_smart_scale: bool = False + smart_scale_max_scale: float = 0.75 + smart_scale_max_tol_angle: float = 0.8 # 45 degrees + dropout:float = 0.0 + # custom_w: bool =False + + def __init__(self, cfg = Config): + super(ExpressionEmbed, self).__init__() + + self.cfg = cfg + + # self.num_gpus = self.cfg.num_gpus + + if self.cfg.use_head_net: + self.net_head = ResNetWrapper( + lpe_output_channels=self.cfg.lpe_output_channels, + lpe_final_pooling_type=self.cfg.lpe_final_pooling_type, + lpe_output_size=self.cfg.lpe_output_size, + image_size=self.cfg.image_size, + lpe_head_transform_sep_scales=self.cfg.lpe_head_transform_sep_scales, + backbone=self.cfg.lpe_head_backbone, pred_head_pose=True) + else: + self.net_head = None + + + self.net_face = ResNetWrapper( + lpe_output_channels=self.cfg.lpe_output_channels_expression, + lpe_final_pooling_type=self.cfg.lpe_final_pooling_type, + lpe_output_size=self.cfg.lpe_output_size, + image_size=self.cfg.image_size, + lpe_head_transform_sep_scales=self.cfg.lpe_head_transform_sep_scales, + backbone=self.cfg.lpe_face_backbone, pred_expression=True, dropout=self.cfg.dropout, #custom_w=cfg.custom_w + ) + + # import pdb;pdb.set_trace() + + self.grid_size = self.cfg.image_size // 2 + grid = torch.linspace(-1, 1, self.grid_size) + v, u = torch.meshgrid(grid, grid) + identity_grid = torch.stack([u, v, torch.ones_like(u)], dim=2).view(1, -1, 3) + self.register_buffer('identity_grid', identity_grid) + + + + grid = torch.linspace(-1, 1, 512) + v, u = torch.meshgrid(grid, grid) + identity_grid = torch.stack([u, v, torch.ones_like(u)], dim=2).view(1, -1, 3) + self.register_buffer('identity_grid_512', identity_grid) + + + + self.lpe_head_transform_sep_scales = self.cfg.lpe_head_transform_sep_scales + aligned_keypoints = torch.from_numpy(np.load(f'{self.cfg.project_dir}/data/aligned_keypoints_3d.npy')) + + # import pdb; pdb.set_trace() + + aligned_keypoints = aligned_keypoints /self.cfg.image_size + aligned_keypoints[:, :2] -= 0.5 + aligned_keypoints *= 2 # map to [-1, 1] + self.register_buffer('aligned_keypoints', aligned_keypoints[None]) + self.image_size = self.cfg.image_size + + if self.cfg.norm_layer_type=='bn': + pass + # if self.cfg.num_gpus > 1: + # # self.net_head = nn.SyncBatchNorm.convert_sync_batchnorm(self.net_head) + # # self.net_face = nn.SyncBatchNorm.convert_sync_batchnorm(self.net_face) + # # self.net_head = apex.parallel.convert_syncbn_model(self.net_head) + # self.net_face = apex.parallel.convert_syncbn_model(self.net_face) + elif self.cfg.norm_layer_type=='in': + + # self.net_head = utils.replace_bn_to_in(self.net_head, 'ExpressionEmbed_net_head') + self.net_face = utils.replace_bn_to_in(self.net_face, 'ExpressionEmbed_net_face') + + elif self.cfg.norm_layer_type=='gn': + # self.net_head = utils.replace_bn_to_gn(self.net_head, 'ExpressionEmbed_net_head') + self.net_face = utils.replace_bn_to_gn(self.net_face, 'ExpressionEmbed_net_face') + elif self.cfg.norm_layer_type=='bcn': + # self.net_head = utils.replace_bn_to_gn(self.net_head, 'ExpressionEmbed_net_head') + self.net_face = utils.replace_bn_to_bcn(self.net_face, 'ExpressionEmbed_net_face') + else: + raise ValueError('Wrong norm type') + + def forward(self, data_dict, use_aug = True): + + if 'masked_source_img' in data_dict: + inputs_face = torch.cat([data_dict['masked_source_img'], data_dict['masked_target_img']]) + else: + inputs_face = torch.cat([data_dict['source_img'], data_dict['target_img']]) + + n = data_dict['source_img'].shape[0] + t = data_dict['target_img'].shape[0] + + data_dict['pred_source_theta'] = data_dict['source_theta'] + data_dict['pred_target_theta'] = data_dict['target_theta'] + theta = torch.cat([data_dict['source_theta'][:, :3, :], data_dict['target_theta'][:, :3, :]]).detach() + + # if self.training: + # # Calc ground truth thetas + # theta = torch.cat([data_dict['source_theta'][:,:3,:], data_dict['target_theta'][:,:3,:]]) + # inputs_face = torch.cat([data_dict['source_img'], data_dict['target_img']]) + # else: + # inputs_face = torch.cat([data_dict['source_img'], data_dict['target_img']]) + + with torch.no_grad(): + # Align input images using theta + eye_vector = torch.zeros(theta.shape[0], 1, 4) + eye_vector[:, :, 3] = 1 + eye_vector = eye_vector.type(theta.type()).to(theta.device) + theta_targ = torch.cat([data_dict['target_theta'][:, :3, :], data_dict['target_theta'][:, :3, :]]) + theta_ = torch.cat([theta, eye_vector], dim=1) + theta_targ_ = torch.cat([theta_targ, eye_vector], dim=1) + inv_theta_2d = theta_.float().inverse()[:, :, [0, 1, 3]][:, [0, 1, 3]] # leave only rows and cols corresponding to 2d transform + inv_theta_2d_ = inv_theta_2d + inv_theta_targ_2d = theta_targ_.float().inverse()[:, :, [0, 1, 3]][:, [0, 1, 3]] + + scale = torch.zeros_like(inv_theta_2d) + scale_full = torch.zeros_like(inv_theta_2d) + + if not self.cfg.use_smart_scale: + scale[:, [0, 1], [0, 1]] = 0.5 + scale[:, 2, 2] = 1 + scale_full[:, [0, 1, 2], [0, 1, 2]] = 2 + inv_theta_2d = torch.bmm(inv_theta_2d, scale)[:, :2] + inv_theta_targ_2d = torch.bmm(inv_theta_targ_2d, scale)[:, :2] + inv_theta_targ_2d_full = torch.bmm(inv_theta_targ_2d, scale_full)[:, :2] + else: + yaw_s, pitch_s, roll_s = data_dict['source_rotation'].split(1, dim=1) + yaw_t, pitch_t, roll_t = data_dict['target_rotation'].split(1, dim=1) + all_yaws = torch.cat([yaw_s.view(-1), yaw_t.view(-1)]) + max_scale = self.cfg.smart_scale_max_scale + max_tol_angle = self.cfg.smart_scale_max_tol_angle # 0.8 rad = 45 degrees no change. 1.6 rad ~=90 degrees + for s, yaw in enumerate(all_yaws): + # scale_ = min(0.5+max((torch.abs(yaw)-max_tol_angle)*(1-max_scale)/(1.6-max_tol_angle), 0), max_scale) + scale_ = min(0.5+max((torch.abs(yaw)-max_tol_angle)*(max_scale-0.5)/(1.6-max_tol_angle), 0), max_scale) + scale[s, [0, 1], [0, 1]] = scale_ + if scale_>0.55: + print(torch.abs(yaw), scale_) + scale[:, 2, 2] = 1 + inv_theta_2d = torch.bmm(inv_theta_2d, scale)[:, :2].detach() + + scale_t = copy.deepcopy(scale) + scale_t[0, [0, 1], [0, 1]] = scale[-1, [0, 1], [0, 1]] + scale_t[1, [0, 1], [0, 1]] = scale[-1, [0, 1], [0, 1]] + inv_theta_targ_2d = torch.bmm(inv_theta_targ_2d, scale_t)[:, :2].detach() + scale_full[:, [0, 1, 2], [0, 1, 2]] = 2 + inv_theta_targ_2d_full = torch.bmm(inv_theta_targ_2d, scale_full)[:, :2].detach() + + + align_warp = self.identity_grid.repeat_interleave(n+t, dim=0) + align_warp = align_warp.bmm(inv_theta_2d.transpose(1, 2)).view(n+t, self.grid_size, self.grid_size, 2) + + + inputs_face_aligned = F.grid_sample(inputs_face.float(), align_warp.float()) + + data_dict['source_img_align'], data_dict['target_img_align'] = inputs_face_aligned.split([n, t], dim=0) + + align_warp_targ = self.identity_grid.repeat_interleave(n + t, dim=0) + data_dict['align_warp'] = align_warp_targ.bmm(inv_theta_targ_2d.transpose(1, 2)).view(n + t, self.grid_size, self.grid_size, 2) + + align_warp_targ = self.identity_grid_512.repeat_interleave(n + t, dim=0) + data_dict['align_warp_full'] = align_warp_targ.bmm(inv_theta_targ_2d_full.transpose(1, 2)).view(n + t, 512, 512, 2) + + exp_embed = self.net_face(inputs_face_aligned)[0] + data_dict['source_exp_embed'] = exp_embed[:n] + data_dict['target_exp_embed'] = exp_embed[-t:] + + return data_dict + + def estimate_theta(self, source_keypoints, target_keypoints): + keypoints = torch.cat([source_keypoints, target_keypoints], dim=0) + keypoints = torch.cat([keypoints, torch.ones(keypoints.shape[0], keypoints.shape[1], 1).to(keypoints.device)], dim=2) + + m = keypoints.shape[0] + + # Solve for ground-truth transform + if self.lpe_head_transform_sep_scales: + # scale_x, scale_y, scale_z, yaw, pitch, roll, dx, dy, dz + param = torch.FloatTensor([1, 1, 1, 0, 0, 0, 0, 0, 0]) + param = param[None].repeat_interleave(m, dim=0) + else: + param = torch.FloatTensor([1, 0, 0, 0, 0, 0, 0]) # scale, yaw, pitch, roll, dx, dy, dz + param = param[None].repeat_interleave(m, dim=0) + + param = param.to(keypoints.device) + + if self.lpe_head_transform_sep_scales: + scale, rotation, translation = param.split([3, 3, 3], dim=1) + else: + scale, rotation, translation = param.split([1, 3, 3], dim=1) + + params = [scale, rotation, translation] + params = [p.clone().requires_grad_() for p in params] + opt = optim.LBFGS(params) + + def closure(): + opt.zero_grad() + + transform_matrix = get_similarity_transform_matrix(*params) + pred_aligned_keypoints = keypoints @ transform_matrix.transpose(1, 2) + + loss = ((pred_aligned_keypoints - self.aligned_keypoints)**2).mean() + loss.backward() + + return loss + + for i in range(5): + opt.step(closure) + + theta = get_similarity_transform_matrix(*params).detach().float() + + source_theta, target_theta = theta.split(m//2) + + return source_theta, target_theta + + def forward_image(self, image, normalize = False, delta_yaw = None, delta_pitch = None): + scale, rotation, translation = self.net_head(image)[0] + pred_rotation = rotation.clone() + + if normalize: + rotation[:, [0, 1]] = 0.0 # zero rotations + translation = torch.zeros_like(translation) # and translations + + if delta_yaw is not None: + rotation[:, 0] = rotation[:, 0].clamp(-math.pi/2, math.pi) + delta_yaw + + if delta_pitch is not None: + rotation[:, 1] = rotation[:, 1].clamp(-math.pi/2, math.pi) + delta_pitch + + theta = get_similarity_transform_matrix(scale, rotation, translation) + + # Align input images using theta + eye_vector = torch.zeros(theta.shape[0], 1, 4) + eye_vector[:, :, 3] = 1 + eye_vector = eye_vector.type(theta.type()).to(theta.device) + + theta_ = torch.cat([theta, eye_vector], dim=1) + inv_theta_2d = theta_.float().inverse()[:, :, [0, 1, 3]][:, [0, 1, 3]] # leave only rows and cols corresponding to 2d transform + + # Perform 2x zoom-in compared to default theta + scale_t = torch.zeros_like(inv_theta_2d) + scale_t[:, [0, 1], [0, 1]] = 0.5 + scale_t[:, 2, 2] = 1 + + inv_theta_2d = torch.bmm(inv_theta_2d, scale_t)[:, :2] + + align_warp = self.identity_grid.repeat_interleave(image.shape[0], dim=0) + align_warp = align_warp.bmm(inv_theta_2d.transpose(1, 2)).view(image.shape[0], self.grid_size, self.grid_size, 2) + + image_align = F.grid_sample(image.float(), align_warp.float()) + + pose_embed = self.net_face(image_align)[0] + + return pose_embed, scale, rotation, translation, pred_rotation, image_align + + +class ResNetWrapper(nn.Module): + def __init__(self, + lpe_output_channels, + lpe_final_pooling_type, + lpe_output_size, + image_size, + backbone, + lpe_head_transform_sep_scales, + + pred_expression=False, + pred_head_pose=False, + dropout=0.0, + # custom_w=False + ): + super(ResNetWrapper, self).__init__() + self.pred_expression = pred_expression + self.pred_head_pose = pred_head_pose + + self.lpe_output_channels =lpe_output_channels + self.lpe_final_pooling_type = lpe_final_pooling_type + self.lpe_output_size = lpe_output_size + self.image_size = image_size + self.backbone = backbone + self.lpe_head_transform_sep_scales = lpe_head_transform_sep_scales + self.pred_expression = pred_expression + # self.custom_w = custom_w + self.dropout = dropout + + + self.net = getattr(models, backbone)(pretrained=True) + # import pdb;pdb.set_trace() + + expansion = 1 if (backbone == 'resnet18' or backbone == 'resnet34') else 4 + num_outputs = lpe_output_channels + + # self.custom_w = custom_w + + self.net.fc = nn.Conv2d( + in_channels=512 * expansion, + out_channels=num_outputs, + kernel_size=1, + bias=False) + + self.drop = nn.Dropout(p=dropout) + + if self.pred_expression: + if lpe_final_pooling_type == 'avg': + self.pose_avgpool = nn.AdaptiveAvgPool2d(lpe_output_size) + self.pose_head = nn.Linear(num_outputs * lpe_output_size**2, num_outputs, bias=False) + + elif lpe_final_pooling_type == 'transformer': + num_inputs = (image_size // 2**5)**2 + self.pose_head = nn.Sequential( + utils.TransformerHead(num_inputs, num_outputs, depth=3, heads=8, dim_head=64, mlp_dim=1024, dropout=0.1, emb_dropout=0.1), + nn.LayerNorm(num_outputs), + nn.Linear(num_outputs, num_outputs, bias=False)) + + if self.pred_head_pose: + self.theta_avgpool = nn.AdaptiveAvgPool2d(1) + + if lpe_head_transform_sep_scales: + num_params = 9 + else: + num_params = 7 + + self.param_head = nn.Linear(num_outputs, num_params) + self.param_head.weight.data.zero_() + + if lpe_head_transform_sep_scales: + # scale_x, scale_y, scale_z, yaw, pitch, roll, dx, dy, dz + self.param_head.bias.data.copy_(torch.tensor([1, 1, 1, 0, 0, 0, 0, 0, 0], dtype=torch.float)) + else: + # scale, yaw, pitch, roll, dx, dy, dz + self.param_head.bias.data.copy_(torch.tensor([1, 0, 0, 0, 0, 0, 0], dtype=torch.float)) + + self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406])[None, :, None, None]) + self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225])[None, :, None, None]) + + def _forward_impl(self, x): + + x = self.net.conv1(x) + x = self.net.bn1(x) + x = self.net.relu(x) + x = self.net.maxpool(x) #if not self.custom_w else x + + x = self.net.layer1(x) + x = self.net.layer2(x) + x = self.net.layer3(x) + x = self.net.layer4(x) + + x = self.net.fc(x) + + x = self.drop(x) + + return x + + def forward(self, x): + # if self.custom_w: + # x = FF.rgb_to_grayscale(x) + # else: + # x = (x - self.mean) / self.std + + x = (x - self.mean) / self.std + x = self._forward_impl(x) + + outputs = [] + + if self.pred_expression: + if hasattr(self, 'pose_avgpool'): + pose_embed = self.pose_avgpool(x) + pose_embed = torch.flatten(pose_embed, 1) + else: + pose_embed = x + + pose_embed = self.pose_head(pose_embed) + + outputs += [pose_embed] + + if self.pred_head_pose: + param = self.theta_avgpool(x) + param = torch.flatten(param, 1) + param = self.param_head(param) + + if param.shape[1] == 7: + scale, rotation, translation = param.split([1, 3, 3], dim=1) + elif param.shape[1] == 9: + scale, rotation, translation = param.split([3, 3, 3], dim=1) + else: + raise + + outputs += [(scale, rotation, translation)] + + return outputs + + +def get_similarity_transform_matrix(scale, rotation, translation): + eye_matrix = torch.eye(4)[None].repeat_interleave(scale.shape[0], dim=0).type(scale.type()).to(scale.device) + + # Scale transform + S = eye_matrix.clone() + + if scale.shape[1] == 3: + S[:, 0, 0] = scale[:, 0] + S[:, 1, 1] = scale[:, 1] + S[:, 2, 2] = scale[:, 2] + else: + S[:, 0, 0] = scale[:, 0] + S[:, 1, 1] = scale[:, 0] + S[:, 2, 2] = scale[:, 0] + + # Rotation transform + R = eye_matrix.clone() + + rotation = rotation.clamp(-math.pi/2, math.pi) + + yaw, pitch, roll = rotation.split(1, dim=1) + yaw, pitch, roll = yaw[:, 0], pitch[:, 0], roll[:, 0] # squeeze angles + yaw_cos = yaw.cos() + yaw_sin = yaw.sin() + pitch_cos = pitch.cos() + pitch_sin = pitch.sin() + roll_cos = roll.cos() + roll_sin = roll.sin() + + R[:, 0, 0] = yaw_cos * pitch_cos + R[:, 0, 1] = yaw_cos * pitch_sin * roll_sin - yaw_sin * roll_cos + R[:, 0, 2] = yaw_cos * pitch_sin * roll_cos + yaw_sin * roll_sin + + R[:, 1, 0] = yaw_sin * pitch_cos + R[:, 1, 1] = yaw_sin * pitch_sin * roll_sin + yaw_cos * roll_cos + R[:, 1, 2] = yaw_sin * pitch_sin * roll_cos - yaw_cos * roll_sin + + R[:, 2, 0] = -pitch_sin + R[:, 2, 1] = pitch_cos * roll_sin + R[:, 2, 2] = pitch_cos * roll_cos + + # Translation transform + T = eye_matrix.clone() + + T[:, 0, 3] = translation[:, 0] + T[:, 1, 3] = translation[:, 1] + T[:, 2, 3] = translation[:, 2] + + theta = (S @ R @ T)[:, :3] + + return theta \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/face_decoder.py b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/face_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..3634a66886572518cf83fafddacf3ec260ed540d --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/face_decoder.py @@ -0,0 +1,378 @@ +import torch +from torch import nn +from torch import optim +from torch.cuda import amp +import torch.nn.functional as F +import math +import numpy as np +import itertools +import copy +from torch.cuda import amp +from scipy import linalg +from . import utils +from .utils import ProjectorConv, ProjectorNorm, ProjectorNormLinear, assign_adaptive_conv_params, \ + assign_adaptive_norm_params, Upsample_sg2 +from dataclasses import dataclass +# from .sg3_generator import Generator + + +class Generator(nn.Module): + + @dataclass + class Config: + eps : float + image_size : int + gen_embed_size : int + gen_adaptive_kernel : bool + gen_adaptive_conv_type: str + gen_latent_texture_size: int + in_channels: int + gen_num_channels: int + dec_max_channels: int + gen_use_adanorm: bool + gen_activation_type: str + gen_use_adaconv: bool + dec_channel_mult: float + dec_num_blocks: int + dec_up_block_type: str + dec_pred_seg: bool + dec_seg_channel_mult: float + # num_gpus: int + norm_layer_type: str + bigger: bool = False + vol_render: bool = False + im_dec_num_lrs_per_resolution: int = 1 + im_dec_ch_div_factor: float = 2.0 + emb_v_exp: bool = False + dec_use_sg3_img_dec: bool = False + no_detach_frec: int = 10 + dec_key_emb: str = 'orig' + zero_to_one: bool = False + + def __init__(self, cfg:Config): + super(Generator, self).__init__() + self.cfg = cfg + self.adaptive_conv_type = self.cfg.gen_adaptive_conv_type + num_blocks = self.cfg.dec_num_blocks + num_up_blocks = int(math.log(self.cfg.image_size // self.cfg.gen_latent_texture_size, 2)) + self.in_channels = self.cfg.in_channels + out_channels = min(int(self.cfg.gen_num_channels * self.cfg.dec_channel_mult * 2**num_up_blocks), self.cfg.dec_max_channels) + # print(num_up_blocks, out_channels) + self.gen_max_channels = self.cfg.dec_max_channels + self.norm_layer_type = self.cfg.norm_layer_type + norm_layer_type = self.cfg.norm_layer_type + + self.final_activation_en = cfg.get('final_activation', False) + if self.final_activation_en and self.cfg.zero_to_one: + self.final_activation = nn.Sigmoid() + print('Using Sigmoid') + else: + self.final_activation = nn.Identity() if self.cfg.zero_to_one else nn.Tanh() + + if self.cfg.vol_render: + layers = [] + else: + layers = [ + nn.Conv2d( + in_channels=self.in_channels, + out_channels=out_channels, + kernel_size=(1, 1), + bias=False) + ] + + for i in range(num_blocks): + layers += [ + utils.blocks['res']( + in_channels=out_channels, + out_channels=out_channels, + norm_layer_type=norm_layer_type, + activation_type=self.cfg.gen_activation_type, + conv_layer_type=('ada_' if self.cfg.gen_use_adaconv else '') + 'conv')] + + + self.res_decoder = nn.Sequential(*layers) + + if self.cfg.dec_use_sg3_img_dec: + self.img_decoder = Generator( + 512, + 512, + 512, + 3, + num_layers=14, # Total number of layers, excluding Fourier features and ToRGB. # NOTE Original 6 + num_critical=2, # Number of critically sampled layers at the end. + first_cutoff=10.079, # Cutoff frequency of the first layer (f_{c,0}). # NOTE Original 2.0 + first_stopband=17.959, # Minimum stopband of the first layer (f_{t,0}). # NOTE Original 2**3.1 + last_stopband_rel=2**0.3, # Minimum stopband of the last layer, expressed relative to the cutoff. + margin_size=16, # Number of additional pixels outside the image. + num_fp16_res=False, # Use FP16 for the N highest resolutions. + ) + else: + self.img_decoder = ImageDecoder( + self.cfg.image_size, + self.cfg.gen_latent_texture_size, + self.cfg.gen_use_adanorm, + self.cfg.gen_num_channels, + self.cfg.dec_up_block_type, + self.cfg.gen_activation_type, + self.cfg.gen_use_adaconv, + self.cfg.dec_pred_seg, + self.cfg.dec_seg_channel_mult, + out_channels, + norm_layer_type=norm_layer_type, + bigger=self.cfg.bigger, + im_dec_num_lrs_per_resolution = self.cfg.im_dec_num_lrs_per_resolution, + im_dec_ch_div_factor = self.cfg.im_dec_ch_div_factor + ) + + + self.gen_use_adanorm = self.cfg.gen_use_adanorm + if self.cfg.gen_use_adanorm: + self.projector = ProjectorNormLinear(net_or_nets=[self.res_decoder, self.img_decoder], eps=self.cfg.eps, + gen_embed_size=self.cfg.gen_embed_size, + gen_max_channels=self.gen_max_channels, emb_v_exp = self.cfg.emb_v_exp, no_detach_frec=self.cfg.no_detach_frec, key_emb=self.cfg.dec_key_emb) + else: + self.projector = ProjectorNorm(net_or_nets=[self.res_decoder, self.img_decoder], eps=self.cfg.eps, + gen_embed_size=self.cfg.gen_embed_size, + gen_max_channels=self.gen_max_channels) + + # print(sum(p.numel() for p in self.res_decoder.parameters() if p.requires_grad), sum(p.numel() for p in self.img_decoder.parameters() if p.requires_grad)) + if self.cfg.gen_use_adaconv: + self.projector_conv = ProjectorConv(net_or_nets=[self.res_decoder, self.img_decoder], eps=self.cfg.eps, + gen_adaptive_kernel=self.cfg.gen_adaptive_kernel, + gen_max_channels=self.gen_max_channels) + + def forward(self, data_dict, embed_dict, feat_2d, input_flip_feat=False, annealing_alpha=0.0, embed=None, stage_two=False, iteration=0): + if self.gen_use_adanorm: + # b, c, es, _ = data_dict['ada_v'].shape + # params_norm = self.projector(data_dict['ada_v'].view(b, c, es ** 2)) + + params_norm = self.projector(embed_dict, iter=iteration) + annealing_alpha = 1 + # print('aaaa') + else: + params_norm = self.projector(embed_dict, iter=iteration) + + if input_flip_feat: + # Repeat params for flipped feat + params_norm_ = [] + for param in params_norm: + if isinstance(param, tuple): + params_norm_.append((torch.cat([p] * 2) for p in param)) + else: + params_norm_.append(torch.cat([param] * 2)) + else: + params_norm_ = params_norm + + assign_adaptive_norm_params([self.res_decoder, self.img_decoder], params_norm_, annealing_alpha) + + if hasattr(self, 'projector_conv'): + params_conv = self.projector_conv(embed_dict) + + if input_flip_feat: + # Repeat params for flipped feat + params_conv_ = [] + for param in params_conv: + if isinstance(param, tuple): + params_conv_.append((torch.cat([p] * 2) for p in param)) + else: + params_conv_.append(torch.cat([param] * 2)) + else: + params_conv_ = params_conv + + assign_adaptive_conv_params([self.res_decoder, self.img_decoder], params_conv, self.adaptive_conv_type, annealing_alpha) + + feat_2d = self.res_decoder(feat_2d) + img, seg, img_f = self.img_decoder(feat_2d, stage_two=stage_two) + + img = self.final_activation(img) + + # Predict conf + if hasattr(self, 'conf_decoder') and self.training and input_flip_feat: + feat, feat_flip = feat_2d.split(feat_2d.shape[0] // 2) + + conf_ms, conf_ms_flip, conf, conf_flip = self.conf_decoder(feat, feat_flip) + + for conf_ms_k, conf_ms_flip_k, conf_name in zip(conf_ms, conf_ms_flip, self.conf_ms_names): + data_dict[f'{conf_name}_ms'] = conf_ms_k + data_dict[f'{conf_name}_flip_ms'] = conf_ms_flip_k + + data_dict[conf_name] = conf_ms_k[0] + data_dict[f'{conf_name}_flip'] = conf_ms_flip_k[0] + + for conf_k, conf_flip_k, conf_name in zip(conf, conf_flip, self.conf_names): + data_dict[f'{conf_name}'] = conf_k + data_dict[f'{conf_name}_flip'] = conf_flip_k + + if stage_two: + return img, seg, feat_2d, img_f + else: + return img, seg, None, None + +class ImageDecoder(nn.Module): + def __init__(self, + image_size, + gen_latent_texture_size, + gen_use_adanorm, + gen_num_channels, + dec_up_block_type, + gen_activation_type, + gen_use_adaconv, + dec_pred_seg, + dec_seg_channel_mult, + shared_in_channels, + norm_layer_type, + bigger=False, + im_dec_num_lrs_per_resolution=1, + im_dec_ch_div_factor = 2 + ): + super(ImageDecoder, self).__init__() + num_up_blocks = int(math.log(image_size // gen_latent_texture_size, 2)) + out_channels = shared_in_channels + self.bigger = bigger + self.im_dec_num_lrs_per_resolution = im_dec_num_lrs_per_resolution + + layers = [] + + if self.bigger: + num_up_blocks = num_up_blocks - 1 + + for i in range(num_up_blocks): + in_channels = out_channels + # out_channels = max(out_channels // 2, gen_num_channels) + out_channels = max(int(out_channels / im_dec_ch_div_factor/32)*32, gen_num_channels) + + + if self.bigger: + out_channels = max(out_channels, 256) + k=0 + for _ in range(self.im_dec_num_lrs_per_resolution): + layers += [ + utils.blocks[dec_up_block_type]( + in_channels=in_channels, + out_channels=out_channels, + stride=2 if k==0 else 1, + norm_layer_type=norm_layer_type, + activation_type=gen_activation_type, + conv_layer_type=('ada_' if gen_use_adaconv else '') + 'conv', + resize_layer_type='nearest' if k==0 else 'none'), + ] + in_channels = out_channels + k+=1 + + if self.bigger: + layers += [ + # utils.blocks[dec_up_block_type]( + # in_channels=out_channels, + # out_channels=out_channels//2, + # norm_layer_type=norm_layer_type, + # activation_type=gen_activation_type, + # conv_layer_type=('ada_' if gen_use_adaconv else '') + 'conv'), + + + + utils.blocks[dec_up_block_type]( + in_channels=out_channels, + out_channels=out_channels//2, + norm_layer_type=norm_layer_type, + activation_type=gen_activation_type, + conv_layer_type=('ada_' if gen_use_adaconv else '') + 'conv'), + + utils.blocks[dec_up_block_type]( + in_channels=out_channels//2, + out_channels=out_channels//2, + stride=2, + norm_layer_type=norm_layer_type, + activation_type=gen_activation_type, + conv_layer_type=('ada_' if gen_use_adaconv else '') + 'conv', + + resize_layer_type='nearest'), + utils.blocks[dec_up_block_type]( + in_channels=out_channels // 2, + out_channels=out_channels // 4, + norm_layer_type=norm_layer_type, + activation_type=gen_activation_type, + conv_layer_type=('ada_' if gen_use_adaconv else '') + 'conv'), + + ] + out_channels = out_channels // 4 + + self.dec_img_blocks = nn.Sequential(*layers) + + layers = [ + utils.norm_layers[norm_layer_type](out_channels), + utils.activations[gen_activation_type](inplace=True), + nn.Conv2d( + in_channels=out_channels, + out_channels=3, + kernel_size=1), + # nn.Sigmoid() + ] + + self.dec_img_head = nn.Sequential(*layers) + + if dec_pred_seg: + in_channels = shared_in_channels + out_channels = int(gen_num_channels * dec_seg_channel_mult * 2**num_up_blocks) + + layers = [ + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + bias=False)] + + for i in range(num_up_blocks): + in_channels = out_channels + out_channels = max(out_channels // 2, int(gen_num_channels * dec_seg_channel_mult)) + layers += [ + utils.blocks[dec_up_block_type]( + in_channels=in_channels, + out_channels=out_channels, + stride=2, + norm_layer_type=norm_layer_type, + activation_type=gen_activation_type, + conv_layer_type=('ada_' if gen_use_adaconv else '') + 'conv', + resize_layer_type='nearest')] + + # self.dec_seg_blocks = nn.Sequential(*layers) + # + # layers = [ + # # utils.norm_layers['bn'](out_channels), + # utils.norm_layers[norm_layer_type](out_channels), + # utils.activations[gen_activation_type](inplace=True), + # nn.Conv2d( + # in_channels=out_channels, + # out_channels=1, + # kernel_size=(1,1)), + # nn.Sigmoid()] + # + # self.dec_seg_head = nn.Sequential(*layers) + + def forward(self, feat, stage_two=False): + img_feat = self.dec_img_blocks(feat) + img = self.dec_img_head(img_feat.float()) + + seg = None + if hasattr(self, 'dec_seg_blocks'): + seg_feat = self.dec_seg_blocks(feat) + seg = self.dec_seg_head(seg_feat.float()) + + # import pdb; pdb.set_trace() + + if stage_two: + return img, None, img_feat + else: + return img, None, None + +def norm_ip(img, low, high): + img.clamp_(min=low, max=high) + img.sub_(low).div_(max(high - low, 1e-5)) + return img + + + + + + + + diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/face_decoder_spade.py b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/face_decoder_spade.py new file mode 100644 index 0000000000000000000000000000000000000000..944fdeed02d0aab902a36a1aa412086faf332b75 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/face_decoder_spade.py @@ -0,0 +1,90 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F +import sys +from pathlib import Path +import math + +sys.path.append(str(Path(__file__).parent.parent.parent)) +from model.head_animation.LIA.util import * + +class SPADEGenerator(nn.Module): + def __init__(self, cfg): + super(SPADEGenerator, self).__init__() + + # settings for image_generator + self.in_channels = cfg.in_channels + self.proj_channels = cfg.proj_channels # channel of projected feature volume + self.flag_estimate_occlusion_map = cfg.flag_estimate_occlusion_map + self.final_activation = cfg.final_activation + self.zero_to_one = cfg.zero_to_one + + self.norm_G = 'spadespectralinstance' + self.label_num_channels = self.proj_channels + self.out_channels = 64 # output channel of final SPADEResnetBlock + + ### generator + self.init_image_generator() + + def init_image_generator(self): + # Projection layers + self.projection = nn.Sequential( + SameBlock2d(self.in_channels, self.proj_channels, kernel_size=(3, 3), padding=(1, 1), lrelu=True), + nn.Conv2d(self.proj_channels, self.proj_channels, kernel_size=1, stride=1) + ) + + self.fc = nn.Conv2d(self.proj_channels, 2 * self.proj_channels, 3, padding=1) + self.G_middle_0 = SPADEResnetBlock(2 * self.proj_channels, 2 * self.proj_channels, self.norm_G, self.label_num_channels) + self.G_middle_1 = SPADEResnetBlock(2 * self.proj_channels, 2 * self.proj_channels, self.norm_G, self.label_num_channels) + self.G_middle_2 = SPADEResnetBlock(2 * self.proj_channels, 2 * self.proj_channels, self.norm_G, self.label_num_channels) + self.G_middle_3 = SPADEResnetBlock(2 * self.proj_channels, 2 * self.proj_channels, self.norm_G, self.label_num_channels) + self.G_middle_4 = SPADEResnetBlock(2 * self.proj_channels, 2 * self.proj_channels, self.norm_G, self.label_num_channels) + self.G_middle_5 = SPADEResnetBlock(2 * self.proj_channels, 2 * self.proj_channels, self.norm_G, self.label_num_channels) + self.up_0 = SPADEResnetBlock(2 * self.proj_channels, self.proj_channels, self.norm_G, self.label_num_channels) + self.up_1 = SPADEResnetBlock(self.proj_channels, self.out_channels, self.norm_G, self.label_num_channels) + self.up = nn.Upsample(scale_factor=2) + + self.conv_img = nn.Sequential( + nn.Conv2d(self.out_channels, 3 * (2 * 2), kernel_size=3, padding=1), + nn.PixelShuffle(upscale_factor=2) + ) + + if self.final_activation: + self.final_activation_fn = nn.Sigmoid() if self.zero_to_one else nn.Tanh() + else: + self.final_activation_fn = nn.Identity() + + if self.flag_estimate_occlusion_map: + self.occlusion = nn.Conv2d(self.in_channels, 1, kernel_size=7, padding=3) + + def image_generation(self, warping_feature_volume): + seg = self.projection(warping_feature_volume) # Bx256x64x64 + + if self.flag_estimate_occlusion_map: + occlusion_map = torch.sigmoid(self.occlusion(warping_feature_volume)) # Bx1x64x64 + seg = seg * occlusion_map + + x = self.fc(seg) # Bx512x64x64 + x = self.G_middle_0(x, seg) + x = self.G_middle_1(x, seg) + x = self.G_middle_2(x, seg) + x = self.G_middle_3(x, seg) + x = self.G_middle_4(x, seg) + x = self.G_middle_5(x, seg) + + x = self.up(x) # Bx512x64x64 -> Bx512x128x128 + x = self.up_0(x, seg) # Bx512x128x128 -> Bx256x128x128 + x = self.up(x) # Bx256x128x128 -> Bx256x256x256 + x = self.up_1(x, seg) # Bx256x256x256 -> Bx64x256x256 + + x = self.conv_img(F.leaky_relu(x, 2e-1)) # Bx64x256x256 -> Bx3xHxW + x = self.final_activation_fn(x) # Bx3xHxW + + return x + + def forward(self, data_dict, target_warp_embed_dict, aligned_target_volume): + # decoding + img = self.image_generation(aligned_target_volume) + + return img, None, None, None diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/face_encoder.py b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/face_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..80bcd608d8303c3a572cf68668949b370553da5c --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/face_encoder.py @@ -0,0 +1,47 @@ +import torch +from torch import nn +import torch.nn.functional as F +import torch.distributed as dist +from torchvision import models +from torch.cuda import amp +import math + +from model.head_animation.EMOP.appearance_encoder import AppearanceEncoder +from model.head_animation.EMOP.identity_embedder import IdtEmbed +from model.head_animation.EMOP.spectral_norm import apply_sp_to_nets +from model.head_animation.EMOP.utils import apply_ws_to_nets + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x): + return self.fn(x) + x + +class FaceEncoder(nn.Module): + def __init__(self, cfg): + super(FaceEncoder, self).__init__() + + self.app_encoder = AppearanceEncoder(cfg.face_encoder) + self.idt_encoder = IdtEmbed(cfg.idt_embedder) + + apply_sp_to_nets(self) + apply_ws_to_nets(self) + + def forward(self, source_img): + if self.training: + return torch.utils.checkpoint.checkpoint( \ + self.manual_forward, *[source_img], + use_reentrant=False) + else: + return self.manual_forward(*[source_img]) + + def manual_forward(self, source_img): + + latent_volume = self.app_encoder(source_img) + idt_embedding = self.idt_encoder(source_img) + + return latent_volume, idt_embedding + + diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/face_generator.py b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/face_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..c1ca6474244dd692982e5d21796292b51836e599 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/face_generator.py @@ -0,0 +1,49 @@ +import torch +from torch import nn +from torch import optim +from torch.cuda import amp +import torch.nn.functional as F +import math +import numpy as np +import itertools +import copy +from torch.cuda import amp +from scipy import linalg +from model.head_animation.EMOP.face_decoder import Generator +from model.head_animation.EMOP.face_decoder_spade import SPADEGenerator +from model.head_animation.EMOP.spectral_norm import apply_sp_to_nets +from model.head_animation.EMOP.utils import apply_ws_to_nets + + +class FaceGenerator(nn.Module): + def __init__(self, cfg): + super(FaceGenerator, self).__init__() + + use_spade = cfg.get('use_spade', False) + if use_spade: + self.decoder = SPADEGenerator(cfg) + else: + self.decoder = Generator(cfg) + + apply_sp_to_nets(self) + apply_ws_to_nets(self) + + def forward(self, data_dict, target_warp_embed_dict, aligned_target_volume): + if self.training: + return torch.utils.checkpoint.checkpoint( \ + self.manual_forward, *[data_dict, target_warp_embed_dict, aligned_target_volume], + use_reentrant=False) + else: + return self.manual_forward(*[data_dict, target_warp_embed_dict, aligned_target_volume]) + + def manual_forward(self, data_dict, target_warp_embed_dict, aligned_target_volume): + data_dict['pred_target_img'], _, _, _ = self.decoder(data_dict, target_warp_embed_dict, aligned_target_volume) + return data_dict + + + + + + + + diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/flow_estimator.py b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/flow_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..89361c2ca82cb7865bb57147c25823e734391d07 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/flow_estimator.py @@ -0,0 +1,105 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F + +import sys +from pathlib import Path + +from model.head_animation.EMOP.warp_generator_resnet import WarpGenerator +from model.head_animation.EMOP.vpn_resblocks import VPN_ResBlocks +from model.head_animation.EMOP.unet_3d import Unet3D + +from model.head_animation.EMOP.spectral_norm import apply_sp_to_nets +from model.head_animation.EMOP.utils import apply_ws_to_nets + +class FlowEstimator(nn.Module): + def __init__(self, cfg): + super(FlowEstimator, self).__init__() + + # Operator that transform exp_emb to extended exp_emb (to match idt_emb size) + self.pose_unsqueeze_nw = nn.Linear(cfg.exp_embedder.lpe_output_channels_expression, cfg.face_encoder.gen_max_channels * cfg.exp_embedder.embed_size ** 2, bias=False) + + # Operator that combine idt_imb and extended exp_emb together (a "+" sign of a scheme) + self.warp_embed_head_orig_nw = nn.Conv2d( + in_channels=cfg.face_encoder.gen_max_channels, + out_channels=cfg.face_encoder.gen_max_channels, + kernel_size=(1, 1), + bias=False) ## 相加之后再加一层变换,之前缺少这层 + + self.src2ref = WarpGenerator(cfg.warp_generator) + self.ref2tgt = WarpGenerator(cfg.warp_generator) + + self.volume_source_nw = VPN_ResBlocks(cfg.vpn_resblocks) + + # Net that process volume after first duble-warping + self.volume_process_nw = Unet3D(cfg.unet3d) + + self.grid_sample = lambda inputs, grid: F.grid_sample(inputs.float(), grid.float(), padding_mode=cfg.grid_sample_padding_mode) + + self.warp_size = [cfg.face_encoder.gen_max_channels, cfg.exp_embedder.embed_size, cfg.exp_embedder.embed_size] + self.volume_size = [cfg.face_encoder.latent_volume_channels, cfg.face_encoder.latent_volume_depth, cfg.face_encoder.latent_volume_size, cfg.face_encoder.latent_volume_size] + + apply_sp_to_nets(self) + apply_ws_to_nets(self) + + + def forward(self, data_dict): + if self.training: + return torch.utils.checkpoint.checkpoint( \ + self.manual_forward, *[data_dict], + use_reentrant=False) + else: + return self.manual_forward(*[data_dict]) + + def predict_warping(self, data_dict): + warp_source_embed = self.pose_unsqueeze_nw(data_dict['source_exp_embed']).view(-1, self.warp_size[0], self.warp_size[1], self.warp_size[2]) # Bx512x4x4 + warp_source_embed_orig = self.warp_embed_head_orig_nw((warp_source_embed + data_dict['idt_embed']) * 0.5) # Bx512x4x4 + source_warp_embed_dict = {'orig': warp_source_embed_orig.view(-1, self.warp_size[0], self.warp_size[1] ** 2), 'ada_v': data_dict['source_exp_embed']} + data_dict['source_warp'], data_dict['source_delta_xy'] = self.src2ref(source_warp_embed_dict) # predict warping field from source embedding + + warp_target_embed = self.pose_unsqueeze_nw(data_dict['target_exp_embed']).view(-1, self.warp_size[0], self.warp_size[1], self.warp_size[2]) + warp_target_embed_orig = self.warp_embed_head_orig_nw((warp_target_embed + data_dict['idt_embed']) * 0.5) + target_warp_embed_dict = {'orig': warp_target_embed_orig.view(-1, self.warp_size[0], self.warp_size[1] ** 2), 'ada_v': data_dict['target_exp_embed']} + data_dict['target_warp'], data_dict['target_delta_uv'] = self.ref2tgt(target_warp_embed_dict) # predict warping field from target embedding + + data_dict['target_warp_embed_dict'] = target_warp_embed_dict + return data_dict + + def predict_target2canonical_warping(self, data_dict): + target_warp_embed_dict = data_dict['target_warp_embed_dict'] + target2ref_warp, _ = self.src2ref(target_warp_embed_dict) # predict warping field from target embedding + return target2ref_warp + + def manual_forward(self, data_dict): + b = data_dict['source_exp_embed'].size(0) + # Predict warping from src and tgt expression embedding + data_dict = self.predict_warping(data_dict) + + # Reshape latents into 3D volume + c, d, s, s = self.volume_size + latent_volume = data_dict['latent_volume'] + latent_volume = latent_volume.view(b, c, d, s, s) + latent_volume = self.volume_source_nw(latent_volume) + + # Warp from source pose + embed_dict = {} + warp_latent_volume = self.grid_sample(self.grid_sample(latent_volume, data_dict['source_rotation_warp']), data_dict['source_warp']) + canonical_latent_volume = self.volume_process_nw(warp_latent_volume, embed_dict) + + # Warp to target pose + aligned_target_volume = self.grid_sample(self.grid_sample(canonical_latent_volume, data_dict['target_warp']), data_dict['target_rotation_warp']) + aligned_target_volume = aligned_target_volume.view(b, c * d, s, s) + + data_dict['target_volume'] = aligned_target_volume + + if 'ref_matching' in data_dict and data_dict['ref_matching'] > 0: + # import pdb; pdb.set_trace() + + target2ref_warp = self.predict_target2canonical_warping(data_dict) + _latent_volume = self.volume_source_nw(data_dict['latent_volume_target'].view(b, c, d, s, s)) + _latent_volume = self.grid_sample(self.grid_sample(_latent_volume, data_dict['target_inv_rotation_warp']), target2ref_warp) + data_dict['canonical_volume_from_tgt'] = self.volume_process_nw(_latent_volume, embed_dict) + data_dict['canonical_volume_from_src'] = canonical_latent_volume + + return data_dict \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/head_pose_regressor.py b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/head_pose_regressor.py new file mode 100644 index 0000000000000000000000000000000000000000..469abd047656daf3352222cfad42cc8a753203cc --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/head_pose_regressor.py @@ -0,0 +1,33 @@ +import torch +from torch import nn +from torch.nn import functional as F +from torchvision import models +import math + +from . import point_transforms + + + +class HeadPoseRegressor(nn.Module): + def __init__(self, model_path, size=128) -> None: + super(HeadPoseRegressor, self).__init__() + self.net = models.resnet18(num_classes=9) + self.net.load_state_dict(torch.load(model_path, map_location='cpu')) + # self.net.eval() + # for param in self.net.parameters(): + # param.requires_grad = False + self.size = size + + # @torch.no_grad() + def forward(self, x, return_srt=False): + if x.shape[2] != self.size or x.shape[3] != self.size: + x = F.interpolate(x, size=(self.size, self.size), mode='bilinear') + + scale, rotation, translation = self.net(x).split([3, 3, 3], dim=1) + # print(scale.shape, rotation.shape, translation.shape) + thetas = point_transforms.get_transform_matrix(scale, rotation, translation) + + if return_srt: + return thetas, scale, rotation, translation + else: + return thetas \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/identity_embedder.py b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/identity_embedder.py new file mode 100644 index 0000000000000000000000000000000000000000..1474dd5d51567850d0c06e7dc681f91b41dc150e --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/identity_embedder.py @@ -0,0 +1,96 @@ +import torch +from torch import nn +from torch import optim +import torch.nn.functional as F +from torch.cuda import amp +from torchvision import models +import itertools +# import apex +from .utils import replace_bn_to_gn +from dataclasses import dataclass + +class IdtEmbed(nn.Module): + @dataclass + class Config: + idt_backbone: str + num_source_frames: int + idt_output_size: int + idt_output_channels: int + # num_gpus: int + norm_layer_type: str + idt_image_size: int + + + def __init__(self, cfg): + super(IdtEmbed, self).__init__() + self.cfg = cfg + EXPANSION = 1 if self.cfg.idt_backbone == 'resnet18' else 4 + self.num_source_frames = self.cfg.num_source_frames # number of source imgs per identity + self.net = getattr(models, self.cfg.idt_backbone)(pretrained=True) + self.idt_image_size = self.cfg.idt_image_size + # Patch backbone according to args + self.net.avgpool = nn.AdaptiveAvgPool2d(self.cfg.idt_output_size) + self.zero_to_one = cfg.zero_to_one + + num_outputs = self.cfg.idt_output_channels + + self.net.fc = nn.Conv2d( + in_channels=512 * EXPANSION, + out_channels=num_outputs, + kernel_size=1, + bias=False) + + if self.cfg.norm_layer_type=='bn': + pass + # if self.cfg.num_gpus > 1: + # # # self.net = nn.SyncBatchNorm.convert_sync_batchnorm(self.net) + # self.net = apex.parallel.convert_syncbn_model(self.net) + elif self.cfg.norm_layer_type=='in': + self.net = replace_bn_to_in(self.net, 'IdtEmbed') + elif self.cfg.norm_layer_type=='gn': + self.net = replace_bn_to_gn(self.net, 'IdtEmbed') + elif self.cfg.norm_layer_type == 'bcn': + self.net = replace_bn_to_bcn(self.net, 'IdtEmbed') + else: + raise ValueError('wrong norm type') + + self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406])[None, :, None, None]) + self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225])[None, :, None, None]) + + def _forward_impl(self, x): + x = self.net.conv1(x) + x = self.net.bn1(x) + x = self.net.relu(x) + x = self.net.maxpool(x) + + x = self.net.layer1(x) + x = self.net.layer2(x) + x = self.net.layer3(x) + x = self.net.layer4(x) + + x = self.net.fc(x) + x = self.net.avgpool(x) + + return x + + def forward(self, source_img): + if not self.zero_to_one: + source_img = (source_img + 1)/2 + idt_embed = self.forward_image(source_img) + + return idt_embed + + def forward_image(self, source_img): + source_img = F.interpolate(source_img, size=(self.idt_image_size, self.idt_image_size), mode='bilinear') + n = self.num_source_frames + b = source_img.shape[0] // n + + inputs = (source_img - self.mean) / self.std + idt_embed_tensor = self._forward_impl(inputs) + idt_embed_tensor = idt_embed_tensor.view(b, n, *idt_embed_tensor.shape[1:]).mean(1) + + return idt_embed_tensor + + + + diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/local_encoder.py b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/local_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..8eeb1d64561f4339f8085f4e58e9cfd8c9593428 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/local_encoder.py @@ -0,0 +1,131 @@ +import torch +from torch import nn +import torch.nn.functional as F +import torch.distributed as dist +from torchvision import models +from torch.cuda import amp +import math + +# from . import GridSample +from .utils import blocks, norm_layers, activations +# from argparse import ArgumentParser + +# from . import GridSample +# from . import utils +# import numpy as np +# import copy +# from scipy import linalg +# import itertools +# from .utils import ProjectorConv, ProjectorNorm, assign_adaptive_conv_params,assign_adaptive_norm_params +from dataclasses import dataclass + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x): + return self.fn(x) + x + +class LocalEncoder(nn.Module): + + @dataclass + class Config: + gen_upsampling_type: str + gen_downsampling_type: str + gen_input_image_size: int + gen_latent_texture_size: int + gen_latent_texture_depth: int + gen_latent_texture_channels: int + gen_num_channels: int + enc_channel_mult: float + norm_layer_type: str + gen_max_channels: int + enc_block_type: str + gen_activation_type: str + # num_gpus: int + warp_norm_grad: bool + in_channels: int = 3 + + + + def __init__(self, cfg: Config): + super(LocalEncoder, self).__init__() + + self.cfg = cfg + self.upsample_type = self.cfg.gen_upsampling_type + self.downsample_type = self.cfg.gen_downsampling_type + self.ratio = self.cfg.gen_input_image_size // self.cfg.gen_latent_texture_size + self.num_2d_blocks = int(math.log(self.ratio, 2)) + self.init_depth = self.cfg.gen_latent_texture_depth + spatial_size = self.cfg.gen_input_image_size + if self.cfg.warp_norm_grad: + self.grid_sample = GridSample(self.cfg.gen_latent_texture_size) + else: + self.grid_sample = lambda inputs, grid: F.grid_sample(inputs.float(), grid.float(), padding_mode='reflection') + + out_channels = int(self.cfg.gen_num_channels * self.cfg.enc_channel_mult) + + setattr( + self, + f'from_rgb_{spatial_size}px', + nn.Conv2d( + in_channels=self.cfg.in_channels, + out_channels=out_channels, + kernel_size=7, + padding=3, + )) + + # if self.cfg.norm_layer_type!='bn': + # norm = self.cfg.norm_layer_type + # else: + # norm = 'bn' if self.cfg.num_gpus < 2 else 'sync_bn' + norm = self.cfg.norm_layer_type + + for i in range(self.num_2d_blocks): + in_channels = out_channels + out_channels = min(out_channels * 2, self.cfg.gen_max_channels) + setattr( + self, + f'enc_{i}_block={spatial_size}px', + blocks[self.cfg.enc_block_type]( + in_channels=in_channels, + out_channels=out_channels, + stride=2, + norm_layer_type=norm, + activation_type=self.cfg.gen_activation_type, + resize_layer_type=self.cfg.gen_downsampling_type)) + spatial_size //= 2 + + in_channels = out_channels + out_channels = self.cfg.gen_latent_texture_channels + finale_layers = [] + if self.cfg.enc_block_type == 'res': + finale_layers += [ + norm_layers[norm](in_channels), + activations[self.cfg.gen_activation_type](inplace=True)] + + finale_layers += [ + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels * self.init_depth, + kernel_size=1)] + + + self.finale_layers = nn.Sequential(*finale_layers) + + def forward(self, source_img): + + s = source_img.shape[2] + + x = getattr(self, f'from_rgb_{s}px')(source_img) + + for i in range(self.num_2d_blocks): + x = getattr(self, f'enc_{i}_block={s}px')(x) + s //= 2 + + x = self.finale_layers(x) + + return x + + diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/motion_encoder.py b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/motion_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..0fac9a75aa504b4c3a669e889edd06db604f8ad3 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/motion_encoder.py @@ -0,0 +1,98 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import os +from pathlib import Path +import torchvision + +from model.head_animation.EMOP.expression_embedder import ExpressionEmbed +from model.head_animation.EMOP.head_pose_regressor import HeadPoseRegressor + +from model.head_animation.EMOP.spectral_norm import apply_sp_to_nets +from model.head_animation.EMOP.utils import apply_ws_to_nets + +class MotionEncoder(nn.Module): + def __init__(self, cfg): + super().__init__() + + self.use_mask_image = False # this setting is only for rigid_pose_encoder + self.zero_to_one = cfg.zero_to_one + self.expression_encoder = ExpressionEmbed(cfg.exp_embedder) + self.rigid_pose_encoder = HeadPoseRegressor(cfg.head_pose_regressor_path, size=128) + # for param in self.rigid_pose_encoder.parameters(): # we need finetune it on masked image + # param.requires_grad = True + + grid_s = torch.linspace(-1, 1, cfg.face_encoder.latent_volume_size) + grid_z = torch.linspace(-1, 1, cfg.face_encoder.latent_volume_depth) + w, v, u = torch.meshgrid(grid_z, grid_s, grid_s) + e = torch.ones_like(u) + self.identity_grid_3d = torch.stack([u, v, w, e], dim=3).view(1, -1, 4) + + apply_sp_to_nets(self) + apply_ws_to_nets(self) + + def get_motion_latent(self, img, src): + rotation_warp, theta, rotation = self.predict_pose(img, src) + ## TODO: predict expression + # motion_latent: rotation, scale, translation, expression + return motion_latent + + def predict_pose(self, img, src): + b, d, s = img.shape[0], 16, 64 + + if not self.zero_to_one: + img = (img + 1) / 2 + theta, scale, rotation, translation = self.rigid_pose_encoder.forward(img, return_srt=True) + + grid = self.identity_grid_3d.repeat_interleave(b, dim=0).to(img.device) + if src: + inv_source_theta = theta.float().inverse().type(theta.type()) + rotation_warp = grid.bmm(inv_source_theta[:, :3].transpose(1, 2)).view(-1, d, s, s, 3) + else: + rotation_warp = grid.bmm(theta[:, :3].transpose(1, 2)).view(-1, d, s, s, 3) + + # rotation_warp = rotation_warp.detach() + # theta = theta.detach() + # rotation = rotation.detach() + rotation_warp = rotation_warp.detach() + + return rotation_warp, theta, rotation, scale, translation + + def predict_expression(self, data_dict): + data_dict = self.expression_encoder(data_dict) ## align src and tgt image, then predict expression code + return data_dict + + def forward(self, data_dict): + if self.training: + return torch.utils.checkpoint.checkpoint( \ + self.manual_forward, *[data_dict], + use_reentrant=False) + else: + return self.manual_forward(*[data_dict]) + + def manual_forward(self, data_dict): + # predict pose + if self.use_mask_image and 'masked_source_img' in data_dict: + src_img = data_dict['masked_source_img'] + tgt_img = data_dict['masked_target_img'] + else: + src_img = data_dict['source_img'] + tgt_img = data_dict['target_img'] + + data_dict['source_rotation_warp'], data_dict['source_theta'], data_dict['source_rotation'], data_dict['source_scale'], data_dict['source_translation'] = self.predict_pose(src_img, src=True) + data_dict['target_rotation_warp'], data_dict['target_theta'], data_dict['target_rotation'], data_dict['target_scale'], data_dict['target_translation'] = self.predict_pose(tgt_img, src=False) + + # predict expression + data_dict = self.predict_expression(data_dict) + + ## predict inverse transformation from target image for canonical volumn consistency supervision + if 'ref_matching' in data_dict and data_dict['ref_matching'] > 0: + data_dict['target_inv_rotation_warp'], _, _, _, _ = self.predict_pose(tgt_img, src=True) + + + return data_dict + + + + + \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/point_transforms.py b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/point_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..b7de4822f3fdc27d3b660ab97915dd886796ed5b --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/point_transforms.py @@ -0,0 +1,307 @@ +import torch +from torch import optim +import math + + + +def parse_3dmm_param(param): + """matrix pose form + param: shape=(trans_dim+shape_dim+exp_dim,), i.e., 62 = 12 + 40 + 10 + """ + + # pre-defined templates for parameter + n = param.shape[0] + if n == 62: + trans_dim, shape_dim, exp_dim = 12, 40, 10 + elif n == 72: + trans_dim, shape_dim, exp_dim = 12, 40, 20 + elif n == 141: + trans_dim, shape_dim, exp_dim = 12, 100, 29 + else: + raise Exception(f'Undefined templated param parsing rule') + + R_ = param[:trans_dim].reshape(3, -1) + R = R_[:, :3] + offset = R_[:, -1].reshape(3, 1) + alpha_shp = param[trans_dim:trans_dim + shape_dim].reshape(-1, 1) + alpha_exp = param[trans_dim + shape_dim:].reshape(-1, 1) + + return R, offset, alpha_shp, alpha_exp + +def world_to_camera(pts_world, params): + R, offset, roi_box, size = params['R'], params['offset'], params['roi_box'], params['size'] + crop_box = params['crop_box'] if 'crop_box' in params.keys() and len(params['crop_box']) else None + + if pts_world.shape[0] < R.shape[0]: + pts_camera = pts_world.repeat_interleave(R.shape[0] // pts_world.shape[0], dim=0) + + elif pts_world.shape[0] > R.shape[0]: + num_repeats = pts_world.shape[0] // R.shape[0] + + R = R.repeat_interleave(num_repeats, dim=0) + offset = offset.repeat_interleave(num_repeats, dim=0) + roi_box = roi_box.repeat_interleave(num_repeats, dim=0) + size = size.repeat_interleave(num_repeats, dim=0) + if crop_box is not None: + crop_box = crop_box.repeat_interleave(num_repeats, dim=0) + + pts_camera = pts_world.clone() + + else: + pts_camera = pts_world.clone() + + pts_camera[..., 2] += 0.5 + pts_camera *= 2e5 + + pts_camera = pts_camera @ R.transpose(1, 2) + offset.transpose(1, 2) + + pts_camera[..., 0] -= 1 + pts_camera[..., 2] -= 1 + pts_camera[..., 1] = 120 - pts_camera[..., 1] + + sx, sy, ex, ey = [chunk[..., 0] for chunk in roi_box.split(1, dim=2)] + scale_x = (ex - sx) / 120 + scale_y = (ey - sy) / 120 + scale_z = (scale_x + scale_y) / 2 + + pts_camera[..., 0] = pts_camera[..., 0] * scale_x + sx + pts_camera[..., 1] = pts_camera[..., 1] * scale_y + sy + pts_camera[..., 2] = pts_camera[..., 2] * scale_z + + pts_camera /= size + pts_camera[..., 0] -= 0.5 + pts_camera[..., 1] -= 0.5 + pts_camera[..., :2] *= 2 + + if crop_box is not None: + crop_shift_x = (crop_box[..., 0] + crop_box[..., 2]) / 2 + crop_shift_y = (crop_box[..., 1] + crop_box[..., 3]) / 2 + + pts_camera[..., 0] -= crop_shift_x + pts_camera[..., 1] -= crop_shift_y + + crop_scale_x = (crop_box[..., 2] - crop_box[..., 0]) / 2 + crop_scale_y = (crop_box[..., 3] - crop_box[..., 1]) / 2 + crop_scale_z = (crop_scale_x + crop_scale_y) / 2 + + pts_camera[..., 0] /= crop_scale_x + pts_camera[..., 1] /= crop_scale_y + pts_camera[..., 2] /= crop_scale_z + + return pts_camera + +def camera_to_world(pts_camera, params): + R, offset, roi_box, size = params['R'], params['offset'], params['roi_box'], params['size'] + crop_box = params['crop_box'] if 'crop_box' in params.keys() and len(params['crop_box']) else None + + if pts_camera.shape[0] < R.shape[0]: + pts_world = pts_camera.repeat_interleave(R.shape[0] // pts_camera.shape[0], dim=0) + + elif pts_camera.shape[0] > R.shape[0]: + num_repeats = pts_camera.shape[0] // R.shape[0] + + R = R.repeat_interleave(num_repeats, dim=0) + offset = offset.repeat_interleave(num_repeats, dim=0) + roi_box = roi_box.repeat_interleave(num_repeats, dim=0) + size = size.repeat_interleave(num_repeats, dim=0) + if crop_box is not None: + crop_box = crop_box.repeat_interleave(num_repeats, dim=0) + + pts_world = pts_camera.clone() + + else: + pts_world = pts_camera.clone() + + if crop_box is not None: + crop_scale_x = (crop_box[..., 2] - crop_box[..., 0]) / 2 + crop_scale_y = (crop_box[..., 3] - crop_box[..., 1]) / 2 + crop_scale_z = (crop_scale_x + crop_scale_y) / 2 + + pts_world[..., 0] *= crop_scale_x + pts_world[..., 1] *= crop_scale_y + pts_world[..., 2] *= crop_scale_z + + crop_shift_x = (crop_box[..., 0] + crop_box[..., 2]) / 2 + crop_shift_y = (crop_box[..., 1] + crop_box[..., 3]) / 2 + + pts_world[..., 0] += crop_shift_x + pts_world[..., 1] += crop_shift_y + + pts_world[..., :2] /= 2 + pts_world[..., 0] += 0.5 + pts_world[..., 1] += 0.5 + pts_world *= size + + sx, sy, ex, ey = [chunk[..., 0] for chunk in roi_box.split(1, dim=2)] + scale_x = (ex - sx) / 120 + scale_y = (ey - sy) / 120 + scale_z = (scale_x + scale_y) / 2 + + pts_world[..., 0] = (pts_world[..., 0] - sx) / scale_x + pts_world[..., 1] = (pts_world[..., 1] - sy) / scale_y + pts_world[..., 2] = pts_world[..., 2] / scale_z + + pts_world[..., 0] += 1 + pts_world[..., 2] += 1 + pts_world[..., 1] = -(pts_world[..., 1] - 120) + + pts_world = (pts_world - offset.transpose(1, 2)) @ torch.linalg.inv(R.transpose(1, 2)) + + pts_world /= 2e5 + pts_world[..., 2] -= 0.5 + + return pts_world + +############################################################################### + +def align_ffhq_with_zoom(pts_camera, params): + R, offset = params['theta'].split([2, 1], dim=2) + crop_box = params['crop_box'] if 'crop_box' in params.keys() and len(params['crop_box']) else None + + if pts_camera.shape[0] != R.shape[0]: + pts_camera = pts_camera.repeat_interleave(R.shape[0], dim=0) + else: + pts_camera = pts_camera.clone() + + pts_camera = pts_camera @ R.transpose(1, 2) + offset.transpose(1, 2) + + # Zoom into face + pts_camera *= 0.6 + + if crop_box is not None: + crop_shift_x = (crop_box[..., 0] + crop_box[..., 2]) / 2 + crop_shift_y = (crop_box[..., 1] + crop_box[..., 3]) / 2 + + pts_camera[..., 0] -= crop_shift_x + pts_camera[..., 1] -= crop_shift_y + + crop_scale_x = (crop_box[..., 2] - crop_box[..., 0]) / 2 + crop_scale_y = (crop_box[..., 3] - crop_box[..., 1]) / 2 + + pts_camera[..., 0] /= crop_scale_x + pts_camera[..., 1] /= crop_scale_y + + return pts_camera + +############################################################################### + +def get_transform_matrix(scale, rotation, translation): + b = scale.shape[0] + dtype = scale.dtype + device = scale.device + + eye_matrix = torch.eye(4, dtype=dtype, device=device)[None].repeat_interleave(b, dim=0) + + # Scale transform + S = eye_matrix.clone() + + if scale.shape[1] == 3: + S[:, 0, 0] = scale[:, 0] + S[:, 1, 1] = scale[:, 1] + S[:, 2, 2] = scale[:, 2] + else: + S[:, 0, 0] = scale[:, 0] + S[:, 1, 1] = scale[:, 0] + S[:, 2, 2] = scale[:, 0] + + # Rotation transform + R = eye_matrix.clone() + + rotation = rotation.clamp(-math.pi/2, math.pi) + + yaw, pitch, roll = rotation.split(1, dim=1) + yaw, pitch, roll = yaw[:, 0], pitch[:, 0], roll[:, 0] # squeeze angles + yaw_cos = yaw.cos() + yaw_sin = yaw.sin() + pitch_cos = pitch.cos() + pitch_sin = pitch.sin() + roll_cos = roll.cos() + roll_sin = roll.sin() + + R[:, 0, 0] = yaw_cos * pitch_cos + R[:, 0, 1] = yaw_cos * pitch_sin * roll_sin - yaw_sin * roll_cos + R[:, 0, 2] = yaw_cos * pitch_sin * roll_cos + yaw_sin * roll_sin + + R[:, 1, 0] = yaw_sin * pitch_cos + R[:, 1, 1] = yaw_sin * pitch_sin * roll_sin + yaw_cos * roll_cos + R[:, 1, 2] = yaw_sin * pitch_sin * roll_cos - yaw_cos * roll_sin + + R[:, 2, 0] = -pitch_sin + R[:, 2, 1] = pitch_cos * roll_sin + R[:, 2, 2] = pitch_cos * roll_cos + + # Translation transform + T = eye_matrix.clone() + + T[:, 0, 3] = translation[:, 0] + T[:, 1, 3] = translation[:, 1] + T[:, 2, 3] = translation[:, 2] + + theta = S @ R @ T + + return theta + +def estimate_transform_from_keypoints(keypoints, aligned_keypoints, dilation=True, shear=False): + b, n = keypoints.shape[:2] + device = keypoints.device + dtype = keypoints.dtype + + keypoints = keypoints.to(device) + aligned_keypoints = aligned_keypoints.to(device) + + keypoints = torch.cat([keypoints, torch.ones(b, n, 1, device=device, dtype=dtype)], dim=2) + + if not dilation and not shear: + # scale, yaw, pitch, roll, dx, dy, dz + param = torch.tensor([[1, 0, 0, 0, 0, 0, 0]], device=device, dtype=dtype) + + scale, rotation, translation = param.repeat_interleave(b, dim=0).split([1, 3, 3], dim=1) + params = [scale, rotation, translation] + + elif dilation and not shear: + # scale_x, scale_y, scale_z, yaw, pitch, roll, dx, dy, dz + param = torch.tensor([[1, 1, 1, 0, 0, 0, 0, 0, 0]], device=device, dtype=dtype) + + scale, rotation, translation = param.repeat_interleave(b, dim=0).split([3, 3, 3], dim=1) + params = [scale, rotation, translation] + + elif dilation and shear: + # full affine matrix + theta = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]], device=device, dtype=dtype) + theta = theta[None].repeat_interleave(b, dim=0) + params = [theta] + + # Solve for a given transform + params = [p.clone().requires_grad_() for p in params] + + opt = optim.LBFGS(params) + + def closure(): + opt.zero_grad() + + if not shear: + theta = get_transform_matrix(*params)[:, :3] + else: + theta = params[0] + + pred_aligned_keypoints = keypoints @ theta.transpose(1, 2) + + loss = ((pred_aligned_keypoints - aligned_keypoints)**2).mean() + loss.backward() + + return loss + + for i in range(5): + opt.step(closure) + + if not shear: + theta = get_transform_matrix(*params).detach() + else: + theta = params[0].detach() + + eye = torch.zeros(b, 4, device=device, dtype=dtype) + eye[:, 2] = 1 + + theta = torch.cat([theta, eye], dim=1) + + return theta, params \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/resblocks_3d.py b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/resblocks_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..19e13cfa898b6f4d47b4dc3d4228aefe99cf65ff --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/resblocks_3d.py @@ -0,0 +1,63 @@ +import torch +from torch import nn +from torch.nn import functional as F +from . import utils +from typing import List, Union + + +class ResBlocks3d(nn.Module): + def __init__(self, + input_channels: int, + conv_layer_type: str, + num_blocks: int, + norm_layer_type: str, + activation_type: str, + channels: Union[List, None], + ) -> None: + super(ResBlocks3d, self).__init__() + # expansion_factor = 4 if block_type == 'bottleneck' else 1 + # hidden_channels = input_channels // expansion_factor + + layers_ = [] + + if channels is None or len(channels)==0: + channels = [input_channels]*num_blocks + + assert len(channels) == num_blocks + + # if norm_layer_type != 'bn': + # norm_3d = norm_layer_type + '_3d' + # else: + # norm_3d = 'bn_3d' if num_gpus < 2 else 'sync_bn' + + norm_3d = norm_layer_type + '_3d' + + input = input_channels + + for i in range(num_blocks): + out = channels[i] + layers_.append(utils.blocks['res']( + in_channels=input, + out_channels=out, + stride=1, + norm_layer_type=norm_3d, + activation_type=activation_type, + conv_layer_type=conv_layer_type)) + input = out + + # layers_.append(layers.blocks[block_type]( + # in_channels=hidden_channels, + # out_channels=hidden_channels, + # num_layers=num_layers, + # expansion_factor=expansion_factor, + # kernel_size=3, + # stride=1, + # norm_layer_type=norm_layer_type, + # activation_type=activation_type, + # conv_layer_type='conv_3d')) + + self.net = nn.Sequential(*layers_) + + def forward(self, x): + + return self.net(x) \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/spectral_norm.py b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/spectral_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..e10c31944344c6baf7f5e89590ca0e533e83ffd7 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/spectral_norm.py @@ -0,0 +1,349 @@ +""" +Spectral Normalization from https://arxiv.org/abs/1802.05957 +""" +import torch +from torch.nn.functional import normalize +from typing import Any, Optional, TypeVar +from torch.nn import Module +from torch import nn +from model.head_animation.EMOP import args as args_utils + + +def apply_spectral_norm(module, name='weight', apply_to=['conv2d', ], n_power_iterations=1, eps=1e-12): + # Apply only to modules in apply_to list + module_name = module.__class__.__name__.lower() + if module_name not in apply_to: + return module + + if isinstance(module, nn.ConvTranspose2d): + dim = 1 + else: + dim = 0 + + SpectralNorm.apply(module, name, n_power_iterations, dim, eps, adaptive='adaptive' in module_name) + + return module + + + +def apply_sp_to_nets(obj): + # spn_layers = args_utils.parse_str_to_list(obj.args.spn_layers, sep=',') + # spn_nets_names = args_utils.parse_str_to_list(obj.args.spn_networks, sep=',') + + spn_layers = ['conv2d', 'conv3d', 'linear', 'conv2d_ws', 'conv3d_ws'] + spn_networks='local_encoder_nw, local_encoder_seg_nw, local_encoder_mask_nw, idt_embedder_nw, expression_embedder_nw, xy_generator_nw, uv_generator_nw, warp_embed_head_orig_nw, pose_embed_decode_nw, pose_embed_code_nw, volume_process_nw, volume_source_nw, volume_pred_nw, decoder_nw, backgroung_adding_nw, background_process_nw' + spn_networks='app_encoder, idt_encoder, expression_encoder, decoder, warp_embed_head_orig_nw, src2ref, ref2tgt, volume_source_nw, volume_process_nw' + spn_nets_names = args_utils.parse_str_to_list(spn_networks, sep=',') + + + for net_name in spn_nets_names: + try: + net = getattr(obj, net_name) + net.apply(lambda module: apply_spectral_norm(module, apply_to=spn_layers)) + # print(f'SN applied to {net_name}') + except Exception as e: + pass + + + +def remove_spectral_norm(module, name='weight'): + for k, hook in module._forward_pre_hooks.items(): + if isinstance(hook, SpectralNorm) and hook.name == name: + hook.remove(module) + del module._forward_pre_hooks[k] + break + + return module + + +class SpectralNorm: + # Invariant before and after each forward call: + # u = normalize(W @ v) + # NB: At initialization, this invariant is not enforced + + _version: int = 1 + # At version 1: + # made `W` not a buffer, + # added `v` as a buffer, and + # made eval mode use `W = u @ W_orig @ v` rather than the stored `W`. + name: str + dim: int + n_power_iterations: int + eps: float + + def __init__(self, name: str = 'weight', n_power_iterations: int = 1, dim: int = 0, eps: float = 1e-12, adaptive: bool = False) -> None: + self.name = name + self.dim = dim + if n_power_iterations <= 0: + raise ValueError('Expected n_power_iterations to be positive, but ' + 'got n_power_iterations={}'.format(n_power_iterations)) + self.n_power_iterations = n_power_iterations + self.eps = eps + self.adaptive = adaptive + + def reshape_weight_to_matrix(self, weight: torch.Tensor) -> torch.Tensor: + weight_mat = weight + + if self.adaptive: + assert self.dim == 0 + + height = weight_mat.size(1) + return weight_mat.reshape(weight_mat.shape[0], height, -1) + + else: + if self.dim != 0: + # permute dim to front + weight_mat = weight_mat.permute(self.dim, + *[d for d in range(weight_mat.dim()) if d != self.dim]) + height = weight_mat.size(0) + return weight_mat.reshape(height, -1) + + def compute_weight(self, module: Module, do_power_iteration: bool) -> torch.Tensor: + # NB: If `do_power_iteration` is set, the `u` and `v` vectors are + # updated in power iteration **in-place**. This is very important + # because in `DataParallel` forward, the vectors (being buffers) are + # broadcast from the parallelized module to each module replica, + # which is a new module object created on the fly. And each replica + # runs its own spectral norm power iteration. So simply assigning + # the updated vectors to the module this function runs on will cause + # the update to be lost forever. And the next time the parallelized + # module is replicated, the same randomly initialized vectors are + # broadcast and used! + # + # Therefore, to make the change propagate back, we rely on two + # important behaviors (also enforced via tests): + # 1. `DataParallel` doesn't clone storage if the broadcast tensor + # is already on correct device; and it makes sure that the + # parallelized module is already on `device[0]`. + # 2. If the out tensor in `out=` kwarg has correct shape, it will + # just fill in the values. + # Therefore, since the same power iteration is performed on all + # devices, simply updating the tensors in-place will make sure that + # the module replica on `device[0]` will update the _u vector on the + # parallized module (by shared storage). + # + # However, after we update `u` and `v` in-place, we need to **clone** + # them before using them to normalize the weight. This is to support + # backproping through two forward passes, e.g., the common pattern in + # GAN training: loss = D(real) - D(fake). Otherwise, engine will + # complain that variables needed to do backward for the first forward + # (i.e., the `u` and `v` vectors) are changed in the second forward. + if self.adaptive: + weight = getattr(module, 'ada_' + self.name + '_orig') + else: + weight = getattr(module, self.name + '_orig') + + weight_mat = self.reshape_weight_to_matrix(weight) + + u = getattr(module, self.name + '_u') + v = getattr(module, self.name + '_v') + + if do_power_iteration: + with torch.no_grad(): + for _ in range(self.n_power_iterations): + # Spectral norm of weight equals to `u^T W v`, where `u` and `v` + # are the first left and right singular vectors. + # This power iteration produces approximations of `u` and `v`. + if self.adaptive: + v = normalize(torch.matmul(u[None, None, :], weight_mat)[:, 0].mean(0), dim=0, eps=self.eps, out=v) + u = normalize(torch.matmul(weight_mat, v[None, :, None])[..., 0].mean(0), dim=0, eps=self.eps, out=u) + + else: + v = normalize(torch.mv(weight_mat.t(), u), dim=0, eps=self.eps, out=v) + u = normalize(torch.mv(weight_mat, v), dim=0, eps=self.eps, out=u) + + if self.n_power_iterations > 0: + # See above on why we need to clone + u = u.clone(memory_format=torch.contiguous_format) + v = v.clone(memory_format=torch.contiguous_format) + + if self.adaptive: + sigma = torch.mv(torch.matmul(weight_mat, v[None, :, None])[..., 0], u) + + if len(weight.shape) == 6: + sigma = sigma[:, None, None, None, None, None] + else: + sigma = sigma[:, None, None, None, None] + + else: + sigma = torch.dot(u, torch.mv(weight_mat, v)) + + weight = weight / sigma + + return weight + + def remove(self, module: Module) -> None: + with torch.no_grad(): + weight = self.compute_weight(module, do_power_iteration=False) + delattr(module, self.name) + delattr(module, self.name + '_u') + delattr(module, self.name + '_v') + delattr(module, self.name + '_orig') + module.register_parameter(self.name, torch.nn.Parameter(weight.detach())) + + def __call__(self, module: Module, inputs: Any) -> None: + setattr(module, 'ada_' + self.name if self.adaptive else self.name, self.compute_weight(module, do_power_iteration=module.training)) + + def _solve_v_and_rescale(self, weight_mat, u, target_sigma): + # Tries to returns a vector `v` s.t. `u = normalize(W @ v)` + # (the invariant at top of this class) and `u @ W @ v = sigma`. + # This uses pinverse in case W^T W is not invertible. + v = torch.chain_matmul(weight_mat.t().mm(weight_mat).pinverse(), weight_mat.t(), u.unsqueeze(1)).squeeze(1) + return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v))) + + @staticmethod + def apply(module: Module, name: str, n_power_iterations: int, dim: int, eps: float, adaptive: bool) -> 'SpectralNorm': + for k, hook in module._forward_pre_hooks.items(): + if isinstance(hook, SpectralNorm) and hook.name == name: + raise RuntimeError("Cannot register two spectral_norm hooks on " + "the same parameter {}".format(name)) + + fn = SpectralNorm(name, n_power_iterations, dim, eps) + weight = module._parameters[name] + + with torch.no_grad(): + weight_mat = fn.reshape_weight_to_matrix(weight) + h, w = weight_mat.size() + + # randomly initialize `u` and `v` + u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps) + v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps) + + fn.adaptive = adaptive + + delattr(module, fn.name) + module.register_parameter(fn.name + "_orig", weight) + # We still need to assign weight back as fn.name because all sorts of + # things may assume that it exists, e.g., when initializing weights. + # However, we can't directly assign as it could be an nn.Parameter and + # gets added as a parameter. Instead, we register weight.data as a plain + # attribute. + setattr(module, fn.name, weight.data) + module.register_buffer(fn.name + "_u", u) + module.register_buffer(fn.name + "_v", v) + + module.register_forward_pre_hook(fn) + module._register_state_dict_hook(SpectralNormStateDictHook(fn)) + module._register_load_state_dict_pre_hook(SpectralNormLoadStateDictPreHook(fn)) + return fn + + +# This is a top level class because Py2 pickle doesn't like inner class nor an +# instancemethod. +class SpectralNormLoadStateDictPreHook: + # See docstring of SpectralNorm._version on the changes to spectral_norm. + def __init__(self, fn) -> None: + self.fn = fn + + # For state_dict with version None, (assuming that it has gone through at + # least one training forward), we have + # + # u = normalize(W_orig @ v) + # W = W_orig / sigma, where sigma = u @ W_orig @ v + # + # To compute `v`, we solve `W_orig @ x = u`, and let + # v = x / (u @ W_orig @ x) * (W / W_orig). + def __call__(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) -> None: + fn = self.fn + version = local_metadata.get('spectral_norm', {}).get(fn.name + '.version', None) + if version is None or version < 1: + weight_key = prefix + fn.name + if version is None and all(weight_key + s in state_dict for s in ('_orig', '_u', '_v')) and \ + weight_key not in state_dict: + # Detect if it is the updated state dict and just missing metadata. + # This could happen if the users are crafting a state dict themselves, + # so we just pretend that this is the newest. + return + has_missing_keys = False + for suffix in ('_orig', '', '_u'): + key = weight_key + suffix + if key not in state_dict: + has_missing_keys = True + if strict: + missing_keys.append(key) + if has_missing_keys: + return + with torch.no_grad(): + weight_orig = state_dict[weight_key + '_orig'] + weight = state_dict.pop(weight_key) + sigma = (weight_orig / weight).mean() + weight_mat = fn.reshape_weight_to_matrix(weight_orig) + u = state_dict[weight_key + '_u'] + v = fn._solve_v_and_rescale(weight_mat, u, sigma) + state_dict[weight_key + '_v'] = v + + +# This is a top level class because Py2 pickle doesn't like inner class nor an +# instancemethod. +class SpectralNormStateDictHook: + # See docstring of SpectralNorm._version on the changes to spectral_norm. + def __init__(self, fn) -> None: + self.fn = fn + + def __call__(self, module, state_dict, prefix, local_metadata) -> None: + if 'spectral_norm' not in local_metadata: + local_metadata['spectral_norm'] = {} + key = self.fn.name + '.version' + if key in local_metadata['spectral_norm']: + raise RuntimeError("Unexpected key in metadata['spectral_norm']: {}".format(key)) + local_metadata['spectral_norm'][key] = self.fn._version + + +T_module = TypeVar('T_module', bound=Module) + +def spectral_norm(module: T_module, + name: str = 'weight', + n_power_iterations: int = 1, + eps: float = 1e-12, + dim: Optional[int] = None) -> T_module: + r"""Applies spectral normalization to a parameter in the given module. + + .. math:: + \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, + \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} + + Spectral normalization stabilizes the training of discriminators (critics) + in Generative Adversarial Networks (GANs) by rescaling the weight tensor + with spectral norm :math:`\sigma` of the weight matrix calculated using + power iteration method. If the dimension of the weight tensor is greater + than 2, it is reshaped to 2D in power iteration method to get spectral + norm. This is implemented via a hook that calculates spectral norm and + rescales weight before every :meth:`~Module.forward` call. + + See `Spectral Normalization for Generative Adversarial Networks`_ . + + .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 + + Args: + module (nn.Module): containing module + name (str, optional): name of weight parameter + n_power_iterations (int, optional): number of power iterations to + calculate spectral norm + eps (float, optional): epsilon for numerical stability in + calculating norms + dim (int, optional): dimension corresponding to number of outputs, + the default is ``0``, except for modules that are instances of + ConvTranspose{1,2,3}d, when it is ``1`` + + Returns: + The original module with the spectral norm hook + + Example:: + + >>> m = spectral_norm(nn.Linear(20, 40)) + >>> m + Linear(in_features=20, out_features=40, bias=True) + >>> m.weight_u.size() + torch.Size([40]) + + """ + if dim is None: + if isinstance(module, (torch.nn.ConvTranspose1d, + torch.nn.ConvTranspose2d, + torch.nn.ConvTranspose3d)): + dim = 1 + else: + dim = 0 + SpectralNorm.apply(module, name, n_power_iterations, dim, eps) + return module \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/unet_3d.py b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/unet_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..3755f791d1c93ba5608fad46e5312eec10977025 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/unet_3d.py @@ -0,0 +1,303 @@ +import torch +from torch import nn +import torch.nn.functional as F +import torch.distributed as dist +from torchvision import models +from torch.cuda import amp +from argparse import ArgumentParser +import math +from .utils import GridSample +from . import utils +import numpy as np +import copy +from scipy import linalg +import itertools +from .utils import ProjectorConv, ProjectorNorm, assign_adaptive_conv_params,assign_adaptive_norm_params +from dataclasses import dataclass + +class Unet3D(nn.Module): + + @dataclass + class Config: + eps: float + gen_embed_size: int + gen_adaptive_kernel: bool + gen_use_adanorm: bool + gen_use_adaconv: bool + gen_upsampling_type: str + gen_downsampling_type: str + gen_dummy_input_size: int + gen_latent_texture_size: int + gen_latent_texture_depth: int + gen_adaptive_conv_type: str + gen_latent_texture_channels: int + gen_activation_type: str + gen_max_channels: int + warp_norm_grad: bool + warp_block_type: str + image_size: int + norm_layer_type: str + tex_pred_rgb: bool = False + tex_use_skip_resblock: bool = True + + def __init__(self, cfg: Config) -> None: + super().__init__() + self.cfg = cfg + + self.upsample_type = self.cfg.gen_upsampling_type + self.downsample_type = self.cfg.gen_downsampling_type + num_3d_blocks = int(math.log(self.cfg.gen_latent_texture_size // self.cfg.gen_dummy_input_size, 2)) + self.init_depth = self.cfg.gen_latent_texture_depth + self.adaptive_conv_type = self.cfg.gen_adaptive_conv_type + self.upsample_type = self.cfg.gen_upsampling_type + self.output_depth = self.cfg.gen_latent_texture_depth + self.gen_max_channels = self.cfg.gen_max_channels + self.norm_layer_type = self.cfg.norm_layer_type + norm_layer_type = self.cfg.norm_layer_type + + if self.cfg.warp_norm_grad: + self.grid_sample = GridSample(self.cfg.gen_latent_texture_size) + else: + self.grid_sample = lambda inputs, grid: F.grid_sample(inputs.float(), grid.float(), padding_mode='reflection') + + out_channels = self.cfg.gen_latent_texture_channels + + self.blocks_3d_down = nn.ModuleList() + + norm_3d = norm_layer_type + '_3d' + + for _ in range(num_3d_blocks): + in_channels = out_channels + out_channels = min(out_channels * 2, self.cfg.gen_max_channels) + self.blocks_3d_down += [ + utils.blocks['res']( + in_channels=in_channels, + out_channels=out_channels, + stride=1, + norm_layer_type=norm_3d, + activation_type=self.cfg.gen_activation_type, + conv_layer_type='conv_3d')] + + self.downsample = utils.downsampling_layers[self.downsample_type + '_3d'](kernel_size=2, stride=2) + self.downsample_no_depth = utils.downsampling_layers[self.downsample_type + '_3d'](kernel_size=(1, 2, 2), + stride=(1, 2, 2)) + +######################################################################################################################### + + num_blocks = int(math.log(self.cfg.gen_latent_texture_size // self.cfg.gen_dummy_input_size, 2)) + self.num_blocks = num_blocks + out_channels = min(int(self.cfg.gen_latent_texture_channels * 2 ** num_blocks), self.cfg.gen_max_channels) + + self.input_tensor = nn.Parameter( + torch.empty(1, out_channels, self.cfg.gen_dummy_input_size, self.cfg.gen_dummy_input_size, + self.cfg.gen_dummy_input_size)) + nn.init.normal_(self.input_tensor, std=1.0) + + # Initialize net + self.blocks_3d_up = nn.ModuleList() + + if self.cfg.tex_use_skip_resblock: + self.skip_blocks_3d_up = nn.ModuleList() + + # if norm_layer_type == 'bn': + # if self.cfg.num_gpus > 1: + # norm_layer_type = 'sync_' + norm_layer_type + + # if self.cfg.gen_use_adanorm: + # norm_layer_type = 'ada_' + norm_layer_type + # elif self.cfg.num_gpus < 2: + # norm_layer_type += '_3d' + + conv_layer_type = 'conv_3d' + if self.cfg.gen_use_adaconv: + conv_layer_type = 'ada_' + conv_layer_type + + + for i in range(num_blocks - 1, -1, -1): + in_channels = out_channels + out_channels = min(int(self.cfg.gen_latent_texture_channels * 2 ** i), self.cfg.gen_max_channels) + + self.blocks_3d_up += [ + utils.blocks['res']( + in_channels=in_channels, + out_channels=out_channels, + stride=1, + norm_layer_type=norm_3d, + activation_type=self.cfg.gen_activation_type, + conv_layer_type=conv_layer_type)] + + if self.cfg.tex_use_skip_resblock: + self.skip_blocks_3d_up += [ + utils.blocks[self.cfg.warp_block_type]( + in_channels=in_channels, + out_channels=in_channels, + stride=1, + # norm_layer_type='bn_3d', + norm_layer_type=norm_3d, + activation_type=self.cfg.gen_activation_type, + conv_layer_type='conv_3d')] + + self.head = nn.Sequential( + utils.norm_layers[norm_3d](out_channels), + # utils.norm_layers['bn_3d'](out_channels), + utils.activations[self.cfg.gen_activation_type](inplace=True), + nn.Conv3d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=1)) + + if self.cfg.tex_pred_rgb: + # Auxiliary blocks which predict rgb texture + num_rgb_blocks = int(math.log(self.cfg.image_size // self.cfg.gen_latent_texture_size, 2)) + self.blocks_rgb = nn.ModuleList() + + for i in range(num_rgb_blocks): + in_channels = out_channels + out_channels = out_channels // 2 + + self.blocks_rgb += [ + utils.blocks['res']( + in_channels=in_channels, + out_channels=out_channels, + stride=1, + norm_layer_type=norm_layer_type, + activation_type=self.cfg.gen_activation_type, + conv_layer_type=conv_layer_type)] + + self.head_rgb = nn.Sequential( + # utils.norm_layers['bn_3d'](out_channels), + utils.norm_layers[norm_3d](out_channels), + utils.activations[self.cfg.gen_activation_type](inplace=True), + nn.Conv3d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=1)) + + net_or_nets = [self.blocks_3d_up] + if hasattr(self, 'blocks_rgb'): + net_or_nets += [self.blocks_rgb] + + self.projector = ProjectorNorm(net_or_nets=net_or_nets, eps=self.cfg.eps, + gen_embed_size=self.cfg.gen_embed_size, + gen_max_channels=self.gen_max_channels) + if self.cfg.gen_use_adaconv: + self.projector_conv = ProjectorConv(net_or_nets=net_or_nets, eps=self.cfg.eps, + gen_adaptive_kernel=self.cfg.gen_adaptive_kernel, + gen_max_channels=self.gen_max_channels) + + self.downsample_up = utils.downsampling_layers[self.downsample_type + '_3d'](kernel_size=(2, 1, 1), + stride=(2, 1, 1)) + + def forward(self, warped_feat_3d, embed_dict=None, align_warp=None, blend_weight=None, annealing_alpha=0.0): + + if blend_weight is not None: + b, n = blend_weight.shape[:2] + warped_feat_3d = (warped_feat_3d.view(b, n, *warped_feat_3d.shape[1:]) * blend_weight).sum(1) + + spatial_size = warped_feat_3d.shape[-1] + outputs = warped_feat_3d + feat_ms = [] + size = [self.init_depth, spatial_size, spatial_size] + + for i, block in enumerate(self.blocks_3d_down): + if i < len(self.blocks_3d_down) - 1: + # Calculate block's output size + size[1] //= 2 + size[2] //= 2 + + depth_new = min(size[0] * 2, size[1]) # depth is increasing at first, but does not exceed height and width + + if depth_new > size[0]: + depth_resize_type = self.upsample_type + elif depth_new < size[0]: + depth_resize_type = self.downsample_type + else: + depth_resize_type = 'none' + size[0] = depth_new + + if depth_resize_type == self.upsample_type: + outputs = F.interpolate(outputs, scale_factor=(2, 1, 1), mode=self.upsample_type) + + # print('before:', i, depth_resize_type, outputs.shape) + + outputs = block(outputs) + feat_ms += [outputs] + + if i < len(self.blocks_3d_down) - 1: + if depth_resize_type == self.downsample_type: + outputs = self.downsample(outputs) + else: + # print(i) + outputs = self.downsample_no_depth(outputs) + + # print(i, depth_resize_type, outputs.shape) + + # import pdb; pdb.set_trace() + +############################################################################################# + net_or_nets = [self.blocks_3d_up] + if hasattr(self, 'blocks_rgb'): + net_or_nets += [self.blocks_rgb] + + + + if embed_dict is not None: + params_norm = self.projector(embed_dict) + assign_adaptive_norm_params(net_or_nets, params_norm, annealing_alpha) + + if hasattr(self, 'projector_conv'): + params_conv = self.projector_conv(embed_dict) + assign_adaptive_conv_params(net_or_nets, params_conv, self.adaptive_conv_type, annealing_alpha) + + assert len(feat_ms) == len(self.blocks_3d_up) + + feat_ms = feat_ms[::-1] # from low res to high res + + outputs = self.input_tensor.repeat_interleave(feat_ms[0].shape[0], dim=0) + + size = [outputs.shape[2], outputs.shape[3], outputs.shape[4]] + + for i, (block_3d, feat) in enumerate(zip(self.blocks_3d_up, feat_ms), 1): + size[1] *= 2 + size[2] *= 2 + + depth_new = min(self.output_depth * 2 ** (len(self.blocks_3d_up) - i), size[1]) + if depth_new > size[0]: + depth_resize_type = self.upsample_type + elif depth_new < size[0]: + depth_resize_type = self.downsample_type + else: + depth_resize_type = 'none' + + size[0] = depth_new + + # print('\nstart', i, outputs.shape) + if depth_resize_type == self.upsample_type: + outputs = F.interpolate(outputs, scale_factor=2, mode=self.upsample_type) + else: + outputs = F.interpolate(outputs, scale_factor=(1, 2, 2), mode=self.upsample_type) + + + + if hasattr(self, 'skip_blocks_3d_up'): + outputs_skip = self.skip_blocks_3d_up[i - 1](feat) + else: + outputs_skip = feat + + # print('up', i, depth_resize_type, outputs.shape, feat.shape, outputs_skip.shape) + + # print(self.num_blocks, i, depth_new, size, depth_resize_type == self.upsample_type, hasattr(self, 'skip_blocks_3d_up'), outputs.shape, outputs_skip.shape) + outputs = block_3d(outputs + outputs_skip) + + if depth_resize_type == self.downsample_type: + outputs = self.downsample_up(outputs) + + # print('after', outputs.shape) + # print('============') + + # import pdb; pdb.set_trace() + + latent_texture = self.head(outputs) + + return latent_texture + diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/utils.py b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..45cd7c2bbcde2eaf1c15d6150dc1680553f1220f --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/utils.py @@ -0,0 +1,832 @@ +import torch +from torch import nn +import torch.nn.functional as F +import math +import functools +from einops import rearrange, repeat +import itertools +import torchvision +from model.head_animation.EMOP import args as args_utils + +def replace_conv_to_ws_conv(module, conv2d=True, conv3d =True): + ''' + Recursively put desired batch norm in nn.module module. + + set module = net to start code. + ''' + # go through all attributes of module nn.module (e.g. network or layer) and bn to in + # for attr_str in dir(module): + prev_prev_attr = None + prev_attr = None + for indx, (attr_str, _) in enumerate(module.named_children()): + + if indx == 0: + prev_prev_attr = getattr(module, attr_str) + elif indx == 1: + prev_attr = getattr(module, attr_str) + else: + # print(type(target_attr)) + target_attr = getattr(module, attr_str) + if type(target_attr) == torch.nn.Conv2d and conv2d and (type(prev_prev_attr) == torch.nn.GroupNorm or type(prev_attr) == torch.nn.GroupNorm): # + new_conv = Conv2d_ws(target_attr.in_channels, target_attr.out_channels, kernel_size=target_attr.kernel_size, stride = target_attr.stride, padding = target_attr.padding, dilation = target_attr.dilation, + groups=target_attr.groups, bias=True) + setattr(module, attr_str, new_conv) + + if type(target_attr) == torch.nn.Conv3d and conv3d and (type(prev_prev_attr) == AdaptiveGroupNorm or type(prev_attr) == AdaptiveGroupNorm): # + new_conv = Conv3d_ws(target_attr.in_channels, target_attr.out_channels, kernel_size=target_attr.kernel_size, stride = target_attr.stride, padding = target_attr.padding, dilation = target_attr.dilation, + groups=target_attr.groups, bias=True) + setattr(module, attr_str, new_conv) + prev_prev_attr = prev_attr + prev_attr = target_attr + + # iterate through immediate child modules. Note, the recursion is done by our code no need to use named_modules() + for name, immediate_child_module in module.named_children(): + replace_conv_to_ws_conv(immediate_child_module, name) + + return module + +def apply_ws_to_nets(obj): + # ws_nets_names = args_utils.parse_str_to_list(obj.args.ws_networks, sep=',') + ws_networks='local_encoder_nw, local_encoder_seg_nw, local_encoder_mask_nw, idt_embedder_nw, expression_embedder_nw, xy_generator_nw, uv_generator_nw, warp_embed_head_orig_nw, pose_embed_decode_nw, pose_embed_code_nw, volume_process_nw, volume_source_nw, volume_pred_nw, decoder_nw, backgroung_adding_nw, background_process_nw' + ws_networks='app_encoder, idt_encoder, expression_encoder, decoder, warp_embed_head_orig_nw, src2ref, ref2tgt, volume_source_nw, volume_process_nw' + ws_nets_names = args_utils.parse_str_to_list(ws_networks, sep=',') + + for net_name in ws_nets_names: + try: + net = getattr(obj, net_name) + # import pdb; pdb.set_trace() + new_net = replace_conv_to_ws_conv(net, conv2d=True, conv3d=True) + setattr(obj, net_name, new_net) + # print(f'WS applied to {net_name}') + except Exception as e: + pass + + + +############################################################ +# Definitions for the layers # +############################################################ +class Conv2d_ws(nn.Conv2d): + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True): + super(Conv2d_ws, self).__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias) + + def forward(self, x): + weight = self.weight + weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, + keepdim=True).mean(dim=3, keepdim=True) + weight = weight - weight_mean + std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5 + weight = weight / std.expand_as(weight) + return F.conv2d(x, weight, self.bias, self.stride, + self.padding, self.dilation, self.groups) + + +class Conv3d_ws(nn.Conv3d): + def __init__(self, in_channels, out_channels, kernel_size, + stride=1, padding=0, dilation=1, groups=1, bias=True): + super(Conv3d_ws, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) + + def forward(self, x): + w = self.weight + w_mean = w.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True).mean(dim=4, keepdim=True) + w = w - w_mean + std = w.view(w.size(0), -1).std(dim=1).view(-1,1,1,1,1) + 1e-5 + w = w / std.expand_as(w) + return F.conv3d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +class ResBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + padding: int = 1, + dilation: int = 1, + groups: int = 1, + conv_layer_type: str = 'conv', + norm_layer_type: str = 'bn', + activation_type: str = 'relu', + resize_layer_type: str = 'none', + efficient_upsampling: bool = False, # place upsampling layer before the second convolution + return_feats: bool = False, # return feats after the first convolution, + ): + """This is a base module for residual blocks""" + super(ResBlock, self).__init__() + # Initialize layers in the block + self.return_feats = return_feats + + m_bias= False + if resize_layer_type in ['nearest', 'bilinear', 'blur']: + self.upsample = lambda inputs: F.interpolate(inputs, scale_factor=stride, mode=resize_layer_type) + self.efficient_upsampling = efficient_upsampling + if resize_layer_type=='blur': + self.upsample = Upsample_sg2(kernel=[1, 3, 3, 1]) + + downsample = resize_layer_type in downsampling_layers and stride > 1 + if downsample: + downsampling_layer = downsampling_layers[resize_layer_type] + + normalize = norm_layer_type != 'none' + if normalize: + norm_layer = norm_layers[norm_layer_type] + + activation = activations[activation_type] + conv_layer = conv_layers[conv_layer_type] + + if '3d' in conv_layer_type: + num_kernel_dims = 3 + else: + num_kernel_dims = 2 + + ### Initialize the layers of the first half of the block ### + layers = [] + + if normalize: + layers += [norm_layer(in_channels)] + + layers += [ + activation(inplace=True), + conv_layer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(kernel_size,) * num_kernel_dims, + padding=padding, + dilation=dilation, + groups=groups, + bias=m_bias)] + + if normalize: + layers += [norm_layer(out_channels)] + + layers += [activation(inplace=True)] + + self.block_feats = nn.Sequential(*layers) + + layers = [ + conv_layer( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=(kernel_size,) * num_kernel_dims, + padding=padding, + dilation=dilation, + groups=groups, + bias=m_bias)] + + if downsample: + layers += [downsampling_layer(stride)] + + self.block = nn.Sequential(*layers) + + ### Initialize a skip connection block, if needed ### + if in_channels != out_channels or downsample: + layers = [] + + if in_channels != out_channels: + layers += [conv_layer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(1,) * num_kernel_dims, + bias=m_bias)] + + if downsample: + layers += [downsampling_layer(stride)] + + self.skip = nn.Sequential(*layers) + + def forward(self, inputs): + outputs = inputs + + if hasattr(self, 'upsample') and not self.efficient_upsampling: + outputs = self.upsample(inputs) + + feats = self.block_feats(outputs) + outputs = feats + + if hasattr(self, 'upsample') and self.efficient_upsampling: + outputs = self.upsample(feats) + + outputs_main = self.block(outputs) + + outputs_skip = inputs + + if hasattr(self, 'upsample'): + outputs_skip = self.upsample(inputs) + + if hasattr(self, 'skip'): + outputs_skip = self.skip(outputs_skip) + + outputs = outputs_main + outputs_skip + + if self.return_feats: + return outputs, feats + else: + return outputs + + +class ConvBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + padding: int = 1, + dilation: int = 1, + groups: int = 1, + conv_layer_type: str = 'conv', + norm_layer_type: str = 'none', + activation_type: str = 'relu', + resize_layer_type: str = 'none', + efficient_upsampling: bool = False, # + return_feats: bool = False, + ): + """This is a base module for residual blocks""" + super(ConvBlock, self).__init__() + # Initialize layers in the block + self.return_feats = return_feats + m_bias = False + + if resize_layer_type in ['nearest', 'bilinear'] and stride > 1: + self.upsample = lambda inputs: F.interpolate(inputs, scale_factor=stride, mode=resize_layer_type) + + downsample = resize_layer_type in downsampling_layers and stride > 1 + if downsample: + downsampling_layer = downsampling_layers[resize_layer_type] + + normalize = norm_layer_type != 'none' + if normalize: + norm_layer = norm_layers[norm_layer_type] + + activation = activations[activation_type] + conv_layer = conv_layers[conv_layer_type] + + if '3d' in conv_layer_type: + num_kernel_dims = 3 + else: + num_kernel_dims = 2 + + ### Initialize the layers of the first half of the block ### + layers = [ + conv_layer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(kernel_size,) * num_kernel_dims, + padding=padding, + dilation=dilation, + groups=groups, + bias=m_bias)] + + if normalize: + layers += [norm_layer(out_channels)] + + layers += [activation(inplace=True)] + + self.block = nn.Sequential(*layers) + + if downsample: + self.downsample = downsampling_layer(stride) + + def assign_spade_feats(self, feats): + for m in self.modules(): + if m.__class__.__name__ == 'AdaptiveSPADE': + m.feats = feats + + def forward(self, inputs, spade_feats=None): + if spade_feats is not None: + self.assign_spade_feats(spade_feats) + + if hasattr(self, 'upsample'): + outputs = self.upsample(inputs) + else: + outputs = inputs + + feats = self.block(outputs) + + if hasattr(self, 'downsample'): + outputs = self.downsample(feats) + else: + outputs = feats + + if self.return_feats: + return outputs, feats + else: + return outputs + + +class PixelUnShuffle(nn.Module): + def __init__(self, upscale_factor): + super(PixelUnShuffle, self).__init__() + self.upscale_factor = upscale_factor + + def forward(self, inputs): + batch_size, channels, in_height, in_width = inputs.size() + + out_height = in_height // self.upscale_factor + out_width = in_width // self.upscale_factor + + input_view = inputs.contiguous().view( + batch_size, channels, out_height, self.upscale_factor, + out_width, self.upscale_factor) + + channels *= self.upscale_factor ** 2 + unshuffle_out = input_view.permute(0, 1, 3, 5, 2, 4).contiguous() + return unshuffle_out.view(batch_size, channels, out_height, out_width) + + def extra_repr(self): + return 'upscale_factor={}'.format(self.upscale_factor) + + +class AdaptiveConv(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=(3, 3), + stride=1, padding=0, dilation=1, groups=1, bias=False): + super(AdaptiveConv, self).__init__() + # Set options + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + + assert not bias, 'bias == True is not supported for AdaptiveConv' + self.bias = None + + self.kernel_numel = kernel_size[0] * kernel_size[1] + if len(kernel_size) == 3: + self.kernel_numel *= kernel_size[2] + + self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *kernel_size)) + + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + self.ada_weight = None # assigned externally + + if len(kernel_size) == 2: + self.conv_func = F.conv2d + elif len(kernel_size) == 3: + self.conv_func = F.conv3d + + def forward(self, inputs): + # Cast parameters into inputs.dtype + if inputs.type() != self.ada_weight.type(): + weight = self.ada_weight.type(inputs.type()) + else: + weight = self.ada_weight + + # Conv is applied to the inputs grouped by t frames + B = weight.shape[0] + T = inputs.shape[0] // B + assert inputs.shape[0] == B * T, 'Wrong shape of weight' + + if self.kernel_numel > 1: + if weight.shape[0] == 1: + # No need to iterate through batch, can apply conv to the whole batch + outputs = self.conv_func(inputs, weight[0], None, self.stride, self.padding, self.dilation, self.groups) + + else: + outputs = [] + for b in range(B): + outputs += [self.conv_func(inputs[b * T:(b + 1) * T], weight[b], None, self.stride, self.padding, + self.dilation, self.groups)] + outputs = torch.cat(outputs, 0) + + else: + if weight.shape[0] == 1: + if len(inputs.shape) == 5: + weight = weight[..., None, None, None] + else: + weight = weight[..., None, None] + + outputs = self.conv_func(inputs, weight[0], None, self.stride, self.padding, self.dilation, self.groups) + else: + # 1x1(x1) adaptive convolution is a simple bmm + if len(weight.shape) == 6: + weight = weight[..., 0, 0, 0] + else: + weight = weight[..., 0, 0] + + outputs = torch.bmm(weight, inputs.view(B * T, inputs.shape[1], -1)).view(B, -1, *inputs.shape[2:]) + + return outputs + + def extra_repr(self): + s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' + ', stride={stride}') + if self.padding != 0: + s += ', padding={padding}' + if self.dilation != 1: + s += ', dilation={dilation}' + if self.groups != 1: + s += ', groups={groups}' + if self.bias is None: + s += ', bias=False' + return s.format(**self.__dict__) + + + +def replace_bn_to_in(module, name): + ''' + Recursively put desired batch norm in nn.module module. + + set module = net to start code. + ''' + # go through all attributes of module nn.module (e.g. network or layer) and bn to in + for attr_str, _ in module.named_children(): + target_attr = getattr(module, attr_str) + if type(target_attr) == torch.nn.BatchNorm2d: + # print('replaced: ', name, attr_str) + new_bn = torch.nn.InstanceNorm2d(target_attr.num_features, target_attr.eps, + target_attr.momentum, target_attr.affine, + track_running_stats=False) + setattr(module, attr_str, new_bn) + + # iterate through immediate child modules. Note, the recursion is done by our code no need to use named_modules() + for name, immediate_child_module in module.named_children(): + replace_bn_to_in(immediate_child_module, name) + + return module + + +def replace_bn_to_gn(module, name): + ''' + Recursively put desired batch norm in nn.module module. + + set module = net to start code. + ''' + # go through all attributes of module nn.module (e.g. network or layer) and bn to in + # for attr_str in dir(module): + for attr_str, _ in module.named_children(): + target_attr = getattr(module, attr_str) + if type(target_attr) == torch.nn.BatchNorm2d or type(target_attr) == torch.nn.InstanceNorm2d: + new_bn = torch.nn.GroupNorm(32, target_attr.num_features, target_attr.eps, target_attr.affine) + setattr(module, attr_str, new_bn) + + # iterate through immediate child modules. Note, the recursion is done by our code no need to use named_modules() + for name, immediate_child_module in module.named_children(): + replace_bn_to_gn(immediate_child_module, name) #TODO поменять на GN + + return module + +def replace_bn_to_bcn(module, name): + ''' + Recursively put desired batch norm in nn.module module. + + set module = net to start code. + ''' + # go through all attributes of module nn.module (e.g. network or layer) and bn to in + # for attr_str in dir(module): + for attr_str, _ in module.named_children(): + target_attr = getattr(module, attr_str) + if type(target_attr) == torch.nn.BatchNorm2d or type(target_attr) == torch.nn.InstanceNorm2d: + new_bn = BCNorm(32, target_attr.num_features, target_attr.eps) + setattr(module, attr_str, new_bn) + + # iterate through immediate child modules. Note, the recursion is done by our code no need to use named_modules() + for name, immediate_child_module in module.named_children(): + replace_bn_to_bcn(immediate_child_module, name) #TODO поменять на GN + + return module + +class ProjectorNorm(nn.Module): + def __init__(self, net_or_nets, + eps, + gen_embed_size, + gen_max_channels): + super(ProjectorNorm, self).__init__() + self.eps = eps + + # Matrices that perform a lowrank matrix decomposition W = U E V + self.u = nn.ParameterList() + self.v = nn.ParameterList() + + if isinstance(net_or_nets, list): + modules = itertools.chain(*[net.modules() for net in net_or_nets]) + else: + modules = net_or_nets.modules() + + for m in modules: + if m.__class__.__name__ in ['AdaptiveBatchNorm', 'AdaptiveSyncBatchNorm', 'AdaptiveInstanceNorm', 'AdaptiveGroupNorm', 'AdaptiveBCNorm'] : + self.u += [nn.Parameter(torch.empty(m.num_features, gen_max_channels))] + self.v += [nn.Parameter(torch.empty(gen_embed_size ** 2, 2))] + + nn.init.uniform_(self.u[-1], a=-math.sqrt(3 / gen_max_channels), + b=math.sqrt(3 / gen_max_channels)) + nn.init.uniform_(self.v[-1], a=-math.sqrt(3 / gen_embed_size ** 2), + b=math.sqrt(3 / gen_embed_size ** 2)) + + + def forward(self, embed_dict, iter=0): + params = [] + + for u, v in zip(self.u, self.v): + # print(u.shape, v.shape) + embed = embed_dict['orig'] + + param = u[None].matmul(embed).matmul(v[None]) + weight, bias = param.split(1, dim=2) + + params += [(weight[..., 0], bias[..., 0])] + + # import pdb; pdb.set_trace() + return params + + +class ProjectorConv(nn.Module): + def __init__(self, net_or_nets, + eps, + gen_adaptive_kernel, + gen_max_channels): + super(ProjectorConv, self).__init__() + self.eps = eps + self.adaptive_kernel = gen_adaptive_kernel + + # Matrices that perform a lowrank matrix decomposition W = U E V + self.u = nn.ParameterList() + self.v = nn.ParameterList() + self.kernel_size = [] + + if isinstance(net_or_nets, list): + modules = itertools.chain(*[net.modules() for net in net_or_nets]) + else: + modules = net_or_nets.modules() + + for m in modules: + if m.__class__.__name__ == 'AdaptiveConv': + # Assumes that adaptive conv layers have no bias + kernel_numel = m.kernel_size[0] * m.kernel_size[1] + if len(m.kernel_size) == 3: + kernel_numel *= m.kernel_size[2] + + if kernel_numel == 1: + self.u += [nn.Parameter(torch.empty(m.out_channels, gen_max_channels // 2))] + self.v += [nn.Parameter(torch.empty(gen_max_channels // 2, m.in_channels))] + + elif kernel_numel > 1: + self.u += [nn.Parameter(torch.empty(m.out_channels, gen_max_channels // 2))] + self.v += [nn.Parameter(torch.empty(m.in_channels, gen_max_channels // 2))] + + self.kernel_size += [m.kernel_size] + + bound = math.sqrt(3 / (gen_max_channels // 2)) + nn.init.uniform_(self.u[-1], a=-bound, b=bound) + nn.init.uniform_(self.v[-1], a=-bound, b=bound) + + def forward(self, embed_dict): + params = [] + + for u, v, kernel_size in zip(self.u, self.v, self.kernel_size): + kernel_numel = kernel_size[0] * kernel_size[1] + if len(kernel_size) == 3: + kernel_numel *= kernel_size[2] + + if kernel_numel == 1: + embed = embed_dict['fc'] + else: + if self.adaptive_kernel: + if kernel_numel == 9: + embed = embed_dict['conv2d'] + elif kernel_numel == 27: + embed = embed_dict['conv3d'] + embed = embed.view(embed.shape[0], embed.shape[1], -1, kernel_numel) + else: + embed = embed_dict['fc'][..., None] + + if kernel_numel == 1: + # AdaptiveConv with kernel size = 1 + weight = u[None].matmul(embed).matmul(v[None]) + weight = weight.view(*weight.shape, *kernel_size) # B x C_out x C_in x 1 ... + else: + # AdaptiveConv with kernel size > 1 + if self.adaptive_kernel: + kernel_numel_ = kernel_numel + kernel_size_ = kernel_size + else: + kernel_numel_ = 1 + kernel_size_ = (1,) * len(kernel_size) + + param = embed.view(*embed.shape[:2], -1) + param = u[None].matmul(param) # B x C_out x C_emb/2 + b, c_out = param.shape[:2] + param = param.view(b, c_out, -1, kernel_numel_) + param = v[None].matmul(param) # B x C_out x C_in x kernel_numel + weight = param.view(*param.shape[:3], *kernel_size_) + + params += [weight] + + return params + + +def assign_adaptive_conv_params(net_or_nets, params, adaptive_conv_type, alpha_conv=1.0): + if isinstance(net_or_nets, list): + modules = itertools.chain(*[net.modules() for net in net_or_nets]) + else: + modules = net_or_nets.modules() + + for m in modules: + m_name = m.__class__.__name__ + if m_name == 'AdaptiveConv': + attr_name = 'weight_orig' if hasattr(m, 'weight_orig') else 'weight' + weight = getattr(m, attr_name) + ada_weight = params.pop(0) + + if adaptive_conv_type == 'sum': + ada_weight = weight[None] + ada_weight * alpha_conv + elif adaptive_conv_type == 'mul': + ada_weight = weight[None] * (torch.sigmoid(ada_weight) * alpha_conv + (1 - alpha_conv)) + + setattr(m, 'ada_' + attr_name, ada_weight) + +def assign_adaptive_norm_params(net_or_nets, params, alpha_norm=1.0): + if isinstance(net_or_nets, list): + modules = itertools.chain(*[net.modules() for net in net_or_nets]) + else: + modules = net_or_nets.modules() + + for m in modules: + m_name = m.__class__.__name__ + if m_name in ['AdaptiveBatchNorm', 'AdaptiveSyncBatchNorm', 'AdaptiveInstanceNorm', 'AdaptiveGroupNorm', 'AdaptiveBCNorm']: #TODO разобраться + ada_weight, ada_bias = params.pop(0) + + m.ada_weight = m.weight[None] + ada_weight * alpha_norm + m.ada_bias = m.bias[None] + ada_bias * alpha_norm + + +class AdaptiveGroupNorm(nn.GroupNorm): + def __init__(self, num_groups, num_features, eps=1e-5, affine=True): + super(AdaptiveGroupNorm, self).__init__(num_groups, num_features, eps, False) + self.num_features = num_features + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + # These tensors are assigned externally + self.ada_weight = None + self.ada_bias = None + + def forward(self, inputs): + outputs = super(AdaptiveGroupNorm, self).forward(inputs) + B = self.ada_weight.shape[0] + T = inputs.shape[0] // B + + outputs = outputs.view(B, T, *outputs.shape[1:]) + # Broadcast weight and bias accross T and spatial size of outputs + if len(outputs.shape) == 5: + outputs = outputs * self.ada_weight[:, None, :, None, None] + self.ada_bias[:, None, :, None, None] + else: + outputs = outputs * self.ada_weight[:, None, :, None, None, None] + self.ada_bias[:, None, :, None, None, + None] + outputs = outputs.view(B * T, *outputs.shape[2:]) + # print(inputs.shape, outputs.shape) + return outputs + + def _check_input_dim(self, input): + pass + + def extra_repr(self) -> str: + return '{num_groups}, {num_features}, eps={eps}, ' \ + 'affine=True'.format(**self.__dict__) + + +class ProjectorNormLinear(nn.Module): + def __init__(self, net_or_nets, + eps, + gen_embed_size, + gen_max_channels, + emb_v_exp=False, + no_detach_frec=1, + key_emb = 'orig'): + super(ProjectorNormLinear, self).__init__() + self.eps = eps + self.emb_v_exp = emb_v_exp + # Matrices that perform a lowrank matrix decomposition W = U E V + self.u = nn.ParameterList() + self.v = nn.ParameterList() + self.no_detach_frec = no_detach_frec + + self.key_emb = key_emb + + input_n = 512 if emb_v_exp else 512*16 + self.fc = nn.Sequential( + nn.Linear(input_n, 512, bias=False), + nn.ReLU(), + nn.Linear(512, 512*2, bias=False)) + + if isinstance(net_or_nets, list): + modules = itertools.chain(*[net.modules() for net in net_or_nets]) + else: + modules = net_or_nets.modules() + + for m in modules: + if m.__class__.__name__ in ['AdaptiveBatchNorm', 'AdaptiveSyncBatchNorm', 'AdaptiveInstanceNorm', 'AdaptiveGroupNorm', 'AdaptiveBCNorm'] : + self.u += [nn.Parameter(torch.empty(m.num_features, 512))] + self.v += [nn.Parameter(torch.empty(2, 2))] + + nn.init.uniform_(self.u[-1], a=-math.sqrt(3 / 512), + b=math.sqrt(3 / 512)) + nn.init.uniform_(self.v[-1], a=-math.sqrt(3 / 2 ), + b=math.sqrt(3 / 2)) + + def forward(self, embed_dict, iter=0): + params = [] + if self.emb_v_exp: + embed = embed_dict['ada_v'].detach() + else: + embed = embed_dict[self.key_emb].view(-1, 512*16) if iter%self.no_detach_frec==0 else embed_dict[self.key_emb].view(-1, 512*16).detach() + + + embed = self.fc(embed).view(-1, 512, 2) + + for u, v in zip(self.u, self.v): + + param = u[None].matmul(embed).matmul(v[None]) + weight, bias = param.split(1, dim=2) + + params += [(weight[..., 0], bias[..., 0])] + + return params + + +class Upsample_sg2(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) * (factor ** 2) + self.register_buffer("kernel", kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) + + return out + + +class GridSample(nn.Module): + def __init__(self, size): + super(GridSample, self).__init__() + self.size = size + self.register_backward_hook(scale_warp_grad_norm) + + def forward(self, inputs, grid, padding_mode='reflection'): + return F.grid_sample(inputs, grid, padding_mode=padding_mode) + + +# Supported blocks +blocks = { + 'res': ResBlock, + 'conv': ConvBlock +} + +# Supported downsampling layers +downsampling_layers = { + 'avgpool': nn.AvgPool2d, + 'maxpool': nn.MaxPool2d, + 'avgpool_3d': nn.AvgPool3d, + 'maxpool_3d': nn.MaxPool3d, + 'pixelunshuffle': PixelUnShuffle} + +# Supported normalization layers +norm_layers = { + 'in': lambda num_features, affine=True: nn.InstanceNorm2d(num_features=num_features, affine=affine), + 'bn': lambda num_features: nn.BatchNorm2d(num_features=num_features, momentum=MOMENTUM), + 'bn_3d': lambda num_features: nn.BatchNorm3d(num_features=num_features, momentum=MOMENTUM), + 'in_3d': lambda num_features, affine=True: nn.InstanceNorm3d(num_features=num_features, affine=affine), + 'sync_bn': lambda num_features: nn.SyncBatchNorm(num_features=num_features, momentum=MOMENTUM), + 'ada_in': lambda num_features, affine=True: AdaptiveInstanceNorm(num_features=num_features, affine=affine), + 'ada_bn': lambda num_features: AdaptiveBatchNorm(num_features=num_features, momentum=MOMENTUM), + 'ada_sync_bn': lambda num_features: AdaptiveSyncBatchNorm(num_features=num_features, momentum=MOMENTUM), + 'gn': lambda num_features, affine=True: nn.GroupNorm(num_groups=32, num_channels=num_features, affine=affine), + 'bcn': lambda num_features, affine=True: BCNorm(num_channels=num_features, num_groups=32, estimate=True), + 'bcn_3d': lambda num_features, affine=True: BCNorm(num_channels=num_features, num_groups=32, estimate=True), + 'gn_24': lambda num_features, affine=True: nn.GroupNorm(num_groups=24, num_channels=num_features, affine=affine), + 'gn_3d': lambda num_features, affine=True: nn.GroupNorm(num_groups=32, num_channels=num_features, affine=affine), + 'ada_gn': lambda num_features, affine=True: AdaptiveGroupNorm(num_groups=32, num_features=num_features, affine=affine), + # 'ada_gn': lambda num_features, affine=True: AdaptiveInstanceNorm(num_features=num_features, affine=affine), + # 'ada_bcn': lambda num_features, affine=True: AdaptiveGroupNorm(num_groups=32, num_features=num_features, affine=affine), + 'ada_bcn': lambda num_features, affine=True: AdaptiveBCNorm(num_groups=32, num_features=num_features, estimate=True) +} + +# Supported activations +activations = { + 'relu': nn.ReLU, + # 'relu': functools.partial(nn.LeakyReLU, negative_slope=0.04), + 'lrelu': functools.partial(nn.LeakyReLU, negative_slope=0.2)} + +# Supported conv layers +conv_layers = { + 'conv': nn.Conv2d, + # 'conv': Conv2d_ws, + 'conv_3d': nn.Conv3d, + # 'conv_3d': Conv3d_ws, + 'ada_conv': AdaptiveConv, + 'ada_conv_3d': AdaptiveConv} \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/vpn_resblocks.py b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/vpn_resblocks.py new file mode 100644 index 0000000000000000000000000000000000000000..5daedc21f5cbbd2e1fdbb14afbb6c66ef7f48bed --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/vpn_resblocks.py @@ -0,0 +1,49 @@ +import torch +from torch import nn +import torch.nn.functional as F +import torch.distributed as dist +from torchvision import models +from torch.cuda import amp +from argparse import ArgumentParser +import math +from .utils import GridSample +# from . import utils +import numpy as np +import copy +from scipy import linalg +import itertools +from .utils import ProjectorConv, ProjectorNorm, assign_adaptive_conv_params,assign_adaptive_norm_params +from .resblocks_3d import ResBlocks3d +from typing import List, Union + +from dataclasses import dataclass, field + + +class VPN_ResBlocks(nn.Module): + @dataclass + class Config: + num_gpus: int + norm_layer_type: str + input_channels: int + num_blocks: int + activation_type: str + conv_layer_type: str = 'conv_3d' + channels: list = field(default_factory=list) + + + def __init__(self, cfg: Config): + super(VPN_ResBlocks, self).__init__() + self.cfg = cfg + self.net = ResBlocks3d( + # num_gpus=self.cfg.num_gpus, + norm_layer_type=self.cfg.norm_layer_type, + input_channels=self.cfg.input_channels, + num_blocks=self.cfg.num_blocks, + activation_type=self.cfg.activation_type, + conv_layer_type='conv_3d', + channels=self.cfg.channels, + ) + + def forward(self, x): + + return self.net(x) \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/warp_generator_resnet.py b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/warp_generator_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..4c6fb5eff0d4e6aeb479446c332bf95ff13b028d --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/EMOP/warp_generator_resnet.py @@ -0,0 +1,205 @@ +import torch +from torch import nn +import torch.nn.functional as F +import math +from torch.cuda import amp +from . import utils +import itertools +from .utils import ProjectorConv, ProjectorNorm, assign_adaptive_conv_params,assign_adaptive_norm_params +from dataclasses import dataclass + +class WarpGenerator(nn.Module): + + @dataclass + class Config: + eps: int + # num_gpus: int + gen_adaptive_conv_type: str + gen_activation_type: str + gen_upsampling_type: str + gen_downsampling_type: str + gen_dummy_input_size: int + gen_latent_texture_depth: int + gen_latent_texture_size: int + gen_max_channels: int + gen_num_channels: int + gen_use_adaconv: bool + gen_adaptive_kernel: bool + gen_embed_size: int + warp_output_size: int + warp_channel_mult: int + warp_block_type: str + norm_layer_type: str + input_channels: int + pred_blend_weight: bool = False + + + + def __init__(self, cfg): + super(WarpGenerator, self).__init__() + + self.cfg = cfg + self.adaptive_conv_type = self.cfg.gen_adaptive_conv_type + self.gen_activation_type = self.cfg.gen_activation_type + self.upsample_type = self.cfg.gen_upsampling_type + self.downsample_type = self.cfg.gen_downsampling_type + self.input_size = self.cfg.gen_dummy_input_size + self.output_depth = self.cfg.gen_latent_texture_depth + self.output_size = self.cfg.gen_latent_texture_size + # self.pred_blend_weight = pred_blend_weight + self.warp_output_size = self.cfg.warp_output_size + self.gen_num_channels = self.cfg.gen_num_channels + self.warp_channel_mult = self.cfg.warp_channel_mult + self.gen_max_channels = self.cfg.gen_max_channels + self.warp_block_type = self.cfg.warp_block_type + self.norm_layer_type = self.cfg.norm_layer_type + + + num_blocks = int(math.log(self.warp_output_size // self.input_size, 2)) + self.num_depth_resize_blocks = int(math.log(self.output_size // self.input_size, 2)) + + out_channels = (min(int(self.gen_num_channels * self.warp_channel_mult * 2**num_blocks), self.gen_max_channels))//32*32 + + + # if self.norm_layer_type=='bn': + # if self.cfg.num_gpus > 1: + # norm_layer_type = 'sync_' + self.norm_layer_type + + norm_layer_type = 'ada_' + self.norm_layer_type + + # Initialize net + self.first_conv = nn.Conv2d(self.cfg.input_channels, out_channels * self.input_size, 1, bias=False) + + self.blocks_3d = nn.ModuleList() + + for i in range(num_blocks - 1, -1, -1): + in_channels = out_channels + out_channels = (min(int(self.gen_num_channels * self.warp_channel_mult * 2**i), self.gen_max_channels))//32*32 + + self.blocks_3d += [ + utils.blocks[self.warp_block_type]( + in_channels=in_channels, + out_channels=out_channels, + stride=1, + norm_layer_type=norm_layer_type, + conv_layer_type=('ada_' if self.cfg.gen_use_adaconv else '') + 'conv_3d', + activation_type=self.gen_activation_type)] + + if self.warp_block_type == 'res': + # if self.norm_layer_type != 'bn': + # norm_3d = self.norm_layer_type + '_3d' + # else: + # norm_3d = 'bn_3d' if self.cfg.num_gpus < 2 else 'sync_bn' + + norm_3d = self.norm_layer_type + '_3d' + + self.pre_head = nn.Sequential( + # utils.norm_layers['bn_3d'](out_channels), + utils.norm_layers[norm_3d](out_channels), + utils.activations[self.gen_activation_type](inplace=True)) + + self.head = nn.ModuleList([ + nn.Sequential( + nn.Conv3d( + in_channels=out_channels, + out_channels=3, + kernel_size=3, + padding=1), + nn.Tanh())]) + + + + self.projector = ProjectorNorm(net_or_nets=self.blocks_3d, eps=self.cfg.eps, gen_embed_size=self.cfg.gen_embed_size, + gen_max_channels=self.gen_max_channels) + if self.cfg.gen_use_adaconv: + self.projector_conv = ProjectorConv(net_or_nets=self.blocks_3d, eps=self.cfg.eps, + gen_adaptive_kernel=self.cfg.gen_adaptive_kernel, + gen_max_channels=self.gen_max_channels) + + self.downsample = utils.downsampling_layers[f'{self.downsample_type}_3d'](kernel_size=(2, 1, 1), stride=(2, 1, 1)) + + # Greate a meshgrid, which is used for warping calculation from deltas + grid_s = torch.linspace(-1, 1, self.warp_output_size) + grid_z = torch.linspace(-1, 1, self.output_depth) + w, v, u = torch.meshgrid(grid_z, grid_s, grid_s) + self.register_buffer('identity_grid', torch.stack([u, v, w], 0)[None]) + + # import pdb; pdb.set_trace() + + + def forward(self, embed_dict, annealing_alpha=0.0):# TODO remove annealing_alpha at all + + # import pdb; pdb.set_trace() + + params_norm = self.projector(embed_dict) + # import pdb; pdb.set_trace() + + assign_adaptive_norm_params(self.blocks_3d, params_norm) + + + + if hasattr(self, 'projector_conv'): + params_conv = self.projector_conv(embed_dict) + assign_adaptive_conv_params(self.blocks_3d, params_conv, self.adaptive_conv_type, annealing_alpha) + + b = embed_dict['orig'].shape[0] + inputs = embed_dict['orig'].view(b, -1, self.input_size, self.input_size) + + size = [self.input_size, self.input_size, self.input_size] + outputs = self.first_conv(inputs).view(b, -1, *size) + + for i, block in enumerate(self.blocks_3d, 1): + + + size[1] *= 2 + size[2] *= 2 + + # Calc new depth and if it is upsampled or downsampled + if i < self.num_depth_resize_blocks: + depth_new = min(self.output_depth * 2**(self.num_depth_resize_blocks - i), size[1]) + else: + depth_new = self.output_depth + + if depth_new > size[0]: + depth_resize_type = self.upsample_type + elif depth_new < size[0]: + depth_resize_type = self.downsample_type + else: + depth_resize_type = 'none' + + size[0] = depth_new + + if depth_resize_type == self.upsample_type: + outputs = F.interpolate(outputs, scale_factor=2, mode=self.upsample_type) + else: + outputs = F.interpolate(outputs, scale_factor=(1, 2, 2), mode=self.upsample_type) + + + + outputs = block(outputs) + + if depth_resize_type == self.downsample_type: + outputs = self.downsample(outputs) + + # import pdb; pdb.set_trace() + + if hasattr(self, 'pre_head'): + outputs = self.pre_head(outputs.float()) + + + + deltas = self.head[0](outputs) + + warp = (self.identity_grid + deltas).permute(0, 2, 3, 4, 1) + + results = [warp, deltas] + + # import pdb; pdb.set_trace() + + # if self.pred_blend_weight: + # results += [self.head[1](outputs)] + + return results + + + diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__init__.py b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/__init__.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1409282efe5fe81acee2f041a5c3a15f7ab85c62 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/__init__.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/__init__.cpython-311.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04d972fbfab96ce713d61e7ca96f73f3fd442f72 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/__init__.cpython-311.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/discriminator.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/discriminator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a060d57b9bd17e06b29fef1859620a641c83bf7b Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/discriminator.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/discriminator.cpython-311.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/discriminator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c037a13f7cc17afbed09233d10c0612afa8180b8 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/discriminator.cpython-311.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/face_encoder.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/face_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55fc6186ae30b62617a174d0a222652ab9474e27 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/face_encoder.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/face_encoder.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/face_encoder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8df64a78a14f2ab9d16963ad1355bf553f7b343c Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/face_encoder.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/face_generator.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/face_generator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea22487181e3ba0c1882f13bdc6fcd4dd54af106 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/face_generator.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/face_generator_spade.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/face_generator_spade.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92e3bed1d5d74d5fb72704825bd1664ddd290778 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/face_generator_spade.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/flow_estimator.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/flow_estimator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7562dc7c1f0e387c1cfbe3d1a962a428a3fad4da Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/flow_estimator.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/flow_estimator.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/flow_estimator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f971a861d45f8bd888a9be4bef3de8341c39b5e Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/flow_estimator.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/loss.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/loss.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..020e09ef330b94516df1c8dd62e96c2860901089 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/loss.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/modules.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/modules.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ae53e593fda706671c5ea13b3a4fcdcee49a50e Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/modules.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/modules.cpython-311.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/modules.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..150c9c7a772a433a7341c298e40c6a96706ca3a7 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/modules.cpython-311.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/modules.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/modules.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ec1617af149b4bf7d841fbbfe220a71ae64a094 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/modules.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/motion_encoder.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/motion_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e0834cb22af22e5d1df906da5a47816b5e8d9b8 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/motion_encoder.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/motion_encoder.cpython-311.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/motion_encoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15b7e53c071f0ee9c11f4ddb935d387e7cfb3ce6 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/motion_encoder.cpython-311.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/motion_encoder.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/motion_encoder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63a79fd0fd8e68952662cc46eb352de7661ec059 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/motion_encoder.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/scaling_lia.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/scaling_lia.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e6db616a81b0775173ae6018390ac85e6d493e0 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/scaling_lia.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/scaling_lia_FaceEncDec.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/scaling_lia_FaceEncDec.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e9c388fe3a0d3c0f7fbae0b78249fe9681d1209 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/scaling_lia_FaceEncDec.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/scaling_lia_Map.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/scaling_lia_Map.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d07e9f2b9211bb4d4843e355427d47a9953d85e Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/scaling_lia_Map.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/scaling_lia_Map_FaceEncDec.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/scaling_lia_Map_FaceEncDec.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e68f22293afd9c0eafcebd2b47fbfb53fec69d52 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/scaling_lia_Map_FaceEncDec.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/util.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c5bed99598e172175e4c4c646b3997747d74707 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/util.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/util.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/util.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..845fd6d0941afa67940993cd478ec65c1de67bbc Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/__pycache__/util.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/discriminator.py b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..708457da20b5e56a2dd746fb12943a271471cf72 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/discriminator.py @@ -0,0 +1,72 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent.parent.parent)) +from model.head_animation.LIA.modules import * + +class Discriminator(nn.Module): + def __init__(self, size, channel_multiplier=1, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + self.size = size + + channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + convs = [ConvLayer(3, channels[size], 1)] + log_size = int(math.log(size, 2)) + in_channel = channels[size] + + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + convs.append(ResBlock(in_channel, out_channel, blur_kernel)) + in_channel = out_channel + + self.convs = nn.Sequential(*convs) + + self.stddev_group = 4 + self.stddev_feat = 1 + + self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) + self.final_linear = nn.Sequential( + EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'), + EqualLinear(channels[4], 1), + ) + + def forward(self, input): + out = self.convs(input) + batch, channel, height, width = out.shape + + group = min(batch, self.stddev_group) + stddev = out.view(group, -1, self.stddev_feat, channel // self.stddev_feat, height, width) + stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) + stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) + stddev = stddev.repeat(group, 1, height, width) + out = torch.cat([out, stddev], 1) + + out = self.final_conv(out) + + out = out.view(batch, -1) + out = self.final_linear(out) + + return out + +if __name__ == "__main__": + from torchsummaryX import summary + + IMAGE_SIZE = 512 + encoder = Discriminator(size=IMAGE_SIZE) + summary(encoder, torch.zeros(1, 3, IMAGE_SIZE, IMAGE_SIZE)) diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/face_encoder.py b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/face_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..bbe6d2987a057c689697dd57dccb2e986ad9d6b1 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/face_encoder.py @@ -0,0 +1,100 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F +import sys +from pathlib import Path +import torch + +sys.path.append(str(Path(__file__).parent.parent.parent)) +from model.head_animation.LIA.modules import * +from model.head_animation.LIA.util import SameBlock2d, DownBlock2d, ResBlock3d + +class FaceEncoder(nn.Module): + def __init__(self, output_channels, size=512): + super(FaceEncoder, self).__init__() + + channel = [32, 64, 128, 256, 512, 512, 512, output_channels] + + self.convs = nn.ModuleList() + self.convs.append(ConvLayer(3, channel[0], 1)) + + in_channel = channel[0] + for i in range(1, len(channel)): + out_channel = channel[i] + self.convs.append(ResBlock(in_channel, out_channel)) + in_channel = out_channel + + self.convs = nn.Sequential(*self.convs) + + def forward(self, source_image): + if self.training: + return torch.utils.checkpoint.checkpoint( \ + self.manual_forward, *[source_image], + use_reentrant=False) + else: + return self.manual_forward(*[source_image]) + + def manual_forward(self, source_image): + res = [] + h = source_image + for conv in self.convs: + h = conv(h) + res.append(h) + # TODO(wei): remove the last layer? + feats = res[::-1][1:] # from 8x8 to 512x512 + return feats + + + +""" +Appearance extractor(F) defined in LivePortrait, which maps the source image s to a 3D appearance feature volume. +""" +class FaceEncoder3D(nn.Module): + def __init__(self, image_size, image_channel, block_expansion, num_down_blocks, max_features, reshape_channel, reshape_depth, num_resblocks): + super(FaceEncoder3D, self).__init__() + self.image_size = image_size + self.image_channel = image_channel + self.block_expansion = block_expansion + self.num_down_blocks = num_down_blocks + self.max_features = max_features + self.reshape_channel = reshape_channel + self.reshape_depth = reshape_depth + + self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(3, 3), padding=(1, 1)) + + down_blocks = [] + for i in range(num_down_blocks): + in_features = min(max_features, block_expansion * (2 ** i)) + out_features = min(max_features, block_expansion * (2 ** (i + 1))) + down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) + self.down_blocks = nn.ModuleList(down_blocks) + + self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1) + + self.resblocks_3d = torch.nn.Sequential() + for i in range(num_resblocks): + self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1)) + + def forward(self, source_image): + if self.training: + return torch.utils.checkpoint.checkpoint( \ + self.manual_forward, *[source_image], + use_reentrant=False) + else: + return self.manual_forward(*[source_image]) + + def manual_forward(self, source_image): + if source_image.size(-1) != self.image_size: + source_image = F.interpolate(source_image, size=(self.image_size, self.image_size), mode='bilinear') + + out = self.first(source_image) # Bx3x256x256 -> Bx64x256x256 + + for i in range(len(self.down_blocks)): + out = self.down_blocks[i](out) + out = self.second(out) + bs, c, h, w = out.shape # ->Bx512x64x64 + + f_s = out.view(bs, self.reshape_channel, self.reshape_depth, h, w) # ->Bx32x16x64x64 + f_s = self.resblocks_3d(f_s) # ->Bx32x16x64x64 + return f_s \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/face_generator.py b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/face_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..949fb8720c1ab3b22fd44be88fee7840e3def26b --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/face_generator.py @@ -0,0 +1,120 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent.parent)) +from model.head_animation.LIA.modules import * + +class FaceGenerator(nn.Module): + def __init__(self, size, latent_dim, blur_kernel=[1, 3, 3, 1], channel_multiplier=1): + super(FaceGenerator, self).__init__() + + self.size = size + self.latent_dim = latent_dim + + self.channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + self.input = ConstantInput(self.channels[4]) # 512, 4, 4 + self.conv1 = StyledConv(self.channels[4], self.channels[4], 3, latent_dim, blur_kernel=blur_kernel) + + self.log_size = int(math.log(size, 2)) + self.num_layers = (self.log_size - 2) * 2 + 1 + + self.convs = nn.ModuleList() + self.upsamples = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + self.to_flows = nn.ModuleList() + + in_channel = self.channels[4] + + for i in range(3, self.log_size + 1): + out_channel = self.channels[2 ** i] + # print(i, 2 ** i, in_channel, out_channel) + # import pdb; pdb.set_trace() + self.convs.append(StyledConv(in_channel, out_channel, 3, latent_dim, upsample=True, blur_kernel=blur_kernel)) + self.convs.append(StyledConv(out_channel, out_channel, 3, latent_dim, blur_kernel=blur_kernel)) + self.to_rgbs.append(ToRGB(out_channel, latent_dim)) + + self.to_flows.append(ToFlow(out_channel, latent_dim)) + + in_channel = out_channel + + self.n_latent = self.log_size * 2 - 2 + + def forward(self, tgt_latent, ref_feats): + if self.training: + return torch.utils.checkpoint.checkpoint( \ + self.manual_forward, *[tgt_latent, ref_feats], + use_reentrant=False) + else: + return self.manual_forward(*[tgt_latent, ref_feats]) + + def manual_forward(self, tgt_latent, ref_feats): + bs = tgt_latent.size(0) + + inject_index = self.n_latent + latent = tgt_latent.unsqueeze(1).repeat(1, inject_index, 1).contiguous() + + out = self.input(latent) + out = self.conv1(out, latent[:, 0]) + # print('0', out.shape) + + i = 1 + + # gradiuent checkpoint --------------------------- + # torch.utils.checkpoint.checkpoint(ckpt_wrapper(self.audio_proj), *audio_proj_args, use_reentrant=False) + def ckpt_wrapper(conv1, conv2, to_flow, to_rgb): + def ckpt_forward(out, latent, feat, skip_flow, skip, i): + out = conv1(out, latent[:, i]) + out = conv2(out, latent[:, i + 1]) + if out.size(2) == 8: + out_warp, out, skip_flow = to_flow(out, latent[:, i + 2], feat) + skip = to_rgb(out_warp) + else: + out_warp, out, skip_flow = to_flow(out, latent[:, i + 2], feat, skip_flow) + skip = to_rgb(out_warp, skip) + return out, skip, skip_flow + + return ckpt_forward + # gradiuent checkpoint --------------------------- + skip_flow, skip = None, None + for conv1, conv2, to_rgb, to_flow, feat in zip(self.convs[::2], self.convs[1::2], self.to_rgbs, + self.to_flows, ref_feats): + # gradiuent checkpoint --------------------------- + input_args = [out, latent, feat, skip_flow, skip, i] + if self.training: + out, skip, skip_flow = torch.utils.checkpoint.checkpoint( \ + ckpt_wrapper(conv1, conv2, to_flow, to_rgb), *input_args, + use_reentrant=False) + else: + out, skip, skip_flow = ckpt_wrapper(conv1, conv2, to_flow, to_rgb)(*input_args) + i += 2 + # gradiuent checkpoint --------------------------- + + # out = conv1(out, latent[:, i]) + # out = conv2(out, latent[:, i + 1]) + # if out.size(2) == 8: + # out_warp, out, skip_flow = to_flow(out, latent[:, i + 2], feat) + # skip = to_rgb(out_warp) + # else: + # out_warp, out, skip_flow = to_flow(out, latent[:, i + 2], feat, skip_flow) + # skip = to_rgb(out_warp, skip) + + # print(i, out.shape, skip.shape, feat.shape) + + img = skip + # import pdb; pdb.set_trace() + return img diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/face_generator_spade.py b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/face_generator_spade.py new file mode 100644 index 0000000000000000000000000000000000000000..134f21cf94ffd6a37d3aad8fd2362cc3cf1aa922 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/face_generator_spade.py @@ -0,0 +1,225 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F +import sys +from pathlib import Path +import math + +sys.path.append(str(Path(__file__).parent.parent.parent)) +from model.basic_model.basic_block import USE_BIAS, ReshapeTo3DLayer, WSConv3d, ReshapeTo2DLayer +from model.head_animation.LIA.util import * + +class AdaptiveGroupNorm(nn.GroupNorm): + def __init__(self, num_groups, num_features, eps=1e-5, affine=True): + super(AdaptiveGroupNorm, self).__init__(num_groups, num_features, eps, False) + self.num_features = num_features + + gen_max_channels, gen_embed_size = 512, 4 + self.u = nn.Parameter(torch.empty(num_features, gen_max_channels)) + self.v = nn.Parameter(torch.empty(gen_embed_size ** 2, 2)) + + nn.init.uniform_(self.u, a=-math.sqrt(3 / gen_max_channels), b=math.sqrt(3 / gen_max_channels)) + nn.init.uniform_(self.v, a=-math.sqrt(3 / gen_embed_size ** 2), b=math.sqrt(3 / gen_embed_size ** 2)) + + def forward(self, inputs, condition_emb): + outputs = super(AdaptiveGroupNorm, self).forward(inputs) + + param = self.u[None].matmul(condition_emb).matmul(self.v[None]) + ada_weight, ada_bias = param.split(1, dim=2) + + outputs = outputs * ada_weight[:, :, :, None, None] + ada_bias[:, :, :, None, None] + return outputs + +class ResBlock3dStar(nn.Module): + def __init__(self, in_channels: int, out_channels: int, num_channels_per_group: int, condition_dim: int): + super().__init__() + + if in_channels != out_channels: + self.skip_layer = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=USE_BIAS) + else: + self.skip_layer = lambda x: x + + self.agn1 = AdaptiveGroupNorm(in_channels // num_channels_per_group, in_channels) + self.conv1 = WSConv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=USE_BIAS) + + self.agn2 = AdaptiveGroupNorm(out_channels // num_channels_per_group, out_channels) + self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=USE_BIAS) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, inp, condition): + x = self.relu(self.agn1(inp, condition)) + x = self.conv1(x) + x = self.relu(self.agn2(x, condition)) + x = self.conv2(x) + x = self.skip_layer(inp) + x + return x + + +class FaceGenerator(nn.Module): + def __init__(self, size, reshape_channel, group_norm_channel, latent_dim, + blur_kernel=[1, 3, 3, 1], + channel_multiplier=1, outputsize=512, + flag_estimate_occlusion_map=False, + final_activation=True, + ): + super(FaceGenerator, self).__init__() + + self.size = size + self.latent_dim = latent_dim + + # settings for warping_generator + self.group_norm_channel = group_norm_channel + self.feature_volume_size = (16, 64, 64) + self.app_fea_size = (latent_dim, 4, 4) + + # settings for image_generator + self.reshape_channel = reshape_channel + self.flag_estimate_occlusion_map = flag_estimate_occlusion_map + self.final_activation = final_activation + self.input_channels = 256 # channel of projected feature volume + self.norm_G = 'spadespectralinstance' + self.label_num_channels = self.input_channels + self.out_channels = 64 # output channel of final SPADEResnetBlock + + ## warping field generator + self.init_warping_generator() + + ### generator + self.init_image_generator() + + def init_warping_generator(self): + num_channels_per_group = self.group_norm_channel + input_dim = self.app_fea_size[0] + + self.extend_layer = nn.Linear(self.latent_dim, self.app_fea_size[0] * self.app_fea_size[1] ** 2, bias=USE_BIAS) + self.conv1 = nn.Conv2d(self.latent_dim, 2048, kernel_size=1, bias=USE_BIAS) + self.reshap3d = ReshapeTo3DLayer(out_depth=4) + self.resblock1 = ResBlock3dStar(512, 256, num_channels_per_group, input_dim) + self.resblock2 = ResBlock3dStar(256, 128, num_channels_per_group, input_dim) + self.resblock3 = ResBlock3dStar(128, 64, num_channels_per_group, input_dim) + self.resblock4 = ResBlock3dStar(64, 32, num_channels_per_group, input_dim) + self.gn = nn.GroupNorm(32 // num_channels_per_group, 32, affine=not USE_BIAS) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv3d(32, 3, kernel_size=3, padding=1, bias=USE_BIAS) + + self.upsample = nn.Upsample(scale_factor=(2, 2, 2), mode="nearest") + self.upsample2 = nn.Upsample(scale_factor=(1, 2, 2), mode="nearest") + + grids = torch.meshgrid( + torch.linspace(-1, 1, self.feature_volume_size[0]), + torch.linspace(-1, 1, self.feature_volume_size[1]), + torch.linspace(-1, 1, self.feature_volume_size[2]), + indexing="ij" + ) + self.identity_grid = torch.stack(grids, dim=-1).flip(-1) + + def init_image_generator(self): + # Projection layers + self.projection = nn.Sequential( + ReshapeTo2DLayer(), + SameBlock2d(self.reshape_channel * self.feature_volume_size[0], 256, kernel_size=(3, 3), padding=(1, 1), lrelu=True), + nn.Conv2d(256, 256, kernel_size=1, stride=1) + ) + + self.fc = nn.Conv2d(self.input_channels, 2 * self.input_channels, 3, padding=1) + self.G_middle_0 = SPADEResnetBlock(2 * self.input_channels, 2 * self.input_channels, self.norm_G, self.label_num_channels) + self.G_middle_1 = SPADEResnetBlock(2 * self.input_channels, 2 * self.input_channels, self.norm_G, self.label_num_channels) + self.G_middle_2 = SPADEResnetBlock(2 * self.input_channels, 2 * self.input_channels, self.norm_G, self.label_num_channels) + self.G_middle_3 = SPADEResnetBlock(2 * self.input_channels, 2 * self.input_channels, self.norm_G, self.label_num_channels) + self.G_middle_4 = SPADEResnetBlock(2 * self.input_channels, 2 * self.input_channels, self.norm_G, self.label_num_channels) + self.G_middle_5 = SPADEResnetBlock(2 * self.input_channels, 2 * self.input_channels, self.norm_G, self.label_num_channels) + self.up_0 = SPADEResnetBlock(2 * self.input_channels, self.input_channels, self.norm_G, self.label_num_channels) + self.up_1 = SPADEResnetBlock(self.input_channels, self.out_channels, self.norm_G, self.label_num_channels) + self.up = nn.Upsample(scale_factor=2) + + self.conv_img = nn.Sequential( + nn.Conv2d(self.out_channels, 3 * (2 * 2), kernel_size=3, padding=1), + nn.PixelShuffle(upscale_factor=2) + ) + + if self.final_activation: + self.final_activation_fn = nn.Tanh() + else: + self.final_activation_fn = nn.Identity() + + if self.flag_estimate_occlusion_map: + self.occlusion = nn.Conv2d(self.reshape_channel*self.feature_volume_size[0], 1, kernel_size=7, padding=3) + + + def flow_field_generation(self, tgt_latent, ref_feats): + batch_size = tgt_latent.size(0) + z_emb = self.extend_layer(tgt_latent).view(batch_size, -1, self.app_fea_size[1], self.app_fea_size[2]) + + _, c, h, w = z_emb.shape + condition = z_emb.view(-1, c, h * w).clone() + + z = self.conv1(z_emb) + z = self.reshap3d(z) + + z = self.upsample(z) + z = self.resblock1(z, condition) + + z = self.upsample(z) + z = self.resblock2(z, condition) + + z = self.upsample2(z) + z = self.resblock3(z, condition) + + z = self.upsample2(z) + z = self.resblock4(z, condition) + + z = self.gn(z) + z = self.relu(z) + z = self.conv2(z) + deltas = F.tanh(z).permute(0, 2, 3, 4, 1) + + warping_field = self.identity_grid[None].to(tgt_latent.device) + deltas + warping_feature_volume = F.grid_sample(ref_feats, warping_field, mode="bilinear", align_corners=False) + + return warping_feature_volume + + def image_generation(self, warping_feature_volume): + seg = self.projection(warping_feature_volume) # Bx256x64x64 + + if self.flag_estimate_occlusion_map: + batch_size, _, d, h, w = warping_feature_volume.shape + warping_feature_volume_reshape = warping_feature_volume.view(batch_size, -1, h, w) + occlusion_map = torch.sigmoid(self.occlusion(warping_feature_volume_reshape)) # Bx1x64x64 + seg = seg * occlusion_map + + x = self.fc(seg) # Bx512x64x64 + x = self.G_middle_0(x, seg) + x = self.G_middle_1(x, seg) + x = self.G_middle_2(x, seg) + x = self.G_middle_3(x, seg) + x = self.G_middle_4(x, seg) + x = self.G_middle_5(x, seg) + + x = self.up(x) # Bx512x64x64 -> Bx512x128x128 + x = self.up_0(x, seg) # Bx512x128x128 -> Bx256x128x128 + x = self.up(x) # Bx256x128x128 -> Bx256x256x256 + x = self.up_1(x, seg) # Bx256x256x256 -> Bx64x256x256 + + x = self.conv_img(F.leaky_relu(x, 2e-1)) # Bx64x256x256 -> Bx3xHxW + x = self.final_activation_fn(x) # Bx3xHxW + + return x + + def forward(self, tgt_latent, ref_feats): + if self.training: + return torch.utils.checkpoint.checkpoint( \ + self.manual_forward, *[tgt_latent, ref_feats], + use_reentrant=False) + else: + return self.manual_forward(*[tgt_latent, ref_feats]) + + def manual_forward(self, tgt_latent, ref_feats): + # generate warping field + warping_feature_volume = self.flow_field_generation(tgt_latent, ref_feats) + + # decoding + img = self.image_generation(warping_feature_volume) + + return img diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/flow_estimator.py b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/flow_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..9cdadf6ef707dfcb497ca2b9493cb321ad8f7668 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/flow_estimator.py @@ -0,0 +1,57 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F + +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent.parent)) +from model.head_animation.LIA.motion_encoder import EqualLinear + +class Direction(nn.Module): + def __init__(self, latent_dim, num_direction): + super(Direction, self).__init__() + + self.weight = nn.Parameter(torch.randn(latent_dim, num_direction)) + + def forward(self, input): + weight = self.weight + 1e-8 + Q, R = torch.qr(weight) # get eignvector, orthogonal [n1, n2, n3, n4] + + if input is None: + return Q + else: + input_diag = torch.diag_embed(input) # alpha, diagonal matrix + out = torch.matmul(input_diag, Q.T) + out = torch.sum(out, dim=1) + return out + + +class FlowEstimator(nn.Module): + def __init__(self, latent_dim, motion_space=20, num_fc=3): + super(FlowEstimator, self).__init__() + + fc = [EqualLinear(latent_dim, latent_dim)] + for i in range(num_fc): + fc.append(EqualLinear(latent_dim, latent_dim)) + fc.append(EqualLinear(latent_dim, motion_space)) + self.fc = nn.Sequential(*fc) + + self.direction = Direction(latent_dim, motion_space) + + def forward(self, ref_fea, tgt_fea): + if self.training: + return torch.utils.checkpoint.checkpoint( \ + self.manual_forward, *[ref_fea, tgt_fea], + use_reentrant=False) + else: + return self.manual_forward(*[ref_fea, tgt_fea]) + + def manual_forward(self, ref_fea, tgt_fea): + feats = self.fc(tgt_fea.view(tgt_fea.size(0), -1)) + + ref2tgt_mapping = self.direction(feats) + tgt_latent = ref_fea + ref2tgt_mapping # reference latent code -> target latent code + + return tgt_latent \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/lia_origin_model.py b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/lia_origin_model.py new file mode 100644 index 0000000000000000000000000000000000000000..bc4e3a795056bce83e72d32fe467caeae78a1d24 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/lia_origin_model.py @@ -0,0 +1,259 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent.parent)) +from model.head_animation.LIA.modules import * + +class MotionEncoder(nn.Module): + def __init__(self, latent_dim, size=512): + super(MotionEncoder, self).__init__() + + # self.input_size = size + # channel = [32, 64, 128, 256, 512, 512, 512, 512] + # 128, 128, 128 -> 256, 64, 64 -> 512, 16, 16 -> 512, 8, 8 + # -> 512, 4, 4 -> 512, 2, 2 -> 512, 1, 1 -> 512 + channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256, + 128: 128, + 256: 64, + 512: 32, + 1024: 16 + } + log_size = int(math.log(size, 2)) + + self.convs = nn.ModuleList() + self.convs.append(ConvLayer(3, channels[size], 1)) + + in_channel = channels[size] + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + self.convs.append(ResBlock(in_channel, out_channel)) + in_channel = out_channel + + self.convs.append(EqualConv2d(in_channel, latent_dim, 4, padding=0, bias=False)) + self.convs = nn.Sequential(*self.convs) + + def forward(self, x): + if self.training: + return torch.utils.checkpoint.checkpoint( \ + self.manual_forward, *[x], + use_reentrant=False) + else: + return self.manual_forward(*[x]) + + def manual_forward(self, x): + res = [] + h = x + res = [] + for conv in self.convs: + h = conv(h) + res.append(h) + res = res[::-1] + feats = res[2:] # from 8x8 to 512x512 + latent_code = res[0] + # [B * T, D] + latent_code = latent_code.view(x.size(0), -1) + return latent_code, feats + +class ConstantInput(nn.Module): + def __init__(self, channel, size=4): + super().__init__() + + self.input = nn.Parameter(torch.randn(1, channel, size, size)) + + def forward(self, input): + batch = input.shape[0] + out = self.input.repeat(batch, 1, 1, 1) + + return out + +class FaceEncoder(nn.Module): + def __init__(self, output_channels, size=512): + super(FaceEncoder, self).__init__() + + channel = [32, 64, 128, 256, 512, 512, 512, output_channels] + + self.convs = nn.ModuleList() + self.convs.append(ConvLayer(3, channel[0], 1)) + + in_channel = channel[0] + for i in range(1, len(channel)): + out_channel = channel[i] + self.convs.append(ResBlock(in_channel, out_channel)) + in_channel = out_channel + + self.convs = nn.Sequential(*self.convs) + + def forward(self, x): + if self.training: + return torch.utils.checkpoint.checkpoint( \ + self.manual_forward, *[x], + use_reentrant=False) + else: + return self.manual_forward(*[x]) + + def manual_forward(self, x): + h = x + res = [] + for conv in self.convs: + h = conv(h) + res.append(h) + feats = res[::-1][1:] # from 8x8 to 512x512 + return feats + + +class FaceGenerator(nn.Module): + def __init__(self, size, latent_dim, blur_kernel=[1, 3, 3, 1], channel_multiplier=1): + super(FaceGenerator, self).__init__() + + self.size = size + self.latent_dim = latent_dim + + self.channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + self.input = ConstantInput(self.channels[4]) # 512, 4, 4 + self.conv1 = StyledConv(self.channels[4], self.channels[4], 3, latent_dim, blur_kernel=blur_kernel) + + self.log_size = int(math.log(size, 2)) + self.num_layers = (self.log_size - 2) * 2 + 1 + + self.convs = nn.ModuleList() + self.upsamples = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + self.to_flows = nn.ModuleList() + + in_channel = self.channels[4] + + for i in range(3, self.log_size + 1): + out_channel = self.channels[2 ** i] + # print(i, 2 ** i, in_channel, out_channel) + # import pdb; pdb.set_trace() + self.convs.append(StyledConv(in_channel, out_channel, 3, latent_dim, upsample=True, blur_kernel=blur_kernel)) + self.convs.append(StyledConv(out_channel, out_channel, 3, latent_dim, blur_kernel=blur_kernel)) + self.to_rgbs.append(ToRGB(out_channel, latent_dim)) + + self.to_flows.append(ToFlow(out_channel, latent_dim)) + + in_channel = out_channel + + self.n_latent = self.log_size * 2 - 2 + + def forward(self, tgt_latent, ref_feats): + if self.training: + return torch.utils.checkpoint.checkpoint( \ + self.manual_forward, *[tgt_latent, ref_feats], + use_reentrant=False) + else: + return self.manual_forward(*[tgt_latent, ref_feats]) + # return self.manual_forward(*[tgt_latent, ref_feats]) + + def manual_forward(self, tgt_latent, ref_feats): + bs = tgt_latent.size(0) + + inject_index = self.n_latent + latent = tgt_latent.unsqueeze(1).repeat(1, inject_index, 1) + + out = self.input(latent) + out = self.conv1(out, latent[:, 0]) + # print('0', out.shape) + + i = 1 + for conv1, conv2, to_rgb, to_flow, feat in zip(self.convs[::2], self.convs[1::2], self.to_rgbs, + self.to_flows, ref_feats): + out = conv1(out, latent[:, i]) + out = conv2(out, latent[:, i + 1]) + if out.size(2) == 8: + out_warp, out, skip_flow = to_flow(out, latent[:, i + 2], feat) + skip = to_rgb(out_warp) + else: + out_warp, out, skip_flow = to_flow(out, latent[:, i + 2], feat, skip_flow) + skip = to_rgb(out_warp, skip) + i += 2 + # print(i, out.shape, skip.shape, feat.shape) + + img = skip + # import pdb; pdb.set_trace() + return img + +class Discriminator(nn.Module): + def __init__(self, size, channel_multiplier=1, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + self.size = size + + channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + convs = [ConvLayer(3, channels[size], 1)] + log_size = int(math.log(size, 2)) + in_channel = channels[size] + + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + convs.append(ResBlock(in_channel, out_channel, blur_kernel)) + in_channel = out_channel + + self.convs = nn.Sequential(*convs) + + self.stddev_group = 4 + self.stddev_feat = 1 + + self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) + self.final_linear = nn.Sequential( + EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'), + EqualLinear(channels[4], 1), + ) + + def forward(self, input): + if self.training: + return torch.utils.checkpoint.checkpoint( \ + self.manual_forward, *[input], + use_reentrant=False) + else: + return self.manual_forward(*[input]) + # return self.manual_forward(*[input]) + + def manual_forward(self, input): + out = self.convs(input) + batch, channel, height, width = out.shape + + group = min(batch, self.stddev_group) + stddev = out.view(group, -1, self.stddev_feat, channel // self.stddev_feat, height, width) + stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) + stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) + stddev = stddev.repeat(group, 1, height, width) + out = torch.cat([out, stddev], 1) + + out = self.final_conv(out) + + out = out.view(batch, -1) + out = self.final_linear(out) + + return out \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/loss.py b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..59bce67101a5277158f77121856b2da23b18f951 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/loss.py @@ -0,0 +1,186 @@ +from torch import nn +import torch +from torchvision import models +import numpy as np +import torch.nn.functional as F + +class AntiAliasInterpolation2d(nn.Module): + """ + Band-limited downsampling, for better preservation of the input signal. + """ + + def __init__(self, channels, scale): + super(AntiAliasInterpolation2d, self).__init__() + sigma = (1 / scale - 1) / 2 + kernel_size = 2 * round(sigma * 4) + 1 + self.ka = kernel_size // 2 + self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka + + kernel_size = [kernel_size, kernel_size] + sigma = [sigma, sigma] + # The gaussian kernel is the product of the + # gaussian function of each dimension. + kernel = 1 + meshgrids = torch.meshgrid( + [ + torch.arange(size, dtype=torch.float32) + for size in kernel_size + ] + ) + for size, std, mgrid in zip(kernel_size, sigma, meshgrids): + mean = (size - 1) / 2 + kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2)) + + # Make sure sum of values in gaussian kernel equals 1. + kernel = kernel / torch.sum(kernel) + # Reshape to depthwise convolutional weight + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) + + self.register_buffer('weight', kernel) + self.groups = channels + self.scale = scale + inv_scale = 1 / scale + self.int_inv_scale = int(inv_scale) + + def forward(self, input): + if self.scale == 1.0: + return input + + out = F.pad(input, (self.ka, self.kb, self.ka, self.kb)) + out = F.conv2d(out, weight=self.weight, groups=self.groups) + out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale] + + return out + + +class ImagePyramide(torch.nn.Module): + """ + Create image pyramide for computing pyramide perceptual loss. See Sec 3.3 + """ + + def __init__(self, scales, num_channels): + super(ImagePyramide, self).__init__() + downs = {} + for scale in scales: + downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale) + self.downs = nn.ModuleDict(downs) + + def forward(self, x): + + out_dict = {} + for scale, down_module in self.downs.items(): + out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x) + + return out_dict + + +class Vgg19(torch.nn.Module): + """ + Vgg19 network for perceptual loss. See Sec 3.3. + """ + + def __init__(self, requires_grad=False): + super(Vgg19, self).__init__() + + vgg_model = models.vgg19(pretrained=True) + # vgg_model.load_state_dict(torch.load('./vgg19-dcbb9e9d.pth')) + vgg_pretrained_features = vgg_model.features + + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + for x in range(2): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(2, 7): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(7, 12): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(12, 21): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(21, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + + self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))), + requires_grad=False) + self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))), + requires_grad=False) + + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + X = X.clamp(-1, 1) + X = X / 2 + 0.5 + X = (X - self.mean) / self.std + h_relu1 = self.slice1(X) + h_relu2 = self.slice2(h_relu1) + h_relu3 = self.slice3(h_relu2) + h_relu4 = self.slice4(h_relu3) + h_relu5 = self.slice5(h_relu4) + # out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] + out = [h_relu1, h_relu2, h_relu3, h_relu4] + + return out + + +class VGGLoss(nn.Module): + def __init__(self): + super(VGGLoss, self).__init__() + + # self.scales = [1, 0.5, 0.25, 0.125] + self.scales = [1, 0.5, 0.25] + self.pyramid = ImagePyramide(self.scales, 3).cuda() + + self.mask_scales = [1, 0.5, 0.25, 0.125, 0.0625, 0.0625/2] + self.mask_pyramid = ImagePyramide(self.mask_scales, 1).cuda() + + # vgg loss + self.vgg = Vgg19().cuda().eval() + # self.weights = (10, 10, 10, 10, 10) + self.weights = (10, 10, 10, 10) + + def forward(self, img_recon, img_real, facial_mask=None): + + # vgg loss + pyramid_real = self.pyramid(img_real) + pyramid_recon = self.pyramid(img_recon) + + if facial_mask is not None: + pyramid_mask = self.mask_pyramid(facial_mask) + pyramid_mask_new = {} + for k, v in pyramid_mask.items(): pyramid_mask_new[f'prediction_{v.size(-1)}'] = v + pyramid_mask = pyramid_mask_new + + vgg_loss = 0 + all_loss_dict = {} + for scale in self.scales: + real_vgg = self.vgg(pyramid_real['prediction_' + str(scale)]) + recon_vgg = self.vgg(pyramid_recon['prediction_' + str(scale)]) + + loss_list = [] + for i, weight in enumerate(self.weights): + if facial_mask is not None: + feat_mask = pyramid_mask[f'prediction_{real_vgg[i].size(-1)}'] + real_vgg_map = real_vgg[i] * feat_mask + recon_vgg_map = recon_vgg[i] * feat_mask + value = torch.abs(recon_vgg_map - real_vgg_map.detach()).sum() / feat_mask.sum() + else: + real_vgg_map = real_vgg[i] + recon_vgg_map = recon_vgg[i] + value = torch.abs(recon_vgg_map - real_vgg_map.detach()).mean() + + loss = value * weight + vgg_loss += loss + + loss_list.append(loss) + + all_loss_dict[str(scale)] = loss_list + + return vgg_loss, all_loss_dict + + + diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/modules.py b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..bcc9ca983260873d219908bd3828a0c6d1aa6b49 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/modules.py @@ -0,0 +1,438 @@ +import torch +from torch import nn +from torch.nn import functional as F +import math +import numpy as np + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): + return F.leaky_relu(input + bias, negative_slope) * scale + + +class FusedLeakyReLU(nn.Module): + def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): + super().__init__() + self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + return out + + +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): + _, minor, in_h, in_w = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, minor, in_h, 1, in_w, 1) + out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0]) + out = out.view(-1, minor, in_h * up_y, in_w * up_x) + + out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0), + max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ] + + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, ) + + return out[:, :, ::down_y, ::down_x] + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) + + +def make_kernel(k): + k = torch.tensor(k, dtype=torch.float32) + + if k.ndim == 1: + k = k[None, :] * k[:, None] + + k /= k.sum() + + return k + + +class Blur(nn.Module): + def __init__(self, kernel, pad, upsample_factor=1): + super().__init__() + + kernel = make_kernel(kernel) + + if upsample_factor > 1: + kernel = kernel * (upsample_factor ** 2) + + self.register_buffer('kernel', kernel) + + self.pad = pad + + def forward(self, input): + return upfirdn2d(input, self.kernel, pad=self.pad) + + +class ScaledLeakyReLU(nn.Module): + def __init__(self, negative_slope=0.2): + super().__init__() + + self.negative_slope = negative_slope + + def forward(self, input): + return F.leaky_relu(input, negative_slope=self.negative_slope) + + +class EqualConv2d(nn.Module): + def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size)) + self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) + + self.stride = stride + self.padding = padding + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channel)) + else: + self.bias = None + + def forward(self, input): + + return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding) + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' + f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' + ) + + +class EqualLinear(nn.Module): + def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + else: + self.bias = None + + self.activation = activation + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + + if self.activation: + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + else: + out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul) + + return out + + def __repr__(self): + return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})') + + +class ConvLayer(nn.Sequential): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + downsample=False, + blur_kernel=[1, 3, 3, 1], + bias=True, + activate=True, + ): + layers = [] + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + layers.append(Blur(blur_kernel, pad=(pad0, pad1))) + + stride = 2 + self.padding = 0 + + else: + stride = 1 + self.padding = kernel_size // 2 + + layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride, + bias=bias and not activate)) + + if activate: + if bias: + layers.append(FusedLeakyReLU(out_channel)) + else: + layers.append(ScaledLeakyReLU(0.2)) + + super().__init__(*layers) + + +class ResBlock(nn.Module): + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + self.conv1 = ConvLayer(in_channel, in_channel, 3) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) + + self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False) + + def forward(self, input): + out = self.conv1(input) + out = self.conv2(out) + + skip = self.skip(input) + out = (out + skip) / math.sqrt(2) + + return out + +class PixelNorm(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) + + +class MotionPixelNorm(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return input * torch.rsqrt(torch.mean(input ** 2, dim=2, keepdim=True) + 1e-8) + + +class Upsample(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) * (factor ** 2) + self.register_buffer('kernel', kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + return upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) + + +class Downsample(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) + self.register_buffer('kernel', kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + return upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) + + +class ModulatedConv2d(nn.Module): + def __init__(self, in_channel, out_channel, kernel_size, style_dim, demodulate=True, upsample=False, + downsample=False, blur_kernel=[1, 3, 3, 1], ): + super().__init__() + + self.eps = 1e-8 + self.kernel_size = kernel_size + self.in_channel = in_channel + self.out_channel = out_channel + self.upsample = upsample + self.downsample = downsample + + if upsample: + factor = 2 + p = (len(blur_kernel) - factor) - (kernel_size - 1) + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + 1 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1)) + + fan_in = in_channel * kernel_size ** 2 + self.scale = 1 / math.sqrt(fan_in) + self.padding = kernel_size // 2 + + self.weight = nn.Parameter(torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)) + + self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) + self.demodulate = demodulate + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, ' + f'upsample={self.upsample}, downsample={self.downsample})' + ) + + def forward(self, input, style): + batch, in_channel, height, width = input.shape + + style = self.modulation(style).view(batch, 1, in_channel, 1, 1) + weight = self.scale * self.weight * style + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) + weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) + + weight = weight.view(batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size) + + if self.upsample: + input = input.view(1, batch * in_channel, height, width) + weight = weight.view(batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size) + weight = weight.transpose(1, 2).reshape(batch * in_channel, self.out_channel, self.kernel_size, + self.kernel_size) + out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + out = self.blur(out) + elif self.downsample: + input = self.blur(input) + _, _, height, width = input.shape + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + else: + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=self.padding, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + return out + + +class NoiseInjection(nn.Module): + def __init__(self): + super().__init__() + + self.weight = nn.Parameter(torch.zeros(1)) + + def forward(self, image, noise=None): + + if noise is None: + return image + else: + return image + self.weight * noise + + +class ConstantInput(nn.Module): + def __init__(self, channel, size=4): + super().__init__() + + self.input = nn.Parameter(torch.randn(1, channel, size, size)) + + def forward(self, input): + batch = input.shape[0] + out = self.input.repeat(batch, 1, 1, 1) + + return out + + +class StyledConv(nn.Module): + def __init__(self, in_channel, out_channel, kernel_size, style_dim, upsample=False, blur_kernel=[1, 3, 3, 1], + demodulate=True): + super().__init__() + + self.conv = ModulatedConv2d( + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=upsample, + blur_kernel=blur_kernel, + demodulate=demodulate, + ) + + self.noise = NoiseInjection() + self.activate = FusedLeakyReLU(out_channel) + + def forward(self, input, style, noise=None): + out = self.conv(input, style) + out = self.noise(out, noise=noise) + out = self.activate(out) + + return out + +class ToRGB(nn.Module): + def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + if upsample: + self.upsample = Upsample(blur_kernel) + + self.conv = ConvLayer(in_channel, 3, 1) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, input, skip=None): + out = self.conv(input) + out = out + self.bias + + if skip is not None: + skip = self.upsample(skip) + out = out + skip + + return out + + +class ToFlow(nn.Module): + def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + if upsample: + self.upsample = Upsample(blur_kernel) + + self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, input, style, feat, skip=None): + out = self.conv(input, style) + out = out + self.bias + + # warping + xs = np.linspace(-1, 1, input.size(2)) + xs = np.meshgrid(xs, xs) + xs = np.stack(xs, 2) + + xs = torch.tensor(xs, requires_grad=False).float().unsqueeze(0).repeat(input.size(0), 1, 1, 1).to(input.device) + + if skip is not None: + skip = self.upsample(skip) + out = out + skip + + sampler = torch.tanh(out[:, 0:2, :, :]) + mask = torch.sigmoid(out[:, 2:3, :, :]) + flow = sampler.permute(0, 2, 3, 1) + xs + + feat_warp = F.grid_sample(feat, flow) * mask + + return feat_warp, feat_warp + input * (1.0 - mask), out \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/motion_encoder.py b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/motion_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..51ba1d68b7b064df987250667f8b085049efde36 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/motion_encoder.py @@ -0,0 +1,97 @@ +import torch +from torch import nn +from torch.nn import functional as F +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent.parent)) +from model.head_animation.LIA.modules import * + +class MotionEncoder(nn.Module): + def __init__(self, latent_dim, size=512): + super(MotionEncoder, self).__init__() + + self.size = size + + if self.size==256: + channel = [64, 128, 256, 512, 512, 512, 512] + + elif self.size==512: + channel = [32, 64, 128, 256, 512, 512, 512, 512] + + self.convs = nn.ModuleList() + self.convs.append(ConvLayer(3, channel[0], 1)) + + in_channel = channel[0] + for i in range(1, len(channel)): + out_channel = channel[i] + self.convs.append(ResBlock(in_channel, out_channel)) + in_channel = out_channel + + self.convs.append(EqualConv2d(in_channel, latent_dim, 4, padding=0, bias=False)) + self.convs = nn.Sequential(*self.convs) + + def forward(self, x): + if self.training: + return torch.utils.checkpoint.checkpoint( \ + self.manual_forward, *[x], + use_reentrant=False) + else: + return self.manual_forward(*[x]) + + def manual_forward(self, x): + + if x.size(-1) != self.size: + x = F.interpolate(x, size=(self.size, self.size), mode='bilinear') + + res = [] + h = x + # gradiuent checkpoint --------------------------- + def ckpt_wrapper(convs): + def ckpt_forward(h): + res = [] + for conv in convs: + h = conv(h) + res.append(h) + return res + return ckpt_forward + + if self.training: + res = torch.utils.checkpoint.checkpoint( \ + ckpt_wrapper(self.convs), *[h], + use_reentrant=False) + else: + res = ckpt_wrapper(self.convs)(*[h]) + # gradiuent checkpoint --------------------------- + # res = [] + # for conv in self.convs: + # h = conv(h) + # res.append(h) + res = res[::-1] + feats = res[2:] # from 8x8 to 512x512 + latent_code = res[0] + # [B * T, D] + latent_code = latent_code.view(x.size(0), -1) + return latent_code, feats + + +class MotionEncoderLight(nn.Module): + def __init__(self, latent_dim, size=512): + super().__init__() + + self.size = size + self.layers = resnet18(pretrained=False, num_classes=latent_dim) # 11.4M + + def forward(self, x): + if self.training: + return torch.utils.checkpoint.checkpoint( \ + self.manual_forward, *[x], + use_reentrant=False) + else: + return self.manual_forward(*[x]) + + def manual_forward(self, x): + if x.size(-1) != self.size: + x = F.interpolate(x, size=(self.size, self.size), mode='bilinear') + out = self.layers(x) + return out, None diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/scaling_lia_Map.py b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/scaling_lia_Map.py new file mode 100644 index 0000000000000000000000000000000000000000..dca7e6fda36a54ad4def7b2e1747b514352a2f80 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/scaling_lia_Map.py @@ -0,0 +1,200 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent.parent)) +from model.head_animation.LIA.modules import * + +class MotionEncoder(nn.Module): + def __init__(self, latent_dim, size=512): + super(MotionEncoder, self).__init__() + + self.input_size = size + channel = [32, 64, 128, 256, 512, 512, 512, 512] + # 3, 512, 512 -> 32, 512, 512 -> 64, 256, 256 -> + # 128, 128, 128 -> 256, 64, 64 -> 512, 32, 32 -> 512, 16, 16 -> 512, 8, 8 + # -> 512, 4, 4 -> 512, 1, 1 -> 512 + + self.convs = nn.ModuleList() + self.convs.append(ConvLayer(3, channel[0], 1)) + + in_channel = channel[0] + for i in range(1, len(channel)): + out_channel = channel[i] + self.convs.append(ResBlock(in_channel, out_channel)) + in_channel = out_channel + + self.convs.append(EqualConv2d(in_channel, latent_dim, 4, padding=0, bias=False)) + self.convs = nn.Sequential(*self.convs) + + def forward(self, x): + res = [] + h = x + # gradiuent checkpoint --------------------------- + def ckpt_wrapper(convs): + def ckpt_forward(h): + res = [] + for conv in convs: + h = conv(h) + res.append(h) + return res + return ckpt_forward + + if self.training: + res = torch.utils.checkpoint.checkpoint( \ + ckpt_wrapper(self.convs), *[h], + use_reentrant=False) + else: + res = ckpt_wrapper(self.convs)(*[h]) + # gradiuent checkpoint --------------------------- + # res = [] + # for conv in self.convs: + # h = conv(h) + # res.append(h) + res = res[::-1] + feats = res[2:] # from 8x8 to 512x512 + latent_code = res[0] + # [B * T, D] + latent_code = latent_code.view(x.size(0), -1) + return latent_code, feats + +class ConstantInput(nn.Module): + def __init__(self, channel, size=4): + super().__init__() + + self.input = nn.Parameter(torch.randn(1, channel, size, size)) + + def forward(self, input): + batch = input.shape[0] + out = self.input.repeat(batch, 1, 1, 1) + + return out + +class FaceEncoder(nn.Module): + def __init__(self, output_channels, size=512): + super(FaceEncoder, self).__init__() + + channel = [32, 64, 128, 256, 512, output_channels] + # 3, 512, 512 -> 32, 512, 512 -> 64, 256, 256 -> + # 128, 128, 128 -> 256, 64, 64 -> 512, 32, 32 -> 512, 16, 16 + + self.convs = nn.ModuleList() + self.convs.append(ConvLayer(3, channel[0], 1)) + + in_channel = channel[0] + for i in range(1, len(channel)): + out_channel = channel[i] + self.convs.append(ResBlock(in_channel, out_channel)) + in_channel = out_channel + + self.convs = nn.Sequential(*self.convs) + + def forward(self, x): + h = x + # gradiuent checkpoint --------------------------- + def ckpt_wrapper(convs): + def ckpt_forward(h): + res = [] + for conv in convs: + h = conv(h) + res.append(h) + return res + return ckpt_forward + + if self.training: + res = torch.utils.checkpoint.checkpoint( \ + ckpt_wrapper(self.convs), *[h], + use_reentrant=False) + else: + res = ckpt_wrapper(self.convs)(*[h]) + # gradiuent checkpoint --------------------------- + # res = [] + # for conv in self.convs: + # h = conv(h) + # res.append(h) + feats = res[::-1][1:] # from 8x8 to 512x512 + return feats + +class FaceGenerator(nn.Module): + def __init__(self, size, latent_dim, blur_kernel=[1, 3, 3, 1], channel_multiplier=1): + super(FaceGenerator, self).__init__() + + self.size = size + self.latent_dim = latent_dim + + self.channels = { + # 4: 512, + # 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + self.input = ConstantInput(self.channels[16], 16) # 512, 4, 4 + self.conv1 = StyledConv(self.channels[16], self.channels[16], 3, latent_dim, blur_kernel=blur_kernel) + + self.log_size = int(math.log(size, 2)) + + self.convs = nn.ModuleList() + self.upsamples = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + self.to_flows = nn.ModuleList() + + in_channel = self.channels[16] + + for i in range(5, self.log_size + 1): + out_channel = self.channels[2 ** i] + # print(i, 2 ** i, in_channel, out_channel) + # import pdb; pdb.set_trace() + self.convs.append(StyledConv(in_channel, out_channel, 3, latent_dim, upsample=True, blur_kernel=blur_kernel)) + self.convs.append(StyledConv(out_channel, out_channel, 3, latent_dim, blur_kernel=blur_kernel)) + self.to_rgbs.append(ToRGB(out_channel, latent_dim)) + + self.to_flows.append(ToFlow(out_channel, latent_dim)) + + in_channel = out_channel + + self.n_latent = 16 * 16 + + def forward(self, tgt_latent, ref_feats): + if self.training: + return torch.utils.checkpoint.checkpoint( \ + self.manual_forward, *[tgt_latent, ref_feats], + use_reentrant=False) + else: + return self.manual_forward(*[tgt_latent, ref_feats]) + + def manual_forward(self, tgt_latent, ref_feats): + bs = tgt_latent.size(0) + + inject_index = self.n_latent + latent = tgt_latent.unsqueeze(1).repeat(1, inject_index, 1) + + out = self.input(latent) + out = self.conv1(out, latent[:, 0]) + # print('0', out.shape) + + i = 1 + for conv1, conv2, to_rgb, to_flow, feat in zip(self.convs[::2], self.convs[1::2], self.to_rgbs, + self.to_flows, ref_feats): + out = conv1(out, latent[:, i]) + out = conv2(out, latent[:, i + 1]) + if out.size(2) == 32: + out_warp, out, skip_flow = to_flow(out, latent[:, i + 2], feat) + skip = to_rgb(out_warp) + else: + out_warp, out, skip_flow = to_flow(out, latent[:, i + 2], feat, skip_flow) + skip = to_rgb(out_warp, skip) + i += 2 + # print(i, out.shape, skip.shape, feat.shape) + + img = skip + # import pdb; pdb.set_trace() + return img diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/scaling_lia_Map_FaceEncDec.py b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/scaling_lia_Map_FaceEncDec.py new file mode 100644 index 0000000000000000000000000000000000000000..d821751a3f4f47b23e24de23aadef294ae94b66b --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/scaling_lia_Map_FaceEncDec.py @@ -0,0 +1,282 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent.parent)) +from model.head_animation.LIA.modules import * + +class MotionEncoder(nn.Module): + def __init__(self, latent_dim, size=512): + super(MotionEncoder, self).__init__() + + self.input_size = size + channel = [32, 64, 128, 256, 512, 512, 512, 512] + # 3, 512, 512 -> 32, 512, 512 -> 64, 256, 256 -> + # 128, 128, 128 -> 256, 64, 64 -> 512, 32, 32 -> 512, 16, 16 -> 512, 8, 8 + # -> 512, 4, 4 -> 512, 1, 1 -> 512 + + self.convs = nn.ModuleList() + self.convs.append(ConvLayer(3, channel[0], 1)) + + in_channel = channel[0] + for i in range(1, len(channel)): + out_channel = channel[i] + self.convs.append(ResBlock(in_channel, out_channel)) + in_channel = out_channel + + self.convs.append(EqualConv2d(in_channel, latent_dim, 4, padding=0, bias=False)) + self.convs = nn.Sequential(*self.convs) + + def forward(self, x): + res = [] + h = x + # gradiuent checkpoint --------------------------- + def ckpt_wrapper(convs): + def ckpt_forward(h): + res = [] + for conv in convs: + h = conv(h) + res.append(h) + return res + return ckpt_forward + + if self.training: + res = torch.utils.checkpoint.checkpoint( \ + ckpt_wrapper(self.convs), *[h], + use_reentrant=False) + else: + res = ckpt_wrapper(self.convs)(*[h]) + # gradiuent checkpoint --------------------------- + # res = [] + # for conv in self.convs: + # h = conv(h) + # res.append(h) + res = res[::-1] + feats = res[2:] # from 8x8 to 512x512 + latent_code = res[0] + # [B * T, D] + latent_code = latent_code.view(x.size(0), -1) + return latent_code, feats + +class ConstantInput(nn.Module): + def __init__(self, channel, size=4): + super().__init__() + + self.input = nn.Parameter(torch.randn(1, channel, size, size)) + + def forward(self, input): + batch = input.shape[0] + out = self.input.repeat(batch, 1, 1, 1) + + return out + +class FaceEncoder(nn.Module): + def __init__(self, output_channels, size=512): + super(FaceEncoder, self).__init__() + + self.channels = [ + (32, 64, True), # 32, 512, 512 -> 64, 256, 256 + (64, 128, True), # 64, 256, 256 -> 128, 128, 128 + (128, 256, True), # 128, 128, 128 -> 256, 64, 64 + (256, 256, False), # 256, 64, 64 -> 256, 64, 64 + (256, 256, False), # 256, 64, 64 -> 256, 64, 64 + (256, 256, False), # 256, 64, 64 -> 256, 64, 64 + (256, 512, True), # 256, 64, 64 -> 512, 32, 32 + (512, output_channels, True), # 512, 32, 32 -> 512, 16, 16 + ] + # 3, 512, 512 -> 32, 512, 512 -> 64, 256, 256 -> + # 128, 128, 128 -> 256, 64, 64 -> 512, 32, 32 -> 512, 16, 16 -> 1024, 8, 8 -> 2048, 4, 4 + + self.convs = nn.ModuleList() + self.convs.append(ConvLayer(3, 32, 1)) + + for in_channel, out_channel, downsample in self.channels: + self.convs.append(ResBlock(in_channel, out_channel, downsample=downsample)) + + self.convs = nn.Sequential(*self.convs) + + def forward(self, x): + h = x + # gradiuent checkpoint --------------------------- + def ckpt_wrapper(convs): + def ckpt_forward(h): + res = [] + for conv in convs: + h = conv(h) + res.append(h) + return res + return ckpt_forward + + if self.training: + res = torch.utils.checkpoint.checkpoint( \ + ckpt_wrapper(self.convs), *[h], + use_reentrant=False) + else: + res = ckpt_wrapper(self.convs)(*[h]) + # gradiuent checkpoint --------------------------- + # res = [] + # for conv in self.convs: + # h = conv(h) + # res.append(h) + feats = res[::-1][1:] # from 8x8 to 512x512 + return feats + +class FaceGenerator(nn.Module): + def __init__(self, size, latent_dim, blur_kernel=[1, 3, 3, 1], channel_multiplier=1): + super(FaceGenerator, self).__init__() + + self.size = size + self.latent_dim = latent_dim + + # self.channels = { + # # 4: 512, + # # 8: 512, + # 16: 512, + # 32: 512, + # 64: 256 * channel_multiplier, + # 128: 128 * channel_multiplier, + # 256: 64 * channel_multiplier, + # 512: 32 * channel_multiplier, + # 1024: 16 * channel_multiplier, + # } + + self.channels = [ + (latent_dim, 512, True), # 512, 16, 16 -> 512, 32, 32 + (512, 256, True), # 512, 32, 32 -> 256, 64, 64 + (256, 256, False), # 256, 64, 64 -> 256, 64, 64 + (256, 256, False), # 256, 64, 64 -> 256, 64, 64 + (256, 256, False), # 256, 64, 64 -> 256, 64, 64 + (256, 128, True), # 256, 64, 64 -> 128, 128, 128 + (128, 64, True), # 128, 128, 128 -> 64, 256, 256 + (64, 32, True), # 64, 256, 256 -> 32, 512, 512 + ] + + self.input = ConstantInput(latent_dim, 16) # 512, 4, 4 + self.conv1 = StyledConv(latent_dim, latent_dim, 3, latent_dim, blur_kernel=blur_kernel) + + self.convs = nn.ModuleList() + self.upsamples = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + self.to_flows = nn.ModuleList() + + for in_channel, out_channel, upsample in self.channels: + self.convs.append(StyledConv(in_channel, out_channel, 3, latent_dim, upsample=upsample, blur_kernel=blur_kernel)) + self.convs.append(StyledConv(out_channel, out_channel, 3, latent_dim, blur_kernel=blur_kernel)) + self.to_rgbs.append(ToRGB(out_channel, latent_dim, upsample=upsample)) + + self.to_flows.append(ToFlow(out_channel, latent_dim, upsample=upsample)) + + in_channel = out_channel + + self.n_latent = 16 * 16 + + def forward(self, tgt_latent, ref_feats): + if self.training: + return torch.utils.checkpoint.checkpoint( \ + self.manual_forward, *[tgt_latent, ref_feats], + use_reentrant=False) + else: + return self.manual_forward(*[tgt_latent, ref_feats]) + + def manual_forward(self, tgt_latent, ref_feats): + bs = tgt_latent.size(0) + + inject_index = self.n_latent + latent = tgt_latent.unsqueeze(1).repeat(1, inject_index, 1) + + out = self.input(latent) + out = self.conv1(out, latent[:, 0]) + # print('0', out.shape) + + i = 1 + for conv1, conv2, to_rgb, to_flow, feat in zip(self.convs[::2], self.convs[1::2], self.to_rgbs, + self.to_flows, ref_feats): + out = conv1(out, latent[:, i]) + out = conv2(out, latent[:, i + 1]) + if out.size(2) == 32: + out_warp, out, skip_flow = to_flow(out, latent[:, i + 2], feat) + skip = to_rgb(out_warp) + else: + out_warp, out, skip_flow = to_flow(out, latent[:, i + 2], feat, skip_flow) + skip = to_rgb(out_warp, skip) + i += 2 + # print(i, out.shape, skip.shape, feat.shape) + + img = skip + # import pdb; pdb.set_trace() + return img + +class ToRGB(nn.Module): + def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + if upsample: + self.upsample = Upsample(blur_kernel) + + self.conv = ConvLayer(in_channel, 3, 1) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, input, skip=None): + out = self.conv(input) + out = out + self.bias + + if skip is not None: + skip = self.upsample(skip) if hasattr(self, "upsample") else skip + out = out + skip + + return out + + +class ToFlow(nn.Module): + def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + if upsample: + self.upsample = Upsample(blur_kernel) + + self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, input, style, feat, skip=None): + out = self.conv(input, style) + out = out + self.bias + + # warping + xs = np.linspace(-1, 1, input.size(2)) + xs = np.meshgrid(xs, xs) + xs = np.stack(xs, 2) + + xs = torch.tensor(xs, requires_grad=False).float().unsqueeze(0).repeat(input.size(0), 1, 1, 1).to(input.device) + + if skip is not None: + skip = self.upsample(skip) if hasattr(self, "upsample") else skip + out = out + skip + + sampler = torch.tanh(out[:, 0:2, :, :]) + mask = torch.sigmoid(out[:, 2:3, :, :]) + flow = sampler.permute(0, 2, 3, 1) + xs + + feat_warp = F.grid_sample(feat, flow) * mask + + return feat_warp, feat_warp + input * (1.0 - mask), out + +class ResBlock(nn.Module): + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1], downsample=True): + super().__init__() + + self.conv1 = ConvLayer(in_channel, in_channel, 3) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=downsample) + + self.skip = ConvLayer(in_channel, out_channel, 1, downsample=downsample, activate=False, bias=False) + + def forward(self, input): + out = self.conv1(input) + out = self.conv2(out) + + skip = self.skip(input) + out = (out + skip) / math.sqrt(2) + + return out \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/util.py b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/util.py new file mode 100644 index 0000000000000000000000000000000000000000..3c2f174dd68a735fc73df3cd54774ca32e214f16 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA/util.py @@ -0,0 +1,477 @@ +# coding: utf-8 + +""" +This file defines various neural network modules and utility functions, including convolutional and residual blocks, +normalizations, and functions for spatial transformation and tensor manipulation. +""" + +from torch import nn +import torch.nn.functional as F +import torch +import torch.nn.utils.spectral_norm as spectral_norm +import math +import warnings +import collections.abc +from itertools import repeat + +def kp2gaussian(kp, spatial_size, kp_variance): + """ + Transform a keypoint into gaussian like representation + """ + mean = kp + + coordinate_grid = make_coordinate_grid(spatial_size, mean) + number_of_leading_dimensions = len(mean.shape) - 1 + shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape + coordinate_grid = coordinate_grid.view(*shape) + repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 1) + coordinate_grid = coordinate_grid.repeat(*repeats) + + # Preprocess kp shape + shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 3) + mean = mean.view(*shape) + + mean_sub = (coordinate_grid - mean) + + out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance) + + return out + + +def make_coordinate_grid(spatial_size, ref, **kwargs): + d, h, w = spatial_size + x = torch.arange(w).type(ref.dtype).to(ref.device) + y = torch.arange(h).type(ref.dtype).to(ref.device) + z = torch.arange(d).type(ref.dtype).to(ref.device) + + # NOTE: must be right-down-in + x = (2 * (x / (w - 1)) - 1) # the x axis faces to the right + y = (2 * (y / (h - 1)) - 1) # the y axis faces to the bottom + z = (2 * (z / (d - 1)) - 1) # the z axis faces to the inner + + yy = y.view(1, -1, 1).repeat(d, 1, w) + xx = x.view(1, 1, -1).repeat(d, h, 1) + zz = z.view(-1, 1, 1).repeat(1, h, w) + + meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3), zz.unsqueeze_(3)], 3) + + return meshed + + +class ConvT2d(nn.Module): + """ + Upsampling block for use in decoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, stride=2, padding=1, output_padding=1): + super(ConvT2d, self).__init__() + + self.convT = nn.ConvTranspose2d(in_features, out_features, kernel_size=kernel_size, stride=stride, + padding=padding, output_padding=output_padding) + self.norm = nn.InstanceNorm2d(out_features) + + def forward(self, x): + out = self.convT(x) + out = self.norm(out) + out = F.leaky_relu(out) + return out + + +class ResBlock3d(nn.Module): + """ + Res block, preserve spatial resolution. + """ + + def __init__(self, in_features, kernel_size, padding): + super(ResBlock3d, self).__init__() + self.conv1 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding) + self.conv2 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding) + self.norm1 = nn.BatchNorm3d(in_features, affine=True) + self.norm2 = nn.BatchNorm3d(in_features, affine=True) + + def forward(self, x): + out = self.norm1(x) + out = F.relu(out) + out = self.conv1(out) + out = self.norm2(out) + out = F.relu(out) + out = self.conv2(out) + out += x + return out + + +class UpBlock3d(nn.Module): + """ + Upsampling block for use in decoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(UpBlock3d, self).__init__() + + self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups) + self.norm = nn.BatchNorm3d(out_features, affine=True) + + def forward(self, x): + out = F.interpolate(x, scale_factor=(1, 2, 2)) + out = self.conv(out) + out = self.norm(out) + out = F.relu(out) + return out + + +class DownBlock2d(nn.Module): + """ + Downsampling block for use in encoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(DownBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups) + self.norm = nn.BatchNorm2d(out_features, affine=True) + self.pool = nn.AvgPool2d(kernel_size=(2, 2)) + + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = F.relu(out) + out = self.pool(out) + return out + + +class DownBlock3d(nn.Module): + """ + Downsampling block for use in encoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(DownBlock3d, self).__init__() + ''' + self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups, stride=(1, 2, 2)) + ''' + self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups) + self.norm = nn.BatchNorm3d(out_features, affine=True) + self.pool = nn.AvgPool3d(kernel_size=(1, 2, 2)) + + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = F.relu(out) + out = self.pool(out) + return out + + +class SameBlock2d(nn.Module): + """ + Simple block, preserve spatial resolution. + """ + + def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1, lrelu=False): + super(SameBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups) + self.norm = nn.BatchNorm2d(out_features, affine=True) + if lrelu: + self.ac = nn.LeakyReLU() + else: + self.ac = nn.ReLU() + + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = self.ac(out) + return out + + +class Encoder(nn.Module): + """ + Hourglass Encoder + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Encoder, self).__init__() + + down_blocks = [] + for i in range(num_blocks): + down_blocks.append(DownBlock3d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), min(max_features, block_expansion * (2 ** (i + 1))), kernel_size=3, padding=1)) + self.down_blocks = nn.ModuleList(down_blocks) + + def forward(self, x): + outs = [x] + for down_block in self.down_blocks: + outs.append(down_block(outs[-1])) + return outs + + +class Decoder(nn.Module): + """ + Hourglass Decoder + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Decoder, self).__init__() + + up_blocks = [] + + for i in range(num_blocks)[::-1]: + in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1))) + out_filters = min(max_features, block_expansion * (2 ** i)) + up_blocks.append(UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1)) + + self.up_blocks = nn.ModuleList(up_blocks) + self.out_filters = block_expansion + in_features + + self.conv = nn.Conv3d(in_channels=self.out_filters, out_channels=self.out_filters, kernel_size=3, padding=1) + self.norm = nn.BatchNorm3d(self.out_filters, affine=True) + + def forward(self, x): + out = x.pop() + for up_block in self.up_blocks: + out = up_block(out) + skip = x.pop() + out = torch.cat([out, skip], dim=1) + out = self.conv(out) + out = self.norm(out) + out = F.relu(out) + return out + + +class Hourglass(nn.Module): + """ + Hourglass architecture. + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Hourglass, self).__init__() + self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features) + self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features) + self.out_filters = self.decoder.out_filters + + def forward(self, x): + return self.decoder(self.encoder(x)) + + +class SPADE(nn.Module): + def __init__(self, norm_nc, label_nc): + super().__init__() + + self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) + nhidden = 128 + + self.mlp_shared = nn.Sequential( + nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1), + nn.ReLU()) + self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1) + self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1) + + def forward(self, x, segmap): + normalized = self.param_free_norm(x) + segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') + actv = self.mlp_shared(segmap) + gamma = self.mlp_gamma(actv) + beta = self.mlp_beta(actv) + out = normalized * (1 + gamma) + beta + return out + + +class SPADEResnetBlock(nn.Module): + def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation=1): + super().__init__() + # Attributes + self.learned_shortcut = (fin != fout) + fmiddle = min(fin, fout) + self.use_se = use_se + # create conv layers + self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation) + self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation) + if self.learned_shortcut: + self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) + # apply spectral norm if specified + if 'spectral' in norm_G: + self.conv_0 = spectral_norm(self.conv_0) + self.conv_1 = spectral_norm(self.conv_1) + if self.learned_shortcut: + self.conv_s = spectral_norm(self.conv_s) + # define normalization layers + self.norm_0 = SPADE(fin, label_nc) + self.norm_1 = SPADE(fmiddle, label_nc) + if self.learned_shortcut: + self.norm_s = SPADE(fin, label_nc) + + def forward(self, x, seg1): + x_s = self.shortcut(x, seg1) + dx = self.conv_0(self.actvn(self.norm_0(x, seg1))) + dx = self.conv_1(self.actvn(self.norm_1(dx, seg1))) + out = x_s + dx + return out + + def shortcut(self, x, seg1): + if self.learned_shortcut: + x_s = self.conv_s(self.norm_s(x, seg1)) + else: + x_s = x + return x_s + + def actvn(self, x): + return F.leaky_relu(x, 2e-1) + + +def filter_state_dict(state_dict, remove_name='fc'): + new_state_dict = {} + for key in state_dict: + if remove_name in key: + continue + new_state_dict[key] = state_dict[key] + return new_state_dict + + +class GRN(nn.Module): + """ GRN (Global Response Normalization) layer + """ + + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x + + +class LayerNorm(nn.Module): + r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.normalized_shape = (normalized_shape, ) + + def forward(self, x): + if self.data_format == "channels_last": + return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + elif self.data_format == "channels_first": + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def drop_path(x, drop_prob=0., training=False, scale_by_keep=True): + """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class DropPath(nn.Module): + """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None, scale_by_keep=True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + return parse + +to_2tuple = _ntuple(2) + + +def compute_grid_points(feature_volume: torch.Tensor): + d, h, w = feature_volume.shape[-3:] + grids = torch.meshgrid( + torch.linspace(-1, 1, d), + torch.linspace(-1, 1, h), + torch.linspace(-1, 1, w), + indexing="ij" + ) + + # NOTE: The 3D coordinates have to correspond to width, height and depth in this order. + # This is what torch.grid_sample expects. So, we flip. + return torch.stack(grids, dim=-1).to(feature_volume.device).flip(-1) + +def compute_2d_grid_points(h, w, device=torch.device("cpu")): + grids = torch.meshgrid( + torch.linspace(-1, 1, h), + torch.linspace(-1, 1, w), + indexing="ij" + ) + + # NOTE: The 3D coordinates have to correspond to width, height and depth in this order. + # This is what torch.grid_sample expects. So, we flip. + return torch.stack(grids, dim=-1).to(device).flip(-1) \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__init__.py b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/__init__.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cb60f4d9ae36330378526929a2139ac4368dcfa Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/__init__.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/__init__.cpython-311.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4b6a0d4a7ad366aee2436cc744a5aa032635cd0 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/__init__.cpython-311.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/__init__.cpython-313.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1da198dfbd3f976c48eef65ca57ed265a14fd0b Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/__init__.cpython-313.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/face_encoder.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/face_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06f654281e0c8a5a98f829bdb5d0789f045fe67e Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/face_encoder.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/face_encoder.cpython-311.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/face_encoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..778659a301a1903b2cf467ec573f7ddcb9045b71 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/face_encoder.cpython-311.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/face_encoder.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/face_encoder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f9dc1613358815ba47b0a6875e67a98f2fba431 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/face_encoder.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/face_encoder.cpython-313.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/face_encoder.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b58a3654cffd716cf90eb9570d61fbe6a69e1ce0 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/face_encoder.cpython-313.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/face_generator.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/face_generator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f2655b254a58e4326f4c7954e348f8f6fcfea02 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/face_generator.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/face_generator.cpython-311.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/face_generator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d2812915f73ac2b53dfb0bfa8f7b4248edf4cd4 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/face_generator.cpython-311.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/face_generator.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/face_generator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0bb80e3dca53e4392f53af4c53c51a8103fbd70 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/face_generator.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/face_generator.cpython-313.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/face_generator.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..469a88739b2bd1279e7e95c0258f9ae67417eee4 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/face_generator.cpython-313.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/flow_estimator.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/flow_estimator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15ab61c109a6cec0616951e2e76f8d47fdf1556a Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/flow_estimator.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/flow_estimator.cpython-311.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/flow_estimator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed6040dfc3b31e832460a80cc92e1379f34616fc Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/flow_estimator.cpython-311.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/flow_estimator.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/flow_estimator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4aed9b9aab931c5f707d9c4e7f2481647e05fad Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/flow_estimator.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/flow_estimator.cpython-313.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/flow_estimator.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..504651429ab902c50ef5bc913858164193c9df7f Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/flow_estimator.cpython-313.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/motion_encoder.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/motion_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4998ea83815b62a4222bb94c082462ab9a19e34 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/motion_encoder.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/motion_encoder.cpython-311.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/motion_encoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97ee0e33699030e1ddc2bd94612554538f7a8be3 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/motion_encoder.cpython-311.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/motion_encoder.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/motion_encoder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25f1b4114e17d9f555311362bbda1878ce61905a Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/motion_encoder.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/motion_encoder.cpython-313.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/motion_encoder.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f244616be6b300dba7b9b17b5a7d4c190081a1c4 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/motion_encoder.cpython-313.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/util.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61e982c41f549e64000c0f6a2c66556570780bcf Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/util.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/util.cpython-311.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/util.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bd10797b0d59c2b12be2e7c738f9d7eb49c8eab Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/util.cpython-311.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/util.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/util.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df67c0ff82f8c312ef7a346ee765711ced2c2122 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/__pycache__/util.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/face_encoder.py b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/face_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..e9ed3c5cae7eb4d84228d480a25f4e64bfe02f43 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/face_encoder.py @@ -0,0 +1,60 @@ +# coding: utf-8 + +""" +Appearance extractor(F) defined in paper, which maps the source image s to a 3D appearance feature volume. +""" + +import torch +from torch import nn +from .util import SameBlock2d, DownBlock2d, ResBlock3d +from torch.nn import functional as F + +class FaceEncoder(nn.Module): + + def __init__(self, image_size, image_channel, block_expansion, num_down_blocks, max_features, reshape_channel, reshape_depth, num_resblocks): + super(FaceEncoder, self).__init__() + self.image_size = image_size + self.image_channel = image_channel + self.block_expansion = block_expansion + self.num_down_blocks = num_down_blocks + self.max_features = max_features + self.reshape_channel = reshape_channel + self.reshape_depth = reshape_depth + + self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(3, 3), padding=(1, 1)) + + down_blocks = [] + for i in range(num_down_blocks): + in_features = min(max_features, block_expansion * (2 ** i)) + out_features = min(max_features, block_expansion * (2 ** (i + 1))) + down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) + self.down_blocks = nn.ModuleList(down_blocks) + + self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1) + + self.resblocks_3d = torch.nn.Sequential() + for i in range(num_resblocks): + self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1)) + + def forward(self, source_image): + if self.training: + return torch.utils.checkpoint.checkpoint( \ + self.manual_forward, *[source_image], + use_reentrant=False) + else: + return self.manual_forward(*[source_image]) + + def manual_forward(self, source_image): + if source_image.size(-1) != self.image_size: + source_image = F.interpolate(source_image, size=(self.image_size, self.image_size), mode='bilinear') + + out = self.first(source_image) # Bx3x256x256 -> Bx64x256x256 + + for i in range(len(self.down_blocks)): + out = self.down_blocks[i](out) + out = self.second(out) + bs, c, h, w = out.shape # ->Bx512x64x64 + + f_s = out.view(bs, self.reshape_channel, self.reshape_depth, h, w) # ->Bx32x16x64x64 + f_s = self.resblocks_3d(f_s) # ->Bx32x16x64x64 + return f_s diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/face_generator.py b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/face_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..e248476220a6ed8a8b5043740e50504ebba81731 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/face_generator.py @@ -0,0 +1,199 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent.parent)) +from model.head_animation.VASA1.building_blocks import USE_BIAS, ResBlock3d, ReshapeTo3DLayer, WSConv3d, ReshapeTo2DLayer +import math +# from .util import * +from model.head_animation.LIA_3d.util import * + +class AdaptiveGroupNorm(nn.GroupNorm): + def __init__(self, num_groups, num_features, eps=1e-5, affine=True): + super(AdaptiveGroupNorm, self).__init__(num_groups, num_features, eps, False) + self.num_features = num_features + + gen_max_channels, gen_embed_size = 512, 4 + self.u = nn.Parameter(torch.empty(num_features, gen_max_channels)) + self.v = nn.Parameter(torch.empty(gen_embed_size ** 2, 2)) + + nn.init.uniform_(self.u, a=-math.sqrt(3 / gen_max_channels), b=math.sqrt(3 / gen_max_channels)) + nn.init.uniform_(self.v, a=-math.sqrt(3 / gen_embed_size ** 2), b=math.sqrt(3 / gen_embed_size ** 2)) + + def forward(self, inputs, condition_emb): + outputs = super(AdaptiveGroupNorm, self).forward(inputs) + + param = self.u[None].matmul(condition_emb).matmul(self.v[None]) + ada_weight, ada_bias = param.split(1, dim=2) + + outputs = outputs * ada_weight[:, :, :, None, None] + ada_bias[:, :, :, None, None] + return outputs + +class ResBlock3dStar(nn.Module): + def __init__(self, in_channels: int, out_channels: int, num_channels_per_group: int, condition_dim: int): + super().__init__() + + if in_channels != out_channels: + self.skip_layer = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=USE_BIAS) + else: + self.skip_layer = lambda x: x + + self.agn1 = AdaptiveGroupNorm(in_channels // num_channels_per_group, in_channels) + self.conv1 = WSConv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=USE_BIAS) + + self.agn2 = AdaptiveGroupNorm(out_channels // num_channels_per_group, out_channels) + self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=USE_BIAS) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, inp, condition): + x = self.relu(self.agn1(inp, condition)) + x = self.conv1(x) + x = self.relu(self.agn2(x, condition)) + x = self.conv2(x) + x = self.skip_layer(inp) + x + return x + + +class FaceGenerator(nn.Module): + def __init__(self, size, reshape_channel, group_norm_channel, latent_dim, blur_kernel=[1, 3, 3, 1], channel_multiplier=1, outputsize=512, flag_estimate_occlusion_map=False): + super(FaceGenerator, self).__init__() + + self.size = size + self.latent_dim = latent_dim + self.flag_estimate_occlusion_map = flag_estimate_occlusion_map + + ## warping field generator + num_channels_per_group = group_norm_channel + app_fea_size = (512, 4, 4) + input_dim = app_fea_size[0] + + self.extend_layer = nn.Linear(latent_dim, app_fea_size[0] * app_fea_size[1] ** 2, bias=USE_BIAS) + self.conv1 = nn.Conv2d(latent_dim, 2048, kernel_size=1, bias=USE_BIAS) + self.reshap3d = ReshapeTo3DLayer(out_depth=4) + self.resblock1 = ResBlock3dStar(512, 256, num_channels_per_group, input_dim) + self.resblock2 = ResBlock3dStar(256, 128, num_channels_per_group, input_dim) + self.resblock3 = ResBlock3dStar(128, 64, num_channels_per_group, input_dim) + self.resblock4 = ResBlock3dStar(64, 32, num_channels_per_group, input_dim) + self.gn = nn.GroupNorm(32 // num_channels_per_group, 32, affine=not USE_BIAS) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv3d(32, 3, kernel_size=3, padding=1, bias=USE_BIAS) + + self.upsample = nn.Upsample(scale_factor=(2, 2, 2), mode="nearest") + self.upsample2 = nn.Upsample(scale_factor=(1, 2, 2), mode="nearest") + + self.extend_layer = nn.Linear(input_dim, app_fea_size[0] * app_fea_size[1] ** 2, bias=USE_BIAS) + self.warp_layer = nn.Conv2d(in_channels=app_fea_size[0], out_channels=app_fea_size[0], kernel_size=(1, 1), bias=USE_BIAS) + + d, h, w = 16, 64, 64 + grids = torch.meshgrid( + torch.linspace(-1, 1, d), + torch.linspace(-1, 1, h), + torch.linspace(-1, 1, w), + indexing="ij" + ) + self.identity_grid = torch.stack(grids, dim=-1).flip(-1) + + if self.flag_estimate_occlusion_map: + self.occlusion = nn.Conv2d(reshape_channel*d, 1, kernel_size=7, padding=3) + + ### generator + # Projection layers + self.projection = nn.Sequential( + ReshapeTo2DLayer(), + SameBlock2d(reshape_channel * d, 256, kernel_size=(3, 3), padding=(1, 1), lrelu=True), + nn.Conv2d(256, 256, kernel_size=1, stride=1) + ) + + input_channels = 256 + norm_G = 'spadespectralinstance' + label_num_channels = input_channels # 256 + out_channels = 64 + + self.fc = nn.Conv2d(input_channels, 2 * input_channels, 3, padding=1) + self.G_middle_0 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels) + self.G_middle_1 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels) + self.G_middle_2 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels) + self.G_middle_3 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels) + self.G_middle_4 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels) + self.G_middle_5 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels) + self.up_0 = SPADEResnetBlock(2 * input_channels, input_channels, norm_G, label_num_channels) + self.up_1 = SPADEResnetBlock(input_channels, out_channels, norm_G, label_num_channels) + self.up = nn.Upsample(scale_factor=2) + + self.conv_img = nn.Sequential( + nn.Conv2d(out_channels, 3 * (2 * 2), kernel_size=3, padding=1), + nn.PixelShuffle(upscale_factor=2) + ) + self.final_activation = nn.Tanh() + + def forward(self, tgt_latent, ref_feats): + if self.training: + return torch.utils.checkpoint.checkpoint( \ + self.manual_forward, *[tgt_latent, ref_feats], + use_reentrant=False) + else: + return self.manual_forward(*[tgt_latent, ref_feats]) + + def manual_forward(self, tgt_latent, ref_feats): + bs = tgt_latent.size(0) + + # generate warping field + z_emb = self.extend_layer(tgt_latent).view(tgt_latent.size(0), -1, 4, 4) + + batch_size, c, h, w = z_emb.shape + condition = z_emb.view(-1, c, h * w).clone() + + z = self.conv1(z_emb) + z = self.reshap3d(z) + + z = self.upsample(z) + z = self.resblock1(z, condition) + + z = self.upsample(z) + z = self.resblock2(z, condition) + + z = self.upsample2(z) + z = self.resblock3(z, condition) + + z = self.upsample2(z) + z = self.resblock4(z, condition) + + z = self.gn(z) + z = self.relu(z) + z = self.conv2(z) + deltas = F.tanh(z).permute(0, 2, 3, 4, 1) + + warping_field = self.identity_grid[None].to(tgt_latent.device) + deltas + warping_feature_volume = F.grid_sample(ref_feats, warping_field, mode="bilinear", align_corners=False) + + # decoding + seg = self.projection(warping_feature_volume) # Bx256x64x64 + + if self.flag_estimate_occlusion_map: + bs, _, d, h, w = warping_feature_volume.shape + warping_feature_volume_reshape = warping_feature_volume.view(bs, -1, h, w) + occlusion_map = torch.sigmoid(self.occlusion(warping_feature_volume_reshape)) # Bx1x64x64 + seg = seg * occlusion_map + + + x = self.fc(seg) # Bx512x64x64 + x = self.G_middle_0(x, seg) + x = self.G_middle_1(x, seg) + x = self.G_middle_2(x, seg) + x = self.G_middle_3(x, seg) + x = self.G_middle_4(x, seg) + x = self.G_middle_5(x, seg) + + x = self.up(x) # Bx512x64x64 -> Bx512x128x128 + x = self.up_0(x, seg) # Bx512x128x128 -> Bx256x128x128 + x = self.up(x) # Bx256x128x128 -> Bx256x256x256 + x = self.up_1(x, seg) # Bx256x256x256 -> Bx64x256x256 + + x = self.conv_img(F.leaky_relu(x, 2e-1)) # Bx64x256x256 -> Bx3xHxW + x = self.final_activation(x) # Bx3xHxW + + return x diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/flow_estimator.py b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/flow_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..f5839a2736df875fd76067eb0bde393099fba004 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/flow_estimator.py @@ -0,0 +1,49 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F + +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent.parent)) +from model.head_animation.LIA.motion_encoder import EqualLinear + +class Direction(nn.Module): + def __init__(self, latent_dim, num_direction): + super(Direction, self).__init__() + + self.weight = nn.Parameter(torch.randn(latent_dim, num_direction)) + + def forward(self, input): + weight = self.weight + 1e-8 + Q, R = torch.qr(weight) # get eignvector, orthogonal [n1, n2, n3, n4] + + if input is None: + return Q + else: + input_diag = torch.diag_embed(input) # alpha, diagonal matrix + out = torch.matmul(input_diag, Q.T) + out = torch.sum(out, dim=1) + return out + + +class FlowEstimator(nn.Module): + def __init__(self, latent_dim, motion_space=20): + super(FlowEstimator, self).__init__() + + fc = [EqualLinear(latent_dim, latent_dim)] + for i in range(3): + fc.append(EqualLinear(latent_dim, latent_dim)) + fc.append(EqualLinear(latent_dim, motion_space)) + self.fc = nn.Sequential(*fc) + + self.direction = Direction(latent_dim, motion_space) + + def forward(self, ref_fea, tgt_fea): + feats = self.fc(tgt_fea.view(tgt_fea.size(0), -1)) + + ref2tgt_mapping = self.direction(feats) + tgt_latent = ref_fea + ref2tgt_mapping # reference latent code -> target latent code + + return tgt_latent \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/motion_encoder.py b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/motion_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..2a11650049324b831334aa5b6d63cad9513c586a --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/motion_encoder.py @@ -0,0 +1,61 @@ +import torch +from torch import nn +from torch.nn import functional as F +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent.parent)) +from model.head_animation.LIA.modules import * + +class MotionEncoder(nn.Module): + def __init__(self, latent_dim, size=512): + super(MotionEncoder, self).__init__() + + self.size = size + + if self.size==256: + channel = [64, 128, 256, 512, 512, 512, 512] + + elif self.size==512: + channel = [32, 64, 128, 256, 512, 512, 512, 512] + + self.convs = nn.ModuleList() + self.convs.append(ConvLayer(3, channel[0], 1)) + + in_channel = channel[0] + for i in range(1, len(channel)): + out_channel = channel[i] + self.convs.append(ResBlock(in_channel, out_channel)) + in_channel = out_channel + + self.convs.append(EqualConv2d(in_channel, latent_dim, 4, padding=0, bias=False)) + self.convs = nn.Sequential(*self.convs) + + def forward(self, x): + if self.training: + return torch.utils.checkpoint.checkpoint( \ + self.manual_forward, *[x], + use_reentrant=False) + else: + return self.manual_forward(*[x]) + + def manual_forward(self, x): + + if x.size(-1) != self.size: + x = F.interpolate(x, size=(self.size, self.size), mode='bilinear') + + res = [] + h = x + for conv in self.convs: + h = conv(h) + res.append(h) + + res = res[::-1] + feats = res[2:] # from 8x8 to 512x512 + latent_code = res[0] + # [B * T, D] + latent_code = latent_code.view(x.size(0), -1) + return latent_code, feats + + + diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/util.py b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/util.py new file mode 100644 index 0000000000000000000000000000000000000000..3c2f174dd68a735fc73df3cd54774ca32e214f16 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/LIA_3d/util.py @@ -0,0 +1,477 @@ +# coding: utf-8 + +""" +This file defines various neural network modules and utility functions, including convolutional and residual blocks, +normalizations, and functions for spatial transformation and tensor manipulation. +""" + +from torch import nn +import torch.nn.functional as F +import torch +import torch.nn.utils.spectral_norm as spectral_norm +import math +import warnings +import collections.abc +from itertools import repeat + +def kp2gaussian(kp, spatial_size, kp_variance): + """ + Transform a keypoint into gaussian like representation + """ + mean = kp + + coordinate_grid = make_coordinate_grid(spatial_size, mean) + number_of_leading_dimensions = len(mean.shape) - 1 + shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape + coordinate_grid = coordinate_grid.view(*shape) + repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 1) + coordinate_grid = coordinate_grid.repeat(*repeats) + + # Preprocess kp shape + shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 3) + mean = mean.view(*shape) + + mean_sub = (coordinate_grid - mean) + + out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance) + + return out + + +def make_coordinate_grid(spatial_size, ref, **kwargs): + d, h, w = spatial_size + x = torch.arange(w).type(ref.dtype).to(ref.device) + y = torch.arange(h).type(ref.dtype).to(ref.device) + z = torch.arange(d).type(ref.dtype).to(ref.device) + + # NOTE: must be right-down-in + x = (2 * (x / (w - 1)) - 1) # the x axis faces to the right + y = (2 * (y / (h - 1)) - 1) # the y axis faces to the bottom + z = (2 * (z / (d - 1)) - 1) # the z axis faces to the inner + + yy = y.view(1, -1, 1).repeat(d, 1, w) + xx = x.view(1, 1, -1).repeat(d, h, 1) + zz = z.view(-1, 1, 1).repeat(1, h, w) + + meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3), zz.unsqueeze_(3)], 3) + + return meshed + + +class ConvT2d(nn.Module): + """ + Upsampling block for use in decoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, stride=2, padding=1, output_padding=1): + super(ConvT2d, self).__init__() + + self.convT = nn.ConvTranspose2d(in_features, out_features, kernel_size=kernel_size, stride=stride, + padding=padding, output_padding=output_padding) + self.norm = nn.InstanceNorm2d(out_features) + + def forward(self, x): + out = self.convT(x) + out = self.norm(out) + out = F.leaky_relu(out) + return out + + +class ResBlock3d(nn.Module): + """ + Res block, preserve spatial resolution. + """ + + def __init__(self, in_features, kernel_size, padding): + super(ResBlock3d, self).__init__() + self.conv1 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding) + self.conv2 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding) + self.norm1 = nn.BatchNorm3d(in_features, affine=True) + self.norm2 = nn.BatchNorm3d(in_features, affine=True) + + def forward(self, x): + out = self.norm1(x) + out = F.relu(out) + out = self.conv1(out) + out = self.norm2(out) + out = F.relu(out) + out = self.conv2(out) + out += x + return out + + +class UpBlock3d(nn.Module): + """ + Upsampling block for use in decoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(UpBlock3d, self).__init__() + + self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups) + self.norm = nn.BatchNorm3d(out_features, affine=True) + + def forward(self, x): + out = F.interpolate(x, scale_factor=(1, 2, 2)) + out = self.conv(out) + out = self.norm(out) + out = F.relu(out) + return out + + +class DownBlock2d(nn.Module): + """ + Downsampling block for use in encoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(DownBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups) + self.norm = nn.BatchNorm2d(out_features, affine=True) + self.pool = nn.AvgPool2d(kernel_size=(2, 2)) + + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = F.relu(out) + out = self.pool(out) + return out + + +class DownBlock3d(nn.Module): + """ + Downsampling block for use in encoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(DownBlock3d, self).__init__() + ''' + self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups, stride=(1, 2, 2)) + ''' + self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups) + self.norm = nn.BatchNorm3d(out_features, affine=True) + self.pool = nn.AvgPool3d(kernel_size=(1, 2, 2)) + + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = F.relu(out) + out = self.pool(out) + return out + + +class SameBlock2d(nn.Module): + """ + Simple block, preserve spatial resolution. + """ + + def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1, lrelu=False): + super(SameBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups) + self.norm = nn.BatchNorm2d(out_features, affine=True) + if lrelu: + self.ac = nn.LeakyReLU() + else: + self.ac = nn.ReLU() + + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = self.ac(out) + return out + + +class Encoder(nn.Module): + """ + Hourglass Encoder + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Encoder, self).__init__() + + down_blocks = [] + for i in range(num_blocks): + down_blocks.append(DownBlock3d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), min(max_features, block_expansion * (2 ** (i + 1))), kernel_size=3, padding=1)) + self.down_blocks = nn.ModuleList(down_blocks) + + def forward(self, x): + outs = [x] + for down_block in self.down_blocks: + outs.append(down_block(outs[-1])) + return outs + + +class Decoder(nn.Module): + """ + Hourglass Decoder + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Decoder, self).__init__() + + up_blocks = [] + + for i in range(num_blocks)[::-1]: + in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1))) + out_filters = min(max_features, block_expansion * (2 ** i)) + up_blocks.append(UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1)) + + self.up_blocks = nn.ModuleList(up_blocks) + self.out_filters = block_expansion + in_features + + self.conv = nn.Conv3d(in_channels=self.out_filters, out_channels=self.out_filters, kernel_size=3, padding=1) + self.norm = nn.BatchNorm3d(self.out_filters, affine=True) + + def forward(self, x): + out = x.pop() + for up_block in self.up_blocks: + out = up_block(out) + skip = x.pop() + out = torch.cat([out, skip], dim=1) + out = self.conv(out) + out = self.norm(out) + out = F.relu(out) + return out + + +class Hourglass(nn.Module): + """ + Hourglass architecture. + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Hourglass, self).__init__() + self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features) + self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features) + self.out_filters = self.decoder.out_filters + + def forward(self, x): + return self.decoder(self.encoder(x)) + + +class SPADE(nn.Module): + def __init__(self, norm_nc, label_nc): + super().__init__() + + self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) + nhidden = 128 + + self.mlp_shared = nn.Sequential( + nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1), + nn.ReLU()) + self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1) + self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1) + + def forward(self, x, segmap): + normalized = self.param_free_norm(x) + segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') + actv = self.mlp_shared(segmap) + gamma = self.mlp_gamma(actv) + beta = self.mlp_beta(actv) + out = normalized * (1 + gamma) + beta + return out + + +class SPADEResnetBlock(nn.Module): + def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation=1): + super().__init__() + # Attributes + self.learned_shortcut = (fin != fout) + fmiddle = min(fin, fout) + self.use_se = use_se + # create conv layers + self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation) + self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation) + if self.learned_shortcut: + self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) + # apply spectral norm if specified + if 'spectral' in norm_G: + self.conv_0 = spectral_norm(self.conv_0) + self.conv_1 = spectral_norm(self.conv_1) + if self.learned_shortcut: + self.conv_s = spectral_norm(self.conv_s) + # define normalization layers + self.norm_0 = SPADE(fin, label_nc) + self.norm_1 = SPADE(fmiddle, label_nc) + if self.learned_shortcut: + self.norm_s = SPADE(fin, label_nc) + + def forward(self, x, seg1): + x_s = self.shortcut(x, seg1) + dx = self.conv_0(self.actvn(self.norm_0(x, seg1))) + dx = self.conv_1(self.actvn(self.norm_1(dx, seg1))) + out = x_s + dx + return out + + def shortcut(self, x, seg1): + if self.learned_shortcut: + x_s = self.conv_s(self.norm_s(x, seg1)) + else: + x_s = x + return x_s + + def actvn(self, x): + return F.leaky_relu(x, 2e-1) + + +def filter_state_dict(state_dict, remove_name='fc'): + new_state_dict = {} + for key in state_dict: + if remove_name in key: + continue + new_state_dict[key] = state_dict[key] + return new_state_dict + + +class GRN(nn.Module): + """ GRN (Global Response Normalization) layer + """ + + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x + + +class LayerNorm(nn.Module): + r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.normalized_shape = (normalized_shape, ) + + def forward(self, x): + if self.data_format == "channels_last": + return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + elif self.data_format == "channels_first": + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def drop_path(x, drop_prob=0., training=False, scale_by_keep=True): + """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class DropPath(nn.Module): + """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None, scale_by_keep=True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + return parse + +to_2tuple = _ntuple(2) + + +def compute_grid_points(feature_volume: torch.Tensor): + d, h, w = feature_volume.shape[-3:] + grids = torch.meshgrid( + torch.linspace(-1, 1, d), + torch.linspace(-1, 1, h), + torch.linspace(-1, 1, w), + indexing="ij" + ) + + # NOTE: The 3D coordinates have to correspond to width, height and depth in this order. + # This is what torch.grid_sample expects. So, we flip. + return torch.stack(grids, dim=-1).to(feature_volume.device).flip(-1) + +def compute_2d_grid_points(h, w, device=torch.device("cpu")): + grids = torch.meshgrid( + torch.linspace(-1, 1, h), + torch.linspace(-1, 1, w), + indexing="ij" + ) + + # NOTE: The 3D coordinates have to correspond to width, height and depth in this order. + # This is what torch.grid_sample expects. So, we flip. + return torch.stack(grids, dim=-1).to(device).flip(-1) \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/__init__.py b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/__pycache__/__init__.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..197599b2e60076a87ee32f74b2b72c1bc14768f8 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/__pycache__/__init__.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/__pycache__/__init__.cpython-311.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a646d7f324e0303361bbdba21cc016e75d39ce5 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/__pycache__/__init__.cpython-311.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/__pycache__/building_blocks.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/__pycache__/building_blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b181416dc5c0c617470dff9ee4183c0536237569 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/__pycache__/building_blocks.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/__pycache__/building_blocks.cpython-311.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/__pycache__/building_blocks.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a225f41cf924788fd58c2348f6732d09c5ac8fd Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/__pycache__/building_blocks.cpython-311.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/__pycache__/loss.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/__pycache__/loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6aad8ffd42bcb3d9012a43420dde0e3d7a22c7f4 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/__pycache__/loss.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/__pycache__/loss.cpython-311.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/__pycache__/loss.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d79f54516c10447b4b7c77a9859de2d3c85ed446 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/__pycache__/loss.cpython-311.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/building_blocks.py b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/building_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..e6422294bceed665823526822cc103924699888a --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/building_blocks.py @@ -0,0 +1,162 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils.parametrizations import spectral_norm + + +USE_BIAS = False + + +# https://github.com/joe-siyuan-qiao/WeightStandardization?tab=readme-ov-file#pytorch +class WSConv2d(nn.Conv2d): + def __init__(self, *args, **kwargs): + super(WSConv2d, self).__init__(*args, **kwargs) + + def forward(self, inp): + weight = self.weight + weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True) + weight = weight - weight_mean + std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5 + weight = weight / std.expand_as(weight) + return F.conv2d(inp, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +class WSConv3d(nn.Conv3d): + def __init__(self, *args, **kwargs): + super(WSConv3d, self).__init__(*args, **kwargs) + + def forward(self, inp): + weight = self.weight + weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True).mean(dim=4, keepdim=True) + weight = weight - weight_mean + std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1, 1) + 1e-5 + weight = weight / std.expand_as(weight) + return F.conv3d(inp, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +class ResBlock2d(nn.Module): + def __init__(self, in_channels: int, out_channels: int, num_channels_per_group: int, use_spectral_norm: bool = False): + super().__init__() + + norm_func = lambda x: x + if use_spectral_norm: + norm_func = spectral_norm + + if in_channels != out_channels: + self.skip_layer = norm_func(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=USE_BIAS)) + else: + self.skip_layer = lambda x: x + + self.layers = nn.Sequential( + nn.GroupNorm(in_channels // num_channels_per_group, in_channels, affine=not USE_BIAS), + nn.ReLU(inplace=True), + norm_func(WSConv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=USE_BIAS)), + nn.GroupNorm(out_channels // num_channels_per_group, out_channels, affine=not USE_BIAS), + nn.ReLU(inplace=True), + norm_func(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=USE_BIAS)), + ) + + def forward(self, inp: torch.Tensor): + return self.skip_layer(inp) + self.layers(inp) + + +class ResBlock3d(nn.Module): + def __init__(self, in_channels: int, out_channels: int, num_channels_per_group: int): + super().__init__() + + if in_channels != out_channels: + self.skip_layer = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=USE_BIAS) + else: + self.skip_layer = lambda x: x + + self.layers = nn.Sequential( + nn.GroupNorm(in_channels // num_channels_per_group, in_channels, affine=not USE_BIAS), + nn.ReLU(inplace=True), + WSConv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=USE_BIAS), + nn.GroupNorm(out_channels // num_channels_per_group, out_channels, affine=not USE_BIAS), + nn.ReLU(inplace=True), + nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=USE_BIAS), + ) + + def forward(self, inp: torch.Tensor): + return self.skip_layer(inp) + self.layers(inp) + + +class ResBasic(nn.Module): + def __init__(self, in_channels: int, out_channels: int, stride: int, num_channels_per_group: int): + super().__init__() + + if stride != 1 and stride != 2: + raise NotImplementedError(f"Stride can be only 1 or 2 but '{stride}' is passed.") + + if in_channels != out_channels or stride != 1: + self.skip_layer = nn.Sequential( + WSConv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=USE_BIAS), + nn.GroupNorm(out_channels // num_channels_per_group, out_channels, affine=not USE_BIAS), + ) + else: + self.skip_layer = lambda x: x + + self.layers = nn.Sequential( + WSConv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=USE_BIAS), + nn.GroupNorm(out_channels // num_channels_per_group, out_channels, affine=not USE_BIAS), + nn.ReLU(inplace=True), + WSConv2d(out_channels, out_channels, kernel_size=1, bias=USE_BIAS), + nn.GroupNorm(out_channels // num_channels_per_group, out_channels, affine=not USE_BIAS), + ) + + + def forward(self, inp: torch.Tensor): + return F.relu(self.skip_layer(inp) + self.layers(inp)) + + +class ResBottleneck(nn.Module): + def __init__(self, in_channels: int, out_channels: int, stride: int, num_channels_per_group: int): + super().__init__() + + if stride != 1 and stride != 2: + raise NotImplementedError(f"Stride can be only 1 or 2 but '{stride}' is passed.") + + if in_channels != out_channels or stride != 1: + self.skip_layer = nn.Sequential( + WSConv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=USE_BIAS), + nn.GroupNorm(out_channels // num_channels_per_group, out_channels, affine=not USE_BIAS), + ) + else: + self.skip_layer = lambda x: x + + temp_out_channels = out_channels // 4 + self.layers = nn.Sequential( + WSConv2d(in_channels, temp_out_channels, kernel_size=1, bias=USE_BIAS), + nn.GroupNorm(temp_out_channels // num_channels_per_group, temp_out_channels, affine=not USE_BIAS), + nn.ReLU(inplace=True), + WSConv2d(temp_out_channels, temp_out_channels, kernel_size=3, stride=stride, padding=1, bias=USE_BIAS), + nn.GroupNorm(temp_out_channels // num_channels_per_group, temp_out_channels, affine=not USE_BIAS), + nn.ReLU(inplace=True), + WSConv2d(temp_out_channels, out_channels, kernel_size=1, bias=USE_BIAS), + nn.GroupNorm(out_channels // num_channels_per_group, out_channels, affine=not USE_BIAS), + ) + + + def forward(self, inp: torch.Tensor): + return F.relu(self.skip_layer(inp) + self.layers(inp)) + + +class ReshapeTo3DLayer(nn.Module): + def __init__(self, out_depth: int): + super().__init__() + + self.out_depth = out_depth + + def forward(self, inp: torch.Tensor): + batch_size, channels, height, width = inp.shape + return inp.view(batch_size, channels // self.out_depth, self.out_depth, height, width) + + +class ReshapeTo2DLayer(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, inp: torch.Tensor): + batch_size, channels, depth, height, width = inp.shape + return inp.view(batch_size, channels * depth, height, width) \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/discriminator.py b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..f6670e50120249195665ad6734a35c6246524f5b --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/discriminator.py @@ -0,0 +1,48 @@ +# Adapted from https://github.com/eriklindernoren/PyTorch-GAN/tree/master, which is MIT license + +import torch +import torch.nn as nn +from torch.nn.utils.parametrizations import spectral_norm + + + +class DiscriminatorBlock(nn.Module): + def __init__(self, in_filters, out_filters, normalization=True): + super(DiscriminatorBlock, self).__init__() + + self.layers = nn.Sequential() + self.layers.append( + spectral_norm(nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)) + ) + if normalization: + self.layers.append(nn.InstanceNorm2d(out_filters)) + self.layers.append(nn.LeakyReLU(0.2, inplace=True)) + + def forward(self, inp): + return self.layers(inp) + + +class Discriminator(nn.Module): + def __init__(self, in_channels=3): + super(Discriminator, self).__init__() + + self.down_blocks = nn.ModuleList( + [ + DiscriminatorBlock(in_channels, 64, normalization=False), + DiscriminatorBlock(64, 128), + DiscriminatorBlock(128, 256), + DiscriminatorBlock(256, 512), + ] + ) + self.final_layer = nn.Sequential( + nn.ZeroPad2d((1, 0, 1, 0)), + spectral_norm(nn.Conv2d(512, 1, 4, padding=1, bias=False)) + ) + + def forward(self, inp): + feature_maps = [] + for block in self.down_blocks: + inp = block(inp) + feature_maps.append(inp) + + return feature_maps, self.final_layer(inp) diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/face_encoder.py b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/face_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..e9a3b0876b1c8ca9597ed41fd5f1481afca3f3ae --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/face_encoder.py @@ -0,0 +1,39 @@ +import torch +import torch.nn as nn +import sys +from pathlib import Path +from model.head_animation.VASA1.building_blocks import USE_BIAS, ResBlock2d, ResBlock3d, ReshapeTo3DLayer +from model.head_animation.VASA1.resnet50 import Resnet50 + +class FaceEncoder(nn.Module): + def __init__(self, latent_dim: int, normalize_output: bool): + super().__init__() + + num_channels_per_group = 8 + self.VolumetricFieldEncoder = nn.Sequential( + # 2D conv layers + nn.Conv2d(3, 64, kernel_size=7, padding=3, bias=USE_BIAS), + ResBlock2d(64, 128, num_channels_per_group), + nn.AvgPool2d(2, 2), + ResBlock2d(128, 256, num_channels_per_group), + nn.AvgPool2d(2, 2), + ResBlock2d(256, 512, num_channels_per_group), + nn.AvgPool2d(2, 2), + # Prepare for reshaping + nn.GroupNorm(512 // num_channels_per_group, 512, affine=not USE_BIAS), + nn.ReLU(inplace=True), + nn.Conv2d(512, 96 * 16, kernel_size=1, bias=USE_BIAS), + # Reshape 2D tensor as a 3D tensor with depth 16. + ReshapeTo3DLayer(16), + # 3D conv layers + ResBlock3d(96, 96, num_channels_per_group), + ResBlock3d(96, 96, num_channels_per_group), + ResBlock3d(96, 96, num_channels_per_group), + ) + + self.global_descriptor_encoder = Resnet50(input_dim=3, output_dim=latent_dim, normalize_output=normalize_output) + + def forward(self, inp: torch.Tensor): + feature_volume = self.VolumetricFieldEncoder(inp) + global_descriptor = self.global_descriptor_encoder(inp) + return [feature_volume, global_descriptor] diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/face_generator.py b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/face_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..7c9154a2d4523d19a28c66a1971504b28f6be1a4 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/face_generator.py @@ -0,0 +1,39 @@ +import torch +import torch.nn as nn +from torch.nn.utils.parametrizations import spectral_norm + +from model.head_animation.VASA1.building_blocks import USE_BIAS, ResBlock2d, ReshapeTo2DLayer + +class Generator(nn.Module): + def __init__(self): + super().__init__() + + num_channels_per_group = 8 + self.layers = nn.Sequential( + # Projection layers + ReshapeTo2DLayer(), + spectral_norm(nn.Conv2d(96 * 16, 512, kernel_size=1, bias=USE_BIAS)), + # Residual blocks + ResBlock2d(512, 512, num_channels_per_group, use_spectral_norm=True), + ResBlock2d(512, 512, num_channels_per_group, use_spectral_norm=True), + ResBlock2d(512, 512, num_channels_per_group, use_spectral_norm=True), + ResBlock2d(512, 512, num_channels_per_group, use_spectral_norm=True), + ResBlock2d(512, 512, num_channels_per_group, use_spectral_norm=True), + ResBlock2d(512, 512, num_channels_per_group, use_spectral_norm=True), + ResBlock2d(512, 512, num_channels_per_group, use_spectral_norm=True), + ResBlock2d(512, 512, num_channels_per_group, use_spectral_norm=True), + # Upsample layers + nn.Upsample(scale_factor=(2, 2), mode="nearest"), + ResBlock2d(512, 256, num_channels_per_group, use_spectral_norm=True), + nn.Upsample(scale_factor=(2, 2), mode="nearest"), + ResBlock2d(256, 128, num_channels_per_group, use_spectral_norm=True), + nn.Upsample(scale_factor=(2, 2), mode="nearest"), + ResBlock2d(128, 64, num_channels_per_group, use_spectral_norm=True), + # Final layers + nn.GroupNorm(64 // num_channels_per_group, 64, affine=not USE_BIAS), + nn.ReLU(inplace=True), + spectral_norm(nn.Conv2d(64, 3, kernel_size=3, padding=1, bias=USE_BIAS)), + ) + + def forward(self, inp: torch.Tensor): + return self.layers(inp) \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/flow_estimator.py b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/flow_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..520c575048092f6fbcce5df0dd7bdff3dd3a33b8 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/flow_estimator.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from model.head_animation.VASA1.nonrigid_pose_encoder import NonrigidPoseEncoder, compute_warping_grid +from model.head_animation.VASA1.util import compute_grid_points + + +class FlowEstimator(nn.Module): + def __init__(self, latent_dim): + super().__init__() + self.latent_dim = latent_dim + self.nonrigid_pose_encoder = NonrigidPoseEncoder(input_dim=latent_dim * 2) + + def generate_warping_feature_volume(self, feature_volume, global_descriptor, expression_code, rigid_pose, inverse): + + nonrigid_pose = self.nonrigid_pose_encoder(torch.cat((expression_code, global_descriptor), dim=1)) + + # We warp the source volume to construct the canonical feature volume. + identity_grid = compute_grid_points(feature_volume) + + warping_grid = compute_warping_grid( + rotation=rigid_pose["rotation"], + translation=rigid_pose["translation"], + nonrigid_pose=nonrigid_pose, + identity_grid=identity_grid, + inverse=inverse, + ) + + warping_feature_volume = F.grid_sample( + feature_volume, + warping_grid, + mode="bilinear", + align_corners=False, + ) + + #### TODO: Add volume_refiner ? + # if self.use_canonical_volume_refiner: + # canonical_feature_volume = self.canonical_volume_refiner(canonical_feature_volume) + + return warping_feature_volume + + def forward(self, src_dict, tgt_dict): + + global_descriptor = src_dict["global_descriptor"] + + ## generate_canonical_volume + canonical_feature_volume = self.generate_warping_feature_volume(src_dict["feature_volume"], global_descriptor, src_dict["expression_code"], src_dict["rigid_pose"], inverse=False) + + ## warp canonical_feature_volume + warped_driving_feature_volume = self.generate_warping_feature_volume(canonical_feature_volume, global_descriptor, tgt_dict["expression_code"], tgt_dict["rigid_pose"], inverse=True) + + return canonical_feature_volume, warped_driving_feature_volume \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/hopenet.py b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/hopenet.py new file mode 100644 index 0000000000000000000000000000000000000000..98301fad4a653cb13dd854c8ea961235acd00376 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/hopenet.py @@ -0,0 +1,165 @@ +# Copied from https://github.com/natanielruiz/deep-head-pose/blob/master/code/hopenet.py +# Apache license 2.0: https://github.com/natanielruiz/deep-head-pose/tree/master?tab=License-1-ov-file#readme + +import torch +import torch.nn as nn +import math +import cv2 + + +def headposeprediction_to_euler(yaw: torch.Tensor, pitch: torch.Tensor, roll: torch.Tensor): + # Conversion from bins to radians is adapted from: + # https://github.com/natanielruiz/deep-head-pose/blob/f7bbb9981c2953c2eca67748d6492a64c8243946/code/test_hopenet.py#L120 + # Apache 2.0 license. + bin_indices = torch.arange(0, 66, device=yaw.device, dtype=yaw.dtype) + + # From bins to angles between [-99, +95] + yaw = (yaw.softmax(1) * bin_indices).sum(1, keepdim=True) * 3 - 99 + pitch = (pitch.softmax(1) * bin_indices).sum(1, keepdim=True) * 3 - 99 + roll = (roll.softmax(1) * bin_indices).sum(1, keepdim=True) * 3 - 99 + + # Angles to radians + yaw = yaw * (torch.pi / 180) + pitch = pitch * (torch.pi / 180) + roll = roll * (torch.pi / 180) + + return yaw, pitch, roll + + +def euler_to_rotation_matrix(yaw: torch.Tensor, pitch: torch.Tensor, roll: torch.Tensor): + # Yaw is around y-axis + cos_yaw = torch.cos(yaw) + sin_yaw = torch.sin(yaw) + zeros = torch.zeros_like(yaw) + ones = torch.ones_like(yaw) + roty = torch.stack( + [ + torch.cat([cos_yaw, zeros, sin_yaw], dim=-1), + torch.cat([zeros, ones, zeros], dim=-1), + torch.cat([-sin_yaw, zeros, cos_yaw], dim=-1), + ], + dim=1, + ) + + # Pitch is around x-axis + cos_pitch = torch.cos(pitch) + sin_pitch = torch.sin(pitch) + rotx = torch.stack( + [ + torch.cat([ones, zeros, zeros], dim=-1), + torch.cat([zeros, cos_pitch, -sin_pitch], dim=-1), + torch.cat([zeros, sin_pitch, cos_pitch], dim=-1), + ], + dim=1, + ) + + # Roll is around z-axis + cos_roll = torch.cos(roll) + sin_roll = torch.sin(roll) + rotz = torch.stack( + [ + torch.cat([cos_roll, -sin_roll, zeros], dim=-1), + torch.cat([sin_roll, cos_roll, zeros], dim=-1), + torch.cat([zeros, zeros, ones], dim=-1), + ], + dim=1, + ) + + return rotx @ roty @ rotz + + +def draw_axis(img, rotmat, tdx=None, tdy=None, size = 100): + if tdx != None and tdy != None: + tdx = tdx + tdy = tdy + else: + height, width = img.shape[:2] + tdx = width / 2 + tdy = height / 2 + + # X-Axis + x1 = size * rotmat[0, 0] + tdx + y1 = size * rotmat[1, 0] + tdy + + # Y-Axis + x2 = size * rotmat[0, 1] + tdx + y2 = size * rotmat[1, 1] + tdy + + # Z-Axis + x3 = size * rotmat[0, 2] + tdx + y3 = size * rotmat[1, 2] + tdy + + cv2.line(img, (int(tdx), int(tdy)), (int(x1),int(y1)), (1,0,0), 3) + cv2.line(img, (int(tdx), int(tdy)), (int(x2),int(y2)), (0,1,0), 3) + cv2.line(img, (int(tdx), int(tdy)), (int(x3),int(y3)), (0,0,1), 3) + + return img + + +class Hopenet(nn.Module): + # Hopenet with 3 output layers for yaw, pitch and roll + # Predicts Euler angles by binning and regression with the expected value + def __init__(self, block, layers, num_bins): + self.inplanes = 64 + super(Hopenet, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AvgPool2d(7) + self.fc_yaw = nn.Linear(512 * block.expansion, num_bins) + self.fc_pitch = nn.Linear(512 * block.expansion, num_bins) + self.fc_roll = nn.Linear(512 * block.expansion, num_bins) + + # Alex: remove unused layer, necessary for DDP + # Vestigial layer from previous experiments + # self.fc_finetune = nn.Linear(512 * block.expansion + 3, 3) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + pre_yaw = self.fc_yaw(x) + pre_pitch = self.fc_pitch(x) + pre_roll = self.fc_roll(x) + + return pre_yaw, pre_pitch, pre_roll \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/loss.py b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..9da7c43bee5fcf306b66dddb434b9d8443deced4 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/loss.py @@ -0,0 +1,65 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import lpips + + +def compute_multiscale_vgg_loss(pred: torch.Tensor, gt: torch.Tensor, vgg_net: torch.nn.Module): + loss = 0 + downscale_factors = [1, 2, 4, 8] + for downscale_factor in downscale_factors: + if downscale_factor == 1: + level_pred = pred + level_gt = gt + else: + level_pred = torch.nn.functional.interpolate(pred, scale_factor=1 / downscale_factor, mode="bilinear") + level_gt = torch.nn.functional.interpolate(gt, scale_factor=1 / downscale_factor, mode="bilinear") + + loss = loss + vgg_net(level_pred, level_gt, normalize=True) + + return loss / len(downscale_factors) + + +def crop_and_resize(images: torch.Tensor, bboxes: torch.Tensor, size: int): + batch_size = images.shape[0] + + output_images = [] + for i in range(batch_size): + bbox = bboxes[i] + output_images.append( + F.interpolate( + images[i:i+1, :, bbox[0, 1]:bbox[1, 1], bbox[0, 0]:bbox[1, 0]], + size=(size, size), + mode="area", + ) + ) + + return torch.cat(output_images, dim=0) + + +def compute_face_embedding_loss(face_bboxes: torch.Tensor, pred: torch.Tensor, gt: torch.Tensor, face_net: torch.nn.Module): + # https://github.com/timesler/facenet-pytorch?tab=readme-ov-file#pretrained-models + + # TODO: If this proves to be useful, we can precompute these embeddings. + with torch.no_grad(): + gt_embedding = face_net(crop_and_resize(gt, face_bboxes, 160) * 2 - 1) + + return torch.abs(face_net(crop_and_resize(pred, face_bboxes, 160) * 2 - 1) - gt_embedding) + + +class VGGLoss(nn.Module): + def __init__(self): + super(VGGLoss, self).__init__() + + self.vgg_net = lpips.LPIPS(net="vgg").eval() + for param in self.vgg_net.parameters(): + param.requires_grad = False + + def forward(self, img_recon, img_real, facial_mask=None): + if img_real.min() < 0: + img_recon = (img_recon + 1) / 2 + img_real = (img_real + 1) / 2 + + vgg_loss = compute_multiscale_vgg_loss(img_recon, img_real, self.vgg_net) + + return vgg_loss, None \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/motion_encoder.py b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/motion_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c6d6531eeccb9063d1c15fbb72cdec253ab3cbca --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/motion_encoder.py @@ -0,0 +1,50 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import os +from pathlib import Path +import torchvision +from model.head_animation.VASA1.resnet18 import Resnet18 +from model.head_animation.VASA1.rigid_pose_encoder import RigidPoseEncoder +from model.head_animation.VASA1.hopenet import Hopenet, euler_to_rotation_matrix, headposeprediction_to_euler + +class MotionEncoder(nn.Module): + def __init__(self, latent_dim: int, size: int, normalize_output: bool, hopenet_checkpoint_path: str, use_gt_rotation: bool): + super().__init__() + + self.expression_encoder = Resnet18(input_dim=3, output_dim=latent_dim, normalize_output=normalize_output) + + # Hopenet + self.hopenet = Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 66) + self.hopenet.load_state_dict( + torch.load(Path(hopenet_checkpoint_path)), + strict=False, # Don't need vestigial layer + ) + self.hopenet.eval() + for param in self.hopenet.parameters(): + param.requires_grad = False + + self.hopenet_test_transformations = torchvision.transforms.Compose( + [ + torchvision.transforms.Resize(224), # only support 224 + torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ] + ) + def gt_rotation_callback(inp): + with torch.no_grad(): + # Important: need [0, 1] + if inp.min() <= -0.5: + inp = (inp + 1) / 2 # need [0, 1] + return euler_to_rotation_matrix(*headposeprediction_to_euler(*self.hopenet(self.hopenet_test_transformations(inp)))) + + self.rigid_pose_encoder = RigidPoseEncoder(use_gt_rotation=use_gt_rotation, gt_rotation_callback=gt_rotation_callback) + + def forward(self, inp: torch.Tensor): + bs = inp.size(0) + expression_code = self.expression_encoder(inp) # [batch_size, latent_dim] + rigid_pose = self.rigid_pose_encoder(inp) # rotation [batch_size, 3, 3] and translation [batch_size, 3] + + # so motion latent is 128 + 9 + 3 = 140? + motion_latent = torch.cat((expression_code, rigid_pose["rotation"].view(bs, -1), rigid_pose["translation"]), dim=1) + + return motion_latent, rigid_pose \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/nonrigid_pose_encoder.py b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/nonrigid_pose_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..8e7653c49961b33494a3a291101bc62404a1e121 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/nonrigid_pose_encoder.py @@ -0,0 +1,62 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from model.head_animation.VASA1.building_blocks import USE_BIAS, ResBlock3d, ReshapeTo3DLayer + + +def compute_warping_grid( + rotation: torch.Tensor, # (bs, 3, 3) + translation: torch.Tensor, # (bs, 3) + nonrigid_pose: torch.Tensor, # (bs, 3, 16, 16, 16) + identity_grid: torch.Tensor, # (16, 32, 32, 3) for a 256x256 input image. + inverse: bool, +): + batch_size = rotation.shape[0] + depth, height, width = identity_grid.shape[:3] + + nonrigid_pose_sampled = F.grid_sample( + input=nonrigid_pose, + grid=identity_grid.unsqueeze(0).expand(batch_size, -1, -1, -1, -1), + mode="bilinear", + align_corners=False, + ).permute(0, 2, 3, 4, 1) # (bs, 16, 32, 32, 3) + + identity_grid = identity_grid.view(-1, 3).unsqueeze(0) + nonrigid_pose_sampled = nonrigid_pose_sampled.reshape(batch_size, -1, 3) + + # Map grid points in the driving frame to the source frame. + if inverse: + warping_grid = (rotation.transpose(1, 2).unsqueeze(1) @ (identity_grid - nonrigid_pose_sampled - translation.unsqueeze(1)).unsqueeze(-1)).squeeze(-1) + else: + warping_grid = (rotation.unsqueeze(1) @ identity_grid.unsqueeze(-1)).squeeze(-1) + translation.unsqueeze(1) + nonrigid_pose_sampled + + return warping_grid.view(batch_size, depth, height, width, 3) + + +class NonrigidPoseEncoder(nn.Module): + def __init__(self, input_dim: int): + super().__init__() + + num_channels_per_group = 8 + self.layers = nn.Sequential( + # 2D conv layers + nn.Conv2d(input_dim, 2048, kernel_size=1, bias=USE_BIAS), + ReshapeTo3DLayer(out_depth=4), + ResBlock3d(512, 256, num_channels_per_group), + nn.Upsample(scale_factor=(2, 2, 2), mode="nearest"), + ResBlock3d(256, 128, num_channels_per_group), + nn.Upsample(scale_factor=(2, 2, 2), mode="nearest"), + ResBlock3d(128, 64, num_channels_per_group), + nn.Upsample(scale_factor=(1, 2, 2), mode="nearest"), + ResBlock3d(64, 32, num_channels_per_group), + nn.Upsample(scale_factor=(1, 2, 2), mode="nearest"), + nn.GroupNorm(32 // num_channels_per_group, 32, affine=not USE_BIAS), + nn.ReLU(inplace=True), + nn.Conv3d(32, 3, kernel_size=3, padding=1, bias=USE_BIAS), + ) + + def forward(self, inp: torch.Tensor): + if len(inp.shape) == 2: + inp = inp.unsqueeze(-1).unsqueeze(-1) + + return F.tanh(self.layers(inp)) \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/pipeline_unittest.py b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/pipeline_unittest.py new file mode 100644 index 0000000000000000000000000000000000000000..baad79cb6cdc9ec9beef591913ac4ed234fd5ba9 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/pipeline_unittest.py @@ -0,0 +1,93 @@ +import unittest +import torch +import torchvision +import os +from face_encoder import FaceEncoder +from motion_encoder import MotionEncoder +from flow_estimator import FlowEstimator +from face_generator import Generator + + +class TestPipeline(unittest.TestCase): + def setUp(self): + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.img_size = 512 + self.latent_dim = 128 # this value is not reported in the MagaPorait paper + self.use_gt_rotation = True + self.hopenet_checkpoint_path = "/mnt/weka/real_time_data/model/hopenet_robust_alpha1.pkl" + self.normalize_output=True + # self.resize_motion_encoder_input = False + + ## Epp Part + self.face_encoder = FaceEncoder(self.latent_dim, self.normalize_output).to(self.device) + + ## Emnt part + self.motion_encoder = MotionEncoder(latent_dim=self.latent_dim, size=self.img_size, normalize_output=self.normalize_output, use_gt_rotation=self.use_gt_rotation, hopenet_checkpoint_path=self.hopenet_checkpoint_path).to(self.device) + + # flow estimator + self.flow_estimator = FlowEstimator(latent_dim=self.latent_dim).to(self.device) + + # face generator + self.generator = Generator().to(self.device) + + def test_face_encoder(self): + """Test if the output shape is correct for a given input shape""" + batch_size = 2 + channels = 3 + + # Create a random input tensor + input_tensor = torch.randn(batch_size, channels, self.img_size, self.img_size).to(self.device) + + # Get output from the model + with torch.no_grad(): + [feature_volume, global_descriptor]= self.face_encoder(input_tensor) + print('Feature_volume shape:', feature_volume.shape) + print('Global_descriptor shape:', global_descriptor.shape) + + def test_motion_encoder(self): + """Test if the output shape is correct for a given input shape""" + batch_size = 2 + channels = 3 + + # Create a random input tensor + input_tensor = torch.randn(batch_size, channels, self.img_size, self.img_size).to(self.device) + + # Get output from the model + with torch.no_grad(): + motion_code, rigid_pose = self.motion_encoder(input_tensor) + + print('Motion code shape:', motion_code.shape) + for k, v in rigid_pose.items(): + print(f'Rigid pose output: {k}', v.shape) + + def test_flow_estimator(self): + """Test if the output shape is correct for a given input shape""" + batch_size = 2 + src_dict = {} + src_dict["feature_volume"] = torch.randn(batch_size, 96, 16, 64, 64).to(self.device) + src_dict["global_descriptor"] = torch.randn(batch_size, self.latent_dim).to(self.device) + src_dict["expression_code"] = torch.randn(batch_size, self.latent_dim).to(self.device) + src_dict["rigid_pose"] = {"rotation": torch.randn(batch_size, 3, 3).to(self.device), "translation": torch.randn(batch_size, 3).to(self.device)} + + tgt_dict = {} + tgt_dict["expression_code"] = torch.randn(batch_size, self.latent_dim).to(self.device) + tgt_dict["rigid_pose"] = {"rotation": torch.randn(batch_size, 3, 3).to(self.device), "translation": torch.randn(batch_size, 3).to(self.device)} + + canonical_feature_volume, warped_driving_feature_volume = self.flow_estimator(src_dict, tgt_dict) + + print('Canonical feature volume shape:', canonical_feature_volume.shape) + print('Warped driving feature volume shape:', warped_driving_feature_volume.shape) + + def test_generator(self): + """ + Test if the output shape is correct for a given input shape + """ + batch_size = 2 + warped_driving_feature_volume = torch.randn(batch_size, 96, 16, 64, 64).to(self.device) + predicted_image = self.generator(warped_driving_feature_volume) + print('Predicted image shape:', predicted_image.shape) + + assert predicted_image.shape == (batch_size, 3, self.img_size, self.img_size) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/resnet18.py b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/resnet18.py new file mode 100644 index 0000000000000000000000000000000000000000..3de8fc823189d9600558fa1359327cd1f1865f6b --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/resnet18.py @@ -0,0 +1,46 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from model.head_animation.VASA1.building_blocks import USE_BIAS, ResBasic + + +class Resnet18(nn.Module): + def __init__(self, input_dim: int, output_dim: int, normalize_output: bool = False): + super().__init__() + + num_channels_per_group = 8 + + # The following architecture follows Resnet-50. + self.layers = nn.Sequential( + # Initial layers + nn.Conv2d(input_dim, 64, kernel_size=7, stride=2, padding=3, bias=USE_BIAS), + nn.GroupNorm(64 // num_channels_per_group, 64, affine=not USE_BIAS), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, 2), + # Layer 1 + ResBasic(64, 64, 1, num_channels_per_group), + ResBasic(64, 64, 1, num_channels_per_group), + # Layer 2 + ResBasic(64, 128, 2, num_channels_per_group), + ResBasic(128, 128, 1, num_channels_per_group), + # Layer 3 + ResBasic(128, 256, 2, num_channels_per_group), + ResBasic(256, 256, 1, num_channels_per_group), + # Layer 4 + ResBasic(256, 512, 2, num_channels_per_group), + ResBasic(512, 512, 1, num_channels_per_group), + # Global average pooling. + nn.AdaptiveAvgPool2d((1, 1)), + # Flatten. + nn.Flatten(start_dim=1), + # Final layer + nn.Linear(512, output_dim), + ) + self.normalize_output = normalize_output + + def forward(self, inp: torch.Tensor): + out = self.layers(inp) + if self.normalize_output: + out = F.normalize(out, dim=1) + + return out \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/resnet50.py b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/resnet50.py new file mode 100644 index 0000000000000000000000000000000000000000..28d3640445d8d1aacecbe4921f69df6be831a5cd --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/resnet50.py @@ -0,0 +1,54 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from model.head_animation.VASA1.building_blocks import USE_BIAS, ResBottleneck + + +class Resnet50(nn.Module): + def __init__(self, input_dim: int, output_dim: int, normalize_output: bool = False): + super().__init__() + + num_channels_per_group = 8 + + # The following architecture follows Resnet-50. + self.layers = nn.Sequential( + # Initial layers + nn.Conv2d(input_dim, 64, kernel_size=7, stride=2, padding=3, bias=USE_BIAS), + nn.GroupNorm(64 // num_channels_per_group, 64, affine=not USE_BIAS), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, 2), + # Layer 1 + ResBottleneck(64, 256, 1, num_channels_per_group), + ResBottleneck(256, 256, 1, num_channels_per_group), + ResBottleneck(256, 256, 1, num_channels_per_group), + # Layer 2 + ResBottleneck(256, 512, 2, num_channels_per_group), + ResBottleneck(512, 512, 1, num_channels_per_group), + ResBottleneck(512, 512, 1, num_channels_per_group), + ResBottleneck(512, 512, 1, num_channels_per_group), + # Layer 3 + ResBottleneck(512, 1024, 2, num_channels_per_group), + ResBottleneck(1024, 1024, 1, num_channels_per_group), + ResBottleneck(1024, 1024, 1, num_channels_per_group), + ResBottleneck(1024, 1024, 1, num_channels_per_group), + ResBottleneck(1024, 1024, 1, num_channels_per_group), + ResBottleneck(1024, 1024, 1, num_channels_per_group), + # Layer 4 + ResBottleneck(1024, 2048, 2, num_channels_per_group), + ResBottleneck(2048, 2048, 1, num_channels_per_group), + ResBottleneck(2048, 2048, 1, num_channels_per_group), + # Global average pooling. + nn.AdaptiveAvgPool2d((1, 1)), + # Flatten. + nn.Flatten(start_dim=1), + # Final layer + nn.Linear(2048, output_dim), + ) + self.normalize_output = normalize_output + + def forward(self, inp: torch.Tensor): + out = self.layers(inp) + if self.normalize_output: + out = F.normalize(out, dim=1) + + return out \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/rigid_pose_encoder.py b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/rigid_pose_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..da3f5bcfbc6f32758a1939f35026eb06ec9308dc --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/rigid_pose_encoder.py @@ -0,0 +1,44 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from model.head_animation.VASA1.resnet18 import Resnet18 +from model.head_animation.VASA1.hopenet import euler_to_rotation_matrix, headposeprediction_to_euler + + +def rotation_6d_to_matrix(rotation_6d: torch.Tensor): + # Via Gram-Schmidt orthogonalization. + a1, a2 = rotation_6d[..., 0:3], rotation_6d[..., 3:6] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + + return torch.stack((b1, b2, b3), dim=-2) + + +class RigidPoseEncoder(nn.Module): + def __init__(self, use_gt_rotation, gt_rotation_callback): + super().__init__() + + self.use_gt_rotation = use_gt_rotation + if use_gt_rotation: + self.gt_rotation_callback = gt_rotation_callback + self.net = Resnet18(input_dim=3, output_dim=3) # 3D translation. + else: + self.net = Resnet18(input_dim=3, output_dim=6 + 3) # 6D rotation and 3D translation. + + def forward(self, inp: torch.Tensor): + out = self.net(inp) + + if not self.use_gt_rotation: + rot_out = rotation_6d_to_matrix(out[:, 0:6]) + trans_out = F.tanh(out[:, 6:9]) + else: + rot_out = self.gt_rotation_callback(inp) + trans_out = F.tanh(out[:, 0:3]) + + return { + "rotation": rot_out, + "translation": trans_out, + } \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/util.py b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/util.py new file mode 100644 index 0000000000000000000000000000000000000000..b8278564078dd61d58dffe8cdb3c7d6d8206fda8 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA1/util.py @@ -0,0 +1,25 @@ +import torch + +def compute_grid_points(feature_volume: torch.Tensor): + d, h, w = feature_volume.shape[-3:] + grids = torch.meshgrid( + torch.linspace(-1, 1, d), + torch.linspace(-1, 1, h), + torch.linspace(-1, 1, w), + indexing="ij" + ) + + # NOTE: The 3D coordinates have to correspond to width, height and depth in this order. + # This is what torch.grid_sample expects. So, we flip. + return torch.stack(grids, dim=-1).to(feature_volume.device).flip(-1) + +def compute_2d_grid_points(h, w, device=torch.device("cpu")): + grids = torch.meshgrid( + torch.linspace(-1, 1, h), + torch.linspace(-1, 1, w), + indexing="ij" + ) + + # NOTE: The 3D coordinates have to correspond to width, height and depth in this order. + # This is what torch.grid_sample expects. So, we flip. + return torch.stack(grids, dim=-1).to(device).flip(-1) \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/building_blocks.py b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/building_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..e6422294bceed665823526822cc103924699888a --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/building_blocks.py @@ -0,0 +1,162 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils.parametrizations import spectral_norm + + +USE_BIAS = False + + +# https://github.com/joe-siyuan-qiao/WeightStandardization?tab=readme-ov-file#pytorch +class WSConv2d(nn.Conv2d): + def __init__(self, *args, **kwargs): + super(WSConv2d, self).__init__(*args, **kwargs) + + def forward(self, inp): + weight = self.weight + weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True) + weight = weight - weight_mean + std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5 + weight = weight / std.expand_as(weight) + return F.conv2d(inp, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +class WSConv3d(nn.Conv3d): + def __init__(self, *args, **kwargs): + super(WSConv3d, self).__init__(*args, **kwargs) + + def forward(self, inp): + weight = self.weight + weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True).mean(dim=4, keepdim=True) + weight = weight - weight_mean + std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1, 1) + 1e-5 + weight = weight / std.expand_as(weight) + return F.conv3d(inp, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +class ResBlock2d(nn.Module): + def __init__(self, in_channels: int, out_channels: int, num_channels_per_group: int, use_spectral_norm: bool = False): + super().__init__() + + norm_func = lambda x: x + if use_spectral_norm: + norm_func = spectral_norm + + if in_channels != out_channels: + self.skip_layer = norm_func(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=USE_BIAS)) + else: + self.skip_layer = lambda x: x + + self.layers = nn.Sequential( + nn.GroupNorm(in_channels // num_channels_per_group, in_channels, affine=not USE_BIAS), + nn.ReLU(inplace=True), + norm_func(WSConv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=USE_BIAS)), + nn.GroupNorm(out_channels // num_channels_per_group, out_channels, affine=not USE_BIAS), + nn.ReLU(inplace=True), + norm_func(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=USE_BIAS)), + ) + + def forward(self, inp: torch.Tensor): + return self.skip_layer(inp) + self.layers(inp) + + +class ResBlock3d(nn.Module): + def __init__(self, in_channels: int, out_channels: int, num_channels_per_group: int): + super().__init__() + + if in_channels != out_channels: + self.skip_layer = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=USE_BIAS) + else: + self.skip_layer = lambda x: x + + self.layers = nn.Sequential( + nn.GroupNorm(in_channels // num_channels_per_group, in_channels, affine=not USE_BIAS), + nn.ReLU(inplace=True), + WSConv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=USE_BIAS), + nn.GroupNorm(out_channels // num_channels_per_group, out_channels, affine=not USE_BIAS), + nn.ReLU(inplace=True), + nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=USE_BIAS), + ) + + def forward(self, inp: torch.Tensor): + return self.skip_layer(inp) + self.layers(inp) + + +class ResBasic(nn.Module): + def __init__(self, in_channels: int, out_channels: int, stride: int, num_channels_per_group: int): + super().__init__() + + if stride != 1 and stride != 2: + raise NotImplementedError(f"Stride can be only 1 or 2 but '{stride}' is passed.") + + if in_channels != out_channels or stride != 1: + self.skip_layer = nn.Sequential( + WSConv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=USE_BIAS), + nn.GroupNorm(out_channels // num_channels_per_group, out_channels, affine=not USE_BIAS), + ) + else: + self.skip_layer = lambda x: x + + self.layers = nn.Sequential( + WSConv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=USE_BIAS), + nn.GroupNorm(out_channels // num_channels_per_group, out_channels, affine=not USE_BIAS), + nn.ReLU(inplace=True), + WSConv2d(out_channels, out_channels, kernel_size=1, bias=USE_BIAS), + nn.GroupNorm(out_channels // num_channels_per_group, out_channels, affine=not USE_BIAS), + ) + + + def forward(self, inp: torch.Tensor): + return F.relu(self.skip_layer(inp) + self.layers(inp)) + + +class ResBottleneck(nn.Module): + def __init__(self, in_channels: int, out_channels: int, stride: int, num_channels_per_group: int): + super().__init__() + + if stride != 1 and stride != 2: + raise NotImplementedError(f"Stride can be only 1 or 2 but '{stride}' is passed.") + + if in_channels != out_channels or stride != 1: + self.skip_layer = nn.Sequential( + WSConv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=USE_BIAS), + nn.GroupNorm(out_channels // num_channels_per_group, out_channels, affine=not USE_BIAS), + ) + else: + self.skip_layer = lambda x: x + + temp_out_channels = out_channels // 4 + self.layers = nn.Sequential( + WSConv2d(in_channels, temp_out_channels, kernel_size=1, bias=USE_BIAS), + nn.GroupNorm(temp_out_channels // num_channels_per_group, temp_out_channels, affine=not USE_BIAS), + nn.ReLU(inplace=True), + WSConv2d(temp_out_channels, temp_out_channels, kernel_size=3, stride=stride, padding=1, bias=USE_BIAS), + nn.GroupNorm(temp_out_channels // num_channels_per_group, temp_out_channels, affine=not USE_BIAS), + nn.ReLU(inplace=True), + WSConv2d(temp_out_channels, out_channels, kernel_size=1, bias=USE_BIAS), + nn.GroupNorm(out_channels // num_channels_per_group, out_channels, affine=not USE_BIAS), + ) + + + def forward(self, inp: torch.Tensor): + return F.relu(self.skip_layer(inp) + self.layers(inp)) + + +class ReshapeTo3DLayer(nn.Module): + def __init__(self, out_depth: int): + super().__init__() + + self.out_depth = out_depth + + def forward(self, inp: torch.Tensor): + batch_size, channels, height, width = inp.shape + return inp.view(batch_size, channels // self.out_depth, self.out_depth, height, width) + + +class ReshapeTo2DLayer(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, inp: torch.Tensor): + batch_size, channels, depth, height, width = inp.shape + return inp.view(batch_size, channels * depth, height, width) \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/canoncial_volume_generator.py b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/canoncial_volume_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..3be23f3372f55e6222b482f590997bdfdb6a1f9e --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/canoncial_volume_generator.py @@ -0,0 +1,62 @@ +import torch +import torch.nn as nn +from model.head_animation.VASA3.building_blocks import USE_BIAS, ResBlock3d + +class CanonicalVolumeGenerator(nn.Module): + def __init__(self): + super().__init__() + + num_channels_per_group = 32 + self.downblock1 = nn.Sequential( + ResBlock3d(96, 192, num_channels_per_group), + nn.AvgPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)), + ) + self.skip1 = nn.Sequential( + ResBlock3d(192, 192, num_channels_per_group), + nn.Upsample(scale_factor=(1, 2, 2), mode="nearest"), + ) + self.downblock2 = nn.Sequential( + ResBlock3d(192, 384, num_channels_per_group), + nn.AvgPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)), + ) + self.bottleneck1 = ResBlock3d(384, 512, num_channels_per_group) + self.bottleneck2 = ResBlock3d(512, 512, num_channels_per_group) + self.upblock1 = nn.Sequential( + ResBlock3d(512, 384, num_channels_per_group), + nn.Upsample(scale_factor=(1, 2, 2), mode="nearest"), + ) + self.skip2 = nn.Sequential( + ResBlock3d(384, 384, num_channels_per_group), + nn.Upsample(scale_factor=(1, 2, 2), mode="nearest"), + ) + + self.upblock2 = nn.Sequential( + ResBlock3d(384, 192, num_channels_per_group), + nn.Upsample(scale_factor=(1, 2, 2), mode="nearest"), + ) + + self.final_layer = nn.Sequential( + ResBlock3d(192, 96, num_channels_per_group), + nn.GroupNorm(96 // num_channels_per_group, 96, affine=not USE_BIAS), + nn.ReLU(inplace=True), + nn.Conv3d(96, 96, kernel_size=3, padding=1, bias=USE_BIAS), + ) + + def forward(self, inp: torch.Tensor): + down_out1 = self.downblock1(inp) + down_out2 = self.downblock2(down_out1) + + bottleneck_out1 = self.bottleneck1(down_out2) + bottleneck_out2 = self.bottleneck2(bottleneck_out1) + + up_out1 = self.upblock1(bottleneck_out2 + bottleneck_out1) + up_out2 = self.upblock2(up_out1 + self.skip2(down_out2)) + + + out = self.final_layer(up_out2 + self.skip1(down_out1)) + + # print(down_out2.shape, up_out1.shape) + # print(down_out1.shape, up_out2.shape) + + + return out \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/discriminator.py b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..f6670e50120249195665ad6734a35c6246524f5b --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/discriminator.py @@ -0,0 +1,48 @@ +# Adapted from https://github.com/eriklindernoren/PyTorch-GAN/tree/master, which is MIT license + +import torch +import torch.nn as nn +from torch.nn.utils.parametrizations import spectral_norm + + + +class DiscriminatorBlock(nn.Module): + def __init__(self, in_filters, out_filters, normalization=True): + super(DiscriminatorBlock, self).__init__() + + self.layers = nn.Sequential() + self.layers.append( + spectral_norm(nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)) + ) + if normalization: + self.layers.append(nn.InstanceNorm2d(out_filters)) + self.layers.append(nn.LeakyReLU(0.2, inplace=True)) + + def forward(self, inp): + return self.layers(inp) + + +class Discriminator(nn.Module): + def __init__(self, in_channels=3): + super(Discriminator, self).__init__() + + self.down_blocks = nn.ModuleList( + [ + DiscriminatorBlock(in_channels, 64, normalization=False), + DiscriminatorBlock(64, 128), + DiscriminatorBlock(128, 256), + DiscriminatorBlock(256, 512), + ] + ) + self.final_layer = nn.Sequential( + nn.ZeroPad2d((1, 0, 1, 0)), + spectral_norm(nn.Conv2d(512, 1, 4, padding=1, bias=False)) + ) + + def forward(self, inp): + feature_maps = [] + for block in self.down_blocks: + inp = block(inp) + feature_maps.append(inp) + + return feature_maps, self.final_layer(inp) diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/face_encoder.py b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/face_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..fa9b05b318abd0d0d6e69bbab042e477607eef80 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/face_encoder.py @@ -0,0 +1,46 @@ +import torch +import torch.nn as nn +import sys +from pathlib import Path +import torch.nn.functional as F +from model.head_animation.VASA3.building_blocks import USE_BIAS, ResBlock2d, ResBlock3d, ReshapeTo3DLayer +from model.head_animation.VASA3.resnet50 import Resnet50 + +class FaceEncoder(nn.Module): + def __init__(self, latent_dim: int, resize: bool, freeze: bool): + super().__init__() + + depth, channel = 16, 96 + self.resize = resize + + num_channels_per_group = 32 + self.VolumetricFieldEncoder = nn.Sequential( + # 2D conv layers + nn.Conv2d(3, 64, kernel_size=7, padding=3, bias=USE_BIAS), + ResBlock2d(64, 128, num_channels_per_group), + nn.AvgPool2d(2, 2), + ResBlock2d(128, 256, num_channels_per_group), + nn.AvgPool2d(2, 2), + ResBlock2d(256, 512, num_channels_per_group), + nn.AvgPool2d(2, 2), + # Prepare for reshaping + nn.GroupNorm(512 // num_channels_per_group, 512, affine=not USE_BIAS), + nn.ReLU(inplace=True), + nn.Conv2d(512, channel * depth, kernel_size=1, bias=USE_BIAS), + # Reshape 2D tensor as a 3D tensor with depth 16. + ReshapeTo3DLayer(depth), + # 3D conv layers + ResBlock3d(channel, channel, num_channels_per_group), + ResBlock3d(channel, channel, num_channels_per_group), + ResBlock3d(channel, channel, num_channels_per_group), + ) + + self.global_descriptor_encoder = Resnet50(input_dim=3, output_dim=latent_dim) + + def forward(self, inp: torch.Tensor): + feature_volume = self.VolumetricFieldEncoder(inp) + + if self.resize: + inp = F.interpolate(inp, size=(256, 256), mode='bilinear') + global_descriptor = self.global_descriptor_encoder(inp) + return [feature_volume, global_descriptor] diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/face_generator.py b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/face_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..6113d625adb1651387393b23c9e46ac7e8e89696 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/face_generator.py @@ -0,0 +1,41 @@ +import torch +import torch.nn as nn +from torch.nn.utils.parametrizations import spectral_norm + +from model.head_animation.VASA3.building_blocks import USE_BIAS, ResBlock2d, ReshapeTo2DLayer + +class Generator(nn.Module): + def __init__(self, freeze: bool): + super().__init__() + self.freeze = freeze + + num_channels_per_group = 32 + self.layers = nn.Sequential( + # Projection layers + ReshapeTo2DLayer(), + spectral_norm(nn.Conv2d(96 * 16, 512, kernel_size=1, bias=USE_BIAS)), + # Residual blocks + ResBlock2d(512, 512, num_channels_per_group, use_spectral_norm=True), + ResBlock2d(512, 512, num_channels_per_group, use_spectral_norm=True), + ResBlock2d(512, 512, num_channels_per_group, use_spectral_norm=True), + ResBlock2d(512, 512, num_channels_per_group, use_spectral_norm=True), + ResBlock2d(512, 512, num_channels_per_group, use_spectral_norm=True), + ResBlock2d(512, 512, num_channels_per_group, use_spectral_norm=True), + ResBlock2d(512, 512, num_channels_per_group, use_spectral_norm=True), + ResBlock2d(512, 512, num_channels_per_group, use_spectral_norm=True), + # Upsample layers + nn.Upsample(scale_factor=(2, 2), mode="nearest"), + ResBlock2d(512, 256, num_channels_per_group, use_spectral_norm=True), + nn.Upsample(scale_factor=(2, 2), mode="nearest"), + ResBlock2d(256, 128, num_channels_per_group, use_spectral_norm=True), + nn.Upsample(scale_factor=(2, 2), mode="nearest"), + ResBlock2d(128, 64, num_channels_per_group, use_spectral_norm=True), + # Final layers + nn.GroupNorm(64 // num_channels_per_group, 64, affine=not USE_BIAS), + nn.ReLU(inplace=True), + spectral_norm(nn.Conv2d(64, 3, kernel_size=3, padding=1, bias=USE_BIAS)), + nn.Tanh() + ) + def forward(self, inp: torch.Tensor): + return self.layers(inp) + diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/flow_estimator.py b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/flow_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..33c94322963906230a592487ca4ac8f11ee52cb8 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/flow_estimator.py @@ -0,0 +1,72 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from model.head_animation.VASA3.nonrigid_pose_encoder import NonrigidPoseEncoder +from model.head_animation.VASA3.util import compute_grid_points +from model.head_animation.VASA3.canoncial_volume_generator import CanonicalVolumeGenerator + + +class FlowEstimator(nn.Module): + def __init__(self, latent_dim): + super().__init__() + self.latent_dim = latent_dim + self.src_nonrigid_pose_encoder = NonrigidPoseEncoder(input_dim=latent_dim) + self.tgt_nonrigid_pose_encoder = NonrigidPoseEncoder(input_dim=latent_dim) + self.canonical_volume_generator = CanonicalVolumeGenerator() + + def generate_canonical_feature_volume(self, feature_volume, global_descriptor, expression_code, rigid_pose, inverse, batch_idx): + + nonrigid_deformation = self.src_nonrigid_pose_encoder(expression_code, global_descriptor) + rigid_transformation = self.compute_rigid_transformation(feature_volume, rigid_pose["rotation"], rigid_pose["translation"], inverse=inverse) + + # We warp the source volume to construct the canonical feature volume. + exp_feature_volume = F.grid_sample(feature_volume.float(), rigid_transformation.float(), padding_mode='zeros') + warping_feature_volume = F.grid_sample(exp_feature_volume.float(), nonrigid_deformation.float(), padding_mode='zeros') + + canonical_feature_volume = self.canonical_volume_generator(warping_feature_volume) + + return canonical_feature_volume, rigid_transformation, nonrigid_deformation + + + def generate_warping_feature_volume(self, feature_volume, global_descriptor, expression_code, rigid_pose, inverse, batch_idx): + + nonrigid_deformation = self.tgt_nonrigid_pose_encoder(expression_code, global_descriptor) + rigid_transformation = self.compute_rigid_transformation(feature_volume, rigid_pose["rotation"], rigid_pose["translation"], inverse=inverse) + + # We warp the canonical volume to construct the driving feature volume. + exp_feature_volume = F.grid_sample(feature_volume.float(), nonrigid_deformation.float(), padding_mode='zeros') + warping_feature_volume = F.grid_sample(exp_feature_volume.float(), rigid_transformation.float(), padding_mode='zeros') + + # if batch_idx == 200: + # import pdb; pdb.set_trace() + + return warping_feature_volume, rigid_transformation, nonrigid_deformation + + + def compute_rigid_transformation(self, feature_volume, rotation, translation, inverse): + + identity_grid = compute_grid_points(feature_volume) + identity_grid = identity_grid.view(-1, 3).unsqueeze(0) + + if inverse: + rigid_transformation = (rotation.transpose(1, 2).unsqueeze(1) @ (identity_grid - translation.unsqueeze(1)).unsqueeze(-1)).squeeze(-1) + else: + rigid_transformation = (rotation.unsqueeze(1) @ (identity_grid).unsqueeze(-1)).squeeze(-1) + translation.unsqueeze(1) + + batch_size, _, depth, height, width = feature_volume.shape + rigid_transformation = rigid_transformation.view(batch_size, depth, height, width, 3) + + return rigid_transformation + + def forward(self, src_dict, tgt_dict, batch_idx=None): + + global_descriptor = src_dict["global_descriptor"] + + ## generate_canonical_volume + canonical_feature_volume, src_rigid_grid, src_delta_grid = self.generate_canonical_feature_volume(src_dict["feature_volume"], global_descriptor, src_dict["expression_code"], src_dict["rigid_pose"], inverse=True, batch_idx=batch_idx) + + ## warp canonical_feature_volume + warped_driving_feature_volume, tgt_rigid_grid, tgt_delta_grid = self.generate_warping_feature_volume(canonical_feature_volume, global_descriptor, tgt_dict["expression_code"], tgt_dict["rigid_pose"], inverse=False, batch_idx=batch_idx) + + return canonical_feature_volume, warped_driving_feature_volume, src_delta_grid, tgt_delta_grid \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/hopenet.py b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/hopenet.py new file mode 100644 index 0000000000000000000000000000000000000000..e045aaef2797651671c1505a683abbfad34c3332 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/hopenet.py @@ -0,0 +1,167 @@ +# Copied from https://github.com/natanielruiz/deep-head-pose/blob/master/code/hopenet.py +# Apache license 2.0: https://github.com/natanielruiz/deep-head-pose/tree/master?tab=License-1-ov-file#readme + +import torch +import torch.nn as nn +import math +import cv2 + + +def headposeprediction_to_euler(yaw: torch.Tensor, pitch: torch.Tensor, roll: torch.Tensor): + # Conversion from bins to radians is adapted from: + # https://github.com/natanielruiz/deep-head-pose/blob/f7bbb9981c2953c2eca67748d6492a64c8243946/code/test_hopenet.py#L120 + # Apache 2.0 license. + bin_indices = torch.arange(0, 66, device=yaw.device, dtype=yaw.dtype) + + # From bins to angles between [-99, +95] + yaw = (yaw.softmax(1) * bin_indices).sum(1, keepdim=True) * 3 - 99 + pitch = (pitch.softmax(1) * bin_indices).sum(1, keepdim=True) * 3 - 99 + roll = (roll.softmax(1) * bin_indices).sum(1, keepdim=True) * 3 - 99 + + # Angles to radians + yaw = yaw * (torch.pi / 180) + pitch = pitch * (torch.pi / 180) + roll = roll * (torch.pi / 180) + + return yaw, pitch, roll + +def euler_to_rotation_matrix(yaw: torch.Tensor, pitch: torch.Tensor, roll: torch.Tensor): + # Yaw is around y-axis + cos_yaw = torch.cos(yaw) + sin_yaw = torch.sin(yaw) + zeros = torch.zeros_like(yaw) + ones = torch.ones_like(yaw) + roty = torch.stack( + [ + torch.cat([cos_yaw, zeros, sin_yaw], dim=-1), + torch.cat([zeros, ones, zeros], dim=-1), + torch.cat([-sin_yaw, zeros, cos_yaw], dim=-1), + ], + dim=1, + ) + + # Pitch is around x-axis + cos_pitch = torch.cos(pitch) + sin_pitch = torch.sin(pitch) + rotx = torch.stack( + [ + torch.cat([ones, zeros, zeros], dim=-1), + torch.cat([zeros, cos_pitch, -sin_pitch], dim=-1), + torch.cat([zeros, sin_pitch, cos_pitch], dim=-1), + ], + dim=1, + ) + + # Roll is around z-axis + cos_roll = torch.cos(roll) + sin_roll = torch.sin(roll) + rotz = torch.stack( + [ + torch.cat([cos_roll, -sin_roll, zeros], dim=-1), + torch.cat([sin_roll, cos_roll, zeros], dim=-1), + torch.cat([zeros, zeros, ones], dim=-1), + ], + dim=1, + ) + + return rotx @ roty @ rotz + + +def draw_axis(img, rotmat, tdx=None, tdy=None, size = 100): + if tdx != None and tdy != None: + tdx = tdx + tdy = tdy + else: + height, width = img.shape[:2] + tdx = width / 2 + tdy = height / 2 + + # X-Axis + x1 = size * rotmat[0, 0] + tdx + y1 = size * rotmat[1, 0] + tdy + + # Y-Axis + x2 = size * rotmat[0, 1] + tdx + y2 = size * rotmat[1, 1] + tdy + + # Z-Axis + x3 = size * rotmat[0, 2] + tdx + y3 = size * rotmat[1, 2] + tdy + + cv2.line(img, (int(tdx), int(tdy)), (int(x1),int(y1)), (1,0,0), 3) + cv2.line(img, (int(tdx), int(tdy)), (int(x2),int(y2)), (0,1,0), 3) + cv2.line(img, (int(tdx), int(tdy)), (int(x3),int(y3)), (0,0,1), 3) + + return img + + +class Hopenet(nn.Module): + # Hopenet with 3 output layers for yaw, pitch and roll + # Predicts Euler angles by binning and regression with the expected value + def __init__(self, block, layers, num_bins): + self.inplanes = 64 + super(Hopenet, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AvgPool2d(7) + self.fc_yaw = nn.Linear(512 * block.expansion, num_bins) + self.fc_pitch = nn.Linear(512 * block.expansion, num_bins) + self.fc_roll = nn.Linear(512 * block.expansion, num_bins) + + # Alex: remove unused layer, necessary for DDP + # Vestigial layer from previous experiments + # self.fc_finetune = nn.Linear(512 * block.expansion + 3, 3) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + pre_yaw = self.fc_yaw(x) + pre_pitch = self.fc_pitch(x) + pre_roll = self.fc_roll(x) + + return pre_yaw, pre_pitch, pre_roll + + + diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/loss.py b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..cb8132bc442656e1972e3e1b269a4b53665fefa5 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/loss.py @@ -0,0 +1,62 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import lpips + + +def compute_multiscale_vgg_loss(pred: torch.Tensor, gt: torch.Tensor, vgg_net: torch.nn.Module): + loss = 0 + downscale_factors = [1, 2, 4, 8] + for downscale_factor in downscale_factors: + if downscale_factor == 1: + level_pred = pred + level_gt = gt + else: + level_pred = torch.nn.functional.interpolate(pred, scale_factor=1 / downscale_factor, mode="bilinear") + level_gt = torch.nn.functional.interpolate(gt, scale_factor=1 / downscale_factor, mode="bilinear") + + loss = loss + vgg_net(level_pred, level_gt, normalize=True) + + return loss / len(downscale_factors) + + +def crop_and_resize(images: torch.Tensor, bboxes: torch.Tensor, size: int): + batch_size = images.shape[0] + + output_images = [] + for i in range(batch_size): + bbox = bboxes[i] + output_images.append( + F.interpolate( + images[i:i+1, :, bbox[0, 1]:bbox[1, 1], bbox[0, 0]:bbox[1, 0]], + size=(size, size), + mode="area", + ) + ) + + return torch.cat(output_images, dim=0) + + +def compute_face_embedding_loss(face_bboxes: torch.Tensor, pred: torch.Tensor, gt: torch.Tensor, face_net: torch.nn.Module): + # https://github.com/timesler/facenet-pytorch?tab=readme-ov-file#pretrained-models + + # TODO: If this proves to be useful, we can precompute these embeddings. + with torch.no_grad(): + gt_embedding = face_net(crop_and_resize(gt, face_bboxes, 160) * 2 - 1) + + return torch.abs(face_net(crop_and_resize(pred, face_bboxes, 160) * 2 - 1) - gt_embedding) + + +class VGGLoss(nn.Module): + def __init__(self): + super(VGGLoss, self).__init__() + + self.vgg_net = lpips.LPIPS(net="vgg").eval() + for param in self.vgg_net.parameters(): + param.requires_grad = False + + def forward(self, img_recon, img_real, facial_mask=None): + + vgg_loss = compute_multiscale_vgg_loss(img_recon, img_real, self.vgg_net) + + return vgg_loss, None \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/motion_encoder.py b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/motion_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..729a6c9362814cb373ee9aebf0da0b33ca342034 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/motion_encoder.py @@ -0,0 +1,60 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import os +from pathlib import Path +import torchvision +from model.head_animation.VASA3.resnet18 import Resnet18 +from model.head_animation.VASA3.rigid_pose_encoder import RigidPoseEncoder +from model.head_animation.VASA3.hopenet import Hopenet, euler_to_rotation_matrix, headposeprediction_to_euler + +class MotionEncoder(nn.Module): + def __init__(self, latent_dim: int, size: int, resize: bool, hopenet_checkpoint_path: str, use_gt_rotation: bool): + super().__init__() + self.resize=resize + self.expression_encoder = Resnet18(input_dim=3, output_dim=latent_dim, add_fc=True) + + # Hopenet + self.hopenet = Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 66) + self.hopenet.load_state_dict( + torch.load(Path(hopenet_checkpoint_path)), + strict=False, # Don't need vestigial layer + ) + self.hopenet.eval() + for param in self.hopenet.parameters(): + param.requires_grad = False + + self.hopenet_test_transformations = torchvision.transforms.Compose( + [ + torchvision.transforms.Resize(224), # only support 224 + torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ] + ) + def gt_rotation_callback(inp): + with torch.no_grad(): + # Important: need [0, 1] + if inp.min() <= -0.5: + inp_org = (inp + 1) / 2 # need [0, 1] + rot = euler_to_rotation_matrix(*headposeprediction_to_euler(*self.hopenet(self.hopenet_test_transformations(inp_org)))) + # img_org = inp.permute(0, 2,3, 1) + # img_org = torch.clamp(img_org, min=0.0, max=1.0) * 255 + # img_org = img_org.cpu().numpy().astype('uint8') + # cv2.imwrite(f'debug/debug_ref_callback.png', img_org[0][:,:,::-1]) + + return rot + + self.rigid_pose_encoder = RigidPoseEncoder(use_gt_rotation=use_gt_rotation, gt_rotation_callback=gt_rotation_callback) + + def forward(self, masked_inp, inp): + bs = inp.size(0) + + if self.resize: + masked_inp = F.interpolate(masked_inp, size=(256, 256), mode='bilinear') + + rigid_pose = self.rigid_pose_encoder(masked_inp, inp) # rotation [batch_size, 3, 3] and translation [batch_size, 3] + expression_code = self.expression_encoder(masked_inp) # [batch_size, latent_dim] + + # so motion latent is 128 + 9 + 3 = 140? + motion_latent = torch.cat((expression_code, rigid_pose["rotation"].view(bs, -1), rigid_pose["translation"]), dim=1) + + return motion_latent, rigid_pose \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/nonrigid_pose_encoder.py b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/nonrigid_pose_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..587b264c483bafe3cd7c6794f8b602d1269bac46 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/nonrigid_pose_encoder.py @@ -0,0 +1,117 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from model.head_animation.VASA3.building_blocks import USE_BIAS, ResBlock3d, ReshapeTo3DLayer, WSConv3d +import math + +class AdaptiveGroupNorm(nn.GroupNorm): + def __init__(self, num_groups, num_features, eps=1e-5, affine=True): + super(AdaptiveGroupNorm, self).__init__(num_groups, num_features, eps, False) + self.num_features = num_features + + gen_max_channels, gen_embed_size = 512, 4 + self.u = nn.Parameter(torch.empty(num_features, gen_max_channels)) + self.v = nn.Parameter(torch.empty(gen_embed_size ** 2, 2)) + + nn.init.uniform_(self.u, a=-math.sqrt(3 / gen_max_channels), b=math.sqrt(3 / gen_max_channels)) + nn.init.uniform_(self.v, a=-math.sqrt(3 / gen_embed_size ** 2), b=math.sqrt(3 / gen_embed_size ** 2)) + + def forward(self, inputs, condition_emb): + outputs = super(AdaptiveGroupNorm, self).forward(inputs) + + param = self.u[None].matmul(condition_emb).matmul(self.v[None]) + ada_weight, ada_bias = param.split(1, dim=2) + + outputs = outputs * ada_weight[:, :, :, None, None] + ada_bias[:, :, :, None, None] + return outputs + + +class ResBlock3dStar(nn.Module): + def __init__(self, in_channels: int, out_channels: int, num_channels_per_group: int, condition_dim: int): + super().__init__() + + if in_channels != out_channels: + self.skip_layer = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=USE_BIAS) + else: + self.skip_layer = lambda x: x + + self.agn1 = AdaptiveGroupNorm(in_channels // num_channels_per_group, in_channels) + self.conv1 = WSConv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=USE_BIAS) + + self.agn2 = AdaptiveGroupNorm(out_channels // num_channels_per_group, out_channels) + self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=USE_BIAS) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, inp, condition): + x = self.relu(self.agn1(inp, condition)) + x = self.conv1(x) + x = self.relu(self.agn2(x, condition)) + x = self.conv2(x) + x = self.skip_layer(inp) + x + return x + + +class NonrigidPoseEncoder(nn.Module): + def __init__(self, input_dim: int): + super().__init__() + + num_channels_per_group = 32 + app_fea_size = (512, 4, 4) + + self.conv1 = nn.Conv2d(app_fea_size[0], 2048, kernel_size=1, bias=USE_BIAS) + self.reshap3d = ReshapeTo3DLayer(out_depth=4) + self.resblock1 = ResBlock3dStar(512, 256, num_channels_per_group, input_dim) + self.resblock2 = ResBlock3dStar(256, 128, num_channels_per_group, input_dim) + self.resblock3 = ResBlock3dStar(128, 64, num_channels_per_group, input_dim) + self.resblock4 = ResBlock3dStar(64, 32, num_channels_per_group, input_dim) + self.gn = nn.GroupNorm(32 // num_channels_per_group, 32, affine=not USE_BIAS) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv3d(32, 3, kernel_size=3, padding=1, bias=USE_BIAS) + + self.upsample = nn.Upsample(scale_factor=(2, 2, 2), mode="nearest") + self.upsample2 = nn.Upsample(scale_factor=(1, 2, 2), mode="nearest") + + self.extend_layer = nn.Linear(input_dim, app_fea_size[0] * app_fea_size[1] ** 2, bias=USE_BIAS) + self.warp_layer = nn.Conv2d(in_channels=app_fea_size[0], out_channels=app_fea_size[0], kernel_size=(1, 1), bias=USE_BIAS) + + # Greate a meshgrid, which is used for warping calculation from deltas + volumn_size, volumn_depth = 64, 16 + grid_s = torch.linspace(-1, 1, volumn_size) + grid_z = torch.linspace(-1, 1, volumn_depth) + w, v, u = torch.meshgrid(grid_z, grid_s, grid_s) + self.register_buffer('identity_grid', torch.stack([u, v, w], 0)[None]) + + def forward(self, z: torch.Tensor, e_s: torch.Tensor): + + # ALign size to e_s + z_emb = self.extend_layer(z).view(z.size(0), -1, 4, 4) + warp_emb = self.warp_layer((z_emb + e_s) * 0.5) + + batch_size, c, h, w = e_s.shape + condition = warp_emb.view(-1, c, h * w).clone() + + z = self.conv1(warp_emb) + z = self.reshap3d(z) + + z = self.upsample(z) + z = self.resblock1(z, condition) + + z = self.upsample(z) + z = self.resblock2(z, condition) + + z = self.upsample2(z) + z = self.resblock3(z, condition) + + z = self.upsample2(z) + z = self.resblock4(z, condition) + + z = self.gn(z) + z = self.relu(z) + z = self.conv2(z) + deltas = F.tanh(z) + + warp = (self.identity_grid + deltas).permute(0, 2, 3, 4, 1) + + return warp + diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/pipeline_unittest.py b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/pipeline_unittest.py new file mode 100644 index 0000000000000000000000000000000000000000..baad79cb6cdc9ec9beef591913ac4ed234fd5ba9 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/pipeline_unittest.py @@ -0,0 +1,93 @@ +import unittest +import torch +import torchvision +import os +from face_encoder import FaceEncoder +from motion_encoder import MotionEncoder +from flow_estimator import FlowEstimator +from face_generator import Generator + + +class TestPipeline(unittest.TestCase): + def setUp(self): + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.img_size = 512 + self.latent_dim = 128 # this value is not reported in the MagaPorait paper + self.use_gt_rotation = True + self.hopenet_checkpoint_path = "/mnt/weka/real_time_data/model/hopenet_robust_alpha1.pkl" + self.normalize_output=True + # self.resize_motion_encoder_input = False + + ## Epp Part + self.face_encoder = FaceEncoder(self.latent_dim, self.normalize_output).to(self.device) + + ## Emnt part + self.motion_encoder = MotionEncoder(latent_dim=self.latent_dim, size=self.img_size, normalize_output=self.normalize_output, use_gt_rotation=self.use_gt_rotation, hopenet_checkpoint_path=self.hopenet_checkpoint_path).to(self.device) + + # flow estimator + self.flow_estimator = FlowEstimator(latent_dim=self.latent_dim).to(self.device) + + # face generator + self.generator = Generator().to(self.device) + + def test_face_encoder(self): + """Test if the output shape is correct for a given input shape""" + batch_size = 2 + channels = 3 + + # Create a random input tensor + input_tensor = torch.randn(batch_size, channels, self.img_size, self.img_size).to(self.device) + + # Get output from the model + with torch.no_grad(): + [feature_volume, global_descriptor]= self.face_encoder(input_tensor) + print('Feature_volume shape:', feature_volume.shape) + print('Global_descriptor shape:', global_descriptor.shape) + + def test_motion_encoder(self): + """Test if the output shape is correct for a given input shape""" + batch_size = 2 + channels = 3 + + # Create a random input tensor + input_tensor = torch.randn(batch_size, channels, self.img_size, self.img_size).to(self.device) + + # Get output from the model + with torch.no_grad(): + motion_code, rigid_pose = self.motion_encoder(input_tensor) + + print('Motion code shape:', motion_code.shape) + for k, v in rigid_pose.items(): + print(f'Rigid pose output: {k}', v.shape) + + def test_flow_estimator(self): + """Test if the output shape is correct for a given input shape""" + batch_size = 2 + src_dict = {} + src_dict["feature_volume"] = torch.randn(batch_size, 96, 16, 64, 64).to(self.device) + src_dict["global_descriptor"] = torch.randn(batch_size, self.latent_dim).to(self.device) + src_dict["expression_code"] = torch.randn(batch_size, self.latent_dim).to(self.device) + src_dict["rigid_pose"] = {"rotation": torch.randn(batch_size, 3, 3).to(self.device), "translation": torch.randn(batch_size, 3).to(self.device)} + + tgt_dict = {} + tgt_dict["expression_code"] = torch.randn(batch_size, self.latent_dim).to(self.device) + tgt_dict["rigid_pose"] = {"rotation": torch.randn(batch_size, 3, 3).to(self.device), "translation": torch.randn(batch_size, 3).to(self.device)} + + canonical_feature_volume, warped_driving_feature_volume = self.flow_estimator(src_dict, tgt_dict) + + print('Canonical feature volume shape:', canonical_feature_volume.shape) + print('Warped driving feature volume shape:', warped_driving_feature_volume.shape) + + def test_generator(self): + """ + Test if the output shape is correct for a given input shape + """ + batch_size = 2 + warped_driving_feature_volume = torch.randn(batch_size, 96, 16, 64, 64).to(self.device) + predicted_image = self.generator(warped_driving_feature_volume) + print('Predicted image shape:', predicted_image.shape) + + assert predicted_image.shape == (batch_size, 3, self.img_size, self.img_size) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/resnet18.py b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/resnet18.py new file mode 100644 index 0000000000000000000000000000000000000000..8076e8ca7e854d4ee60b5771b5586abb73ca6cd4 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/resnet18.py @@ -0,0 +1,59 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from model.head_animation.VASA1.building_blocks import USE_BIAS, ResBasic + + +class Resnet18(nn.Module): + def __init__(self, input_dim: int, output_dim: int, add_fc: bool, dropout: float = 0.2): + super().__init__() + + num_channels_per_group = 32 + + # The following architecture follows Resnet-18. + self.layers = nn.Sequential( + # Initial layers + nn.Conv2d(input_dim, 64, kernel_size=7, stride=2, padding=3, bias=USE_BIAS), + nn.GroupNorm(64 // num_channels_per_group, 64, affine=not USE_BIAS), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, 2), + # Layer 1 + ResBasic(64, 64, 1, num_channels_per_group), + ResBasic(64, 64, 1, num_channels_per_group), + # Layer 2 + ResBasic(64, 128, 2, num_channels_per_group), + ResBasic(128, 128, 1, num_channels_per_group), + # Layer 3 + ResBasic(128, 256, 2, num_channels_per_group), + ResBasic(256, 256, 1, num_channels_per_group), + # Layer 4 + ResBasic(256, 512, 2, num_channels_per_group), + ResBasic(512, 512, 1, num_channels_per_group), + ) + + if add_fc: + self.final_layers = nn.Sequential( + nn.ReLU(inplace=False), + nn.Conv2d( + in_channels=512, + out_channels=output_dim, + kernel_size=1, + bias=USE_BIAS), + nn.Dropout(p=dropout), + nn.AdaptiveAvgPool2d(4), + nn.Flatten(), + nn.Linear(output_dim*4**2, output_dim, bias=False), + ) + else: + self.final_layers = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + nn.Flatten(), + nn.Linear(512, output_dim, bias=False), + ) + + def forward(self, inp: torch.Tensor): + out = self.layers(inp) + out = self.final_layers(out) + + return out + diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/resnet50.py b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/resnet50.py new file mode 100644 index 0000000000000000000000000000000000000000..850966a1876cc09e8ed1a5b04765cba5b98886f3 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/resnet50.py @@ -0,0 +1,56 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from model.head_animation.VASA1.building_blocks import USE_BIAS, ResBottleneck + + +class Resnet50(nn.Module): + def __init__(self, input_dim: int, output_dim: int): + super().__init__() + + num_channels_per_group = 32 + + # The following architecture follows Resnet-50. + self.layers = nn.Sequential( + # Initial layers + nn.Conv2d(input_dim, 64, kernel_size=7, stride=2, padding=3, bias=USE_BIAS), + nn.GroupNorm(64 // num_channels_per_group, 64, affine=not USE_BIAS), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, 2), + # Layer 1 + ResBottleneck(64, 256, 1, num_channels_per_group), + ResBottleneck(256, 256, 1, num_channels_per_group), + ResBottleneck(256, 256, 1, num_channels_per_group), + # Layer 2 + ResBottleneck(256, 512, 2, num_channels_per_group), + ResBottleneck(512, 512, 1, num_channels_per_group), + ResBottleneck(512, 512, 1, num_channels_per_group), + ResBottleneck(512, 512, 1, num_channels_per_group), + # Layer 3 + ResBottleneck(512, 1024, 2, num_channels_per_group), + ResBottleneck(1024, 1024, 1, num_channels_per_group), + ResBottleneck(1024, 1024, 1, num_channels_per_group), + ResBottleneck(1024, 1024, 1, num_channels_per_group), + ResBottleneck(1024, 1024, 1, num_channels_per_group), + ResBottleneck(1024, 1024, 1, num_channels_per_group), + # Layer 4 + ResBottleneck(1024, 2048, 2, num_channels_per_group), + ResBottleneck(2048, 2048, 1, num_channels_per_group), + ResBottleneck(2048, 2048, 1, num_channels_per_group), + ) + + self.final_layers = nn.Sequential( + nn.ReLU(inplace=False), + nn.Conv2d( + in_channels=2048, + out_channels=512, + kernel_size=1, + bias=USE_BIAS), + nn.AdaptiveAvgPool2d(4), + ) + + def forward(self, inp: torch.Tensor): + out = self.layers(inp) + out = self.final_layers(out) + + return out \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/rigid_pose_encoder.py b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/rigid_pose_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..281fede1d11c93a9f4bf1e6cbc81350e7fb8ba15 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/rigid_pose_encoder.py @@ -0,0 +1,44 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from model.head_animation.VASA3.resnet18 import Resnet18 +from model.head_animation.VASA3.hopenet import euler_to_rotation_matrix, headposeprediction_to_euler + + +def rotation_6d_to_matrix(rotation_6d: torch.Tensor): + # Via Gram-Schmidt orthogonalization. + a1, a2 = rotation_6d[..., 0:3], rotation_6d[..., 3:6] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + + return torch.stack((b1, b2, b3), dim=-2) + + +class RigidPoseEncoder(nn.Module): + def __init__(self, use_gt_rotation, gt_rotation_callback): + super().__init__() + self.use_gt_rotation = use_gt_rotation + if use_gt_rotation: + self.gt_rotation_callback = gt_rotation_callback + self.net = Resnet18(input_dim=3, output_dim=3, add_fc=False) # 3D translation. + else: + self.net = Resnet18(input_dim=3, output_dim=6 + 3, add_fc=False) # 6D rotation and 3D translation. + + def forward(self, masked_inp, inp): + + out = self.net(masked_inp) + + if not self.use_gt_rotation: + rot_out = rotation_6d_to_matrix(out[:, 0:6]) + trans_out = F.tanh(out[:, 6:9]) + else: + rot_out = self.gt_rotation_callback(inp) + trans_out = F.tanh(out[:, 0:3]) + + return { + "rotation": rot_out, + "translation": trans_out, + } \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/util.py b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/util.py new file mode 100644 index 0000000000000000000000000000000000000000..b8278564078dd61d58dffe8cdb3c7d6d8206fda8 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/VASA3/util.py @@ -0,0 +1,25 @@ +import torch + +def compute_grid_points(feature_volume: torch.Tensor): + d, h, w = feature_volume.shape[-3:] + grids = torch.meshgrid( + torch.linspace(-1, 1, d), + torch.linspace(-1, 1, h), + torch.linspace(-1, 1, w), + indexing="ij" + ) + + # NOTE: The 3D coordinates have to correspond to width, height and depth in this order. + # This is what torch.grid_sample expects. So, we flip. + return torch.stack(grids, dim=-1).to(feature_volume.device).flip(-1) + +def compute_2d_grid_points(h, w, device=torch.device("cpu")): + grids = torch.meshgrid( + torch.linspace(-1, 1, h), + torch.linspace(-1, 1, w), + indexing="ij" + ) + + # NOTE: The 3D coordinates have to correspond to width, height and depth in this order. + # This is what torch.grid_sample expects. So, we flip. + return torch.stack(grids, dim=-1).to(device).flip(-1) \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/__init__.py b/tools/visualization_0416/utils/model_0506/model/head_animation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/__pycache__/__init__.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..574085510c07a03b5970e2b261ef836150bcd0e3 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/__pycache__/__init__.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/__pycache__/__init__.cpython-311.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de3d6b5a52d9b5821df0c37b43f9b140b5a9d423 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/__pycache__/__init__.cpython-311.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/__pycache__/__init__.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1fb873879903c6a44a9c86ddeaeccaeef8c28709 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/__pycache__/__init__.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/__pycache__/__init__.cpython-313.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63a4ca6c16919a8aa6567b60aa1e911b559cb66d Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/__pycache__/__init__.cpython-313.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/__pycache__/head_animator.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/__pycache__/head_animator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff8d2fef6ffe941ce7a80a2788be79325bdab466 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/__pycache__/head_animator.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/__pycache__/head_animator.cpython-311.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/__pycache__/head_animator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5de0d192ca8cfcaa946bcb9ebc82a2b36798379e Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/__pycache__/head_animator.cpython-311.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/__pycache__/head_animator.cpython-313.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/__pycache__/head_animator.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1dfeb075fe2838f9148d365bae5b50016eabdc2 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/__pycache__/head_animator.cpython-313.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/__pycache__/head_animator_emop.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/__pycache__/head_animator_emop.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c57d425f51ee8692c7814f229b63e1916ee23b8c Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/__pycache__/head_animator_emop.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/__pycache__/head_animator_emop.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/head_animation/__pycache__/head_animator_emop.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3784adc247afb779de9d31c95e1d3ef7ecb4fad Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/head_animation/__pycache__/head_animator_emop.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/audio_head_animator.py b/tools/visualization_0416/utils/model_0506/model/head_animation/audio_head_animator.py new file mode 100644 index 0000000000000000000000000000000000000000..ec9909ca489b4dd446c756ca10c8f87c7561a871 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/audio_head_animator.py @@ -0,0 +1,215 @@ +import torch +from torch import nn +from diffusers import DDIMScheduler +from omegaconf import OmegaConf + +from model.head_animation.head_animator import HeadAnimatorModule +from utils import instantiate +import time + +class AudioHeadAnimatorModule(HeadAnimatorModule): + def __init__(self, config): + super().__init__(config) + self._get_scheduler() + + def configure_model(self): + super().configure_model() + # self.motion_generator = instantiate(self.config.model.motion_generator, True) + # if self.config.model.motion_gen_ckpt is not None: + # checkpoint = torch.load(self.config.model.motion_gen_ckpt) + # # print(checkpoint.keys()) + # self.motion_generator.load_state_dict(checkpoint, strict=False) + # self.motion_generator.to(dtype=eval(self.config.model.dtype)) + + def _get_scheduler(self): + pass + # self.motion_gen_cfg = OmegaConf.load(self.config.model.motion_generator.class_cfg) + # ddim_sched_kwargs = OmegaConf.to_container(self.motion_gen_cfg.noise_scheduler_kwargs) + # if self.motion_gen_cfg.enable_zero_snr: + # ddim_sched_kwargs.update( + # rescale_betas_zero_snr=True, + # timestep_spacing="trailing", + # # prediction_type="v_prediction", + # ) + # self.train_noise_scheduler = DDIMScheduler( + # **ddim_sched_kwargs + # ) + + def forward(self, source_img, masked_source_img, masked_past_frames, audio_self, audio_other, video_length, style_ref_frames=None): + timings = {} + device = source_img.device + if self.using_hybrid_mask: + t_start = time.time() + style_ref_frames = style_ref_frames.squeeze(0) + # print(f"masked_source_img shape: {masked_source_img.shape}, style_ref_frames shape: {style_ref_frames.shape}") + src_latent, _ = self.motion_encoder(masked_source_img[0:1]) + src_latent = src_latent.repeat(video_length, 1) + ref_latent, _ = self.motion_encoder(style_ref_frames) + # print(f"src_latent shape: {src_latent.shape}, ref_latent shape: {ref_latent.shape}") + style_motion = ref_latent.unsqueeze(0) + timings['motion_encoder'] = time.time() - t_start + print(f"Motion encoder time: {timings['motion_encoder']:.4f}s") + + t_start = time.time() + # tgt_motion_latent = self.motion_generate(masked_past_frames, audio_self, audio_other, video_length) # project target image to reference latent space + speaker_id = torch.zeros(1,1).to(device).long() + audio_self = audio_self.unsqueeze(0) + motion_latent_in = src_latent.unsqueeze(0) + # print(f"motion_latent_in shape: {motion_latent_in.shape}, audio_self shape: {audio_self.shape}, speaker_id shape: {speaker_id.shape}") + tgt_motion_latent = self.motion_generator.inference(audio_self, speaker_id, masked_motion=motion_latent_in, + noise_scheduler=self.train_noise_scheduler, style_motion=style_motion) + # print(f"tgt_motion_latent shape: {tgt_motion_latent.shape}", src_latent.shape) + timings['motion_generate'] = time.time() - t_start + print(f"Motion generate time: {timings['motion_generate']:.4f}s") + tgt_motion_latent = tgt_motion_latent.view(-1, tgt_motion_latent.size(-1)) + + + + t_start = time.time() + tgt_latent = self.flow_estimator(src_latent, tgt_motion_latent) # navigate source to target in reference latent space + timings['flow_estimator'] = time.time() - t_start + print(f"Flow estimator time: {timings['flow_estimator']:.4f}s") + + t_start = time.time() + face_feat = self.face_encoder(source_img) + timings['face_encoder'] = time.time() - t_start + print(f"Face encoder time: {timings['face_encoder']:.4f}s") + + t_start = time.time() + recon_imgs = self.face_generator(tgt_latent, face_feat) + timings['face_generator'] = time.time() - t_start + print(f"Face generator time: {timings['face_generator']:.4f}s") + else: + t_start = time.time() + tgt_latent = self.motion_generate(masked_past_frames, audio_self, audio_other, video_length) # project target image to reference latent space + timings['motion_generate'] = time.time() - t_start + print(f"Motion generate time: {timings['motion_generate']:.4f}s") + + t_start = time.time() + src_latent = self.motion_encoder(source_img) # project source image to reference latent space + timings['motion_encoder'] = time.time() - t_start + print(f"Motion encoder time: {timings['motion_encoder']:.4f}s") + + t_start = time.time() + tgt_latent = self.flow_estimator(src_latent, tgt_latent) # navigate source to target in reference latent space + timings['flow_estimator'] = time.time() - t_start + print(f"Flow estimator time: {timings['flow_estimator']:.4f}s") + + t_start = time.time() + face_feat = self.face_encoder(source_img) + timings['face_encoder'] = time.time() - t_start + print(f"Face encoder time: {timings['face_encoder']:.4f}s") + + t_start = time.time() + recon_imgs = self.face_generator(tgt_latent, face_feat) + timings['face_generator'] = time.time() - t_start + print(f"Face generator time: {timings['face_generator']:.4f}s") + + total_time = sum(timings.values()) + print("\nModel Component Timings:") + for component, t in timings.items(): + print(f"{component}: {t:.4f}s ({(t/total_time)*100:.1f}%)") + print(f"Total time: {total_time:.4f}s") + + return recon_imgs + + def _step(self, batch): + + optimizer_g, optimizer_d = self.optimizers() + + ## train generator + self.toggle_optimizer(optimizer_g) + masked_target_vid = batch['pixel_values_vid'] # this is a video batch: [B, T, C, H, W] + masked_past_frames = batch['pixel_values_past_frames'] + masked_target_vid = torch.cat([masked_target_vid, masked_past_frames], dim=1) + masked_ref_img = batch['pixel_values_ref_img'] + + ref_img_original = batch['ref_img_original'] + target_vid_original = batch['pixel_values_vid_original'] + past_frames = batch['pixel_values_past_frames_original'] + target_vid_original = torch.cat([target_vid_original, past_frames], dim=1) + + # construct ref-tgt pairs + masked_ref_img = masked_ref_img[:,None].repeat(1, masked_target_vid.size(1), 1, 1, 1) + masked_ref_img = rearrange(masked_ref_img, "b t c h w -> (b t) c h w") + masked_target_vid = rearrange(masked_target_vid, "b t c h w -> (b t) c h w") + masked_past_frames = rearrange(masked_past_frames, "b t c h w -> (b t) c h w") + + ref_img_original = ref_img_original[:,None].repeat(1, target_vid_original.size(1), 1, 1, 1) + ref_img_original = rearrange(ref_img_original, "b t c h w -> (b t) c h w") + target_vid_original = rearrange(target_vid_original, "b t c h w -> (b t) c h w") + + audio_self = batch['target_wav_fea'].to(self.device) + # TODO(wei): use audio other + audio_other = torch.zeros_like(audio_self) + # get reconstructed image + predicted_img = self.forward(ref_img_original, masked_ref_img, masked_past_frames, audio_self, audio_other) + + if self.l_w_face > 0: + eye_mouth_mask_vid = batch['eye_mouth_mask_vid'] + eye_mouth_mask_past_frames = batch['eye_mouth_mask_past_frames'] + face_mask = torch.cat([eye_mouth_mask_vid, eye_mouth_mask_past_frames], dim=1) + face_mask = rearrange(face_mask, "b t c h w -> (b t) c h w") + + loss_dict = self.compute_loss(target_vid_original, predicted_img, face_mask) + + else: + loss_dict = self.compute_loss(target_vid_original, predicted_img) + + if self.l_w_gan > 0: + # adversarial loss + pred_label = self.discriminator(predicted_img).reshape(-1) + g_loss = self.l_w_gan * self.g_nonsaturating_loss(pred_label) + + loss_dict['loss'] += g_loss + loss_dict['g_loss'] = g_loss + + self.manual_backward(loss_dict['loss']) + optimizer_g.step() + optimizer_g.zero_grad() + self.untoggle_optimizer(optimizer_g) + + # import pdb; pdb.set_trace() + + ## train discriminator + self.toggle_optimizer(optimizer_d) + + real_img_pred = self.discriminator(target_vid_original) + recon_img_pred = self.discriminator(predicted_img.detach()) + + d_loss = self.d_nonsaturating_loss(recon_img_pred, real_img_pred) + + self.manual_backward(d_loss) + optimizer_d.step() + optimizer_d.zero_grad() + self.untoggle_optimizer(optimizer_d) + + self.log("d_loss", d_loss, prog_bar=True) + + else: + self.manual_backward(loss_dict['loss']) + optimizer_g.step() + optimizer_g.zero_grad() + self.untoggle_optimizer(optimizer_g) + + for k, v in loss_dict.items(): + self.log(k, v, prog_bar=True) + + + if False: + checkpoint = torch.load(self.config.model.pretrained_ckpt)["state_dict"] + (self.motion_encoder.convs[0][0].weight - checkpoint['motion_encoder.convs.0.0.weight']).sum() + + # check vgg16 weight + from torchvision import models + vgg_model = models.vgg19(pretrained=True).cuda() + vgg_params = [] + for p in vgg_model.parameters(): + vgg_params.append(p) + + (self.criterion_vgg.vgg.slice1[0].weight - vgg_params[0]).mean() + (self.criterion_vgg.vgg.slice2[0].weight - vgg_params[2]).mean() + import pdb; pdb.set_trace() + + + return loss_dict \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/audio_head_animator_test.py b/tools/visualization_0416/utils/model_0506/model/head_animation/audio_head_animator_test.py new file mode 100644 index 0000000000000000000000000000000000000000..446538d48dd41fee5abc75ed36db56919301efb1 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/audio_head_animator_test.py @@ -0,0 +1,227 @@ +import unittest +import torch +import torch.nn as nn +from model.head_animation.audio_head_animator import AudioHeadAnimatorModule +from omegaconf import OmegaConf +import sys +from unittest.mock import MagicMock, patch + +# Mock classes +class MockImagePyramide(nn.Module): + def __init__(self, scales, num_channels): + super().__init__() + self.scales = scales + self.num_channels = num_channels + self.downs = nn.ModuleDict({str(scale).replace('.', '-'): nn.Identity() for scale in scales}) + + def forward(self, x): + return {'prediction_1.0': x} + + def cuda(self): + return self + +class MockVgg19(nn.Module): + def __init__(self, requires_grad=False): + super().__init__() + self.slice1 = nn.Identity() + self.mean = nn.Parameter(torch.zeros(1, 3, 1, 1)) + self.std = nn.Parameter(torch.ones(1, 3, 1, 1)) + + def forward(self, x): + return [torch.randn_like(x) for _ in range(4)] + + def cuda(self): + return self + + def eval(self): + return self + +class MockVGGLoss(nn.Module): + def __init__(self): + super().__init__() + self.scales = [1, 0.5, 0.25] + self.pyramid = MockImagePyramide(self.scales, 3) + self.mask_scales = [1, 0.5, 0.25, 0.125, 0.0625, 0.0625/2] + self.mask_pyramid = MockImagePyramide(self.mask_scales, 1) + self.vgg = MockVgg19() + self.weights = (10, 10, 10, 10) + + def forward(self, img_recon, img_real, facial_mask=None): + return torch.tensor(0.0), {} + + def cuda(self): + return self + +class MockMotionGenerator(nn.Module): + def __init__(self, size=512, latent_dim=512): + super().__init__() + self.size = size + self.latent_dim = latent_dim + + def forward(self, hidden_states, encoder_hidden_states, audio_feature, face_mask, timestep): + batch_size = hidden_states.shape[0] + return type('MockOutput', (), {'sample': torch.randn(batch_size, self.latent_dim, 1, self.size//8, self.size//8)})() + + def cuda(self): + return self + + def load_from_checkpoint(self, checkpoint_path, config=None): + pass + + def to(self, dtype=None): + return self + +# Create a mock loss module +mock_loss_module = MagicMock() +mock_loss_module.VGGLoss = MockVGGLoss +mock_loss_module.ImagePyramide = MockImagePyramide +mock_loss_module.Vgg19 = MockVgg19 +mock_loss_module.AntiAliasInterpolation2d = nn.Identity +sys.modules['model.head_animation.LIA.loss'] = mock_loss_module + +# Create a mock motion generator module +mock_motion_generator_module = MagicMock() +mock_motion_generator_module.MotionGenerator = MockMotionGenerator +sys.modules['model.head_animation.LIA.motion_generator'] = mock_motion_generator_module + +# Mock torch.cuda to prevent CUDA initialization +class MockCuda: + @staticmethod + def is_available(): + return False + + @staticmethod + def _lazy_init(): + pass + +torch.cuda = MockCuda + +# Mock torch.nn.Module.cuda +original_cuda = nn.Module.cuda +def mock_cuda(self, device=None): + return self +nn.Module.cuda = mock_cuda + +class TestAudioHeadAnimatorModule(unittest.TestCase): + def setUp(self): + # Create a minimal config for testing + self.config = OmegaConf.create({ + 'model': { + 'dtype': 'torch.float32', + 'using_hybrid_mask': True, + 'pretrained_ckpt': None, + 'motion_generator': { + 'module_name': 'model.head_animation.LIA.motion_generator', + 'class_name': 'MotionGenerator', + 'size': 512, + 'latent_dim': 512, + 'pretrained_ckpt': 'path/to/mock/checkpoint' + }, + 'motion_encoder': { + 'module_name': 'model.head_animation.LIA.motion_encoder', + 'class_name': 'MotionEncoder', + 'latent_dim': 512, + 'size': 512 + }, + 'flow_estimator': { + 'module_name': 'model.head_animation.LIA.flow_estimator', + 'class_name': 'FlowEstimator', + 'latent_dim': 512, + 'motion_space': 64 + }, + 'face_encoder': { + 'module_name': 'model.head_animation.LIA.face_encoder', + 'class_name': 'FaceEncoder', + 'output_channels': 512 + }, + 'face_generator': { + 'module_name': 'model.head_animation.LIA.face_generator', + 'class_name': 'FaceGenerator', + 'size': 512, + 'latent_dim': 512 + }, + 'discriminator': { + 'module_name': 'model.head_animation.LIA.discriminator', + 'class_name': 'Discriminator', + 'size': 512 + } + }, + 'inference': { + 'num_inference_steps': 10 + }, + 'data': { + 'train_width': 512, + 'train_height': 512 + }, + 'loss': { + 'l_w_recon': 1.0, + 'l_w_vgg': 1e-3, + 'l_w_gan': 1e-5, + 'l_w_face': 1.0 + }, + 'optimizer': { + 'lr': 2e-4, + 'adam_beta1': 0.9, + 'adam_beta2': 0.999, + 'adam_epsilon': 1e-8, + 'weight_decay': 0 + } + }) + + # Create module with mocked components + self.module = AudioHeadAnimatorModule(self.config) + + # Mock the required components + self.module.motion_generator = MockMotionGenerator() + self.module.motion_encoder = nn.Identity() + self.module.flow_estimator = nn.Identity() + self.module.face_encoder = nn.Identity() + self.module.face_generator = nn.Identity() + self.module.discriminator = nn.Identity() + self.module.vae = type('MockVAE', (), { + 'encode': lambda x: type('MockDist', (), {'latent_dist': type('MockSample', (), {'sample': lambda: torch.randn(1, 4, 1, 32, 32)})}), + 'config': type('MockConfig', (), {'scaling_factor': 0.18215, 'spatial_compression_ratio': 8}), + 'dtype': torch.float32 + })() + + def test_initialization(self): + self.assertIsNotNone(self.module) + self.assertIsInstance(self.module.config, OmegaConf) + self.assertTrue(self.module.using_hybrid_mask) + + def test_forward(self): + # Create mock inputs + source_img = torch.randn(1, 3, 256, 256) + masked_source_img = torch.randn(1, 3, 256, 256) + audio_self = torch.randn(1, 16000) # Mock audio input + audio_other = torch.randn(1, 16000) # Mock audio input + + # Test forward pass + output = self.module.forward(source_img, masked_source_img, audio_self, audio_other) + self.assertIsInstance(output, torch.Tensor) + + def test_motion_generator_method(self): + # Create mock audio inputs + audio_self = torch.randn(1, 16000) + audio_other = torch.randn(1, 16000) + masked_past_frames = torch.randn(1, 3, 256, 256) + + # Mock the noise scheduler + self.module.train_noise_scheduler = type('MockScheduler', (), { + 'set_timesteps': lambda num_steps, device: None, + 'timesteps': torch.tensor([1000, 900, 800]), + 'order': 1, + 'step': lambda noise_pred, t, latent: type('MockStep', (), {'prev_sample': torch.randn(1, 4, 1, 32, 32)})() + })() + + # Test motion generation + latent = self.module.generate_motion(masked_past_frames, audio_self, audio_other) + self.assertIsInstance(latent, torch.Tensor) + self.assertEqual(len(latent.shape), 5) # Should be B, C, F, H, W format + + def tearDown(self): + # Restore original cuda method + nn.Module.cuda = original_cuda + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/head_animator.py b/tools/visualization_0416/utils/model_0506/model/head_animation/head_animator.py new file mode 100644 index 0000000000000000000000000000000000000000..9dd765b4fd3a2b5cd43fbeefd7f637b3b0b82e8f --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/head_animator.py @@ -0,0 +1,571 @@ +import numpy as np +import torch +from torch import nn +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim.lr_scheduler import LambdaLR +import sys +from pathlib import Path +from einops import rearrange +import torch.nn.functional as F +import math +import lpips +import numpy as np +from skimage.metrics import structural_similarity +from torch.optim.lr_scheduler import CosineAnnealingLR, ConstantLR, SequentialLR +from time import time + +sys.path.append(str(Path(__file__).parent.parent.parent)) +from model.lightning.base_modules import BaseModule +from utils import instantiate + + +class HeadAnimatorModule(BaseModule): + def __init__(self, config): + super().__init__(config) + + self.validation_step_outputs = [] + self.output_dir = config.model.get("output_dir", "outputs") + + self.config = config + self.using_hybrid_mask = config.model.get("using_hybrid_mask", True) + self.using_seg = config.model.get("using_seg", False) + + print(f'Using Hybird Mask: {self.using_hybrid_mask}') + print(f'Using Segmentation: {self.using_seg}') + print(f'Results will be saved to: {self.output_dir}') + + self.criterion_recon = nn.L1Loss() + self.criterion_masked_face_l1 = nn.L1Loss(reduction='none') + + if self.config.model.get('sdm_loss', None) is not None and self.config.loss.get('l_w_sdm', 0) > 0: + self.criterion_sdm = instantiate(self.config.model.sdm_loss) + + if self.config.model.get('vgg_loss', None) is not None: + self.criterion_vgg = instantiate(self.config.model.vgg_loss) + + if self.config.get('loss', None) is not None: + self.l_w_recon = self.config.loss.get("l_w_recon", 0) + self.l_w_vgg = self.config.loss.get("l_w_vgg", 0) + self.l_w_face = self.config.loss.get("l_w_face", 0) + self.l_w_gan = self.config.loss.get("l_w_gan", 0) + self.l_w_face_l1 = self.config.loss.get("l_w_face_l1", 0) + self.l_w_gaze = self.config.loss.get("l_w_gaze", 0) + self.l_w_foreground = self.config.loss.get("l_w_foreground", 0) + self.l_w_local = self.config.loss.get("l_w_local", 0) + self.l_w_sdm = self.config.loss.get("l_w_sdm", 0) + self.l_w_ref_consistency = self.config.loss.get("l_w_ref_consistency", 0) + self.add_gan_step = self.config.loss.get("add_gan_step", 0) + else: + self.l_w_recon = 1 + self.l_w_vgg = 0 + self.l_w_face = 0 + self.l_w_gan = 0 + self.l_w_face_l1 = 0 + self.l_w_gaze = 0 + self.l_w_foreground = 0 + self.l_w_local = 0 + self.l_w_sdm = 0 + self.l_w_ref_consistency = 0 + self.add_gan_step = 0 + self.step_cnt = 0 + + self.face_parsing_en = self.l_w_foreground > 0 or self.l_w_local > 0 + + # support GAN training & normal training + self.automatic_optimization = False + + if 'VASA' in self.config.model.motion_encoder.module_name: + self.model_name = 'VASA' + + if 'LIA' in self.config.model.motion_encoder.module_name: + self.model_name = 'LIA' + + print(f'Using {self.model_name} for Head Animation') + + def configure_model(self): + config = self.config + self.motion_encoder = instantiate(config.model.motion_encoder) + self.flow_estimator = instantiate(config.model.flow_estimator) + self.face_generator = instantiate(config.model.face_generator) + self.face_encoder = instantiate(config.model.face_encoder) + + if self.config.get('loss', None) is not None: + if config.loss.l_w_gan > 0: + self.discriminator = instantiate(config.model.discriminator) + + if self.config.loss.get('l_w_gaze', None) is not None and config.loss.l_w_gaze > 0: + self.gaze_estimator = instantiate(config.model.gaze_estimator) + + if self.config.loss.get('l_w_foreground', None) is not None and config.loss.l_w_foreground > 0 or \ + self.config.loss.get('l_w_local', None) is not None and config.loss.l_w_local > 0 or \ + self.config.model.get('using_seg', None) is not None and config.model.using_seg: + self.face_parser = instantiate(config.model.face_parser) + + if self.config.model.get('pretrained_ckpt', None) is not None: + checkpoint = torch.load(self.config.model.pretrained_ckpt)["state_dict"] + ckpt = {} + for k, v in checkpoint.items(): + if 'motion_encoder' in k: + ckpt[k.replace('motion_encoder.', '')] = v + self.motion_encoder.load_state_dict(ckpt, strict=True) + + ckpt = {} + for k, v in checkpoint.items(): + if 'flow_estimator' in k: + ckpt[k.replace('flow_estimator.', '')] = v + self.flow_estimator.load_state_dict(ckpt, strict=True) + + ckpt = {} + for k, v in checkpoint.items(): + if 'face_generator' in k: + ckpt[k.replace('face_generator.', '')] = v + self.face_generator.load_state_dict(ckpt, strict=True) + + ckpt = {} + for k, v in checkpoint.items(): + if 'face_encoder' in k: + ckpt[k.replace('face_encoder.', '')] = v + self.face_encoder.load_state_dict(ckpt, strict=True) + + def motion_encode(self, source_img): + latent_code, pyramid_feat = self.motion_encoder(source_img) + return latent_code, pyramid_feat + + def forward(self, source_img, target_img, masked_source_img, masked_target_img, batch_idx=None): + + face_feat = self.face_encoder(source_img) # get source appearance feature + + source_motion_img = masked_source_img if self.using_hybrid_mask else source_img + tgt_motion_img = masked_target_img if self.using_hybrid_mask else target_img + src_latent, _ = self.motion_encoder(source_motion_img) # project target image to reference latent space + tgt_latent, _ = self.motion_encoder(tgt_motion_img) # project source image to reference latent space + + tgt_latent_from_src = self.flow_estimator(src_latent, tgt_latent) # navigate source to target in reference latent space + recon_img = self.face_generator(tgt_latent_from_src, face_feat) + + + out_dict = {} + out_dict['recon_img'] = recon_img + out_dict['tgt_latent'] = tgt_latent + out_dict['src_latent'] = src_latent + out_dict['face_mask'] = None + + if self.l_w_ref_consistency > 0: + tgt_fea = self.face_encoder(target_img) + out_dict['tgt_fea'] = tgt_fea.detach() # avoid to optimize face_encoder + out_dict['src_fea'] = face_feat.detach() # avoid to optimize face_encoder + + return out_dict + + def compute_base_loss(self, img_target, img_target_recon, face_mask=None, tgt_parsing_map_dict=None): + + l1_loss = self.l_w_recon * self.criterion_recon(img_target_recon, img_target) + + # Perceptual Loss + if self.l_w_vgg > 0: + vgg_loss, vgg_loss_dict = self.criterion_vgg(img_target_recon, img_target) + vgg_loss = self.l_w_vgg * vgg_loss.mean() + else: + vgg_loss = torch.zeros(1).to(self.device) + + # Facial Experssion Perceptual Loss + if face_mask is not None and self.l_w_face > 0: + face_loss, face_vgg_loss_dict = self.criterion_vgg(img_target_recon, img_target, face_mask) + face_loss = self.l_w_face * face_loss.mean() + else: + face_loss = torch.zeros(1).to(self.device) + + if face_mask is not None and self.l_w_face_l1 > 0: + face_l1_loss = self.criterion_masked_face_l1(img_target_recon*face_mask, img_target*face_mask) + face_l1_loss = face_l1_loss.view(face_mask.size(0), -1).sum(-1) / face_mask.view(face_mask.size(0), -1).sum(-1) + face_l1_loss = self.l_w_face_l1 * face_l1_loss.mean() + else: + face_l1_loss = torch.zeros(1).to(self.device) + + gaze_loss = torch.zeros(1).to(self.device) + + if self.face_parsing_en: + assert tgt_parsing_map_dict is not None + face_mask = tgt_parsing_map_dict['face_mask'] + face_body = tgt_parsing_map_dict['face_body'] + cloth_mask = tgt_parsing_map_dict['cloth_mask'] + mouth = tgt_parsing_map_dict['mouth'] + eye = tgt_parsing_map_dict['eye'] + ear = tgt_parsing_map_dict['ear'] + + if self.l_w_foreground > 0: + human_mask = face_body + cloth_mask + img_target_human = img_target * human_mask + foreground_loss, _ = self.criterion_vgg(img_target_recon, img_target_human) + foreground_loss = self.l_w_foreground * foreground_loss.mean() + else: + foreground_loss = torch.zeros(1).to(self.device) + + if self.l_w_local > 0: + eye_mouth_ear_mask = eye + mouth + ear + img_target_local = img_target * eye_mouth_ear_mask + img_target_recon_local = img_target_recon * eye_mouth_ear_mask + + local_loss, _ = self.criterion_vgg(img_target_recon_local, img_target_local) + local_loss = self.l_w_local * local_loss.mean() + else: + local_loss = torch.zeros(1).to(self.device) + else: + foreground_loss = torch.zeros(1).to(self.device) + local_loss = torch.zeros(1).to(self.device) + + return vgg_loss, l1_loss, face_loss, face_l1_loss, gaze_loss, foreground_loss, local_loss + + def compute_loss(self, img_target, out_dict, tgt_parsing_map_dict=None): + vgg_loss, l1_loss, face_loss, face_l1_loss, gaze_loss, foreground_loss, local_loss = self.compute_base_loss(img_target, out_dict['recon_img'], out_dict['face_mask'], tgt_parsing_map_dict) + + loss = vgg_loss + l1_loss + face_loss + face_l1_loss + gaze_loss + foreground_loss + local_loss + loss_dict = {'loss': loss, 'l1_loss': l1_loss, 'face_l1_loss': face_l1_loss, 'vgg_loss': vgg_loss, + 'gaze_loss': gaze_loss, 'face_loss': face_loss, 'foreground_loss': foreground_loss, 'local_loss': local_loss,} + + if self.l_w_sdm > 0: + sdm_loss = self.l_w_sdm * self.criterion_sdm(out_dict['src_latent'], out_dict['tgt_latent']) + loss_dict['loss'] += sdm_loss + loss_dict['sdm_loss'] = sdm_loss + + if self.l_w_ref_consistency > 0: + ref_consistency_loss = self.l_w_ref_consistency * self.ref_consistency_loss(out_dict['src_latent'], out_dict['tgt_latent'], out_dict['src_fea'], out_dict['tgt_fea']) + loss_dict['loss'] += ref_consistency_loss + loss_dict['ref_consistency_loss'] = ref_consistency_loss + + return loss_dict + + def ref_consistency_loss(self, src_latent, tgt_latent, src_fea, tgt_fea): + ref_img_from_src = self.face_generator(src_latent, src_fea) + ref_img_from_tgt = self.face_generator(tgt_latent, tgt_fea) + + return F.l1_loss(ref_img_from_src, ref_img_from_tgt) + + def g_nonsaturating_loss(self, fake_pred): + return F.softplus(-fake_pred).mean() + + def d_nonsaturating_loss(self, fake_pred, real_pred): + real_loss = F.softplus(-real_pred) + fake_loss = F.softplus(fake_pred) + + return real_loss.mean() + fake_loss.mean() + + def prepare_datapair(self, batch): + # when not zero_to_one, all of bellow is [-1, 1] + masked_target_vid = batch['pixel_values_vid'] # this is a video batch: [B, T, C, H, W] + masked_past_frames = batch['pixel_values_past_frames'] + masked_target_vid = torch.cat([masked_past_frames, masked_target_vid], dim=1) + masked_ref_img = batch['pixel_values_ref_img'] + + # when not zero_to_one, all of bellow is [-1, 1] + ref_img_original = batch['ref_img_original'] + target_vid_original = batch['pixel_values_vid_original'] + past_frames = batch['pixel_values_past_frames_original'] + target_vid_original = torch.cat([past_frames, target_vid_original], dim=1) + + # construct ref-tgt pairs + masked_ref_img = masked_ref_img[:,None].repeat(1, masked_target_vid.size(1), 1, 1, 1) + masked_ref_img = rearrange(masked_ref_img, "b t c h w -> (b t) c h w") + masked_target_vid = rearrange(masked_target_vid, "b t c h w -> (b t) c h w") + + ref_img_original = ref_img_original[:,None].repeat(1, target_vid_original.size(1), 1, 1, 1) + ref_img_original = rearrange(ref_img_original, "b t c h w -> (b t) c h w") + target_vid_original = rearrange(target_vid_original, "b t c h w -> (b t) c h w") + + ref_img_original = ref_img_original.to(self.device) + target_vid_original = target_vid_original.to(self.device) + masked_ref_img = masked_ref_img.to(self.device) + masked_target_vid = masked_target_vid.to(self.device) + + return ref_img_original, target_vid_original, masked_ref_img, masked_target_vid + + def _step(self, batch, batch_idx): + # get source-target image pair + ref_img_original, target_vid_original, masked_ref_img, masked_target_vid = self.prepare_datapair(batch) + + if self.using_seg or self.face_parsing_en: + # get human parsing maps + tgt_parsing_map_dict = self.face_parser.forward(target_vid_original) + + if self.using_seg: + src_parsing_map_dict = self.face_parser.forward(ref_img_original) + src_face_body = src_parsing_map_dict['face_body'] + src_cloth_mask = src_parsing_map_dict['cloth_mask'] + src_human_mask = src_face_body + src_cloth_mask + ref_img_original = ref_img_original * src_human_mask + + tgt_face_body = tgt_parsing_map_dict['face_body'] + tgt_cloth_mask = tgt_parsing_map_dict['cloth_mask'] + tgt_human_mask = tgt_face_body + tgt_cloth_mask + target_vid_original = target_vid_original * tgt_human_mask + else: + tgt_parsing_map_dict = None + + out_dict = self.forward(ref_img_original, target_vid_original, masked_ref_img, masked_target_vid, batch_idx) + loss_dict = self.compute_loss(target_vid_original, out_dict, tgt_parsing_map_dict=tgt_parsing_map_dict) + + if self.l_w_gan > 0: + optimizer_g, optimizer_d = self.optimizers() + + self.discriminator.requires_grad_(False) + pred_label = self.discriminator(predicted_img).reshape(-1) + g_loss = self.l_w_gan * self.g_nonsaturating_loss(pred_label) + + if self.step_cnt >= self.add_gan_step: + loss_dict['loss'] += g_loss + loss_dict['g_loss'] = g_loss + + if is_grad_step: + optimizer_g.zero_grad() + self.manual_backward(loss_dict['loss']) + optimizer_g.step() + + ## train discriminator + self.discriminator.requires_grad_(True) + real_img_pred = self.discriminator(target_vid_original) + predicted_img = self.forward(ref_img_original, target_vid_original, masked_ref_img, masked_target_vid) + recon_img_pred = self.discriminator(predicted_img.detach()) + # import pdb; pdb.set_trace() + + d_loss = self.d_nonsaturating_loss(recon_img_pred, real_img_pred) + + real_probs = torch.sigmoid(real_img_pred) + fake_probs = torch.sigmoid(recon_img_pred) + correct_real = (real_probs >= 0.5).float() + correct_fake = (fake_probs < 0.5).float() + total_correct = correct_real.sum() + correct_fake.sum() + total_samples = real_probs.numel() + fake_probs.numel() + accuracy = total_correct / total_samples + real_acc = correct_real.sum() / total_samples + fake_acc = correct_fake.sum() / total_samples + + optimizer_d.zero_grad() + self.manual_backward(d_loss) + optimizer_d.step() + + loss_dict['d_loss'] = d_loss + # loss_dict['d_acc'] = accuracy + loss_dict['d_real_acc'] = real_acc + loss_dict['d_fake_acc'] = fake_acc + self.step_cnt += 1 + else: + optimizer_g = self.optimizers() + lr_scheduler = self.lr_schedulers() + self.set_module_eval_train_state(True) + self.toggle_optimizer(optimizer_g) + # get reconstructed image + predicted_img = self.forward(ref_img_original, target_vid_original, masked_ref_img, masked_target_vid) + + if self.l_w_face > 0 or self.l_w_face_l1 > 0: + eye_mouth_mask_vid = batch['eye_mouth_mask_vid'] + eye_mouth_mask_past_frames = batch['eye_mouth_mask_past_frames'] + face_mask = torch.cat([eye_mouth_mask_vid, eye_mouth_mask_past_frames], dim=1) + face_mask = rearrange(face_mask, "b t c h w -> (b t) c h w") + + loss_dict = self.compute_loss(target_vid_original, predicted_img, face_mask) + + else: + loss_dict = self.compute_loss(target_vid_original, predicted_img) + + if is_grad_step: + optimizer_g.zero_grad() + self.manual_backward(loss_dict['loss']) + if is_grad_step: + optimizer_g.step() + lr_scheduler.step() + self.untoggle_optimizer(optimizer_g) + + lr_g = optimizer_g.param_groups[0]['lr'] + self.log("learning_rate_g", lr_g) + + for k, v in loss_dict.items(): + if v > 0: + self.log(k, v, prog_bar=True) + + return loss_dict + + def training_step(self, batch, batch_idx): + + loss_dict = self._step(batch, batch_idx) + + # log current learning rate + if self.l_w_gan > 0: + optimizer_g, optimizer_d = self.optimizers() + else: + optimizer_g = self.optimizers() + current_lr = optimizer_g.param_groups[0]['lr'] + self.log('lr', current_lr, on_step=True, on_epoch=False, prog_bar=True) + + return loss_dict['loss'] + + def validation_step(self, batch, batch_idx): + if self.trainer.global_step > 1: + # get source-target image pair + ref_img_original, target_vid_original, masked_ref_img, masked_target_vid = self.prepare_datapair(batch) + + if self.using_seg or self.face_parsing_en: + # get human parsing maps + tgt_parsing_map_dict = self.face_parser.forward(target_vid_original) + + if self.using_seg: + src_parsing_map_dict = self.face_parser.forward(ref_img_original) + src_face_body = src_parsing_map_dict['face_body'] + src_cloth_mask = src_parsing_map_dict['cloth_mask'] + src_human_mask = src_face_body + src_cloth_mask + ref_img_original = ref_img_original * src_human_mask + + tgt_face_body = tgt_parsing_map_dict['face_body'] + tgt_cloth_mask = tgt_parsing_map_dict['cloth_mask'] + tgt_human_mask = tgt_face_body + tgt_cloth_mask + target_vid_original = target_vid_original * tgt_human_mask + else: + tgt_parsing_map_dict = None + + # get reconstructed image + with torch.no_grad(): + out_dict = self.forward(ref_img_original, target_vid_original, masked_ref_img, masked_target_vid, batch_idx) + loss_dict = self.compute_loss(target_vid_original, out_dict, tgt_parsing_map_dict=tgt_parsing_map_dict) + + if target_vid_original.min() < 0: + predicted_img = (out_dict['recon_img'] + 1) / 2 + target_vid_original = (target_vid_original + 1) / 2 + + loss_dict['l1_loss'] = F.l1_loss(predicted_img, target_vid_original).mean() + predicted_img = (predicted_img * 255).permute(0, 2, 3, 1).cpu().numpy() + target_vid_original = (target_vid_original * 255).permute(0, 2, 3, 1).cpu().numpy() + + psnr_list = [] + ssim_list = [] + for tmp_i in range(len(predicted_img)): + psnr = lpips.psnr(predicted_img[tmp_i], target_vid_original[tmp_i], peak=255.) + ssim = structural_similarity(predicted_img[tmp_i], target_vid_original[tmp_i], data_range=255, multichannel=True, channel_axis=2) + psnr_list.append(psnr) + ssim_list.append(ssim) + avg_psnr = np.mean(psnr_list) + avg_ssim = np.mean(ssim_list) + + loss_dict['val_psnr'] = avg_psnr + loss_dict['val_ssim'] = avg_ssim + + self.validation_step_outputs.append(loss_dict) + + return loss_dict + + def on_validation_epoch_end(self): + if not hasattr(self, 'validation_step_outputs') or len(self.validation_step_outputs) == 0: + return + + # get all metrics + outputs = self.validation_step_outputs + avg_recon_loss = torch.stack([x['l1_loss'] for x in outputs]).mean() + avg_foreground_loss = torch.stack([x['foreground_loss'] for x in outputs]).mean() + avg_local_loss = torch.stack([x['local_loss'] for x in outputs]).mean() + + avg_psnr = np.mean([x['val_psnr'] for x in outputs]) + avg_ssim = np.mean([x['val_ssim'] for x in outputs]) + + # log metrics + self.log('val_recon_loss', avg_recon_loss, prog_bar=True) + self.log('val_psnr', avg_psnr, prog_bar=True) + self.log('val_ssim', avg_ssim, prog_bar=True) + + if self.face_parsing_en: + self.log('val_foreground_loss', avg_foreground_loss, prog_bar=True) + self.log('val_local_loss', avg_local_loss, prog_bar=True) + + if self.global_rank == 0: + log_file = f"{self.output_dir}/validation_metrics.txt" + current_epoch = self.current_epoch + global_step = self.global_step + log_content = ( + f"Epoch: {current_epoch}, " + f"Step: {global_step}, " + f"Recon Loss: {avg_recon_loss.item():.4f}, " + f"PSNR: {avg_psnr:.4f}, " + f"SSIM: {avg_ssim:.4f}" + ) + if self.face_parsing_en: + log_content += ( + f", Foreground Loss: {avg_foreground_loss.item():.4f}, " + f"Local Loss: {avg_local_loss.item():.4f}" + ) + + with open(log_file, "a") as f: + f.write("*" * 50 + "\n") + f.write(log_content + "\n") + + # clear cache for next epoch + self.validation_step_outputs.clear() + + def configure_optimizers(self): + params_to_update = list(self.motion_encoder.parameters()) + list(self.flow_estimator.parameters()) + \ + list(self.face_encoder.parameters()) + list(self.face_generator.parameters()) + params_to_update = [p for p in params_to_update if p.requires_grad] + params_name_to_update = [name for name, p in self.named_parameters() if p.requires_grad] + + optimizer = torch.optim.AdamW( + params_to_update, + lr=self.config.optimizer.lr, + weight_decay=self.config.optimizer.weight_decay, + betas=(self.config.optimizer.adam_beta1, self.config.optimizer.adam_beta2), + eps=self.config.optimizer.adam_epsilon, + ) + if (self.config.get("lr_scheduler", None) is not None) and (self.config.lr_scheduler.type == "cos_anneal"): + lr_scheduler = CosineAnnealingLR(optimizer, + T_max=self.config.lr_scheduler.T_max, + eta_min=self.config.lr_scheduler.eta_min) + else: + lr_scheduler = LambdaLR(optimizer, lr_lambda=lambda step: 1.0) + + if self.l_w_gan > 0: + if self.model_name == 'LIA': + d_reg_ratio = self.config.optimizer.d_reg_every / (self.config.optimizer.d_reg_every + 1) + optimizer_dis = torch.optim.AdamW( + self.discriminator.parameters(), + lr=self.config.optimizer.discriminator_lr * d_reg_ratio, + weight_decay=self.config.optimizer.weight_decay, + betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio), + eps=self.config.optimizer.adam_epsilon, + ) + else: + optimizer_dis = torch.optim.AdamW( + self.discriminator.parameters(), + lr=self.config.optimizer.discriminator_lr, + weight_decay=self.config.optimizer.weight_decay, + betas=(self.config.optimizer.adam_beta1, self.config.optimizer.adam_beta2), + eps=self.config.optimizer.adam_epsilon, + ) + + return [optimizer, optimizer_dis], [] + else: + # import pdb; pdb.set_trace() + return [optimizer], [lr_scheduler] + +if __name__ == "__main__": + from model.head_animation.LIA.motion_encoder import MotionEncoder + from model.head_animation.LIA.flow_estimator import FlowEstimator + from model.head_animation.LIA.face_encoder import FaceEncoder + from model.head_animation.LIA.face_generator import FaceGenerator + from torchsummaryX import summary + + IMAGE_SIZE = 512 + latent_dim = 512 + + encoder = MotionEncoder(latent_dim=latent_dim, size=IMAGE_SIZE) + # summary(encoder, torch.zeros(1, 3, IMAGE_SIZE, IMAGE_SIZE)) + + motion_space=20 + flow_estimator = FlowEstimator(latent_dim=latent_dim, motion_space=motion_space) + # summary(flow_estimator, torch.zeros(1, latent_dim), torch.zeros(1, latent_dim)) + tgt_latent = flow_estimator(torch.zeros(1, latent_dim), torch.zeros(1, latent_dim)) + + face_encoder = FaceEncoder(output_channels=latent_dim) + # summary(face_encoder, torch.zeros(1, 3, IMAGE_SIZE, IMAGE_SIZE)) + feat = face_encoder(torch.zeros(1, 3, IMAGE_SIZE, IMAGE_SIZE)) + # for fea in feat: print(fea.shape) + + face_generator = FaceGenerator(IMAGE_SIZE, latent_dim, channel_multiplier=1) + face_generator(tgt_latent, feat) + + + \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/head_animator_3d.py b/tools/visualization_0416/utils/model_0506/model/head_animation/head_animator_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..1b31d3912f85107a42e7ee3945f8314055e84f43 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/head_animator_3d.py @@ -0,0 +1,159 @@ +import torch +from torch import nn +import sys +from pathlib import Path +from einops import rearrange +import torch.nn.functional as F + +sys.path.append(str(Path(__file__).parent.parent.parent)) +from model.head_animation.head_animator import HeadAnimatorModule +from model.head_animation.VASA1.hopenet import headposeprediction_to_euler, euler_to_rotation_matrix +from utils import instantiate + +class HeadAnimator3DModule(HeadAnimatorModule): + def __init__(self, config): + super().__init__(config) + + def motion_encode(self, masked_img, img): + latent_code, rigid_pose = self.motion_encoder(masked_img, img) + return latent_code, rigid_pose + + def forward(self, source_img, target_img, masked_source_img, masked_target_img, batch_idx=None): + [feature_volume, global_descriptor]= self.face_encoder(source_img) + + if self.using_hybrid_mask: + tgt_latent, tgt_rigid_pose = self.motion_encoder(masked_target_img, target_img) # project target image to latent space + src_latent, src_rigid_pose = self.motion_encoder(masked_source_img, source_img) # project source image to latent space + else: + tgt_latent, tgt_rigid_pose = self.motion_encoder(target_img, target_img) # project target image to latent space + src_latent, src_rigid_pose = self.motion_encoder(source_img, source_img) # project source image to latent space + + src_dict = {} + src_dict["feature_volume"] = feature_volume + src_dict["global_descriptor"] = global_descriptor + src_dict["expression_code"] = src_latent[:, :self.config.model.motion_encoder.latent_dim] + src_dict["rigid_pose"] = src_rigid_pose + + tgt_dict = {} + tgt_dict["expression_code"] = tgt_latent[:, :self.config.model.motion_encoder.latent_dim] + tgt_dict["rigid_pose"] = tgt_rigid_pose + + canonical_feature_volume, warped_driving_feature_volume, src_delta_grid, tgt_delta_grid = self.flow_estimator(src_dict, tgt_dict, batch_idx) + + recon_img = self.face_generator(warped_driving_feature_volume) + + # if batch_idx == 20: + # recon_src_img = self.face_generator(feature_volume) + # recon_can_img = self.face_generator(canonical_feature_volume) + + # recon_src_img = 255 * (recon_src_img.permute(0, 2, 3, 1).cpu().numpy() + 1) / 2 + # recon_img_vis = 255 * (recon_img.permute(0, 2, 3, 1).detach().cpu().numpy() + 1) / 2 + # recon_img_can = 255 * (recon_can_img.permute(0, 2, 3, 1).detach().cpu().numpy() + 1) / 2 + # cv2.imwrite("recon_src_img.png", recon_src_img[0][:,:,::-1]) + # cv2.imwrite("recon_tgt_img.png", recon_img_vis[0][:,:,::-1]) + # cv2.imwrite("recon_can_img.png", recon_img_can[0][:,:,::-1]) + # import pdb; pdb.set_trace() + + [feature_volume_tgt, global_descriptor_tgt]= self.face_encoder(target_img) + tgt_dict["feature_volume"] = feature_volume_tgt + tgt_dict["global_descriptor"] = global_descriptor_tgt + canonical_feature_volume_tgt, warped_driving_feature_volume_tgt, _, _ = self.flow_estimator(tgt_dict, tgt_dict, batch_idx) + feature_volume = {} + feature_volume["canonical_feature_volume"] = canonical_feature_volume + feature_volume["warped_driving_feature_volume"] = warped_driving_feature_volume + feature_volume["canonical_feature_volume_tgt"] = canonical_feature_volume_tgt + feature_volume["warped_driving_feature_volume_tgt"] = warped_driving_feature_volume_tgt + feature_volume["feature_volume_tgt"] = feature_volume_tgt + feature_volume["src_delta_grid"] = src_delta_grid + feature_volume["tgt_delta_grid"] = tgt_delta_grid + + return recon_img, tgt_rigid_pose, feature_volume + + def compute_headpose_loss(self, tgt_image, tgt_rigid_pose): + if tgt_image.min() <= -0.5: + tgt_image_org = (tgt_image + 1) / 2 #[-1, 1] -> [0, 1] + + with torch.no_grad(): + gt_rotation_matrix = euler_to_rotation_matrix(*headposeprediction_to_euler(*self.motion_encoder.hopenet(self.motion_encoder.hopenet_test_transformations(tgt_image_org)))) + + headpose_loss = torch.acos( + # Clip is needed to avoid NAN values. + torch.clip( + # (trace(m1, m2.T) - 1) / 2 + input=torch.linalg.diagonal(gt_rotation_matrix @ tgt_rigid_pose["rotation"].transpose(1, 2)).sum(-1) * 0.5 - 0.5, + min=-0.999999, + max=0.999999, + ) + ).mean() * self.config.loss.l_w_headpose + + return headpose_loss + + def compute_feat_loss(self, feature_volume): + canonical_feature_volume = feature_volume["canonical_feature_volume"] + warped_driving_feature_volume = feature_volume["warped_driving_feature_volume"] + + canonical_feature_volume_tgt = feature_volume["canonical_feature_volume_tgt"] + warped_driving_feature_volume_tgt = feature_volume["warped_driving_feature_volume_tgt"] + feature_volume_tgt = feature_volume["feature_volume_tgt"].detach() + + feat_loss = F.mse_loss(canonical_feature_volume, canonical_feature_volume_tgt) + F.mse_loss(warped_driving_feature_volume, feature_volume_tgt) + return feat_loss + + def compute_driving_feat_loss(self, feature_volume): + warped_driving_feature_volume = feature_volume["warped_driving_feature_volume"] + feature_volume_tgt = feature_volume["feature_volume_tgt"].detach() + + feat_loss = F.mse_loss(warped_driving_feature_volume, feature_volume_tgt) + return feat_loss + + def compute_non_rigid_warping_loss(self, feature_volume): + # import pdb; pdb.set_trace() + batch_size, channel, depth, height, width = feature_volume["src_delta_grid"].shape + src_delta_grid = feature_volume["src_delta_grid"].reshape(batch_size, -1) + tgt_delta_grid = feature_volume["tgt_delta_grid"].reshape(batch_size, -1) + delta_grid_penalty = torch.mean(torch.abs(src_delta_grid).sum(dim=1)) + torch.mean(torch.abs(tgt_delta_grid).sum(dim=1)) + return delta_grid_penalty + + + def compute_loss(self, img_target, img_target_recon, face_mask=None, tgt_rigid_pose=None, feature_volume=None): + + vgg_loss, l1_loss, face_loss, face_l1_loss = self.compute_base_loss(img_target, img_target_recon, face_mask) + + loss = vgg_loss + l1_loss + face_loss + face_l1_loss + loss_dict = {'loss': loss, 'l1_loss': l1_loss, 'face_l1_loss': face_l1_loss, 'vgg_loss': vgg_loss, 'face_loss': face_loss} + + if not self.config.model.motion_encoder.use_gt_rotation: + # TODO: check this loss works or not + assert self.config.loss.l_w_headpose > 0 + headpose_loss = self.config.loss.l_w_headpose * self.compute_headpose_loss(img_target, tgt_rigid_pose) + loss_dict['loss'] += headpose_loss + loss_dict['headpose_loss'] = headpose_loss + + if feature_volume is not None and self.config.loss.l_w_feat > 0: + feat_loss = self.config.loss.l_w_feat * self.compute_feat_loss(feature_volume) + loss_dict['loss'] += feat_loss + loss_dict['feat_loss'] = feat_loss + + if feature_volume is not None and self.config.loss.l_w_dri_feat > 0: + dri_feat_loss = self.config.loss.l_w_dri_feat * self.compute_driving_feat_loss(feature_volume) + loss_dict['loss'] += dri_feat_loss + loss_dict['dri_feat_loss'] = dri_feat_loss + + if feature_volume is not None and self.config.loss.l_w_warping_penalty > 0: + warping_penalty_loss = self.config.loss.l_w_warping_penalty * self.compute_non_rigid_warping_loss(feature_volume) + loss_dict['loss'] += warping_penalty_loss + loss_dict['warping_penalty_loss'] = warping_penalty_loss + + return loss_dict + +if __name__ == "__main__": + from model.head_animation.VASA1.motion_encoder import MotionEncoder + from torchsummaryX import summary + + IMAGE_SIZE = 512 + latent_dim = 512 + + encoder = MotionEncoder(latent_dim=latent_dim, size=IMAGE_SIZE) + summary(encoder, torch.zeros(1, 3, IMAGE_SIZE, IMAGE_SIZE)) + + \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/head_animator_LIA_origin.py b/tools/visualization_0416/utils/model_0506/model/head_animation/head_animator_LIA_origin.py new file mode 100644 index 0000000000000000000000000000000000000000..15c61e6320d89053d79de845c68c2c3e951f0281 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/head_animator_LIA_origin.py @@ -0,0 +1,531 @@ +import numpy as np +import torch +from torch import nn +import sys +from pathlib import Path +from einops import rearrange +import torch.nn.functional as F +import math + +sys.path.append(str(Path(__file__).parent.parent.parent)) +from model.lightning.base_modules import BaseModule +from utils import instantiate +from model.head_animation.VASA3.building_blocks import * +from model.head_animation.VASA3.nonrigid_pose_encoder import AdaptiveGroupNorm +from model.modnets.modnet import MODNet + +class HeadAnimatorModule(BaseModule): + def __init__(self, config): + super().__init__(config) + self.automatic_optimization = False + self.config = config + self.using_hybrid_mask = config.model.get("using_hybrid_mask", True) + print(f'Using Hybird Mask: {self.using_hybrid_mask}') + if not self.using_hybrid_mask: + self.face_encoder = nn.Identity() + + self.criterion_recon = nn.L1Loss() + self.criterion_masked_face_l1 = nn.L1Loss(reduction='none') + + self.l_w_recon = config.loss.l_w_recon + self.l_w_vgg = config.loss.l_w_vgg + self.l_w_face = config.loss.get("l_w_face", 0) + self.l_w_gan = config.loss.get("l_w_gan", 0) + self.l_w_face_l1 = config.loss.get("l_w_face_l1", 0) + + # support GAN training & normal training + self.automatic_optimization = False + + if 'VASA' in self.config.model.motion_encoder.module_name: + self.model_name = 'VASA' + + if 'LIA' in self.config.model.motion_encoder.module_name: + self.model_name = 'LIA' + + print(f'Using {self.model_name} for Head Animation') + self.init_model() + + def init_model(self): + config = self.config + self.motion_encoder = instantiate(config.model.motion_encoder) + self.flow_estimator = instantiate(config.model.flow_estimator) + self.face_generator = instantiate(config.model.face_generator) + self.face_encoder = instantiate(config.model.face_encoder) + if not self.config.model.get("using_hybrid_mask", True): + self.face_encoder = nn.Identity() + + self.use_modnets = False + if config.model.get("modnets", None) is not None: + self.use_modnets = True + self.modnet = MODNet(backbone_pretrained=False) + self.modnet.eval() + modnet_state_dict = torch.load(config.model.modnets.pretrained_weights) + modnet_state_ckpt = {} + for k, v in modnet_state_dict.items(): + modnet_state_ckpt[k.replace("module.", "")] = v + self.modnet.load_state_dict(modnet_state_ckpt) + for name, param in self.modnet.named_parameters(): + param.requires_grad = False + + if config.loss.l_w_vgg > 0 or config.loss.l_w_face > 0: + # self.criterion_vgg = VGGLoss() + self.criterion_vgg = instantiate(config.model.vgg_loss) + for name, param in self.criterion_vgg.named_parameters(): + param.requires_grad = False + self.criterion_vgg.eval() + + if config.loss.l_w_gan > 0: + self.discriminator = instantiate(config.model.discriminator) + + if self.config.model.get('pretrained_ckpt', None) is not None: + checkpoint = torch.load(self.config.model.pretrained_ckpt)["state_dict"] + ckpt = {} + for k, v in checkpoint.items(): + if 'motion_encoder' in k: + ckpt[k.replace('motion_encoder.', '')] = v + self.motion_encoder.load_state_dict(ckpt, strict=True) + + ckpt = {} + for k, v in checkpoint.items(): + if 'flow_estimator' in k: + ckpt[k.replace('flow_estimator.', '')] = v + self.flow_estimator.load_state_dict(ckpt, strict=True) + + ckpt = {} + for k, v in checkpoint.items(): + if 'face_generator' in k: + ckpt[k.replace('face_generator.', '')] = v + self.face_generator.load_state_dict(ckpt, strict=True) + if self.face_generator.freeze: + for param in self.face_generator.parameters(): + param.requires_grad = False + + + ckpt = {} + for k, v in checkpoint.items(): + if 'face_encoder' in k: + ckpt[k.replace('face_encoder.', '')] = v + self.face_encoder.load_state_dict(ckpt, strict=True) + if self.face_encoder.freeze: + for name, param in self.face_encoder.named_parameters(): + # Note: we only pretrain the VolumetricFieldEncoder + if 'global_descriptor_encoder' not in name: + param.requires_grad = False + + + def configure_model(self): + pass + # def setup(self, stage=None): + # if stage == "fit" or stage is None: + # print('Model is initializing weights...') + # # self.initialize_weights() + + # def initialize_weights(self): + # for m in self.modules(): + # if isinstance(m, (AdaptiveGroupNorm)): + # pass + + # # -------------------- Initialize convolutional layers (WSConv2d/WSConv3d) -------------------- + # if isinstance(m, (WSConv2d, WSConv3d)): + # # Check if the weights need to be updated + # if m.weight.requires_grad: + # # Adapt the initialization of WSConv to the scaling + # nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') + # m.weight.data.mul_(math.sqrt(2)) # Compensate for the scaling of the weights + # # Initialize the bias (if it exists and needs to be updated) + # if hasattr(m, 'bias') and m.bias is not None and m.bias.requires_grad: + # nn.init.constant_(m.bias, 0) + + # # -------------------- Initialize ordinary convolutional layers (non-WSConv) -------------------- + # elif isinstance(m, (nn.Conv2d, nn.Conv3d)): + # if m.weight.requires_grad: + # nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') + # if m.bias is not None and m.bias.requires_grad: + # nn.init.constant_(m.bias, 0) + + # # -------------------- Handle the last layer of the residual block's GroupNorm -------------------- + # # Rule: Initialize the weight of the last GroupNorm in the main path to 0 (if the parameter needs to be updated) + # elif isinstance(m, (ResBlock2d, ResBlock3d, ResBasic, ResBottleneck)): + # # Find the last GroupNorm in the main path + # last_group_norm = None + # for layer in reversed(m.layers): + # if isinstance(layer, nn.GroupNorm): + # last_group_norm = layer + # break + # # Initialize the weight of the last GroupNorm to 0 (only if the parameter needs to be updated) + # if last_group_norm is not None and last_group_norm.weight.requires_grad: + # nn.init.constant_(last_group_norm.weight, 0) + + # # Initialize the convolutional layer of the skip connection (if it exists and needs to be updated) + # if not isinstance(m.skip_layer, (nn.Identity, type(None))): + # if isinstance(m.skip_layer, (nn.Conv2d, nn.Conv3d)): + # if m.skip_layer.weight.requires_grad: + # nn.init.kaiming_normal_(m.skip_layer.weight, mode='fan_in', nonlinearity='relu') + # elif isinstance(m.skip_layer, nn.Sequential): + # # Handle the convolutional layer in the skip connection (e.g. in ResBasic) + # for subm in m.skip_layer: + # if isinstance(subm, (nn.Conv2d, nn.Conv3d)): + # if subm.weight.requires_grad: + # nn.init.kaiming_normal_(subm.weight, mode='fan_in', nonlinearity='relu') + + # # -------------------- Initialize GroupNorm layers -------------------- + # elif isinstance(m, nn.GroupNorm): + # # Only initialize the parameters that need to be updated + # if m.weight is not None and m.weight.requires_grad: + # nn.init.constant_(m.weight, 1.0) + # if m.bias is not None and m.bias.requires_grad: + # nn.init.constant_(m.bias, 0) + + # # -------------------- Initialize linear layers -------------------- + # elif isinstance(m, nn.Linear): + # if m.weight.requires_grad: + # nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') + # if m.bias is not None and m.bias.requires_grad: + # nn.init.constant_(m.bias, 0) + + # # -------------------- Handle Spectral Normalization parameters -------------------- + # # If using spectral normalization, initialize the original weights instead of the parameterized weights + # if hasattr(m, 'parametrizations') and 'weight' in m.parametrizations: + # parametrization = m.parametrizations.weight[0] + # if hasattr(parametrization, 'original'): + # original_weight = parametrization.original + # if original_weight.requires_grad: + # # Initialize the original weights (e.g. using Kaiming) + # nn.init.kaiming_normal_(original_weight, mode='fan_in', nonlinearity='relu') + + def motion_encode(self, source_img): + latent_code, pyramid_feat = self.motion_encoder(source_img) + return latent_code, pyramid_feat + + def forward(self, source_img, target_img, masked_source_img, masked_target_img, batch_idx=None): + if self.using_hybrid_mask: + tgt_latent, _ = self.motion_encoder(masked_target_img) # project target image to reference latent space + src_latent, _ = self.motion_encoder(masked_source_img) # project source image to reference latent space + + tgt_latent = self.flow_estimator(src_latent, tgt_latent) # navigate source to target in reference latent space + + face_feat = self.face_encoder(source_img) + recon_img = self.face_generator(tgt_latent, face_feat) + else: + tgt_latent, _ = self.motion_encoder(target_img) # project target image to reference latent space + src_latent, face_feat = self.motion_encoder(source_img) # project source image to reference latent space + + tgt_latent = self.flow_estimator(src_latent, tgt_latent) # navigate source to target in reference latent space + recon_img = self.face_generator(tgt_latent, face_feat) + # import pdb; pdb.set_trace() + + return recon_img, None, None + + def compute_base_loss(self, img_target, img_target_recon, face_mask=None): + + l1_loss = self.l_w_recon * self.criterion_recon(img_target_recon, img_target) + + # Perceptual Loss + if self.l_w_vgg > 0: + # img_target_recon = F.interpolate(img_target_recon, size=(256, 256), mode='bilinear', align_corners=False) + # img_target = F.interpolate(img_target, size=(256, 256), mode='bilinear', align_corners=False) + vgg_loss, vgg_loss_dict = self.criterion_vgg(img_target_recon, img_target) + vgg_loss = self.l_w_vgg * vgg_loss.mean() + else: + vgg_loss = torch.zeros(1).to(self.device) + + # Facial Experssion Perceptual Loss + if face_mask is not None and self.l_w_face > 0: + face_loss, face_vgg_loss_dict = self.criterion_vgg(img_target_recon, img_target, face_mask) + face_loss = self.l_w_face * face_loss.mean() + else: + face_loss = torch.zeros(1).to(self.device) + + if face_mask is not None and self.l_w_face_l1 > 0: + face_l1_loss = self.criterion_masked_face_l1(img_target_recon*face_mask, img_target*face_mask) + face_l1_loss = face_l1_loss.view(face_mask.size(0), -1).sum(-1) / face_mask.view(face_mask.size(0), -1).sum(-1) + face_l1_loss = self.l_w_face_l1 * face_l1_loss.mean() + else: + face_l1_loss = torch.zeros(1).to(self.device) + + return vgg_loss, l1_loss, face_loss, face_l1_loss + + def compute_loss(self, img_target, img_target_recon, face_mask=None, tgt_rigid_pose=None, feature_volume=None): + vgg_loss, l1_loss, face_loss, face_l1_loss = self.compute_base_loss(img_target, img_target_recon, face_mask) + + loss = vgg_loss + l1_loss + face_loss + face_l1_loss + loss_dict = {'loss': loss, 'l1_loss': l1_loss, 'face_l1_loss': face_l1_loss, 'vgg_loss': vgg_loss, 'face_loss': face_loss} + return loss_dict + + def g_nonsaturating_loss(self, fake_pred): + return F.softplus(-fake_pred).mean() + + def d_nonsaturating_loss(self, fake_pred, real_pred): + real_loss = F.softplus(-real_pred) + fake_loss = F.softplus(fake_pred) + + return real_loss.mean() + fake_loss.mean() + + def prepare_datapair(self, batch): + # when not zero_to_one, all of bellow is [-1, 1] + masked_target_vid = batch['pixel_values_vid'] # this is a video batch: [B, T, C, H, W] + masked_past_frames = batch['pixel_values_past_frames'] + masked_target_vid = torch.cat([masked_past_frames, masked_target_vid], dim=1) + masked_ref_img = batch['pixel_values_ref_img'] # b c h w + + # when not zero_to_one, all of bellow is [-1, 1] + ref_img_original = batch['ref_img_original'] + target_vid_original = batch['pixel_values_vid_original'] + past_frames = batch['pixel_values_past_frames_original'] + target_vid_original = torch.cat([past_frames, target_vid_original], dim=1) + + # import pdb; pdb.set_trace() + + # construct ref-tgt pairs + masked_ref_img = masked_ref_img[:,None].repeat(1, masked_target_vid.size(1), 1, 1, 1) + masked_ref_img = rearrange(masked_ref_img, "b t c h w -> (b t) c h w") + masked_target_vid = rearrange(masked_target_vid, "b t c h w -> (b t) c h w") + + ref_img_original = ref_img_original[:,None].repeat(1, target_vid_original.size(1), 1, 1, 1) + ref_img_original = rearrange(ref_img_original, "b t c h w -> (b t) c h w") + target_vid_original = rearrange(target_vid_original, "b t c h w -> (b t) c h w") + + ref_img_original = ref_img_original.to(self.device) + target_vid_original = target_vid_original.to(self.device) + masked_ref_img = masked_ref_img.to(self.device) + masked_target_vid = masked_target_vid.to(self.device) + + if self.use_modnets: + with torch.no_grad(): + _, _, target_vid_original_mask = self.modnet((target_vid_original + 1.) / 2., True) + # target_vid_original_mask is (b 1 h w), range is (0, 1). + # Making to Binary mask: You can set value >= 0.5 to be 1 and value <= 0.5 to be 0. + + # visual ----------------------------------- + # import imageio + # visual_list = [] + # for ref_img_original_i, target_vid_original_i, masked_ref_img_i, \ + # masked_target_vid_i, target_vid_original_mask_i in zip(ref_img_original, target_vid_original, \ + # masked_ref_img, masked_target_vid, target_vid_original_mask): + # ref_img_original_i = (((ref_img_original_i.cpu().numpy() + 1.) / 2.) * 255.).transpose(1, 2, 0).astype("uint8") + # target_vid_original_i = (((target_vid_original_i.cpu().numpy() + 1.) / 2.) * 255.).transpose(1, 2, 0).astype("uint8") + # masked_ref_img_i = (((masked_ref_img_i.cpu().numpy() + 1.) / 2.) * 255.).transpose(1, 2, 0).astype("uint8") + # masked_target_vid_i = (((masked_target_vid_i.cpu().numpy() + 1.) / 2.) * 255.).transpose(1, 2, 0).astype("uint8") + # target_vid_original_mask_i[target_vid_original_mask_i < 0.5] = 0 + # target_vid_original_mask_i = (target_vid_original_mask_i.repeat(3, 1, 1).cpu().numpy().transpose(1, 2, 0) * 255.).astype("uint8") + # visuals = np.concatenate([ref_img_original_i, target_vid_original_i, masked_ref_img_i, masked_target_vid_i, target_vid_original_mask_i], axis=1) + # visual_list.append(visuals) + # import os + # os.makedirs("visual_train_data", exist_ok=True) + # imageio.mimwrite(f"./visual_train_data/{self.trainer.global_step}_{self.trainer.global_rank}.mp4", visual_list, fps=8) + # video_path = batch["video_path"][0] + # print(f"{video_path=}") + + # import imageio + # visual_list = [] + # for ref_img_original_i, target_vid_original_i, masked_ref_img_i, \ + # masked_target_vid_i in zip(ref_img_original, target_vid_original, \ + # masked_ref_img, masked_target_vid): + # ref_img_original_i = (((ref_img_original_i.cpu().numpy() + 1.) / 2.) * 255.).transpose(1, 2, 0).astype("uint8") + # target_vid_original_i = (((target_vid_original_i.cpu().numpy() + 1.) / 2.) * 255.).transpose(1, 2, 0).astype("uint8") + # masked_ref_img_i = (((masked_ref_img_i.cpu().numpy() + 1.) / 2.) * 255.).transpose(1, 2, 0).astype("uint8") + # masked_target_vid_i = (((masked_target_vid_i.cpu().numpy() + 1.) / 2.) * 255.).transpose(1, 2, 0).astype("uint8") + # visuals = np.concatenate([ref_img_original_i, target_vid_original_i, masked_ref_img_i, masked_target_vid_i], axis=1) + # visual_list.append(visuals) + # import os + # video_base = os.path.basename(video_path) + # os.makedirs("GAN_DEBUG_VISUAL", exist_ok=True) + # imageio.mimwrite(f"./GAN_DEBUG_VISUAL/{self.trainer.global_step}_{self.trainer.global_rank}_{video_base}", visual_list, fps=self.config.data.train_fps) + # visual ----------------------------------- + + + return ref_img_original, target_vid_original, masked_ref_img, masked_target_vid + + def set_module_eval_train_state(self, state_is_g=True): + if self.l_w_gan > 0: + self.discriminator.train() + self.motion_encoder.train() + self.flow_estimator.train() + self.face_generator.train() + self.face_encoder.train() + # if state_is_g: + # self.discriminator.eval() + # self.motion_encoder.train() + # self.flow_estimator.train() + # self.face_generator.train() + # self.face_encoder.train() + # else: + # self.discriminator.train() + # self.motion_encoder.eval() + # self.flow_estimator.eval() + # self.face_generator.eval() + # self.face_encoder.eval() + + def _step(self, batch, batch_idx): + # get source-target image pair + ref_img_original, target_vid_original, masked_ref_img, masked_target_vid = self.prepare_datapair(batch) + + if self.l_w_gan > 0: # using GAN training + optimizer_g, optimizer_d = self.optimizers() + ## train generator + # toggle is same to set grad is true + self.set_module_eval_train_state(True) + self.toggle_optimizer(optimizer_g) + # get reconstructed image + predicted_img, tgt_rigid_pose, feature_volume = self.forward(ref_img_original, target_vid_original, masked_ref_img, masked_target_vid, batch_idx) + + if self.l_w_face > 0 or self.l_w_face_l1 > 0: + eye_mouth_mask_vid = batch['eye_mouth_mask_vid'] + eye_mouth_mask_past_frames = batch['eye_mouth_mask_past_frames'] + face_mask = torch.cat([eye_mouth_mask_vid, eye_mouth_mask_past_frames], dim=1) + face_mask = rearrange(face_mask, "b t c h w -> (b t) c h w") + + loss_dict = self.compute_loss(target_vid_original, predicted_img, face_mask, tgt_rigid_pose, feature_volume) + + else: + loss_dict = self.compute_loss(target_vid_original, predicted_img, tgt_rigid_pose=tgt_rigid_pose, feature_volume=feature_volume) + + # adversarial loss + pred_label = self.discriminator(predicted_img).reshape(-1) + g_loss = self.l_w_gan * self.g_nonsaturating_loss(pred_label) + + loss_dict['loss'] += g_loss + loss_dict['g_loss'] = g_loss + + optimizer_g.zero_grad() + self.manual_backward(loss_dict['loss']) + optimizer_g.step() + self.untoggle_optimizer(optimizer_g) + + # import pdb; pdb.set_trace() + + ## train discriminator + self.set_module_eval_train_state(False) + self.toggle_optimizer(optimizer_d) + + real_img_pred = self.discriminator(target_vid_original) + predicted_img, tgt_rigid_pose, feature_volume = self.forward(ref_img_original, target_vid_original, masked_ref_img, masked_target_vid, batch_idx) + recon_img_pred = self.discriminator(predicted_img.detach()) + + d_loss = self.d_nonsaturating_loss(recon_img_pred, real_img_pred) + + optimizer_d.zero_grad() + self.manual_backward(d_loss) + optimizer_d.step() + self.untoggle_optimizer(optimizer_d) + + self.log("d_loss", d_loss, prog_bar=True) + + else: + optimizer_g = self.optimizers() + self.set_module_eval_train_state(True) + self.toggle_optimizer(optimizer_g) + # get reconstructed image + predicted_img, tgt_rigid_pose, feature_volume = self.forward(ref_img_original, target_vid_original, masked_ref_img, masked_target_vid, batch_idx) + + if self.l_w_face > 0 or self.l_w_face_l1 > 0: + eye_mouth_mask_vid = batch['eye_mouth_mask_vid'] + eye_mouth_mask_past_frames = batch['eye_mouth_mask_past_frames'] + face_mask = torch.cat([eye_mouth_mask_vid, eye_mouth_mask_past_frames], dim=1) + face_mask = rearrange(face_mask, "b t c h w -> (b t) c h w") + + loss_dict = self.compute_loss(target_vid_original, predicted_img, face_mask, tgt_rigid_pose, feature_volume) + + else: + loss_dict = self.compute_loss(target_vid_original, predicted_img, tgt_rigid_pose=tgt_rigid_pose, feature_volume=feature_volume) + + optimizer_g.zero_grad() + self.manual_backward(loss_dict['loss']) + optimizer_g.step() + self.untoggle_optimizer(optimizer_g) + + for k, v in loss_dict.items(): + self.log(k, v, prog_bar=True) + + + if False: + checkpoint = torch.load(self.config.model.pretrained_ckpt)["state_dict"] + (self.motion_encoder.convs[0][0].weight - checkpoint['motion_encoder.convs.0.0.weight']).sum() + + # check vgg16 weight + from torchvision import models + vgg_model = models.vgg19(pretrained=True).cuda() + vgg_params = [] + for p in vgg_model.parameters(): + vgg_params.append(p) + + (self.criterion_vgg.vgg.slice1[0].weight - vgg_params[0]).mean() + (self.criterion_vgg.vgg.slice2[0].weight - vgg_params[2]).mean() + import pdb; pdb.set_trace() + + return loss_dict + + def training_step(self, batch, batch_idx): + loss_dict = self._step(batch, batch_idx) + + def validation_step(self, batch, batch_idx): + if self.trainer.global_step > 5: + # get source-target image pair + ref_img_original, target_vid_original, masked_ref_img, masked_target_vid = self.prepare_datapair(batch) + + # get reconstructed image + with torch.no_grad(): + predicted_img, tgt_rigid_pose, feature_volume = self.forward(ref_img_original, target_vid_original, masked_ref_img, masked_target_vid, batch_idx) + loss_dict = self.compute_loss(target_vid_original, predicted_img, tgt_rigid_pose=tgt_rigid_pose) + + self.log('val_recon_loss', loss_dict['l1_loss'], prog_bar=True) + + return loss_dict['l1_loss'] + + def configure_optimizers(self): + params_to_update = list(self.motion_encoder.parameters()) + list(self.flow_estimator.parameters()) + \ + list(self.face_encoder.parameters()) + list(self.face_generator.parameters()) + params_to_update = [p for p in params_to_update if p.requires_grad] + + g_reg_ratio = self.config.optimizer.g_reg_every / (self.config.optimizer.g_reg_every + 1) + d_reg_ratio = self.config.optimizer.d_reg_every / (self.config.optimizer.d_reg_every + 1) + optimizer = torch.optim.Adam( + params_to_update, + lr=self.config.optimizer.lr * g_reg_ratio, + betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio), + ) + + if self.l_w_gan > 0: + optimizer_dis = torch.optim.Adam( + self.discriminator.parameters(), + lr=self.config.optimizer.lr * d_reg_ratio, + betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio), + ) + # return [ + # {"optimizer": optimizer}, + # {"optimizer": optimizer_dis, "do_not_count_global_step": True}, + # ] + return [optimizer, optimizer_dis], [] + else: + # import pdb; pdb.set_trace() + return [optimizer], [] + + +if __name__ == "__main__": + from model.head_animation.LIA.motion_encoder import MotionEncoder + from model.head_animation.LIA.flow_estimator import FlowEstimator + from model.head_animation.LIA.face_encoder import FaceEncoder + from model.head_animation.LIA.face_generator import FaceGenerator + from torchsummaryX import summary + + IMAGE_SIZE = 512 + latent_dim = 512 + + encoder = MotionEncoder(latent_dim=latent_dim, size=IMAGE_SIZE) + # summary(encoder, torch.zeros(1, 3, IMAGE_SIZE, IMAGE_SIZE)) + + motion_space=20 + flow_estimator = FlowEstimator(latent_dim=latent_dim, motion_space=motion_space) + # summary(flow_estimator, torch.zeros(1, latent_dim), torch.zeros(1, latent_dim)) + tgt_latent = flow_estimator(torch.zeros(1, latent_dim), torch.zeros(1, latent_dim)) + + face_encoder = FaceEncoder(output_channels=latent_dim) + # summary(face_encoder, torch.zeros(1, 3, IMAGE_SIZE, IMAGE_SIZE)) + feat = face_encoder(torch.zeros(1, 3, IMAGE_SIZE, IMAGE_SIZE)) + # for fea in feat: print(fea.shape) + + face_generator = FaceGenerator(IMAGE_SIZE, latent_dim, channel_multiplier=1) + face_generator(tgt_latent, feat) + + + \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/head_animator_emop.py b/tools/visualization_0416/utils/model_0506/model/head_animation/head_animator_emop.py new file mode 100644 index 0000000000000000000000000000000000000000..ac52bf02f46cfe0cb1bd4db998665a0171b4867f --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/head_animator_emop.py @@ -0,0 +1,936 @@ +import torch +from torch import nn +import sys +from pathlib import Path +from einops import rearrange +import torch.nn.functional as F +import math +import lpips +import numpy as np +from skimage.metrics import structural_similarity +from time import time +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, ConstantLR, SequentialLR +from utils import instantiate +import torchvision +import matplotlib.pyplot as plt +from torchvision.ops import roi_align +import random + +sys.path.append(str(Path(__file__).parent.parent.parent)) +from model.lightning.base_modules import BaseModule +from model.head_animation.EMOP.head_pose_regressor import HeadPoseRegressor + +class HeadAnimatorModule(BaseModule): + def __init__(self, config): + super().__init__(config) + + self.validation_step_outputs = [] + self.output_dir = config.model.get("output_dir", "outputs") + + self.config = config + self.using_hybrid_mask = config.model.get("using_hybrid_mask", True) + self.using_seg = config.model.get("using_seg", False) + self.face_part_start_iter = config.model.get("face_part_start_iter", 0) + self.log_grad_freq = config.model.get("log_grad_freq", 1e8) + self.max_grad_norm = config.model.get("max_grad_norm", 5) + self.add_gan_step = config.model.get("add_gan_step", 0) + + print(f'Using Hybird Mask: {self.using_hybrid_mask}') + print(f'Using Segmentation: {self.using_seg}') + print(f'Add gan loss from {self.add_gan_step} steps!') + print(f'Training log will be saved to: {self.output_dir}') + + self.criterion_recon = nn.L1Loss() + self.criterion_masked_face_l1 = nn.L1Loss(reduction='none') + + if self.config.model.get('sdm_loss', None) is not None and self.config.loss.get('l_w_sdm', 0) > 0: + self.criterion_sdm = instantiate(self.config.model.sdm_loss) + + if self.config.model.get('vgg_loss', None) is not None: + self.criterion_vgg = instantiate(self.config.model.vgg_loss) + + if self.config.get('loss', None) is not None: + self.l_w_recon = self.config.loss.get("l_w_recon", 0) + self.l_w_vgg = self.config.loss.get("l_w_vgg", 0) + self.l_w_face_vgg = self.config.loss.get("l_w_face_vgg", 0) + self.l_w_gan = self.config.loss.get("l_w_gan", 0) + self.l_w_face_l1 = self.config.loss.get("l_w_face_l1", 0) + self.l_w_gaze = self.config.loss.get("l_w_gaze", 0) + self.l_w_foreground = self.config.loss.get("l_w_foreground", 0) + self.l_w_local = self.config.loss.get("l_w_local", 0) + self.l_w_sdm = self.config.loss.get("l_w_sdm", 0) + self.l_w_ref_consistency = self.config.loss.get("l_w_ref_consistency", 0) + self.l_w_facial = self.config.loss.get("l_w_facial", 0) + self.l_w_id = self.config.loss.get("l_w_id", 0) + self.l_w_fea_match = self.config.loss.get("l_w_fea_match", 0) + else: + self.l_w_recon = 1 + self.l_w_vgg = 0 + self.l_w_face = 0 + self.l_w_gan = 0 + self.l_w_face_vgg = 0 + self.l_w_face_l1 = 0 + self.l_w_gaze = 0 + self.l_w_foreground = 0 + self.l_w_local = 0 + self.l_w_sdm = 0 + self.l_w_ref_consistency = 0 + self.l_w_facial = 0 + self.l_w_id = 0 + self.l_w_fea_match = 0 + + self.step_cnt = 0 + self.time = 0 + self.init_time = None + + self.face_parsing_en = self.l_w_foreground > 0 or self.l_w_local > 0 + + # support GAN training & normal training + self.automatic_optimization = False + + if 'VASA' in self.config.model.motion_encoder.module_name: + self.model_name = 'VASA' + + if 'LIA' in self.config.model.motion_encoder.module_name: + self.model_name = 'LIA' + + if 'EMOP' in self.config.model.motion_encoder.module_name: + self.model_name = 'EMOPortrait' + + print(f'Using {self.model_name} for Head Animation') + + def configure_model(self): + config = self.config + self.motion_encoder = instantiate(config.model.motion_encoder) + self.flow_estimator = instantiate(config.model.flow_estimator) + self.face_generator = instantiate(config.model.face_generator) + self.face_encoder = instantiate(config.model.face_encoder) + + if not self.motion_encoder.use_mask_image: + self.motion_encoder.rigid_pose_encoder.eval() ## only freeze rigid_pose_encoder + # for param in self.motion_encoder.rigid_pose_encoder.parameters(): + # param.requires_grad = False + print('Use original image to predict head pose') + else: + self.head_pose_regrssor = HeadPoseRegressor(config.model_config.head_pose_regressor_path) + self.head_pose_regrssor.eval() + for param in self.head_pose_regrssor.parameters(): + param.requires_grad = False + + if self.config.get('loss', None) is not None: + if config.loss.l_w_gan > 0: + self.discriminator = instantiate(config.model.discriminator) + + if self.config.loss.get('l_w_facial', None) is not None and config.loss.l_w_facial > 0: + self.facial_discriminator = instantiate(config.model.facial_component_loss) + + if self.config.loss.get('l_w_id', None) is not None and config.loss.l_w_id > 0: + self.criterion_vggface = instantiate(config.model.vggface) + self.criterion_vggface.eval() + for param in self.criterion_vggface.parameters(): + param.requires_grad = False + + if self.config.loss.get('l_w_gaze', None) is not None and config.loss.l_w_gaze > 0: + self.gaze_estimator = instantiate(config.model.gaze_estimator) + self.gaze_estimator.eval() + for param in self.gaze_estimator.parameters(): + param.requires_grad = False + + if self.config.loss.get('l_w_foreground', None) is not None and config.loss.l_w_foreground > 0 or \ + self.config.loss.get('l_w_local', None) is not None and config.loss.l_w_local > 0 or \ + self.config.model.get('using_seg', None) is not None and config.model.using_seg: + self.face_parser = instantiate(config.model.face_parser) + + if self.config.model.get('pretrained_ckpt', None) is not None: + checkpoint = torch.load(self.config.model.pretrained_ckpt)["state_dict"] + ckpt = {} + for k, v in checkpoint.items(): + if 'motion_encoder' in k: + ckpt[k.replace('motion_encoder.', '')] = v + self.motion_encoder.load_state_dict(ckpt, strict=True) + + ckpt = {} + for k, v in checkpoint.items(): + if 'flow_estimator' in k: + ckpt[k.replace('flow_estimator.', '')] = v + self.flow_estimator.load_state_dict(ckpt, strict=True) + + ckpt = {} + for k, v in checkpoint.items(): + if 'face_generator' in k: + ckpt[k.replace('face_generator.', '')] = v + self.face_generator.load_state_dict(ckpt, strict=True) + + ckpt = {} + for k, v in checkpoint.items(): + if 'face_encoder' in k: + ckpt[k.replace('face_encoder.', '')] = v + self.face_encoder.load_state_dict(ckpt, strict=True) + + def load_state_dict(self, state_dict: dict, strict: bool = True): + filtered_state_dict = { + k: v for k, v in state_dict.items() + if not k.startswith('motion_encoder.rigid_pose_encoder.') and not k.startswith('head_pose_regrssor.') and not k.startswith('criterion_vggface.') and not k.startswith('face_parser.') + } + super().load_state_dict(filtered_state_dict, strict=False) + + ### overwrite state_dict function to reduce checkpoint size + def state_dict(self, destination=None, prefix='', keep_vars=False): + state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) + + keys_to_remove = [k for k in state_dict.keys() if k.startswith('head_pose_regrssor.')] + keys_to_remove += [k for k in state_dict.keys() if k.startswith('criterion_vggface.')] + keys_to_remove += [k for k in state_dict.keys() if k.startswith('face_parser.')] + keys_to_remove += [k for k in state_dict.keys() if k.startswith('gaze_estimator.')] + + for k in keys_to_remove: + del state_dict[k] + + return state_dict + + def log_gradient_stats(self, loss_dict=None): + grad_stats = {} + + for name, param in self.named_parameters(): + if param.grad is not None: + grad = param.grad + grad_stats[name] = { + "max": grad.abs().max().item(), + "mean": grad.abs().mean().item(), + "std": grad.std().item() + } + + max_layer = max(grad_stats, key=lambda k: grad_stats[k]["max"]) + + log_file = f"{self.output_dir}/grad_info.txt" + current_epoch = self.current_epoch + global_step = self.global_step + + if loss_dict is not None: l1_loss = f"L1 Loss: {loss_dict['l1']:.6f}" + else: l1_loss = "" + + log_content = ( + f"Epoch: {current_epoch}, " + f"Step: {global_step}, " + f"{l1_loss}, " + f"Max Gradient Layer: {max_layer}, Value: {grad_stats[max_layer]['max']:.6f}" + ) + + if self.global_rank == 0: + with open(log_file, "a") as f: + f.write("*" * 50 + "\n") + f.write(log_content + "\n") + + return grad_stats + + + def motion_encode(self, source_img): + latent_code, pyramid_feat = self.motion_encoder(source_img) + return latent_code, pyramid_feat + + def forward(self, source_img, target_img, masked_source_img, masked_target_img, batch_idx=None): + data_dict = {'source_img': source_img, 'target_img': target_img, 'masked_source_img': masked_source_img, 'masked_target_img': masked_target_img} + + data_dict['latent_volume'], data_dict['idt_embed'] = self.face_encoder(data_dict['source_img']) + + if self.l_w_ref_consistency > 0: + data_dict['ref_matching'] = self.l_w_ref_consistency + data_dict['latent_volume_target'], _ = self.face_encoder(data_dict['target_img']) + + data_dict = self.motion_encoder(data_dict) + data_dict = self.flow_estimator(data_dict) + + aligned_target_volume = data_dict['target_volume'] + target_warp_embed_dict = data_dict['target_warp_embed_dict'] + data_dict = self.face_generator(data_dict, target_warp_embed_dict, aligned_target_volume) + + recon_img = data_dict['pred_target_img'] + + out_dict = {} + out_dict['recon_img'] = recon_img + out_dict['target_img'] = target_img + out_dict['tgt_latent'] = data_dict['target_exp_embed'] + out_dict['src_latent'] = data_dict['source_exp_embed'] + out_dict['face_mask'] = None + out_dict['source_theta'] = data_dict['source_theta'] + out_dict['target_theta'] = data_dict['target_theta'] + out_dict['target_scale'] = data_dict['target_scale'] + out_dict['target_rotation'] = data_dict['target_rotation'] + out_dict['target_translation'] = data_dict['target_translation'] + out_dict['align_warp'] = data_dict['align_warp'] + + if self.l_w_ref_consistency > 0: + out_dict['canonical_volume_from_tgt'] = data_dict['canonical_volume_from_tgt'] + out_dict['canonical_volume_from_src'] = data_dict['canonical_volume_from_src'] + + return out_dict + + def compute_base_loss(self, out_dict, tgt_parsing_map_dict=None): + + img_target, img_target_recon, face_mask = out_dict['target_img'], out_dict['recon_img'], out_dict['face_mask'] + + l1_loss = self.l_w_recon * self.criterion_recon(img_target_recon, img_target) + + # Full-image Perceptual Loss + if self.l_w_vgg > 0: + # vgg_loss, vgg_loss_dict = self.criterion_vgg(img_target_recon, img_target) + vgg_loss = self.criterion_vgg(img_target_recon, img_target) + vgg_loss = self.l_w_vgg * vgg_loss.mean() + else: + vgg_loss = torch.zeros(1).to(self.device) + + # Facial Experssion Perceptual Loss + if self.l_w_face_vgg > 0: + face_vgg_loss, face_vgg_loss_dict = self.criterion_vgg(out_dict['pred_target_img_face_align'], out_dict['target_img_align_orig']) + face_vgg_loss = self.l_w_face_vgg * face_vgg_loss.mean() + else: + face_vgg_loss = torch.zeros(1).to(self.device) + + if face_mask is not None and self.l_w_face_l1 > 0: + face_l1_loss = self.criterion_masked_face_l1(img_target_recon*face_mask, img_target*face_mask) + face_l1_loss = face_l1_loss.view(face_mask.size(0), -1).sum(-1) / (face_mask.view(face_mask.size(0), -1).sum(-1) + 1e-6) + face_l1_loss = self.l_w_face_l1 * face_l1_loss.mean() + else: + face_l1_loss = torch.zeros(1).to(self.device) + + if self.l_w_gaze > 0: + gaze_loss = self.l_w_gaze * self.gaze_estimator(out_dict['recon_img'], out_dict['target_img']) + else: + gaze_loss = torch.zeros(1).to(self.device) + + if self.face_parsing_en and self.step_cnt >= self.face_part_start_iter: + assert tgt_parsing_map_dict is not None + face_mask = tgt_parsing_map_dict['face_mask'] + face_body = tgt_parsing_map_dict['face_body'] + cloth_mask = tgt_parsing_map_dict['cloth_mask'] + mouth = tgt_parsing_map_dict['mouth'] + eye = tgt_parsing_map_dict['eye'] + ear = tgt_parsing_map_dict['ear'] + human_mask = (face_body + cloth_mask).float() + eye_mouth_ear_mask = (eye + mouth + ear).float() + + pred_tgt_parsing_map_dict = self.face_parser.forward(img_target_recon) + pred_mouth = pred_tgt_parsing_map_dict['mouth'] + pred_eye = pred_tgt_parsing_map_dict['eye'] + pred_ear = pred_tgt_parsing_map_dict['ear'] + pred_face_body = tgt_parsing_map_dict['face_body'] + pred_cloth_mask = tgt_parsing_map_dict['cloth_mask'] + pred_human_mask = (pred_face_body + pred_cloth_mask).float() + pred_eye_mouth_ear_mask = (pred_eye + pred_mouth + pred_ear).float().detach() + + pred_eye_mouth_ear_mask = eye_mouth_ear_mask + + gt_mask = eye_mouth_ear_mask.reshape(eye_mouth_ear_mask.size(0), -1).sum(-1) + pred_mask = pred_eye_mouth_ear_mask.reshape(eye_mouth_ear_mask.size(0), -1).sum(-1) + if (gt_mask == 0).any() or (pred_mask == 0).any(): + mask_flag = ((gt_mask * pred_mask) > 0).float() + mask_flag = mask_flag.detach() + else: + mask_flag = torch.ones(eye_mouth_ear_mask.size(0)).float().to(self.device) + + if self.l_w_foreground > 0: + img_target_human = img_target * human_mask + img_target_recon_human = img_target_recon * pred_human_mask + + foreground_loss, _ = self.criterion_vgg(img_target_recon_human, img_target_human, human_mask) + foreground_loss = self.l_w_foreground * foreground_loss.mean() + else: + foreground_loss = torch.zeros(1).to(self.device) + + if self.l_w_local > 0: + use_mask = self.config.model.get('use_mask', False) + # if use_mask: import pdb; pdb.set_trace() + # import pdb; pdb.set_trace() + img_target_local = img_target * eye_mouth_ear_mask + img_target_recon_local = img_target_recon * pred_eye_mouth_ear_mask + + if 'rois_face' in out_dict and True: + img_target_local = roi_align(img_target_local, boxes=out_dict['rois_face'], output_size=img_target_local.size(-1)) + img_target_recon_local = roi_align(img_target_recon_local, boxes=out_dict['rois_face'], output_size=img_target_local.size(-1)) + + if self.trainer.global_step > 1 and 0: + import cv2 + pred_img_align = (img_target_local[-8].permute(1,2,0).detach().cpu().numpy() * 255).astype('uint8') + gt_img_align = (img_target_recon_local[-8].permute(1,2,0).detach().cpu().numpy() * 255).astype('uint8') + align_img = np.concatenate([gt_img_align, pred_img_align], axis=1) + + # align_img = ( data_dict['target_img'][-4].permute(1,2,0).detach().cpu().numpy() * 255).astype('uint8') + cv2.imwrite(f'{self.output_dir}/emop_src_img_align_training_{self.step_cnt}.png', align_img[:,:,::-1]) + import pdb; pdb.set_trace() + + # local_loss, _ = self.criterion_vgg(img_target_recon_local, img_target_local, eye_mouth_ear_mask, use_mask=use_mask) + # local_loss = self.l_w_local * (local_loss * mask_flag).mean() + + local_loss = self.l_w_local * self.criterion_vgg(img_target_recon_local, img_target_local) + + # local_loss = self.l_w_local * self.criterion_recon(img_target_recon_local, img_target_local) + else: + local_loss = torch.zeros(1).to(self.device) + else: + foreground_loss = torch.zeros(1).to(self.device) + local_loss = torch.zeros(1).to(self.device) + + return vgg_loss, l1_loss, face_vgg_loss, face_l1_loss, gaze_loss, foreground_loss, local_loss + + def compute_head_pose(self, out_dict): + target_img = out_dict['target_img'] + target_theta, target_scale, target_rotation, target_translation = out_dict['target_theta'], out_dict['target_scale'], out_dict['target_rotation'], out_dict['target_translation'] + + with torch.no_grad(): + if not self.motion_encoder.zero_to_one: + img_gt = (target_img + 1) / 2 + else: + img_gt = target_img + gt_theta, gt_scale, gt_rotation, gt_translation = self.head_pose_regrssor.forward(img_gt, return_srt=True) + gt_theta = gt_theta.detach() + + pred_pose = torch.cat([target_scale, target_rotation, target_translation], dim=-1) + gt_pose = torch.cat([gt_scale, gt_rotation, gt_translation], dim=-1) + headpose_loss = self.criterion_recon(pred_pose, gt_pose) + return headpose_loss + + def compute_loss(self, out_dict, tgt_parsing_map_dict=None): + + if self.l_w_id > 0 or self.l_w_face_vgg > 0: + if True: + t = out_dict['target_img'].shape[0] + pred_align_warp = out_dict['align_warp'].float() # B x 128 x 128 x 2 + inputs_orig_face_aligned = F.grid_sample(torch.cat([out_dict['recon_img'], out_dict['target_img']]).float(), pred_align_warp) + out_dict['pred_target_img_face_align'], out_dict['target_img_align_orig'] = inputs_orig_face_aligned.split([t, t], dim=0) + + face_size = 512 + face_bbox = out_dict['facial_info']['face_bbox'] + rois_face = [] + for b in range(face_bbox.size(0)): # loop for batch size + img_inds = face_bbox.new_full((1, 1), b).to(face_bbox.device) + rois = torch.cat([img_inds, face_bbox[b:b + 1, :]], dim=-1) # shape: (1, 5) + rois_face.append(rois) + rois_face = torch.cat(rois_face, 0).float() + out_dict['rois_face'] = rois_face + else: + face_size = 512 + face_bbox = out_dict['facial_info']['face_bbox'] + rois_face = [] + for b in range(face_bbox.size(0)): # loop for batch size + img_inds = face_bbox.new_full((1, 1), b).to(face_bbox.device) + rois = torch.cat([img_inds, face_bbox[b:b + 1, :]], dim=-1) # shape: (1, 5) + rois_face.append(rois) + rois_face = torch.cat(rois_face, 0).float() + out_dict['rois_face'] = rois_face + + out_dict['target_img_align_orig'] = roi_align(out_dict['target_img'], boxes=rois_face, output_size=face_size) + out_dict['pred_target_img_face_align'] = roi_align(out_dict['recon_img'], boxes=rois_face, output_size=face_size) + + + vgg_loss, l1_loss, face_vgg_loss, face_l1_loss, gaze_loss, foreground_loss, local_loss = self.compute_base_loss(out_dict, tgt_parsing_map_dict) + + loss = vgg_loss + l1_loss + face_vgg_loss + face_l1_loss + gaze_loss + foreground_loss + local_loss + loss_dict = {'loss': loss, 'l1': l1_loss, 'face_l1': face_l1_loss, 'vgg': vgg_loss, + 'gaze': gaze_loss, 'face_vgg': face_vgg_loss, 'foreground': foreground_loss, 'local': local_loss,} + + if self.l_w_sdm > 0: + sdm_loss = self.l_w_sdm * self.criterion_sdm(out_dict['src_latent'], out_dict['tgt_latent']) + loss_dict['loss'] += sdm_loss + loss_dict['sdm'] = sdm_loss + + if self.l_w_id > 0: + if not self.motion_encoder.zero_to_one: + gt_img = (out_dict['target_img_align_orig'] + 1) / 2 + recon_img = (out_dict['pred_target_img_face_align'] + 1) / 2 + else: + gt_img = out_dict['target_img_align_orig'] + recon_img = out_dict['pred_target_img_face_align'] + + id_loss = self.l_w_id * self.criterion_vggface(recon_img, gt_img.detach()) + loss_dict['loss'] += id_loss + loss_dict['id'] = id_loss + + if self.l_w_ref_consistency > 0: + canonical_volume_from_tgt = out_dict['canonical_volume_from_tgt'] + canonical_volume_from_src = out_dict['canonical_volume_from_src'].detach() + + ref_match_loss = self.l_w_ref_consistency * self.criterion_recon(canonical_volume_from_tgt, canonical_volume_from_src) + loss_dict['loss'] += ref_match_loss + loss_dict['ref_match'] = ref_match_loss + + # compute head pose loss + if self.motion_encoder.use_mask_image: + assert self.config.loss.get("l_w_headpose", 0) > 0 + headposs_loss = self.config.loss.l_w_headpose * self.compute_head_pose(out_dict) + loss_dict['loss'] += headposs_loss + loss_dict['headpose'] = headposs_loss + + return loss_dict + + def g_nonsaturating_loss(self, fake_pred): + return F.softplus(-fake_pred).mean() + + def d_nonsaturating_loss(self, fake_pred, real_pred): + real_loss = F.softplus(-real_pred) + fake_loss = F.softplus(fake_pred) + + return real_loss.mean() + fake_loss.mean() + + def prepare_datapair(self, batch): + masked_target_vid = batch['pixel_values_vid'] # this is a video batch: [B, T, C, H, W] + masked_past_frames = batch['pixel_values_past_frames'] + masked_target_vid = torch.cat([masked_past_frames, masked_target_vid], dim=1) + masked_ref_img = batch['pixel_values_ref_img'] + + ref_img_original = batch['ref_img_original'] + target_vid_original = batch['pixel_values_vid_original'] + past_frames = batch['pixel_values_past_frames_original'] + target_vid_original = torch.cat([past_frames, target_vid_original], dim=1) + + ref_img_original = ref_img_original.to(self.device) + target_vid_original = target_vid_original.to(self.device) + masked_ref_img = masked_ref_img.to(self.device) + masked_target_vid = masked_target_vid.to(self.device) + + # construct ref-tgt pairs + masked_ref_img = masked_ref_img[:,None].repeat(1, masked_target_vid.size(1), 1, 1, 1) + masked_ref_img = rearrange(masked_ref_img, "b t c h w -> (b t) c h w") + masked_target_vid = rearrange(masked_target_vid, "b t c h w -> (b t) c h w") + + ref_img_original = ref_img_original[:,None].repeat(1, target_vid_original.size(1), 1, 1, 1) + ref_img_original = rearrange(ref_img_original, "b t c h w -> (b t) c h w") + target_vid_original = rearrange(target_vid_original, "b t c h w -> (b t) c h w") + + l_eye_bbox = torch.cat([batch['past_l_eye_bbox'], batch['vid_l_eye_bbox']], dim=1) + r_eye_bbox = torch.cat([batch['past_r_eye_bbox'], batch['vid_r_eye_bbox']], dim=1) + mouth_bbox = torch.cat([batch['past_mouth_bbox'], batch['vid_mouth_bbox']], dim=1) + face_bbox = torch.cat([batch['past_face_bbox'], batch['vid_face_bbox']], dim=1) + + l_eye_bbox = rearrange(l_eye_bbox, "b t c -> (b t) c") + r_eye_bbox = rearrange(r_eye_bbox, "b t c -> (b t) c") + mouth_bbox = rearrange(mouth_bbox, "b t c -> (b t) c") + face_bbox = rearrange(face_bbox, "b t c -> (b t) c") + facial_info = {'l_eye_bbox': l_eye_bbox, 'r_eye_bbox': r_eye_bbox, 'mouth_bbox': mouth_bbox, 'face_bbox': face_bbox} + + return ref_img_original, target_vid_original, masked_ref_img, masked_target_vid, facial_info + + def _step(self, batch, batch_idx): + # get source-target image pair + ref_img_original, target_vid_original, masked_ref_img, masked_target_vid, facial_info = self.prepare_datapair(batch) + + if (self.using_seg or self.face_parsing_en) and self.step_cnt >= self.face_part_start_iter: + # get human parsing maps + tgt_parsing_map_dict = self.face_parser.forward(target_vid_original) + + if self.using_seg: + src_parsing_map_dict = self.face_parser.forward(ref_img_original) + src_face_body = src_parsing_map_dict['face_body'] + src_cloth_mask = src_parsing_map_dict['cloth_mask'] + src_human_mask = src_face_body + src_cloth_mask + ref_img_original = ref_img_original * src_human_mask + + tgt_face_body = tgt_parsing_map_dict['face_body'] + tgt_cloth_mask = tgt_parsing_map_dict['cloth_mask'] + tgt_human_mask = tgt_face_body + tgt_cloth_mask + target_vid_original = target_vid_original * tgt_human_mask + else: + tgt_parsing_map_dict = None + + out_dict = self.forward(ref_img_original, target_vid_original, masked_ref_img, masked_target_vid, batch_idx) + out_dict['facial_info'] = facial_info + loss_dict = self.compute_loss(out_dict, tgt_parsing_map_dict=tgt_parsing_map_dict) + + optimizers = self.optimizers() + + if isinstance(optimizers, (list, tuple)): + optimizer_g, optimizer_d = optimizers[0], optimizers[1] + + ################# train generator ################# + self.toggle_optimizer(optimizer_g) + + GAN_STYLE = False + ## step:1 compute global gan loss + self.discriminator.eval() + # print('1:', self.discriminator.training) + self.discriminator.requires_grad_(False) + # for p in self.discriminator.parameters(): + # p.requires_grad = False + + if GAN_STYLE: + pred_label = self.discriminator(out_dict['recon_img']).reshape(-1) + g_loss = self.l_w_gan * self.g_nonsaturating_loss(pred_label) + global_loss_dict = self.discriminator.gan_forward(out_dict['recon_img']) + g_loss = global_loss_dict['global_gan_loss'] + feature_matching_loss = global_loss_dict['style_loss'] + else: + # Without grad as it is ground truth + with torch.no_grad(): + _, real_feats_gen = self.discriminator(out_dict['target_img']) + # With grad as it is predict + fake_score_gen, fake_feats_gen = self.discriminator(out_dict['recon_img']) + g_loss = self.l_w_gan * self.discriminator.compute_loss(fake_score_gen, mode='gen') + feature_matching_loss = self.l_w_fea_match * self.discriminator.feature_matching_loss(real_feats_gen, fake_feats_gen) + + # if self.step_cnt >= self.add_gan_step: + # loss_dict['loss'] += g_loss + # loss_dict['loss'] += feature_matching_loss + # else: + # loss_dict['loss'] += g_loss * self.step_cnt / self.add_gan_step + # loss_dict['loss'] += feature_matching_loss * self.step_cnt / self.add_gan_step + + loss_dict['gan'] = g_loss + loss_dict['feat_m'] = feature_matching_loss + + ## step:2 compute facial gan loss + if self.l_w_facial > 0: + self.facial_discriminator.eval() + self.facial_discriminator.requires_grad_(False) + facial_dict = self.facial_discriminator.get_facial_component(out_dict['recon_img'], target_vid_original.detach(), facial_info) + facial_loss_dict = self.facial_discriminator.gan_forward(facial_dict) + facial_loss = facial_loss_dict['facial_gan_loss'] + facial_loss_dict['style_loss'] + + # if self.step_cnt >= self.add_gan_step: + # loss_dict['loss'] += facial_loss + # else: + # loss_dict['loss'] += facial_loss * self.step_cnt / self.add_gan_step + # loss_dict['facial_gan'] = facial_loss + + loss_dict['loss'] += facial_loss + + loss_dict['face_gan'] = facial_loss_dict['facial_gan_loss'] + loss_dict['face_sty'] = facial_loss_dict['style_loss'] + + optimizer_g.zero_grad() + self.manual_backward(loss_dict['loss']) + + # torch.nn.utils.clip_grad_norm_(self.parameters(), self.max_grad_norm) + optimizer_g.step() + if self.step_cnt % self.log_grad_freq == 0: + self.log_gradient_stats(loss_dict) + + self.untoggle_optimizer(optimizer_g) + + GAN_DISABLE = True + if not GAN_DISABLE: + ################# train global discriminator ################# + self.discriminator.train() + # print('2:', self.discriminator.training) + self.discriminator.requires_grad_(True) + self.toggle_optimizer(optimizer_d) + + if GAN_STYLE: + dis_loss_dict = self.discriminator.dis_forward(target_vid_original, out_dict['recon_img'].detach()) + d_loss = dis_loss_dict['d_loss'] + else: + real_score_dis, _ = self.discriminator(out_dict['target_img']) + fake_score_dis, _ = self.discriminator(out_dict['recon_img'].detach()) + real_score, fake_score = real_score_dis[0][0], fake_score_dis[0][0] + + d_loss = self.discriminator.compute_loss(fake_scores=fake_score, real_scores=real_score, mode='dis') + loss_dict['d_loss'] = d_loss + + if True: + # import pdb; pdb.set_trace() + real_probs = torch.sigmoid(real_score) + fake_probs = torch.sigmoid(fake_score) + correct_real = (real_probs >= 0.5).float() + correct_fake = (fake_probs < 0.5).float() + total_correct = correct_real.sum() + correct_fake.sum() + total_samples = real_probs.numel() + fake_probs.numel() + accuracy = total_correct / total_samples + real_acc = correct_real.sum() / real_probs.numel() + fake_acc = correct_fake.sum() / fake_probs.numel() + # loss_dict['d_acc'] = accuracy + loss_dict['real_acc'] = real_acc + loss_dict['fake_acc'] = fake_acc + + + if not GAN_DISABLE: + optimizer_d.zero_grad() + self.manual_backward(d_loss) + optimizer_d.step() + self.untoggle_optimizer(optimizer_d) + + ################# train facial discriminator ################# + if self.l_w_facial > 0: + optimizer_d_facial = optimizers[2] + + if not GAN_DISABLE: + self.facial_discriminator.train() + self.facial_discriminator.requires_grad_(True) + self.toggle_optimizer(optimizer_d_facial) + + facial_dis_loss_dict = self.facial_discriminator.dis_forward(facial_dict) + loss_dict['facial_d_loss'] = facial_dis_loss_dict['d_loss'] + + if not GAN_DISABLE: + optimizer_d_facial.zero_grad() + self.manual_backward(facial_dis_loss_dict['d_loss']) + optimizer_d_facial.step() + self.untoggle_optimizer(optimizer_d_facial) + + if True: + real_probs = torch.sigmoid(facial_dis_loss_dict['real_d_pred_mouth']) + fake_probs = torch.sigmoid(facial_dis_loss_dict['fake_d_pred_mouth']) + correct_real = (real_probs >= 0.5).float() + correct_fake = (fake_probs < 0.5).float() + total_correct = correct_real.sum() + correct_fake.sum() + total_samples = real_probs.numel() + fake_probs.numel() + accuracy = total_correct / total_samples + real_acc = correct_real.sum() / real_probs.numel() + fake_acc = correct_fake.sum() / fake_probs.numel() + # loss_dict['d_acc'] = accuracy + loss_dict['mouth_real_acc'] = real_acc + loss_dict['mouth_fake_acc'] = fake_acc + + else: + optimizer_g = optimizers + self.toggle_optimizer(optimizer_g) + optimizer_g.zero_grad() + self.manual_backward(loss_dict['loss']) + optimizer_g.step() + self.untoggle_optimizer(optimizer_g) + + for k, v in loss_dict.items(): + if v > 0: + self.log(k, v, prog_bar=True) + + if self.global_rank == 0 and self.global_step % 100 == 0: + log_file = f"{self.output_dir}/train_log.txt" + current_epoch = self.current_epoch + global_step = self.global_step + log_content = f"Epoch: {current_epoch}, Step: {global_step}, " + for k, v in loss_dict.items(): + log_content += f"{k}: {v.item():.4f}, " + + with open(log_file, "a") as f: + f.write("*" * 50 + "\n") + f.write(log_content + "\n") + + return loss_dict + + + + def training_step(self, batch, batch_idx): + if self.init_time is None: + self.init_time = time() + + loss_dict = self._step(batch, batch_idx) + + # log current learning rate + optimizers = self.optimizers() + if isinstance(optimizers, (list, tuple)): + optimizer_g = optimizers[0] + else: + optimizer_g = optimizers + current_lr = optimizer_g.param_groups[0]['lr'] + self.log('lr', current_lr, on_step=True, on_epoch=False, prog_bar=True) + + # Step the scheduler after every training step + if self.config.get('scheduler', None) is not None: + if current_lr > self.config.scheduler.min_lr: + if self.l_w_gan > 0: + schedulers_list = self.lr_schedulers() + scheduler_g, scheduler_d = schedulers_list[0], schedulers_list[1] + scheduler_g.step() + scheduler_d.step() + + if self.l_w_facial > 0: + scheduler_facial = schedulers_list[2] + scheduler_facial.step() + + # print(optimizers[1].param_groups[0]['lr'], optimizers[2].param_groups[0]['lr']) + # import pdb; pdb.set_trace() + else: + scheduler = self.lr_schedulers() + scheduler.step() + + + self.step_cnt += 1 + self.time = time() - self.init_time + self.log('avg_time', self.time / self.step_cnt, on_step=True, on_epoch=False, prog_bar=True) + + return loss_dict['loss'] + + def validation_step(self, batch, batch_idx): + if self.trainer.global_step > 1: + # get source-target image pair + ref_img_original, target_vid_original, masked_ref_img, masked_target_vid, facial_info = self.prepare_datapair(batch) + + if self.using_seg or self.face_parsing_en: + # get human parsing maps + tgt_parsing_map_dict = self.face_parser.forward(target_vid_original) + + if self.using_seg: + src_parsing_map_dict = self.face_parser.forward(ref_img_original) + src_face_body = src_parsing_map_dict['face_body'] + src_cloth_mask = src_parsing_map_dict['cloth_mask'] + src_human_mask = src_face_body + src_cloth_mask + ref_img_original = ref_img_original * src_human_mask + + tgt_face_body = tgt_parsing_map_dict['face_body'] + tgt_cloth_mask = tgt_parsing_map_dict['cloth_mask'] + tgt_human_mask = tgt_face_body + tgt_cloth_mask + target_vid_original = target_vid_original * tgt_human_mask + else: + tgt_parsing_map_dict = None + + # get reconstructed image + with torch.no_grad(): + out_dict = self.forward(ref_img_original, target_vid_original, masked_ref_img, masked_target_vid, batch_idx) + out_dict['facial_info'] = facial_info + loss_dict = self.compute_loss(out_dict, tgt_parsing_map_dict=tgt_parsing_map_dict) + + if target_vid_original.min() < 0: + predicted_img = (out_dict['recon_img'] + 1) / 2 + target_vid_original = (target_vid_original + 1) / 2 + else: + predicted_img = out_dict['recon_img'] + + loss_dict['l1_loss'] = F.l1_loss(predicted_img, target_vid_original).mean() + predicted_img = (predicted_img * 255).permute(0, 2, 3, 1).cpu().numpy() + target_vid_original = (target_vid_original * 255).permute(0, 2, 3, 1).cpu().numpy() + + psnr_list = [] + ssim_list = [] + for tmp_i in range(len(predicted_img)): + psnr = lpips.psnr(predicted_img[tmp_i], target_vid_original[tmp_i], peak=255.) + ssim = structural_similarity(predicted_img[tmp_i], target_vid_original[tmp_i], data_range=255, multichannel=True, channel_axis=2) + psnr_list.append(psnr) + ssim_list.append(ssim) + + avg_psnr = np.mean(psnr_list) + avg_ssim = np.mean(ssim_list) + + loss_dict['val_psnr'] = avg_psnr + loss_dict['val_ssim'] = avg_ssim + + self.validation_step_outputs.append(loss_dict) + + return loss_dict + + def on_validation_epoch_end(self): + if not hasattr(self, 'validation_step_outputs') or len(self.validation_step_outputs) == 0: + return + + # get all metrics + outputs = self.validation_step_outputs + avg_recon_loss = torch.stack([x['l1'] for x in outputs]).mean() + avg_foreground_loss = torch.stack([x['foreground'] for x in outputs]).mean() + avg_local_loss = torch.stack([x['local'] for x in outputs]).mean() + + avg_psnr = np.mean([x['val_psnr'] for x in outputs]) + avg_ssim = np.mean([x['val_ssim'] for x in outputs]) + + # log metrics + self.log('val_recon', avg_recon_loss, prog_bar=True) + self.log('val_psnr', avg_psnr, prog_bar=True) + self.log('val_ssim', avg_ssim, prog_bar=True) + + if self.face_parsing_en: + if self.l_w_foreground > 0: + avg_foreground_loss /= self.l_w_foreground + self.log('val_foreground_loss', avg_foreground_loss, prog_bar=True) + + if self.l_w_local > 0: + avg_local_loss /= self.l_w_local + self.log('val_local_loss', avg_local_loss, prog_bar=True) + + if self.global_rank == 0: + log_file = f"{self.output_dir}/validation_metrics.txt" + current_epoch = self.current_epoch + global_step = self.global_step + log_content = ( + f"Epoch: {current_epoch}, " + f"Step: {global_step}, " + f"Recon Loss: {avg_recon_loss.item():.4f}, " + f"PSNR: {avg_psnr:.4f}, " + f"SSIM: {avg_ssim:.4f}" + ) + if self.face_parsing_en: + log_content += ( + f", Foreground Loss: {avg_foreground_loss.item():.4f}, " + f"Local Loss: {avg_local_loss.item():.4f}" + ) + + with open(log_file, "a") as f: + f.write("*" * 50 + "\n") + f.write(log_content + "\n") + + # clear cache for next epoch + self.validation_step_outputs.clear() + + def configure_optimizers(self): + params_to_update = list(self.motion_encoder.parameters()) + list(self.flow_estimator.parameters()) + \ + list(self.face_encoder.parameters()) + list(self.face_generator.parameters()) + params_to_update = [p for p in params_to_update if p.requires_grad] + params_name_to_update = [name for name, p in self.named_parameters() if p.requires_grad] + + optimizer = torch.optim.AdamW( + params_to_update, + lr=self.config.optimizer.lr, + weight_decay=self.config.optimizer.weight_decay, + betas=(self.config.optimizer.adam_beta1, self.config.optimizer.adam_beta2), + eps=self.config.optimizer.adam_epsilon, + ) + optimizer_list = [optimizer] + + # Scheduler with warm-up and cosine annealing + if self.config.get('scheduler', None) is not None: + total_steps = self.config.scheduler.total_steps + + # Cosine scheduler + cos_scheduler = CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=self.config.scheduler.min_lr) + scheduler = { + 'scheduler': cos_scheduler, + 'interval': 'step', + 'frequency': 1, + } + scheduler_cfg = [scheduler] + else: + scheduler_cfg = [] + + + if self.l_w_gan > 0: + optimizer_dis = torch.optim.AdamW( + self.discriminator.parameters(), + lr=self.config.optimizer.discriminator_lr, + weight_decay=self.config.optimizer.weight_decay, + betas=(self.config.optimizer.adam_beta1, self.config.optimizer.adam_beta2), + eps=self.config.optimizer.adam_epsilon, + ) + + # Scheduler with warm-up and cosine annealing + if self.config.get('scheduler', None) is not None: + # Cosine scheduler + dis_cos_scheduler = CosineAnnealingLR(optimizer_dis, T_max=total_steps, eta_min=self.config.scheduler.min_lr) + dis_scheduler = { + 'scheduler': dis_cos_scheduler, + 'interval': 'step', + 'frequency': 1, + } + + optimizer_list += [optimizer_dis] + scheduler_cfg += [dis_scheduler] + + if self.l_w_facial > 0: + optimizer_facial = torch.optim.AdamW( + self.facial_discriminator.parameters(), + lr=self.config.optimizer.discriminator_lr, + weight_decay=self.config.optimizer.weight_decay, + betas=(self.config.optimizer.adam_beta1, self.config.optimizer.adam_beta2), + eps=self.config.optimizer.adam_epsilon, + ) + + # Scheduler with warm-up and cosine annealing + if self.config.get('scheduler', None) is not None: + # Cosine scheduler + facial_cos_scheduler = CosineAnnealingLR(optimizer_facial, T_max=total_steps, eta_min=self.config.scheduler.min_lr) + facial_scheduler = { + 'scheduler': facial_cos_scheduler, + 'interval': 'step', + 'frequency': 1, + } + + optimizer_list += [optimizer_facial] + scheduler_cfg += [facial_scheduler] + + return optimizer_list, scheduler_cfg + + + \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/head_animator_fine.py b/tools/visualization_0416/utils/model_0506/model/head_animation/head_animator_fine.py new file mode 100644 index 0000000000000000000000000000000000000000..0b322e8970d4980b06db95e264f48110f60015d2 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/head_animator_fine.py @@ -0,0 +1,478 @@ +import numpy as np +import torch +from torch import nn +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim.lr_scheduler import LambdaLR +import sys +from pathlib import Path +from einops import rearrange +import torch.nn.functional as F +import math +import cv2 + +sys.path.append(str(Path(__file__).parent.parent.parent)) +from model.lightning.base_modules import BaseModule +from utils import instantiate +from losses.face_parsing_loss.face_parser import FaceParser + +class HeadAnimatorModule(BaseModule): + def __init__(self, config): + super().__init__(config) + self.automatic_optimization = False + self.config = config + self.using_hybrid_mask = config.model.get("using_hybrid_mask", True) + print(f'Using Hybird Mask: {self.using_hybrid_mask}') + if not self.using_hybrid_mask: + self.face_encoder = nn.Identity() + + self.criterion_recon = nn.L1Loss() + self.criterion_masked_face_l1 = nn.L1Loss(reduction='none') + + self.l_w_recon = config.loss.l_w_recon + self.l_w_vgg = config.loss.l_w_vgg + self.l_w_face = config.loss.get("l_w_face", 0) + self.l_w_gan = config.loss.get("l_w_gan", 0) + self.l_w_face_l1 = config.loss.get("l_w_face_l1", 0) + self.l_w_gaze = config.loss.get("l_w_gaze", 0) + + # support GAN training & normal training + self.automatic_optimization = False + + if 'VASA' in self.config.model.motion_encoder.module_name: + self.model_name = 'VASA' + + if 'LIA' in self.config.model.motion_encoder.module_name: + self.model_name = 'LIA' + + print(f'Using {self.model_name} for Head Animation') + self.init_model() + + def init_model(self): + config = self.config + self.motion_encoder = instantiate(config.model.motion_encoder) + self.flow_estimator = instantiate(config.model.flow_estimator) + self.face_generator = instantiate(config.model.face_generator) + self.face_encoder = instantiate(config.model.face_encoder) + if not self.config.model.get("using_hybrid_mask", True): + self.face_encoder = nn.Identity() + + if config.loss.get("l_mask_parse", False): + self.face_parse_model = FaceParser() + self.face_parse_model.eval() + for name, param in self.face_parse_model.named_parameters(): + param.requires_grad = False + if config.loss.l_w_vgg > 0 or config.loss.l_w_face > 0: + # self.criterion_vgg = VGGLoss() + self.criterion_vgg = instantiate(config.model.vgg_loss) + for name, param in self.criterion_vgg.named_parameters(): + param.requires_grad = False + self.criterion_vgg.eval() + + if config.loss.l_w_gan > 0: + self.discriminator = instantiate(config.model.discriminator) + + if self.config.model.get('pretrained_ckpt', None) is not None: + checkpoint = torch.load(self.config.model.pretrained_ckpt)["state_dict"] + ckpt = {} + for k, v in checkpoint.items(): + if 'motion_encoder' in k: + ckpt[k.replace('motion_encoder.', '')] = v + self.motion_encoder.load_state_dict(ckpt, strict=True) + + ckpt = {} + for k, v in checkpoint.items(): + if 'flow_estimator' in k: + ckpt[k.replace('flow_estimator.', '')] = v + self.flow_estimator.load_state_dict(ckpt, strict=True) + + ckpt = {} + for k, v in checkpoint.items(): + if 'face_generator' in k: + ckpt[k.replace('face_generator.', '')] = v + self.face_generator.load_state_dict(ckpt, strict=True) + if self.face_generator.freeze: + for param in self.face_generator.parameters(): + param.requires_grad = False + + + ckpt = {} + for k, v in checkpoint.items(): + if 'face_encoder' in k: + ckpt[k.replace('face_encoder.', '')] = v + self.face_encoder.load_state_dict(ckpt, strict=True) + + if self.l_w_gaze > 0: + self.criterion_gaze = instantiate(config.model.gaze_loss) + for name, param in self.criterion_gaze.named_parameters(): + param.requires_grad = False + self.criterion_gaze.eval() + + def configure_model(self): + pass + + def motion_encode(self, source_img): + latent_code, pyramid_feat = self.motion_encoder(source_img) + return latent_code, pyramid_feat + + def forward(self, source_img, target_img, masked_source_img, masked_target_img): + if self.using_hybrid_mask: + tgt_latent, _ = self.motion_encoder(masked_target_img) # project target image to reference latent space + src_latent, _ = self.motion_encoder(masked_source_img) # project source image to reference latent space + + tgt_latent = self.flow_estimator(src_latent, tgt_latent) # navigate source to target in reference latent space + + face_feat = self.face_encoder(source_img) + recon_img = self.face_generator(tgt_latent, face_feat) + else: + tgt_latent, _ = self.motion_encoder(target_img) # project target image to reference latent space + src_latent, face_feat = self.motion_encoder(source_img) # project source image to reference latent space + + tgt_latent = self.flow_estimator(src_latent, tgt_latent) # navigate source to target in reference latent space + recon_img = self.face_generator(tgt_latent, face_feat) + # import pdb; pdb.set_trace() + out_dict = {} + out_dict['recon_img'] = recon_img + return out_dict + + @torch.no_grad() + def get_face_parse(self, img): + # import pdb; pdb.set_trace() + parsing_map_dict = self.face_parse_model(img) + return parsing_map_dict["mouth"] | parsing_map_dict["eye"] + + def compute_loss(self, img_target, img_target_recon, face_mask=None, face_keypoints=None): + + l1_loss = self.l_w_recon * self.criterion_recon(img_target_recon, img_target) + + # Perceptual Loss + if self.l_w_vgg > 0: + vgg_loss, vgg_loss_dict = self.criterion_vgg(img_target_recon, img_target) + vgg_loss = self.l_w_vgg * vgg_loss.mean() + else: + vgg_loss = torch.zeros(1).to(self.device) + + # Facial Experssion Perceptual Loss + if face_mask is not None and self.l_w_face > 0: + face_loss, face_vgg_loss_dict = self.criterion_vgg(img_target_recon, img_target, face_mask) + face_loss = self.l_w_face * face_loss.mean() + else: + face_loss = torch.zeros(1).to(self.device) + + if face_mask is not None and self.l_w_face_l1 > 0: + face_l1_loss = self.criterion_masked_face_l1(img_target_recon*face_mask, img_target*face_mask) + face_l1_loss = face_l1_loss.view(face_mask.size(0), -1).sum(-1) / (face_mask.view(face_mask.size(0), -1).sum(-1) + 1e-6) + face_l1_loss = self.l_w_face_l1 * face_l1_loss.mean() + else: + face_l1_loss = torch.zeros(1).to(self.device) + + # Gaze Loss + if self.l_w_gaze > 0: + gaze_loss = self.criterion_gaze(img_target_recon, img_target, face_keypoints) + gaze_loss = self.l_w_gaze * gaze_loss.to(self.device) + print(f"[DEBUG] Gaze Loss Value: {gaze_loss.item()}") + print(f"[DEBUG] Gaze Loss requires_grad: {gaze_loss.requires_grad}") + if isinstance(self.criterion_gaze, torch.nn.Module): + if hasattr(self.criterion_gaze, "_last_input_feature"): + grad_tensor = self.criterion_gaze._last_input_feature + if grad_tensor is not None: + print("[DEBUG] _last_input_feature.grad_fn:", grad_tensor.grad_fn) + print("[DEBUG] _last_input_feature.requires_grad:", grad_tensor.requires_grad) + grad_tensor.retain_grad() + else: + gaze_loss = torch.zeros(1).to(self.device) + + loss = vgg_loss + l1_loss + face_loss + face_l1_loss + gaze_loss + loss_dict = { + 'loss': loss, + 'l1_loss': l1_loss, + 'face_l1_loss': face_l1_loss, + 'vgg_loss': vgg_loss, + 'face_loss': face_loss, + 'gaze_loss': gaze_loss + } + return loss_dict + + def g_nonsaturating_loss(self, fake_pred): + return F.softplus(-fake_pred).mean() + + def d_nonsaturating_loss(self, fake_pred, real_pred): + real_loss = F.softplus(-real_pred) + fake_loss = F.softplus(fake_pred) + + return real_loss.mean() + fake_loss.mean() + + def prepare_datapair(self, batch): + pixel_values_vid_key = "eyeball_pixel_values_vid" if self.config.data.get("eyeball_enable", False) else "pixel_values_vid" + pixel_values_past_frames_key = "eyeball_past_frames" if self.config.data.get("eyeball_enable", False) else "pixel_values_past_frames" + pixel_values_ref_img_key = "eyeball_ref_img" if self.config.data.get("eyeball_enable", False) else "pixel_values_ref_img" + # when not zero_to_one, all of bellow is [-1, 1] + masked_target_vid = batch[pixel_values_vid_key] # this is a video batch: [B, T, C, H, W] + masked_past_frames = batch[pixel_values_past_frames_key] + masked_target_vid = torch.cat([masked_past_frames, masked_target_vid], dim=1) + masked_ref_img = batch[pixel_values_ref_img_key] # b c h w + + # when not zero_to_one, all of bellow is [-1, 1] + ref_img_original = batch['ref_img_original'] + target_vid_original = batch['pixel_values_vid_original'] + past_frames = batch['pixel_values_past_frames_original'] + target_vid_original = torch.cat([past_frames, target_vid_original], dim=1) + + + # get face keypoints + face_keypoints = batch['keypoints'] + if face_keypoints is not None: + face_keypoints = face_keypoints.to(self.device) + face_keypoints = rearrange(face_keypoints, "b t v c -> (b t) v c") + + + + # construct ref-tgt pairs + masked_ref_img = masked_ref_img[:,None].repeat(1, masked_target_vid.size(1), 1, 1, 1) + masked_ref_img = rearrange(masked_ref_img, "b t c h w -> (b t) c h w") + masked_target_vid = rearrange(masked_target_vid, "b t c h w -> (b t) c h w") + + ref_img_original = ref_img_original[:,None].repeat(1, target_vid_original.size(1), 1, 1, 1) + ref_img_original = rearrange(ref_img_original, "b t c h w -> (b t) c h w") + target_vid_original = rearrange(target_vid_original, "b t c h w -> (b t) c h w") + + ref_img_original = ref_img_original.to(self.device) + target_vid_original = target_vid_original.to(self.device) + masked_ref_img = masked_ref_img.to(self.device) + masked_target_vid = masked_target_vid.to(self.device) + + return ref_img_original, target_vid_original, masked_ref_img, masked_target_vid, face_keypoints + def set_module_eval_train_state(self, state_is_g=True): + if self.l_w_gan > 0: + self.discriminator.train() + self.motion_encoder.train() + self.flow_estimator.train() + self.face_generator.train() + self.face_encoder.train() + + def _step(self, batch, batch_idx): + # get source-target image pair + ref_img_original, target_vid_original, masked_ref_img, masked_target_vid, face_keypoints = self.prepare_datapair(batch) + + # 添加关键点数据的调试信息 + # if face_keypoints is not None: + # print(f"[DEBUG] Face Keypoints Shape: {face_keypoints.shape}") + # print(f"[DEBUG] Face Keypoints Range: min={face_keypoints.min().item():.2f}, max={face_keypoints.max().item():.2f}") + # print(f"[DEBUG]ref_img_original: {ref_img_original.shape}") + # # print(f"[DEBUG] Face Keypoints: {face_keypoints}") + + # normalize face keypoints + if face_keypoints is not None: + face_keypoints = face_keypoints / 512.0 * 2 - 1 + + accumulate_grad_batches = self.config.data.get("accumulate_grad_batches", 1) + is_grad_step = ((batch_idx + 1 )% accumulate_grad_batches == 0) + if self.l_w_gan > 0: # using GAN training + optimizer_g, optimizer_d = self.optimizers() + lr_scheduler, dis_lr_scheduler = self.lr_schedulers() + ## train generator + # toggle is same to set grad is true + self.set_module_eval_train_state(True) + self.toggle_optimizer(optimizer_g) + # get reconstructed image + predicted_img = self.forward(ref_img_original, target_vid_original, masked_ref_img, masked_target_vid)['recon_img'] + + if self.l_w_face > 0 or self.l_w_face_l1 > 0: + eye_mouth_mask_vid_key = "eye_mouth_mask_vid" if self.config.data.get("l_mask_scale", True) else "eye_mouth_mask_no_scale_vid" + eye_mouth_mask_past_frames_key = "eye_mouth_mask_past_frames" if self.config.data.get("l_mask_scale", True) else "eye_mouth_mask_no_scale_past_frames" + eye_mouth_mask_vid = batch[eye_mouth_mask_vid_key] + eye_mouth_mask_past_frames = batch[eye_mouth_mask_past_frames_key] + face_mask = torch.cat([eye_mouth_mask_vid, eye_mouth_mask_past_frames], dim=1) + face_mask = rearrange(face_mask, "b t c h w -> (b t) c h w") + if self.config.loss.get("l_mask_parse", False): + face_mask = self.get_face_parse(target_vid_original) + loss_dict = self.compute_loss(target_vid_original, predicted_img, face_mask, face_keypoints) + else: + loss_dict = self.compute_loss(target_vid_original, predicted_img, face_keypoints=face_keypoints) + + # adversarial loss + pred_label = self.discriminator(predicted_img).reshape(-1) + g_loss = self.l_w_gan * self.g_nonsaturating_loss(pred_label) + + loss_dict['loss'] += g_loss + loss_dict['g_loss'] = g_loss + + if is_grad_step: + optimizer_g.zero_grad() + print("[DEBUG] Before backward - Total Loss:", loss_dict['loss'].item()) + print("[DEBUG] Before backward - Gaze Loss:", loss_dict['gaze_loss'].item()) + self.manual_backward(loss_dict['loss']) + + if is_grad_step: + optimizer_g.step() + lr_scheduler.step() + self.untoggle_optimizer(optimizer_g) + + ## train discriminator + self.set_module_eval_train_state(False) + self.toggle_optimizer(optimizer_d) + + real_img_pred = self.discriminator(target_vid_original) + predicted_img = self.forward(ref_img_original, target_vid_original, masked_ref_img, masked_target_vid)['recon_img'] + recon_img_pred = self.discriminator(predicted_img.detach()) + + d_loss = self.d_nonsaturating_loss(recon_img_pred, real_img_pred) + + if is_grad_step: + optimizer_d.zero_grad() + self.manual_backward(d_loss) + if is_grad_step: + optimizer_d.step() + dis_lr_scheduler.step() + self.untoggle_optimizer(optimizer_d) + + self.log("d_loss", d_loss, prog_bar=True) + + lr_g = optimizer_g.param_groups[0]['lr'] + lr_d = optimizer_d.param_groups[0]['lr'] + self.log("learning_rate_g", lr_g) + self.log("learning_rate_d", lr_d) + + else: + optimizer_g = self.optimizers() + lr_scheduler = self.lr_schedulers() + self.set_module_eval_train_state(True) + self.toggle_optimizer(optimizer_g) + # get reconstructed image + predicted_img = self.forward(ref_img_original, target_vid_original, masked_ref_img, masked_target_vid)['recon_img'] + + if self.l_w_face > 0 or self.l_w_face_l1 > 0: + eye_mouth_mask_vid = batch['eye_mouth_mask_vid'] + eye_mouth_mask_past_frames = batch['eye_mouth_mask_past_frames'] + face_mask = torch.cat([eye_mouth_mask_vid, eye_mouth_mask_past_frames], dim=1) + face_mask = rearrange(face_mask, "b t c h w -> (b t) c h w") + + loss_dict = self.compute_loss(target_vid_original, predicted_img, face_mask, face_keypoints=face_keypoints) + + else: + loss_dict = self.compute_loss(target_vid_original, predicted_img, face_keypoints=face_keypoints) + + if is_grad_step: + optimizer_g.zero_grad() + self.manual_backward(loss_dict['loss']) + if is_grad_step: + optimizer_g.step() + lr_scheduler.step() + self.untoggle_optimizer(optimizer_g) + + lr_g = optimizer_g.param_groups[0]['lr'] + self.log("learning_rate_g", lr_g) + + for k, v in loss_dict.items(): + if k in ['loss', 'l1_loss', 'gaze_loss', 'face_l1_loss', 'vgg_loss', 'g_loss']: + self.log(k, v, prog_bar=True) + else: + self.log(k, v) + + + if False: + checkpoint = torch.load(self.config.model.pretrained_ckpt)["state_dict"] + (self.motion_encoder.convs[0][0].weight - checkpoint['motion_encoder.convs.0.0.weight']).sum() + + # check vgg16 weight + from torchvision import models + vgg_model = models.vgg19(pretrained=True).cuda() + vgg_params = [] + for p in vgg_model.parameters(): + vgg_params.append(p) + + (self.criterion_vgg.vgg.slice1[0].weight - vgg_params[0]).mean() + (self.criterion_vgg.vgg.slice2[0].weight - vgg_params[2]).mean() + import pdb; pdb.set_trace() + + return loss_dict + + def training_step(self, batch, batch_idx): + loss_dict = self._step(batch, batch_idx) + + def validation_step(self, batch, batch_idx): + return + # if self.trainer.global_step > 5: + # # get source-target image pair + # ref_img_original, target_vid_original, masked_ref_img, masked_target_vid = self.prepare_datapair(batch) + + # # get reconstructed image + # with torch.no_grad(): + # predicted_img = self.forward(ref_img_original, target_vid_original, masked_ref_img, masked_target_vid) + # loss_dict = self.compute_loss(target_vid_original, predicted_img, tgt_rigid_pose=tgt_rigid_pose) + + # self.log('val_recon_loss', loss_dict['l1_loss'], prog_bar=True) + + # return loss_dict['l1_loss'] + + def configure_optimizers(self): + params_to_update = list(self.motion_encoder.parameters()) + list(self.flow_estimator.parameters()) + \ + list(self.face_encoder.parameters()) + list(self.face_generator.parameters()) + params_to_update = [p for p in params_to_update if p.requires_grad] + params_name_to_update = [name for name, p in self.named_parameters() if p.requires_grad] + + g_reg_every, d_reg_every = 4, 16 + g_reg_ratio = g_reg_every / (g_reg_every + 1) + d_reg_ratio = d_reg_every / (d_reg_every + 1) + optimizer = torch.optim.AdamW( + params_to_update, + lr=self.config.optimizer.lr, + weight_decay=self.config.optimizer.weight_decay, + betas=(self.config.optimizer.adam_beta1, self.config.optimizer.adam_beta2), + eps=self.config.optimizer.adam_epsilon, + ) + if (self.config.get("lr_scheduler", None) is not None) and (self.config.lr_scheduler.type == "cos_anneal"): + lr_scheduler = CosineAnnealingLR(optimizer, + T_max=self.config.lr_scheduler.T_max, + eta_min=self.config.lr_scheduler.eta_min) + else: + lr_scheduler = LambdaLR(optimizer, lr_lambda=lambda step: 1.0) + + if self.l_w_gan > 0: + optimizer_dis = torch.optim.AdamW( + self.discriminator.parameters(), + lr=self.config.optimizer.discriminator_lr * d_reg_ratio, + weight_decay=self.config.optimizer.weight_decay, + betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio), + eps=self.config.optimizer.adam_epsilon, + ) + if (self.config.get("dis_lr_scheduler", None) is not None) and (self.config.dis_lr_scheduler.type == "cos_anneal"): + dis_lr_scheduler = CosineAnnealingLR(optimizer, + T_max=self.config.dis_lr_scheduler.T_max, + eta_min=self.config.dis_lr_scheduler.eta_min) + else: + dis_lr_scheduler = LambdaLR(optimizer_dis, lr_lambda=lambda step: 1.0) + + return [optimizer, optimizer_dis], [lr_scheduler, dis_lr_scheduler] + else: + # import pdb; pdb.set_trace() + return [optimizer], [lr_scheduler] + + +if __name__ == "__main__": + from model.head_animation.LIA.motion_encoder import MotionEncoder + from model.head_animation.LIA.flow_estimator import FlowEstimator + from model.head_animation.LIA.face_encoder import FaceEncoder + from model.head_animation.LIA.face_generator import FaceGenerator + from torchsummaryX import summary + + IMAGE_SIZE = 512 + latent_dim = 512 + + encoder = MotionEncoder(latent_dim=latent_dim, size=IMAGE_SIZE) + # summary(encoder, torch.zeros(1, 3, IMAGE_SIZE, IMAGE_SIZE)) + + motion_space=20 + flow_estimator = FlowEstimator(latent_dim=latent_dim, motion_space=motion_space) + # summary(flow_estimator, torch.zeros(1, latent_dim), torch.zeros(1, latent_dim)) + tgt_latent = flow_estimator(torch.zeros(1, latent_dim), torch.zeros(1, latent_dim)) + + face_encoder = FaceEncoder(output_channels=latent_dim) + # summary(face_encoder, torch.zeros(1, 3, IMAGE_SIZE, IMAGE_SIZE)) + feat = face_encoder(torch.zeros(1, 3, IMAGE_SIZE, IMAGE_SIZE)) + # for fea in feat: print(fea.shape) + + face_generator = FaceGenerator(IMAGE_SIZE, latent_dim, channel_multiplier=1) + face_generator(tgt_latent, feat) + + + diff --git a/tools/visualization_0416/utils/model_0506/model/head_animation/head_animator_visual.py b/tools/visualization_0416/utils/model_0506/model/head_animation/head_animator_visual.py new file mode 100644 index 0000000000000000000000000000000000000000..15a34aa8f15e0768a38fdb058ca56f60ac1ba910 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/head_animation/head_animator_visual.py @@ -0,0 +1,490 @@ +import numpy as np +import torch +from torch import nn +import sys +from pathlib import Path +from einops import rearrange +import torch.nn.functional as F +import math + +sys.path.append(str(Path(__file__).parent.parent.parent)) +from model.lightning.base_modules import BaseModule +from utils import instantiate +from model.head_animation.VASA3.building_blocks import * +from model.head_animation.VASA3.nonrigid_pose_encoder import AdaptiveGroupNorm +from model.modnets.modnet import MODNet + +class HeadAnimatorModule(BaseModule): + def __init__(self, config): + super().__init__(config) + self.config = config + self.using_hybrid_mask = config.model.get("using_hybrid_mask", True) + print(f'Using Hybird Mask: {self.using_hybrid_mask}') + if not self.using_hybrid_mask: + self.face_encoder = nn.Identity() + + self.criterion_recon = nn.L1Loss() + self.criterion_masked_face_l1 = nn.L1Loss(reduction='none') + + self.l_w_recon = config.loss.l_w_recon + self.l_w_vgg = config.loss.l_w_vgg + self.l_w_face = config.loss.get("l_w_face", 0) + self.l_w_gan = config.loss.get("l_w_gan", 0) + self.l_w_face_l1 = config.loss.get("l_w_face_l1", 0) + + # support GAN training & normal training + self.automatic_optimization = False + + if 'VASA' in self.config.model.motion_encoder.module_name: + self.model_name = 'VASA' + + if 'LIA' in self.config.model.motion_encoder.module_name: + self.model_name = 'LIA' + + print(f'Using {self.model_name} for Head Animation') + + def configure_model(self): + config = self.config + self.motion_encoder = instantiate(config.model.motion_encoder) + self.flow_estimator = instantiate(config.model.flow_estimator) + self.face_generator = instantiate(config.model.face_generator) + self.face_encoder = instantiate(config.model.face_encoder) + + self.use_modnets = False + if config.model.get("modnets", None) is not None: + self.use_modnets = True + self.modnet = MODNet(backbone_pretrained=False) + self.modnet.eval() + modnet_state_dict = torch.load(config.model.modnets.pretrained_weights) + modnet_state_ckpt = {} + for k, v in modnet_state_dict.items(): + modnet_state_ckpt[k.replace("module.", "")] = v + self.modnet.load_state_dict(modnet_state_ckpt) + for name, param in self.modnet.named_parameters(): + param.requires_grad = False + + if config.loss.l_w_vgg > 0 or config.loss.l_w_face > 0: + # self.criterion_vgg = VGGLoss() + self.criterion_vgg = instantiate(config.model.vgg_loss) + for name, param in self.criterion_vgg.named_parameters(): + param.requires_grad = False + self.criterion_vgg.eval() + + if config.loss.l_w_gan > 0: + self.discriminator = instantiate(config.model.discriminator) + + if self.config.model.get('pretrained_ckpt', None) is not None: + checkpoint = torch.load(self.config.model.pretrained_ckpt)["state_dict"] + ckpt = {} + for k, v in checkpoint.items(): + if 'motion_encoder' in k: + ckpt[k.replace('motion_encoder.', '')] = v + self.motion_encoder.load_state_dict(ckpt, strict=True) + + ckpt = {} + for k, v in checkpoint.items(): + if 'flow_estimator' in k: + ckpt[k.replace('flow_estimator.', '')] = v + self.flow_estimator.load_state_dict(ckpt, strict=True) + + ckpt = {} + for k, v in checkpoint.items(): + if 'face_generator' in k: + ckpt[k.replace('face_generator.', '')] = v + self.face_generator.load_state_dict(ckpt, strict=True) + if self.face_generator.freeze: + for param in self.face_generator.parameters(): + param.requires_grad = False + + + ckpt = {} + for k, v in checkpoint.items(): + if 'face_encoder' in k: + ckpt[k.replace('face_encoder.', '')] = v + self.face_encoder.load_state_dict(ckpt, strict=True) + if self.face_encoder.freeze: + for name, param in self.face_encoder.named_parameters(): + # Note: we only pretrain the VolumetricFieldEncoder + if 'global_descriptor_encoder' not in name: + param.requires_grad = False + + def setup(self, stage=None): + if stage == "fit" or stage is None: + print('Model is initializing weights...') + self.initialize_weights() + + def initialize_weights(self): + for m in self.modules(): + if isinstance(m, (AdaptiveGroupNorm)): + pass + + # -------------------- Initialize convolutional layers (WSConv2d/WSConv3d) -------------------- + if isinstance(m, (WSConv2d, WSConv3d)): + # Check if the weights need to be updated + if m.weight.requires_grad: + # Adapt the initialization of WSConv to the scaling + nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') + m.weight.data.mul_(math.sqrt(2)) # Compensate for the scaling of the weights + # Initialize the bias (if it exists and needs to be updated) + if hasattr(m, 'bias') and m.bias is not None and m.bias.requires_grad: + nn.init.constant_(m.bias, 0) + + # -------------------- Initialize ordinary convolutional layers (non-WSConv) -------------------- + elif isinstance(m, (nn.Conv2d, nn.Conv3d)): + if m.weight.requires_grad: + nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') + if m.bias is not None and m.bias.requires_grad: + nn.init.constant_(m.bias, 0) + + # -------------------- Handle the last layer of the residual block's GroupNorm -------------------- + # Rule: Initialize the weight of the last GroupNorm in the main path to 0 (if the parameter needs to be updated) + elif isinstance(m, (ResBlock2d, ResBlock3d, ResBasic, ResBottleneck)): + # Find the last GroupNorm in the main path + last_group_norm = None + for layer in reversed(m.layers): + if isinstance(layer, nn.GroupNorm): + last_group_norm = layer + break + # Initialize the weight of the last GroupNorm to 0 (only if the parameter needs to be updated) + if last_group_norm is not None and last_group_norm.weight.requires_grad: + nn.init.constant_(last_group_norm.weight, 0) + + # Initialize the convolutional layer of the skip connection (if it exists and needs to be updated) + if not isinstance(m.skip_layer, (nn.Identity, type(None))): + if isinstance(m.skip_layer, (nn.Conv2d, nn.Conv3d)): + if m.skip_layer.weight.requires_grad: + nn.init.kaiming_normal_(m.skip_layer.weight, mode='fan_in', nonlinearity='relu') + elif isinstance(m.skip_layer, nn.Sequential): + # Handle the convolutional layer in the skip connection (e.g. in ResBasic) + for subm in m.skip_layer: + if isinstance(subm, (nn.Conv2d, nn.Conv3d)): + if subm.weight.requires_grad: + nn.init.kaiming_normal_(subm.weight, mode='fan_in', nonlinearity='relu') + + # -------------------- Initialize GroupNorm layers -------------------- + elif isinstance(m, nn.GroupNorm): + # Only initialize the parameters that need to be updated + if m.weight is not None and m.weight.requires_grad: + nn.init.constant_(m.weight, 1.0) + if m.bias is not None and m.bias.requires_grad: + nn.init.constant_(m.bias, 0) + + # -------------------- Initialize linear layers -------------------- + elif isinstance(m, nn.Linear): + if m.weight.requires_grad: + nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') + if m.bias is not None and m.bias.requires_grad: + nn.init.constant_(m.bias, 0) + + # -------------------- Handle Spectral Normalization parameters -------------------- + # If using spectral normalization, initialize the original weights instead of the parameterized weights + if hasattr(m, 'parametrizations') and 'weight' in m.parametrizations: + parametrization = m.parametrizations.weight[0] + if hasattr(parametrization, 'original'): + original_weight = parametrization.original + if original_weight.requires_grad: + # Initialize the original weights (e.g. using Kaiming) + nn.init.kaiming_normal_(original_weight, mode='fan_in', nonlinearity='relu') + + def motion_encode(self, source_img): + latent_code, pyramid_feat = self.motion_encoder(source_img) + return latent_code, pyramid_feat + + def forward(self, source_img, target_img, masked_source_img, masked_target_img, batch_idx=None): + if self.using_hybrid_mask: + tgt_latent, _ = self.motion_encoder(masked_target_img) # project target image to reference latent space + src_latent, _ = self.motion_encoder(masked_source_img) # project source image to reference latent space + + tgt_latent = self.flow_estimator(src_latent, tgt_latent) # navigate source to target in reference latent space + + face_feat = self.face_encoder(source_img) + recon_img = self.face_generator(tgt_latent, face_feat) + else: + tgt_latent, _ = self.motion_encoder(target_img) # project target image to reference latent space + src_latent, face_feat = self.motion_encoder(source_img) # project source image to reference latent space + + tgt_latent = self.flow_estimator(src_latent, tgt_latent) # navigate source to target in reference latent space + recon_img = self.face_generator(tgt_latent, face_feat) + # import pdb; pdb.set_trace() + + return recon_img, None, None + + def compute_base_loss(self, img_target, img_target_recon, face_mask=None): + + l1_loss = self.l_w_recon * self.criterion_recon(img_target_recon, img_target) + + # Perceptual Loss + if self.l_w_vgg > 0: + # img_target_recon = F.interpolate(img_target_recon, size=(256, 256), mode='bilinear', align_corners=False) + # img_target = F.interpolate(img_target, size=(256, 256), mode='bilinear', align_corners=False) + vgg_loss, vgg_loss_dict = self.criterion_vgg(img_target_recon, img_target) + vgg_loss = self.l_w_vgg * vgg_loss.mean() + else: + vgg_loss = torch.zeros(1).to(self.device) + + # Facial Experssion Perceptual Loss + if face_mask is not None and self.l_w_face > 0: + face_loss, face_vgg_loss_dict = self.criterion_vgg(img_target_recon, img_target, face_mask) + face_loss = self.l_w_face * face_loss.mean() + else: + face_loss = torch.zeros(1).to(self.device) + + if face_mask is not None and self.l_w_face_l1 > 0: + face_l1_loss = self.criterion_masked_face_l1(img_target_recon*face_mask, img_target*face_mask) + face_l1_loss = face_l1_loss.view(face_mask.size(0), -1).sum(-1) / face_mask.view(face_mask.size(0), -1).sum(-1) + face_l1_loss = self.l_w_face_l1 * face_l1_loss.mean() + else: + face_l1_loss = torch.zeros(1).to(self.device) + + return vgg_loss, l1_loss, face_loss, face_l1_loss + + def compute_loss(self, img_target, img_target_recon, face_mask=None, tgt_rigid_pose=None, feature_volume=None): + vgg_loss, l1_loss, face_loss, face_l1_loss = self.compute_base_loss(img_target, img_target_recon, face_mask) + + loss = vgg_loss + l1_loss + face_loss + face_l1_loss + loss_dict = {'loss': loss, 'l1_loss': l1_loss, 'face_l1_loss': face_l1_loss, 'vgg_loss': vgg_loss, 'face_loss': face_loss} + return loss_dict + + def g_nonsaturating_loss(self, fake_pred): + return F.softplus(-fake_pred).mean() + + def d_nonsaturating_loss(self, fake_pred, real_pred): + real_loss = F.softplus(-real_pred) + fake_loss = F.softplus(fake_pred) + + return real_loss.mean() + fake_loss.mean() + + def prepare_datapair(self, batch): + # when not zero_to_one, all of bellow is [-1, 1] + masked_target_vid = batch['pixel_values_vid'] # this is a video batch: [B, T, C, H, W] + masked_past_frames = batch['pixel_values_past_frames'] + masked_target_vid = torch.cat([masked_past_frames, masked_target_vid], dim=1) + masked_ref_img = batch['pixel_values_ref_img'] # b c h w + + # when not zero_to_one, all of bellow is [-1, 1] + ref_img_original = batch['ref_img_original'] + target_vid_original = batch['pixel_values_vid_original'] + past_frames = batch['pixel_values_past_frames_original'] + target_vid_original = torch.cat([past_frames, target_vid_original], dim=1) + + # import pdb; pdb.set_trace() + + # construct ref-tgt pairs + masked_ref_img = masked_ref_img[:,None].repeat(1, masked_target_vid.size(1), 1, 1, 1) + masked_ref_img = rearrange(masked_ref_img, "b t c h w -> (b t) c h w") + masked_target_vid = rearrange(masked_target_vid, "b t c h w -> (b t) c h w") + + ref_img_original = ref_img_original[:,None].repeat(1, target_vid_original.size(1), 1, 1, 1) + ref_img_original = rearrange(ref_img_original, "b t c h w -> (b t) c h w") + target_vid_original = rearrange(target_vid_original, "b t c h w -> (b t) c h w") + + ref_img_original = ref_img_original.to(self.device) + target_vid_original = target_vid_original.to(self.device) + masked_ref_img = masked_ref_img.to(self.device) + masked_target_vid = masked_target_vid.to(self.device) + + if self.use_modnets: + with torch.no_grad(): + _, _, target_vid_original_mask = self.modnet((target_vid_original + 1.) / 2., True) + # target_vid_original_mask is (b 1 h w), range is (0, 1). + # Making to Binary mask: You can set value >= 0.5 to be 1 and value <= 0.5 to be 0. + + # visual ----------------------------------- + # import imageio + # visual_list = [] + # for ref_img_original_i, target_vid_original_i, masked_ref_img_i, \ + # masked_target_vid_i, target_vid_original_mask_i in zip(ref_img_original, target_vid_original, \ + # masked_ref_img, masked_target_vid, target_vid_original_mask): + # ref_img_original_i = (((ref_img_original_i.cpu().numpy() + 1.) / 2.) * 255.).transpose(1, 2, 0).astype("uint8") + # target_vid_original_i = (((target_vid_original_i.cpu().numpy() + 1.) / 2.) * 255.).transpose(1, 2, 0).astype("uint8") + # masked_ref_img_i = (((masked_ref_img_i.cpu().numpy() + 1.) / 2.) * 255.).transpose(1, 2, 0).astype("uint8") + # masked_target_vid_i = (((masked_target_vid_i.cpu().numpy() + 1.) / 2.) * 255.).transpose(1, 2, 0).astype("uint8") + # target_vid_original_mask_i[target_vid_original_mask_i < 0.5] = 0 + # target_vid_original_mask_i = (target_vid_original_mask_i.repeat(3, 1, 1).cpu().numpy().transpose(1, 2, 0) * 255.).astype("uint8") + # visuals = np.concatenate([ref_img_original_i, target_vid_original_i, masked_ref_img_i, masked_target_vid_i, target_vid_original_mask_i], axis=1) + # visual_list.append(visuals) + # import os + # os.makedirs("visual_train_data", exist_ok=True) + # imageio.mimwrite(f"./visual_train_data/{self.trainer.global_step}_{self.trainer.global_rank}.mp4", visual_list, fps=8) + video_path = batch["video_path"][0] + print(f"{video_path=}") + + import imageio + visual_list = [] + for ref_img_original_i, target_vid_original_i, masked_ref_img_i, \ + masked_target_vid_i in zip(ref_img_original, target_vid_original, \ + masked_ref_img, masked_target_vid): + ref_img_original_i = (((ref_img_original_i.cpu().numpy() + 1.) / 2.) * 255.).transpose(1, 2, 0).astype("uint8") + target_vid_original_i = (((target_vid_original_i.cpu().numpy() + 1.) / 2.) * 255.).transpose(1, 2, 0).astype("uint8") + masked_ref_img_i = (((masked_ref_img_i.cpu().numpy() + 1.) / 2.) * 255.).transpose(1, 2, 0).astype("uint8") + masked_target_vid_i = (((masked_target_vid_i.cpu().numpy() + 1.) / 2.) * 255.).transpose(1, 2, 0).astype("uint8") + visuals = np.concatenate([ref_img_original_i, target_vid_original_i, masked_ref_img_i, masked_target_vid_i], axis=1) + visual_list.append(visuals) + import os + video_base = os.path.basename(video_path) + os.makedirs(self.config.data.visual_dir, exist_ok=True) + imageio.mimwrite(f"./{self.config.data.visual_dir}/{self.trainer.global_step}_{self.trainer.global_rank}_{video_base}", visual_list, fps=self.config.data.train_fps) + # visual ----------------------------------- + + + return ref_img_original, target_vid_original, masked_ref_img, masked_target_vid + + def _step(self, batch, batch_idx): + # get source-target image pair + ref_img_original, target_vid_original, masked_ref_img, masked_target_vid = self.prepare_datapair(batch) + # get reconstructed image + predicted_img, tgt_rigid_pose, feature_volume = self.forward(ref_img_original, target_vid_original, masked_ref_img, masked_target_vid, batch_idx) + + if self.l_w_face > 0 or self.l_w_face_l1 > 0: + eye_mouth_mask_vid = batch['eye_mouth_mask_vid'] + eye_mouth_mask_past_frames = batch['eye_mouth_mask_past_frames'] + face_mask = torch.cat([eye_mouth_mask_vid, eye_mouth_mask_past_frames], dim=1) + face_mask = rearrange(face_mask, "b t c h w -> (b t) c h w") + + loss_dict = self.compute_loss(target_vid_original, predicted_img, face_mask, tgt_rigid_pose, feature_volume) + + else: + loss_dict = self.compute_loss(target_vid_original, predicted_img, tgt_rigid_pose=tgt_rigid_pose, feature_volume=feature_volume) + + if self.l_w_gan > 0: + optimizer_g, optimizer_d = self.optimizers() + + ## train generator + # self.toggle_optimizer(optimizer_g) + + # adversarial loss + pred_label = self.discriminator(predicted_img).reshape(-1) + g_loss = self.l_w_gan * self.g_nonsaturating_loss(pred_label) + + loss_dict['loss'] += g_loss + loss_dict['g_loss'] = g_loss + + optimizer_g.zero_grad() + self.manual_backward(loss_dict['loss']) + optimizer_g.step() + # self.untoggle_optimizer(optimizer_g) + + # import pdb; pdb.set_trace() + + ## train discriminator + # self.toggle_optimizer(optimizer_d) + + real_img_pred = self.discriminator(target_vid_original) + recon_img_pred = self.discriminator(predicted_img.detach()) + + d_loss = self.d_nonsaturating_loss(recon_img_pred, real_img_pred) + + optimizer_d.zero_grad() + self.manual_backward(d_loss) + optimizer_d.step() + # self.untoggle_optimizer(optimizer_d) + + self.log("d_loss", d_loss, prog_bar=True) + + else: + optimizer_g = self.optimizers() + + optimizer_g.zero_grad() + self.manual_backward(loss_dict['loss']) + optimizer_g.step() + # pass + + for k, v in loss_dict.items(): + self.log(k, v, prog_bar=True) + + + if False: + checkpoint = torch.load(self.config.model.pretrained_ckpt)["state_dict"] + (self.motion_encoder.convs[0][0].weight - checkpoint['motion_encoder.convs.0.0.weight']).sum() + + # check vgg16 weight + from torchvision import models + vgg_model = models.vgg19(pretrained=True).cuda() + vgg_params = [] + for p in vgg_model.parameters(): + vgg_params.append(p) + + (self.criterion_vgg.vgg.slice1[0].weight - vgg_params[0]).mean() + (self.criterion_vgg.vgg.slice2[0].weight - vgg_params[2]).mean() + import pdb; pdb.set_trace() + + return loss_dict + + def training_step(self, batch, batch_idx): + self.motion_encoder.train() + self.flow_estimator.train() + self.face_generator.train() + self.face_encoder.train() + loss_dict = self._step(batch, batch_idx) + return loss_dict['loss'] + + def validation_step(self, batch, batch_idx): + if self.trainer.global_step > 5: + # get source-target image pair + ref_img_original, target_vid_original, masked_ref_img, masked_target_vid = self.prepare_datapair(batch) + + # get reconstructed image + with torch.no_grad(): + predicted_img, tgt_rigid_pose, feature_volume = self.forward(ref_img_original, target_vid_original, masked_ref_img, masked_target_vid, batch_idx) + loss_dict = self.compute_loss(target_vid_original, predicted_img, tgt_rigid_pose=tgt_rigid_pose) + + self.log('val_recon_loss', loss_dict['l1_loss'], prog_bar=True) + + return loss_dict['l1_loss'] + + def configure_optimizers(self): + params_to_update = list(self.motion_encoder.parameters()) + list(self.flow_estimator.parameters()) + \ + list(self.face_encoder.parameters()) + list(self.face_generator.parameters()) + params_to_update = [p for p in params_to_update if p.requires_grad] + params_name_to_update = [name for name, p in self.named_parameters() if p.requires_grad] + + optimizer = torch.optim.AdamW( + params_to_update, + lr=self.config.optimizer.lr, + weight_decay=self.config.optimizer.weight_decay, + betas=(self.config.optimizer.adam_beta1, self.config.optimizer.adam_beta2), + eps=self.config.optimizer.adam_epsilon, + ) + + if self.l_w_gan > 0: + optimizer_dis = torch.optim.AdamW( + self.discriminator.parameters(), + lr=self.config.optimizer.discriminator_lr, + weight_decay=self.config.optimizer.weight_decay, + betas=(self.config.optimizer.adam_beta1, self.config.optimizer.adam_beta2), + eps=self.config.optimizer.adam_epsilon, + ) + return [optimizer, optimizer_dis], [] + else: + # import pdb; pdb.set_trace() + return [optimizer], [] + + +if __name__ == "__main__": + from model.head_animation.LIA.motion_encoder import MotionEncoder + from model.head_animation.LIA.flow_estimator import FlowEstimator + from model.head_animation.LIA.face_encoder import FaceEncoder + from model.head_animation.LIA.face_generator import FaceGenerator + from torchsummaryX import summary + + IMAGE_SIZE = 512 + latent_dim = 512 + + encoder = MotionEncoder(latent_dim=latent_dim, size=IMAGE_SIZE) + # summary(encoder, torch.zeros(1, 3, IMAGE_SIZE, IMAGE_SIZE)) + + motion_space=20 + flow_estimator = FlowEstimator(latent_dim=latent_dim, motion_space=motion_space) + # summary(flow_estimator, torch.zeros(1, latent_dim), torch.zeros(1, latent_dim)) + tgt_latent = flow_estimator(torch.zeros(1, latent_dim), torch.zeros(1, latent_dim)) + + face_encoder = FaceEncoder(output_channels=latent_dim) + # summary(face_encoder, torch.zeros(1, 3, IMAGE_SIZE, IMAGE_SIZE)) + feat = face_encoder(torch.zeros(1, 3, IMAGE_SIZE, IMAGE_SIZE)) + # for fea in feat: print(fea.shape) + + face_generator = FaceGenerator(IMAGE_SIZE, latent_dim, channel_multiplier=1) + face_generator(tgt_latent, feat) + + + \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/lightning/__init__.py b/tools/visualization_0416/utils/model_0506/model/lightning/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tools/visualization_0416/utils/model_0506/model/lightning/__pycache__/__init__.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/lightning/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..822140f970515ef578a3214a202c374922c45119 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/lightning/__pycache__/__init__.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/lightning/__pycache__/__init__.cpython-311.pyc b/tools/visualization_0416/utils/model_0506/model/lightning/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8d42320ff9b2feb3659af89e16d3ca774ce84e1 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/lightning/__pycache__/__init__.cpython-311.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/lightning/__pycache__/base_modules.cpython-310.pyc b/tools/visualization_0416/utils/model_0506/model/lightning/__pycache__/base_modules.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d393cb433c446ce718401ce8fbdf4a1a1003aee Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/lightning/__pycache__/base_modules.cpython-310.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/lightning/__pycache__/base_modules.cpython-311.pyc b/tools/visualization_0416/utils/model_0506/model/lightning/__pycache__/base_modules.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8dce6be2e2ace67ec69be99fc1305e4718b75aaa Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/lightning/__pycache__/base_modules.cpython-311.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/lightning/__pycache__/base_modules.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/lightning/__pycache__/base_modules.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ccd965190260c4486bb8b84c67711bf6f576b7d Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/lightning/__pycache__/base_modules.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/lightning/base_modules.py b/tools/visualization_0416/utils/model_0506/model/lightning/base_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..1038c0ed31794966e71bf36619107200b124462e --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/lightning/base_modules.py @@ -0,0 +1,21 @@ +from omegaconf import DictConfig +import torch.nn as nn +import torch.nn.functional as F +from lightning import LightningModule +from utils import instantiate + + +class BaseModule(LightningModule): + def __init__(self, config: DictConfig): + super().__init__() + self.config = config + self.model = self.configure_model() + + def forward(self, x): + return self.model(x) + + def training_step(self, batch, batch_idx): + raise NotImplementedError + + def configure_optimizers(self): + raise NotImplementedError \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/lightning/callback.py b/tools/visualization_0416/utils/model_0506/model/lightning/callback.py new file mode 100644 index 0000000000000000000000000000000000000000..877afd9a56102265cba8112451ecb215472b4632 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/lightning/callback.py @@ -0,0 +1,157 @@ +from lightning import Callback +import torch +import matplotlib.pyplot as plt +import os +import numpy as np +import torchvision +from einops import rearrange + +class VisualizationCallback(Callback): + def __init__(self, save_freq=2000, output_dir="visualizations"): + self.save_freq = save_freq + self.output_dir = output_dir + if not os.path.exists(self.output_dir): + os.makedirs(self.output_dir) + + # def on_train_batch_end(self, trainer, model, outputs, batch, batch_idx): + # # Check if the current step is a multiple of save_freq + # if trainer.is_global_zero: + # global_step = trainer.global_step + # if global_step % self.save_freq == 0: + # # Perform your visualization logic here + # # Example: save a plot of the current input or output (you can replace it with any other visualization) + # self.save_visualization(trainer, model, global_step, batch) + def on_train_batch_start(self, trainer, model, batch, batch_idx): + # Check if the current step is a multiple of save_freq + if trainer.is_global_zero: + global_step = trainer.global_step + if global_step % self.save_freq == 0: + # Perform your visualization logic here + # Example: save a plot of the current input or output (you can replace it with any other visualization) + self.save_visualization(trainer, model, global_step, batch) + + def save_visualization(self, trainer, model, global_step, batch): + # Example visualization: Save a plot of a dummy tensor (replace with your actual data or outputs) + fig, ax = plt.subplots() + ax.plot([1, 2, 3], [4, 5, 6]) # Replace with actual data, such as outputs from the model + ax.set_title(f"Visualization at Step {global_step}") + + # Save the plot to a file + plt.savefig(f"{self.output_dir}/visualization_{global_step}.png") + plt.close(fig) + print(f"Saved visualization at step {global_step}") + + +class VisualizationVAECallback(VisualizationCallback): + def __init__(self, save_freq=2000, output_dir="visualizations"): + super().__init__(save_freq, output_dir) + + def save_visualization(self, trainer, model, global_step, batch): + # Example visualization: Save a plot of a dummy tensor (replace with your actual data or outputs) + model.eval() + with torch.no_grad(): + x_pred, x_gt = model(batch) + + x_pred = x_pred.cpu() + x_gt = x_gt.cpu() + + x_pred = torch.clamp(x_pred, min=0.0, max=1.0) + x_gt = torch.clamp(x_gt, min=0.0, max=1.0) + + B = x_gt.shape[0] + rows = int(np.ceil(np.sqrt(B))) + cols = int(np.ceil(B / rows)) + + gt_grid = torchvision.utils.make_grid(x_gt, nrow=rows) + pred_grid = torchvision.utils.make_grid(x_pred, nrow=rows) + + fig, axes = plt.subplots(1, 2, figsize=(12, 6)) + axes[0].imshow(gt_grid.permute(1, 2, 0)) + axes[0].axis('off') + # axes[0].set_title('Ground Truth') + + axes[1].imshow(pred_grid.permute(1, 2, 0)) + axes[1].axis('off') + # axes[1].set_title('Prediction') + + plt.tight_layout() + plt.show() + plt.savefig(f"{self.output_dir}/image_grid_{global_step}.png") + plt.close() + + # import pdb; pdb.set_trace() + + +class Visualization_HeadAnimator_Callback(VisualizationCallback): + def __init__(self, save_freq=2000, output_dir="visualizations"): + super().__init__(save_freq, output_dir) + + def save_visualization(self, trainer, model, global_step, batch): + # Example visualization: Save a plot of a dummy tensor (replace with your actual data or outputs) + + masked_target_vid = batch['pixel_values_vid'] # this is a video batch: [B, T, C, H, W] + masked_ref_img = batch['pixel_values_ref_img'] + + ref_img_original = batch['ref_img_original'] + target_vid_original = batch['pixel_values_vid_original'] + + # construct ref-tgt pairs + masked_ref_img = masked_ref_img[:,None].repeat(1, masked_target_vid.size(1), 1, 1, 1) + masked_ref_img = rearrange(masked_ref_img, "b t c h w -> (b t) c h w") + masked_target_vid = rearrange(masked_target_vid, "b t c h w -> (b t) c h w") + + ref_img_original = ref_img_original[:,None].repeat(1, target_vid_original.size(1), 1, 1, 1) + ref_img_original = rearrange(ref_img_original, "b t c h w -> (b t) c h w") + target_vid_original = rearrange(target_vid_original, "b t c h w -> (b t) c h w") + + with torch.no_grad(): + # get reconstructed image + model_out = model.forward(ref_img_original, target_vid_original, masked_ref_img, masked_target_vid) + x_pred = model_out['recon_img'] + x_gt = target_vid_original + + x_pred = x_pred.cpu() + x_gt = x_gt.cpu() + x_ref = ref_img_original.cpu() + + if x_gt.min() < -0.5: + x_gt = (x_gt + 1) / 2 + x_pred = (x_pred + 1) / 2 + x_ref = (x_ref + 1) / 2 + + x_pred = torch.clamp(x_pred, min=0.0, max=1.0) + x_gt = torch.clamp(x_gt, min=0.0, max=1.0) + x_ref = torch.clamp(x_ref, min=0.0, max=1.0) + + B = x_gt.shape[0] + rows = int(np.ceil(np.sqrt(B))) + cols = int(np.ceil(B / rows)) + + ref_grid = torchvision.utils.make_grid(x_ref, nrow=rows) + gt_grid = torchvision.utils.make_grid(x_gt, nrow=rows) + pred_grid = torchvision.utils.make_grid(x_pred, nrow=rows) + + diff = (x_pred-x_gt).abs() + diff_grid = torchvision.utils.make_grid(diff, nrow=rows) + + fig, axes = plt.subplots(1, 4, figsize=(12, 6)) + axes[0].imshow(ref_grid.permute(1, 2, 0)) + axes[0].axis('off') + + axes[1].imshow(gt_grid.permute(1, 2, 0)) + axes[1].axis('off') + + axes[2].imshow(pred_grid.permute(1, 2, 0)) + axes[2].axis('off') + + axes[3].imshow(diff_grid.permute(1, 2, 0), cmap='jet') + axes[3].axis('off') + + plt.tight_layout() + plt.show() + plt.savefig(f"{self.output_dir}/image_grid_{global_step}.png") + plt.close() + + + + \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/lightning/head_imitation.py b/tools/visualization_0416/utils/model_0506/model/lightning/head_imitation.py new file mode 100644 index 0000000000000000000000000000000000000000..0effcca7e07702f6baa0b46747bdeb94c1cd0edb --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/lightning/head_imitation.py @@ -0,0 +1,148 @@ +import torch +from torch import nn +from model.lightning.base_modules import BaseModule +from torch.utils.data import DataLoader, Dataset +from models.volumetric_avatar.img2vol_enc import LocalEncoder +from models.volumetric_avatar.warp_generator import WarpGenerator +from models.volumetric_avatar.warped_vol_dec import Decoder_stage2 + +class HeadImitationModule(BaseModule): + def __init__(self, encoder, warp_generator, decoder, config): + super().__init__(config) + self.encoder = encoder + self.warp_generator = warp_generator + self.decoder = decoder + self.config = config + self.criterion = nn.MSELoss() # TODO:loss + self.learning_rate = config.get("learning_rate", 1e-4) + + def forward(self, source_img): + latent_volume = self.encoder(source_img) + warped_volume, deltas = self.warp_generator({"orig": latent_volume}) + output_img, _, _, _ = self.decoder({}, {}, warped_volume) + return output_img + + def _step(self, batch): + source_img, target_img = batch + predicted_img = self.forward(source_img) + loss = self.criterion(predicted_img, target_img) + return loss + + def training_step(self, batch, batch_idx): + loss = self._step(batch) + self.log("train_loss", loss, prog_bar=True) + return loss + + def validation_step(self, batch, batch_idx): + loss = self._step(batch) + self.log("val_loss", loss, prog_bar=True) + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) + return optimizer + + +class CustomDataset(Dataset): + def __init__(self, source_images, target_images): + self.source_images = source_images + self.target_images = target_images + + def __len__(self): + return len(self.source_images) + + def __getitem__(self, idx): + return self.source_images[idx], self.target_images[idx] + + +def create_data_loaders(source_images, target_images, batch_size=16): + dataset = CustomDataset(source_images, target_images) + return DataLoader(dataset, batch_size=batch_size, shuffle=True) + + +if __name__ == "__main__": + # TODO:config + config = { + "learning_rate": 1e-4, + "batch_size": 16, + "num_epochs": 10, + } + + encoder = LocalEncoder( + use_amp_autocast=True, + gen_upsampling_type="nearest", + gen_downsampling_type="bilinear", + gen_input_image_size=256, + gen_latent_texture_size=64, + gen_latent_texture_depth=8, + gen_latent_texture_channels=64, + warp_norm_grad=True, + gen_num_channels=32, + enc_channel_mult=1, + norm_layer_type="bn", + num_gpus=1, + gen_max_channels=256, + enc_block_type="res", + gen_activation_type="relu", + in_channels=3, + ) + warp_generator = WarpGenerator(WarpGenerator.Config( + eps=1e-8, + num_gpus=1, + gen_adaptive_conv_type="conv", + gen_activation_type="relu", + gen_upsampling_type="nearest", + gen_downsampling_type="bilinear", + gen_dummy_input_size=64, + gen_latent_texture_depth=8, + gen_latent_texture_size=64, + gen_max_channels=256, + gen_num_channels=32, + gen_use_adaconv=False, + gen_adaptive_kernel=False, + gen_embed_size=32, + warp_output_size=64, + warp_channel_mult=1, + warp_block_type="res", + norm_layer_type="bn", + input_channels=64, + )) + decoder = Decoder_stage2( + eps=1e-8, + image_size=256, + use_amp_autocast=True, + gen_embed_size=32, + gen_adaptive_kernel=False, + gen_adaptive_conv_type="conv", + gen_latent_texture_size=64, + in_channels=64, + gen_num_channels=32, + dec_max_channels=256, + gen_use_adanorm=False, + gen_activation_type="relu", + gen_use_adaconv=False, + dec_channel_mult=1, + dec_num_blocks=4, + dec_up_block_type="res", + dec_pred_seg=False, + dec_seg_channel_mult=1, + dec_pred_conf=False, + dec_conf_ms_names="", + dec_conf_names="", + dec_conf_ms_scales=4, + dec_conf_channel_mult=1, + gen_downsampling_type="bilinear", + num_gpus=1, + norm_layer_type="bn", + ) + + # TODO:data + source_images = torch.randn(100, 3, 256, 256) + target_images = torch.randn(100, 3, 256, 256) + train_loader = create_data_loaders(source_images, target_images, batch_size=config["batch_size"]) + + model = LightningModel(encoder, warp_generator, decoder, config) + + # training + from lightning.pytorch import Trainer + trainer = Trainer(max_epochs=config["num_epochs"], devices="auto", accelerator="gpu") + trainer.fit(model, train_loader) \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/modnets/__init__.py b/tools/visualization_0416/utils/model_0506/model/modnets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4cbeee5ff49bf093a575cb5e5a0100836927de3e --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/modnets/__init__.py @@ -0,0 +1,10 @@ +from .wrapper import * + + +#------------------------------------------------------------------------------ +# Replaceable Backbones +#------------------------------------------------------------------------------ + +SUPPORTED_BACKBONES = { + 'mobilenetv2': MobileNetV2Backbone, +} diff --git a/tools/visualization_0416/utils/model_0506/model/modnets/__pycache__/__init__.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/modnets/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..febe8be5a2f44f1dc7648f4949983875dd4303f8 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/modnets/__pycache__/__init__.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/modnets/__pycache__/mobilenetv2.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/modnets/__pycache__/mobilenetv2.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c025198e01ad09ac143460f78771f907d3151387 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/modnets/__pycache__/mobilenetv2.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/modnets/__pycache__/modnet.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/modnets/__pycache__/modnet.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6f90b52f9709e77b93ad35cf33148d32c1c7077 Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/modnets/__pycache__/modnet.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/modnets/__pycache__/wrapper.cpython-312.pyc b/tools/visualization_0416/utils/model_0506/model/modnets/__pycache__/wrapper.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7451bd14460b61bcd3123613df6a1592b267b7ab Binary files /dev/null and b/tools/visualization_0416/utils/model_0506/model/modnets/__pycache__/wrapper.cpython-312.pyc differ diff --git a/tools/visualization_0416/utils/model_0506/model/modnets/mobilenetv2.py b/tools/visualization_0416/utils/model_0506/model/modnets/mobilenetv2.py new file mode 100644 index 0000000000000000000000000000000000000000..709d352565799f181bfaed652c796ef065e71a0f --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/modnets/mobilenetv2.py @@ -0,0 +1,199 @@ +""" This file is adapted from https://github.com/thuyngch/Human-Segmentation-PyTorch""" + +import math +import json +from functools import reduce + +import torch +from torch import nn + + +#------------------------------------------------------------------------------ +# Useful functions +#------------------------------------------------------------------------------ + +def _make_divisible(v, divisor, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +def conv_bn(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), + nn.ReLU6(inplace=True) + ) + + +def conv_1x1_bn(inp, oup): + return nn.Sequential( + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + nn.ReLU6(inplace=True) + ) + + +#------------------------------------------------------------------------------ +# Class of Inverted Residual block +#------------------------------------------------------------------------------ + +class InvertedResidual(nn.Module): + def __init__(self, inp, oup, stride, expansion, dilation=1): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = round(inp * expansion) + self.use_res_connect = self.stride == 1 and inp == oup + + if expansion == 1: + self.conv = nn.Sequential( + # dw + nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False), + nn.BatchNorm2d(hidden_dim), + nn.ReLU6(inplace=True), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ) + else: + self.conv = nn.Sequential( + # pw + nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), + nn.BatchNorm2d(hidden_dim), + nn.ReLU6(inplace=True), + # dw + nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False), + nn.BatchNorm2d(hidden_dim), + nn.ReLU6(inplace=True), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ) + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +#------------------------------------------------------------------------------ +# Class of MobileNetV2 +#------------------------------------------------------------------------------ + +class MobileNetV2(nn.Module): + def __init__(self, in_channels, alpha=1.0, expansion=6, num_classes=1000): + super(MobileNetV2, self).__init__() + self.in_channels = in_channels + self.num_classes = num_classes + input_channel = 32 + last_channel = 1280 + interverted_residual_setting = [ + # t, c, n, s + [1 , 16, 1, 1], + [expansion, 24, 2, 2], + [expansion, 32, 3, 2], + [expansion, 64, 4, 2], + [expansion, 96, 3, 1], + [expansion, 160, 3, 2], + [expansion, 320, 1, 1], + ] + + # building first layer + input_channel = _make_divisible(input_channel*alpha, 8) + self.last_channel = _make_divisible(last_channel*alpha, 8) if alpha > 1.0 else last_channel + self.features = [conv_bn(self.in_channels, input_channel, 2)] + + # building inverted residual blocks + for t, c, n, s in interverted_residual_setting: + output_channel = _make_divisible(int(c*alpha), 8) + for i in range(n): + if i == 0: + self.features.append(InvertedResidual(input_channel, output_channel, s, expansion=t)) + else: + self.features.append(InvertedResidual(input_channel, output_channel, 1, expansion=t)) + input_channel = output_channel + + # building last several layers + self.features.append(conv_1x1_bn(input_channel, self.last_channel)) + + # make it nn.Sequential + self.features = nn.Sequential(*self.features) + + # building classifier + if self.num_classes is not None: + self.classifier = nn.Sequential( + nn.Dropout(0.2), + nn.Linear(self.last_channel, num_classes), + ) + + # Initialize weights + self._init_weights() + + def forward(self, x): + # Stage1 + x = self.features[0](x) + x = self.features[1](x) + # Stage2 + x = self.features[2](x) + x = self.features[3](x) + # Stage3 + x = self.features[4](x) + x = self.features[5](x) + x = self.features[6](x) + # Stage4 + x = self.features[7](x) + x = self.features[8](x) + x = self.features[9](x) + x = self.features[10](x) + x = self.features[11](x) + x = self.features[12](x) + x = self.features[13](x) + # Stage5 + x = self.features[14](x) + x = self.features[15](x) + x = self.features[16](x) + x = self.features[17](x) + x = self.features[18](x) + + # Classification + if self.num_classes is not None: + x = x.mean(dim=(2,3)) + x = self.classifier(x) + + # Output + return x + + def _load_pretrained_model(self, pretrained_file): + pretrain_dict = torch.load(pretrained_file, map_location='cpu') + model_dict = {} + state_dict = self.state_dict() + print("[MobileNetV2] Loading pretrained model...") + for k, v in pretrain_dict.items(): + if k in state_dict: + model_dict[k] = v + else: + print(k, "is ignored") + state_dict.update(model_dict) + self.load_state_dict(state_dict) + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + n = m.weight.size(1) + m.weight.data.normal_(0, 0.01) + m.bias.data.zero_() diff --git a/tools/visualization_0416/utils/model_0506/model/modnets/modnet.py b/tools/visualization_0416/utils/model_0506/model/modnets/modnet.py new file mode 100644 index 0000000000000000000000000000000000000000..d4bceb550100f4b4b8ecf77cc22a36d5cd7bdf74 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/modnets/modnet.py @@ -0,0 +1,255 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from . import SUPPORTED_BACKBONES + + +#------------------------------------------------------------------------------ +# MODNet Basic Modules +#------------------------------------------------------------------------------ + +class IBNorm(nn.Module): + """ Combine Instance Norm and Batch Norm into One Layer + """ + + def __init__(self, in_channels): + super(IBNorm, self).__init__() + in_channels = in_channels + self.bnorm_channels = int(in_channels / 2) + self.inorm_channels = in_channels - self.bnorm_channels + + self.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True) + self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False) + + def forward(self, x): + bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous()) + in_x = self.inorm(x[:, self.bnorm_channels:, ...].contiguous()) + + return torch.cat((bn_x, in_x), 1) + + +class Conv2dIBNormRelu(nn.Module): + """ Convolution + IBNorm + ReLu + """ + + def __init__(self, in_channels, out_channels, kernel_size, + stride=1, padding=0, dilation=1, groups=1, bias=True, + with_ibn=True, with_relu=True): + super(Conv2dIBNormRelu, self).__init__() + + layers = [ + nn.Conv2d(in_channels, out_channels, kernel_size, + stride=stride, padding=padding, dilation=dilation, + groups=groups, bias=bias) + ] + + if with_ibn: + layers.append(IBNorm(out_channels)) + if with_relu: + layers.append(nn.ReLU(inplace=True)) + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + + +class SEBlock(nn.Module): + """ SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf + """ + + def __init__(self, in_channels, out_channels, reduction=1): + super(SEBlock, self).__init__() + self.pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(in_channels, int(in_channels // reduction), bias=False), + nn.ReLU(inplace=True), + nn.Linear(int(in_channels // reduction), out_channels, bias=False), + nn.Sigmoid() + ) + + def forward(self, x): + b, c, _, _ = x.size() + w = self.pool(x).view(b, c) + w = self.fc(w).view(b, c, 1, 1) + + return x * w.expand_as(x) + + +#------------------------------------------------------------------------------ +# MODNet Branches +#------------------------------------------------------------------------------ + +class LRBranch(nn.Module): + """ Low Resolution Branch of MODNet + """ + + def __init__(self, backbone): + super(LRBranch, self).__init__() + + enc_channels = backbone.enc_channels + + self.backbone = backbone + self.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4) + self.conv_lr16x = Conv2dIBNormRelu(enc_channels[4], enc_channels[3], 5, stride=1, padding=2) + self.conv_lr8x = Conv2dIBNormRelu(enc_channels[3], enc_channels[2], 5, stride=1, padding=2) + self.conv_lr = Conv2dIBNormRelu(enc_channels[2], 1, kernel_size=3, stride=2, padding=1, with_ibn=False, with_relu=False) + + def forward(self, img, inference): + enc_features = self.backbone.forward(img) + enc2x, enc4x, enc32x = enc_features[0], enc_features[1], enc_features[4] + + enc32x = self.se_block(enc32x) + lr16x = F.interpolate(enc32x, scale_factor=2, mode='bilinear', align_corners=False) + lr16x = self.conv_lr16x(lr16x) + lr8x = F.interpolate(lr16x, scale_factor=2, mode='bilinear', align_corners=False) + lr8x = self.conv_lr8x(lr8x) + + pred_semantic = None + if not inference: + lr = self.conv_lr(lr8x) + pred_semantic = torch.sigmoid(lr) + + return pred_semantic, lr8x, [enc2x, enc4x] + + +class HRBranch(nn.Module): + """ High Resolution Branch of MODNet + """ + + def __init__(self, hr_channels, enc_channels): + super(HRBranch, self).__init__() + + self.tohr_enc2x = Conv2dIBNormRelu(enc_channels[0], hr_channels, 1, stride=1, padding=0) + self.conv_enc2x = Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=2, padding=1) + + self.tohr_enc4x = Conv2dIBNormRelu(enc_channels[1], hr_channels, 1, stride=1, padding=0) + self.conv_enc4x = Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1) + + self.conv_hr4x = nn.Sequential( + Conv2dIBNormRelu(3 * hr_channels + 3, 2 * hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1), + ) + + self.conv_hr2x = nn.Sequential( + Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1), + ) + + self.conv_hr = nn.Sequential( + Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu(hr_channels, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False), + ) + + def forward(self, img, enc2x, enc4x, lr8x, inference): + img2x = F.interpolate(img, scale_factor=1/2, mode='bilinear', align_corners=False) + img4x = F.interpolate(img, scale_factor=1/4, mode='bilinear', align_corners=False) + + enc2x = self.tohr_enc2x(enc2x) + hr4x = self.conv_enc2x(torch.cat((img2x, enc2x), dim=1)) + + enc4x = self.tohr_enc4x(enc4x) + hr4x = self.conv_enc4x(torch.cat((hr4x, enc4x), dim=1)) + + lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False) + hr4x = self.conv_hr4x(torch.cat((hr4x, lr4x, img4x), dim=1)) + + hr2x = F.interpolate(hr4x, scale_factor=2, mode='bilinear', align_corners=False) + hr2x = self.conv_hr2x(torch.cat((hr2x, enc2x), dim=1)) + + pred_detail = None + if not inference: + hr = F.interpolate(hr2x, scale_factor=2, mode='bilinear', align_corners=False) + hr = self.conv_hr(torch.cat((hr, img), dim=1)) + pred_detail = torch.sigmoid(hr) + + return pred_detail, hr2x + + +class FusionBranch(nn.Module): + """ Fusion Branch of MODNet + """ + + def __init__(self, hr_channels, enc_channels): + super(FusionBranch, self).__init__() + self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2) + + self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1) + self.conv_f = nn.Sequential( + Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1), + Conv2dIBNormRelu(int(hr_channels / 2), 1, 1, stride=1, padding=0, with_ibn=False, with_relu=False), + ) + + def forward(self, img, lr8x, hr2x): + lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False) + lr4x = self.conv_lr4x(lr4x) + lr2x = F.interpolate(lr4x, scale_factor=2, mode='bilinear', align_corners=False) + + f2x = self.conv_f2x(torch.cat((lr2x, hr2x), dim=1)) + f = F.interpolate(f2x, scale_factor=2, mode='bilinear', align_corners=False) + f = self.conv_f(torch.cat((f, img), dim=1)) + pred_matte = torch.sigmoid(f) + + return pred_matte + + +#------------------------------------------------------------------------------ +# MODNet +#------------------------------------------------------------------------------ + +class MODNet(nn.Module): + """ Architecture of MODNet + """ + + def __init__(self, in_channels=3, hr_channels=32, backbone_arch='mobilenetv2', backbone_pretrained=True): + super(MODNet, self).__init__() + + self.in_channels = in_channels + self.hr_channels = hr_channels + self.backbone_arch = backbone_arch + self.backbone_pretrained = backbone_pretrained + + self.backbone = SUPPORTED_BACKBONES[self.backbone_arch](self.in_channels) + + self.lr_branch = LRBranch(self.backbone) + self.hr_branch = HRBranch(self.hr_channels, self.backbone.enc_channels) + self.f_branch = FusionBranch(self.hr_channels, self.backbone.enc_channels) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + self._init_conv(m) + elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d): + self._init_norm(m) + + if self.backbone_pretrained: + self.backbone.load_pretrained_ckpt() + + def forward(self, img, inference): + pred_semantic, lr8x, [enc2x, enc4x] = self.lr_branch(img, inference) + pred_detail, hr2x = self.hr_branch(img, enc2x, enc4x, lr8x, inference) + pred_matte = self.f_branch(img, lr8x, hr2x) + + return pred_semantic, pred_detail, pred_matte + + def freeze_norm(self): + norm_types = [nn.BatchNorm2d, nn.InstanceNorm2d] + for m in self.modules(): + for n in norm_types: + if isinstance(m, n): + m.eval() + continue + + def _init_conv(self, conv): + nn.init.kaiming_uniform_( + conv.weight, a=0, mode='fan_in', nonlinearity='relu') + if conv.bias is not None: + nn.init.constant_(conv.bias, 0) + + def _init_norm(self, norm): + if norm.weight is not None: + nn.init.constant_(norm.weight, 1) + nn.init.constant_(norm.bias, 0) diff --git a/tools/visualization_0416/utils/model_0506/model/modnets/wrapper.py b/tools/visualization_0416/utils/model_0506/model/modnets/wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..72b8f17b2e497409ce6b9a0561fd04175a3f4630 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/modnets/wrapper.py @@ -0,0 +1,82 @@ +import os +from functools import reduce + +import torch +import torch.nn as nn + +from .mobilenetv2 import MobileNetV2 + + +class BaseBackbone(nn.Module): + """ Superclass of Replaceable Backbone Model for Semantic Estimation + """ + + def __init__(self, in_channels): + super(BaseBackbone, self).__init__() + self.in_channels = in_channels + + self.model = None + self.enc_channels = [] + + def forward(self, x): + raise NotImplementedError + + def load_pretrained_ckpt(self): + raise NotImplementedError + + +class MobileNetV2Backbone(BaseBackbone): + """ MobileNetV2 Backbone + """ + + def __init__(self, in_channels): + super(MobileNetV2Backbone, self).__init__(in_channels) + + self.model = MobileNetV2(self.in_channels, alpha=1.0, expansion=6, num_classes=None) + self.enc_channels = [16, 24, 32, 96, 1280] + + def forward(self, x): + # x = reduce(lambda x, n: self.model.features[n](x), list(range(0, 2)), x) + x = self.model.features[0](x) + x = self.model.features[1](x) + enc2x = x + + # x = reduce(lambda x, n: self.model.features[n](x), list(range(2, 4)), x) + x = self.model.features[2](x) + x = self.model.features[3](x) + enc4x = x + + # x = reduce(lambda x, n: self.model.features[n](x), list(range(4, 7)), x) + x = self.model.features[4](x) + x = self.model.features[5](x) + x = self.model.features[6](x) + enc8x = x + + # x = reduce(lambda x, n: self.model.features[n](x), list(range(7, 14)), x) + x = self.model.features[7](x) + x = self.model.features[8](x) + x = self.model.features[9](x) + x = self.model.features[10](x) + x = self.model.features[11](x) + x = self.model.features[12](x) + x = self.model.features[13](x) + enc16x = x + + # x = reduce(lambda x, n: self.model.features[n](x), list(range(14, 19)), x) + x = self.model.features[14](x) + x = self.model.features[15](x) + x = self.model.features[16](x) + x = self.model.features[17](x) + x = self.model.features[18](x) + enc32x = x + return [enc2x, enc4x, enc8x, enc16x, enc32x] + + def load_pretrained_ckpt(self): + # the pre-trained model is provided by https://github.com/thuyngch/Human-Segmentation-PyTorch + ckpt_path = './pretrained/mobilenetv2_human_seg.ckpt' + if not os.path.exists(ckpt_path): + print('cannot find the pretrained mobilenetv2 backbone') + exit() + + ckpt = torch.load(ckpt_path) + self.model.load_state_dict(ckpt) diff --git a/tools/visualization_0416/utils/model_0506/model/vae/cnn.py b/tools/visualization_0416/utils/model_0506/model/vae/cnn.py new file mode 100644 index 0000000000000000000000000000000000000000..883b5ac782d68baf0c0fd059f9e7aaf1b066efe6 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/vae/cnn.py @@ -0,0 +1,101 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from omegaconf import DictConfig +from typing import Any, Dict, Tuple +from utils import instantiate +import cv2 +from PIL import Image +import numpy as np + +class ResidualBlock(nn.Module): + def __init__(self, dim): + super().__init__() + self.relu = nn.ReLU() + self.conv1 = nn.Conv2d(dim, dim, 3, 1, 1) + self.conv2 = nn.Conv2d(dim, dim, 1) + + def forward(self, x): + tmp = self.relu(x) + tmp = self.conv1(tmp) + tmp = self.relu(tmp) + tmp = self.conv2(tmp) + return x + tmp + + +class Encoder2D(nn.Module): + def __init__(self, output_channels=512): + super(Encoder2D, self).__init__() + + self.block = nn.Sequential( + nn.Conv2d(3, output_channels, 4, 2, 1), # 512x512 -> 256x256 + nn.ReLU(), + nn.Conv2d(output_channels, output_channels, 4, 2, 1), # 256x256 -> 128x128 + nn.ReLU(), + nn.Conv2d(output_channels, output_channels, 4, 2, 1), # 128x128 -> 64x64 + nn.ReLU(), + nn.Conv2d(output_channels, output_channels, 4, 2, 1), # 64x64 -> 32x32 + nn.ReLU(), + nn.Conv2d(output_channels, output_channels, 4, 2, 1), # 32x32 -> 16x16 + nn.ReLU(), + nn.Conv2d(output_channels, output_channels, 3, 1, 1), # Final Convolutional layer before residuals + ResidualBlock(output_channels), # Residual block 1 + ResidualBlock(output_channels), # Residual block 2 + ) + + def forward(self, x): + x = self.block(x) + return x + + +class Decoder2D(nn.Module): + def __init__(self, input_dim=512): + super(Decoder2D, self).__init__() + + self.fea_map_size=16 + + self.block = nn.Sequential( + nn.Conv2d(input_dim, input_dim, 3, 1, 1), # Initial convolution in the decoder + ResidualBlock(input_dim), # Residual block 1 + ResidualBlock(input_dim), # Residual block 2 + nn.ConvTranspose2d(input_dim, input_dim, 4, 2, 1), # 16x16 -> 32x32 + nn.ReLU(), + nn.ConvTranspose2d(input_dim, input_dim, 4, 2, 1), # 32x32 -> 64x64 + nn.ReLU(), + nn.ConvTranspose2d(input_dim, input_dim, 4, 2, 1), # 64x64 -> 128x128 + nn.ReLU(), + nn.ConvTranspose2d(input_dim, input_dim, 4, 2, 1), # 128x128 -> 256x256 + nn.ReLU(), + nn.ConvTranspose2d(input_dim, 3, 4, 2, 1) # 256x256 -> 512x512 + ) + + def forward(self, x): + x_hat = self.block(x) + + return x_hat + + +class Encoder(Encoder2D): + def __init__(self, output_channels=512): + super().__init__(output_channels) + self.pool = nn.AdaptiveAvgPool2d(output_size=(1, 1)) + + def forward(self, x): + x = self.block(x) + x = self.pool(x) + return x + + +class Decoder(Decoder2D): + def __init__(self, input_dim=512): + super().__init__(input_dim) + + self.fc = nn.Linear(input_dim, input_dim*self.fea_map_size*self.fea_map_size) + + def forward(self, x): + x = self.fc(x.view(x.size(0), -1)) + x = x.view(x.size(0), 512, self.fea_map_size, self.fea_map_size) + x_hat = self.block(x) + + return x_hat + diff --git a/tools/visualization_0416/utils/model_0506/model/vae/quantizer.py b/tools/visualization_0416/utils/model_0506/model/vae/quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..82215a4cf920c20175d5905312b21f8caf836d71 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/vae/quantizer.py @@ -0,0 +1,447 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from torch import einsum +from einops import rearrange + + +class VectorQuantizer(nn.Module): + """ + see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py + ____________________________________________ + Discretization bottleneck part of the VQ-VAE. + Inputs: + - n_e : number of embeddings + - e_dim : dimension of embedding + - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 + _____________________________________________ + """ + + # NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for + # a fix and use legacy=False to apply that fix. VectorQuantizer2 can be + # used wherever VectorQuantizer has been used before and is additionally + # more efficient. + def __init__(self, n_e, e_dim, beta): + super(VectorQuantizer, self).__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + def forward(self, z): + """ + Inputs the output of the encoder network z and maps it to a discrete + one-hot vector that is the index of the closest embedding vector e_j + z (continuous) -> z_q (discrete) + z.shape = (batch, channel, height, width) + quantization pipeline: + 1. get encoder input (B,C,H,W) + 2. flatten input to (B*H*W,C) + """ + # reshape z -> (batch, height, width, channel) and flatten + # z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight**2, dim=1) - 2 * \ + torch.matmul(z_flattened, self.embedding.weight.t()) + + ## could possible replace this here + # #\start... + # find closest encodings + min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) + + min_encodings = torch.zeros( + min_encoding_indices.shape[0], self.n_e).to(z) + min_encodings.scatter_(1, min_encoding_indices, 1) + + # dtype min encodings: torch.float32 + # min_encodings shape: torch.Size([2048, 512]) + # min_encoding_indices.shape: torch.Size([2048, 1]) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) + #.........\end + + # with: + # .........\start + #min_encoding_indices = torch.argmin(d, dim=1) + #z_q = self.embedding(min_encoding_indices) + # ......\end......... (TODO) + + # compute loss for embedding + loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ + torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # perplexity + e_mean = torch.mean(min_encodings, dim=0) + perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) + + # reshape back to match original input shape + # z_q = z_q.permute(0, 3, 1, 2).contiguous() + + # import pdb; pdb.set_trace() + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + # TODO: check for more easy handling with nn.Embedding + min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices) + min_encodings.scatter_(1, indices[:,None], 1) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings.float(), self.embedding.weight) + + if shape is not None: + z_q = z_q.view(shape) + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +class GumbelQuantize(nn.Module): + """ + credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!) + Gumbel Softmax trick quantizer + Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016 + https://arxiv.org/abs/1611.01144 + """ + def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True, + kl_weight=5e-4, temp_init=1.0, use_vqinterface=True, + remap=None, unknown_index="random"): + super().__init__() + + self.embedding_dim = embedding_dim + self.n_embed = n_embed + + self.straight_through = straight_through + self.temperature = temp_init + self.kl_weight = kl_weight + + self.proj = nn.Conv2d(num_hiddens, n_embed, 1) + self.embed = nn.Embedding(n_embed, embedding_dim) + + self.use_vqinterface = use_vqinterface + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed+1 + print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices.") + else: + self.re_embed = n_embed + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + match = (inds[:,:,None]==used[None,None,...]).long() + new = match.argmax(-1) + unknown = match.sum(2)<1 + if self.unknown_index == "random": + new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds>=self.used.shape[0]] = 0 # simply set to zero + back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) + return back.reshape(ishape) + + def forward(self, z, temp=None, return_logits=False): + # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work + hard = self.straight_through if self.training else True + temp = self.temperature if temp is None else temp + + logits = self.proj(z) + if self.remap is not None: + # continue only with used logits + full_zeros = torch.zeros_like(logits) + logits = logits[:,self.used,...] + + soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard) + if self.remap is not None: + # go back to all entries but unused set to zero + full_zeros[:,self.used,...] = soft_one_hot + soft_one_hot = full_zeros + z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight) + + # + kl divergence to the prior loss + qy = F.softmax(logits, dim=1) + diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean() + + ind = soft_one_hot.argmax(dim=1) + if self.remap is not None: + ind = self.remap_to_used(ind) + if self.use_vqinterface: + if return_logits: + return z_q, diff, (None, None, ind), logits + return z_q, diff, (None, None, ind) + return z_q, diff, ind + + def get_codebook_entry(self, indices, shape): + b, h, w, c = shape + assert b*h*w == indices.shape[0] + indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w) + if self.remap is not None: + indices = self.unmap_to_all(indices) + one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float() + z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight) + return z_q + + +class VectorQuantizer2(nn.Module): + """ + Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly + avoids costly matrix multiplications and allows for post-hoc remapping of indices. + """ + # NOTE: due to a bug the beta term was applied to the wrong term. for + # backwards compatibility we use the buggy version by default, but you can + # specify legacy=False to fix it. + def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", + sane_index_shape=False, legacy=True): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.legacy = legacy + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed+1 + print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices.") + else: + self.re_embed = n_e + + self.sane_index_shape = sane_index_shape + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + match = (inds[:,:,None]==used[None,None,...]).long() + new = match.argmax(-1) + unknown = match.sum(2)<1 + if self.unknown_index == "random": + new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds>=self.used.shape[0]] = 0 # simply set to zero + back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) + return back.reshape(ishape) + + def forward(self, z, temp=None, rescale_logits=False, return_logits=False): + assert temp is None or temp==1.0, "Only for interface compatible with Gumbel" + assert rescale_logits==False, "Only for interface compatible with Gumbel" + assert return_logits==False, "Only for interface compatible with Gumbel" + # reshape z -> (batch, height, width, channel) and flatten + z = rearrange(z, 'b c h w -> b h w c').contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight**2, dim=1) - 2 * \ + torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = self.embedding(min_encoding_indices).view(z.shape) + perplexity = None + min_encodings = None + + # compute loss for embedding + if not self.legacy: + loss = self.beta * torch.mean((z_q.detach()-z)**2) + \ + torch.mean((z_q - z.detach()) ** 2) + else: + loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ + torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape( + z_q.shape[0], z_q.shape[2], z_q.shape[3]) + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape(shape[0],-1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + +class EmbeddingEMA(nn.Module): + def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5): + super().__init__() + self.decay = decay + self.eps = eps + weight = torch.randn(num_tokens, codebook_dim) + self.weight = nn.Parameter(weight, requires_grad = False) + self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False) + self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False) + self.update = True + + def forward(self, embed_id): + return F.embedding(embed_id, self.weight) + + def cluster_size_ema_update(self, new_cluster_size): + self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay) + + def embed_avg_ema_update(self, new_embed_avg): + self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay) + + def weight_update(self, num_tokens): + n = self.cluster_size.sum() + smoothed_cluster_size = ( + (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n + ) + #normalize embedding average with smoothed cluster size + embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1) + self.weight.data.copy_(embed_normalized) + + +class EMAVectorQuantizer(nn.Module): + def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5, + remap=None, unknown_index="random"): + super().__init__() + self.codebook_dim = codebook_dim + self.num_tokens = num_tokens + self.beta = beta + self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed+1 + print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices.") + else: + self.re_embed = n_embed + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + match = (inds[:,:,None]==used[None,None,...]).long() + new = match.argmax(-1) + unknown = match.sum(2)<1 + if self.unknown_index == "random": + new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds>=self.used.shape[0]] = 0 # simply set to zero + back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) + return back.reshape(ishape) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + #z, 'b c h w -> b h w c' + z = rearrange(z, 'b c h w -> b h w c') + z_flattened = z.reshape(-1, self.codebook_dim) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \ + self.embedding.weight.pow(2).sum(dim=1) - 2 * \ + torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n' + + + encoding_indices = torch.argmin(d, dim=1) + + z_q = self.embedding(encoding_indices).view(z.shape) + encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) + avg_probs = torch.mean(encodings, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + if self.training and self.embedding.update: + #EMA cluster size + encodings_sum = encodings.sum(0) + self.embedding.cluster_size_ema_update(encodings_sum) + #EMA embedding average + embed_sum = encodings.transpose(0,1) @ z_flattened + self.embedding.embed_avg_ema_update(embed_sum) + #normalize embed_avg and update weight + self.embedding.weight_update(self.num_tokens) + + # compute loss for embedding + loss = self.beta * F.mse_loss(z_q.detach(), z) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + #z_q, 'b h w c -> b c h w' + z_q = rearrange(z_q, 'b h w c -> b c h w') + return z_q, loss, (perplexity, encodings, encoding_indices) \ No newline at end of file diff --git a/tools/visualization_0416/utils/model_0506/model/vae/resnet.py b/tools/visualization_0416/utils/model_0506/model/vae/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..7f497a654819c1b01cb3edc0b3322cb45975f2da --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/vae/resnet.py @@ -0,0 +1,447 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# build resnet for cifar10, debug use only +# from https://github.com/huyvnphan/PyTorch_CIFAR10/blob/master/cifar10_models/resnet.py + +import os +import requests +from tqdm import tqdm +import zipfile +import torch.utils.model_zoo as modelzoo +import torch.nn.functional as F +import torch +import torch.nn as nn + +__all__ = [ + "ResNet", + "resnet18", + "resnet34", + "resnet50", +] +weights_downloaded = False + + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation, + ) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__( + self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None, + ): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError("BasicBlock only supports groups=1 and base_width=64") + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__( + self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None, + ): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.0)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + def __init__( + self, + block, + layers, + num_classes=10, + zero_init_residual=False, + groups=1, + width_per_group=64, + replace_stride_with_dilation=None, + norm_layer=None, + ): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError( + "replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation) + ) + self.groups = groups + self.base_width = width_per_group + + # CIFAR10: kernel_size 7 -> 3, stride 2 -> 1, padding 3->1 + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) + # END + + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer( + block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0] + ) + self.layer3 = self._make_layer( + block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1] + ) + self.layer4 = self._make_layer( + block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2] + ) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append( + block( + self.inplanes, + planes, + stride, + downsample, + self.groups, + self.base_width, + previous_dilation, + norm_layer, + ) + ) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer, + ) + ) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.reshape(x.size(0), -1) + x = self.fc(x) + + return x + +def _resnet(arch, block, layers, pretrained, progress, device, **kwargs): + global weights_downloaded + model = ResNet(block, layers, **kwargs) + if pretrained: + if not weights_downloaded: + download_weights() + weights_downloaded = True + + script_dir = os.path.dirname(__file__) + state_dict_path = os.path.join(script_dir, "../../cifar10_models/state_dicts", arch + ".pt") + if os.path.isfile(state_dict_path): + state_dict = torch.load(state_dict_path, map_location=device) + model.load_state_dict(state_dict) + else: + raise FileNotFoundError(f"No such file or directory: '{state_dict_path}'") + return model + + +def resnet18(pretrained=False, progress=True, device="cpu", **kwargs): + """Constructs a ResNet-18 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet("resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, device, **kwargs) + + +def resnet34(pretrained=False, progress=True, device="cpu", **kwargs): + """Constructs a ResNet-34 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet("resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, device, **kwargs) + + +def resnet50(pretrained=False, progress=True, device="cpu", **kwargs): + """Constructs a ResNet-50 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, device, **kwargs) + + +def download_weights(): + + script_dir = os.path.dirname(__file__) + state_dicts_dir = os.path.join(script_dir, "cifar10_models") + + if os.path.isdir(state_dicts_dir) and len(os.listdir(state_dicts_dir)) > 0: + print("Weights already downloaded. Skipping download.") + return + + url = "https://rutgers.box.com/shared/static/gkw08ecs797j2et1ksmbg1w5t3idf5r5.zip" + + # Streaming, so we can iterate over the response. + r = requests.get(url, stream=True) + + # Total size in Mebibyte + total_size = int(r.headers.get("content-length", 0)) + block_size = 2**20 # Mebibyte + t = tqdm(total=total_size, unit="MiB", unit_scale=True) + + with open("state_dicts.zip", "wb") as f: + for data in r.iter_content(block_size): + t.update(len(data)) + f.write(data) + t.close() + + if total_size != 0 and t.n != total_size: + raise Exception("Error, something went wrong") + + print("Download successful. Unzipping file...") + path_to_zip_file = os.path.join(os.getcwd(), "state_dicts.zip") + directory_to_extract_to = os.path.join(os.getcwd(), "cifar10_models") + + with zipfile.ZipFile(path_to_zip_file, "r") as zip_ref: + zip_ref.extractall(directory_to_extract_to) + print("Unzip file successful!") + + + # original resblock +class ResBlock2D(nn.Module): + def __init__(self, n_c, kernel=3, dilation=1, p_drop=0.15): + super(ResBlock2D, self).__init__() + padding = self._get_same_padding(kernel, dilation) + + layer_s = list() + layer_s.append(nn.Conv2d(n_c, n_c, kernel, padding=padding, dilation=dilation, bias=False)) + layer_s.append(nn.InstanceNorm2d(n_c, affine=True, eps=1e-6)) + layer_s.append(nn.ELU(inplace=True)) + # dropout + layer_s.append(nn.Dropout(p_drop)) + # convolution + layer_s.append(nn.Conv2d(n_c, n_c, kernel, dilation=dilation, padding=padding, bias=False)) + layer_s.append(nn.InstanceNorm2d(n_c, affine=True, eps=1e-6)) + self.layer = nn.Sequential(*layer_s) + self.final_activation = nn.ELU(inplace=True) + + def _get_same_padding(self, kernel, dilation): + return (kernel + (kernel - 1) * (dilation - 1) - 1) // 2 + + def forward(self, x): + out = self.layer(x) + return self.final_activation(x + out) + + +def create_layer_basic(in_chan, out_chan, bnum, stride=1): + layers = [BasicBlock(in_chan, out_chan, stride=stride)] + for i in range(bnum-1): + layers.append(BasicBlock(out_chan, out_chan, stride=1)) + return nn.Sequential(*layers) + + +resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' +class ResNet18(nn.Module): + def __init__(self): + super(ResNet18, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) + self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) + self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) + self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) + self.init_weight() + + def forward(self, x): + x = self.conv1(x) + x = F.relu(self.bn1(x)) + x = self.maxpool(x) + + x = self.layer1(x) + feat8 = self.layer2(x) # 1/8 + feat16 = self.layer3(feat8) # 1/16 + feat32 = self.layer4(feat16) # 1/32 + return feat8, feat16, feat32 + + def init_weight(self): + state_dict = modelzoo.load_url(resnet18_url) + # state_dict = torch.load('/apdcephfs/share_1290939/kevinyxpang/STIT/resnet18-5c106cde.pth') + self_state_dict = self.state_dict() + for k, v in state_dict.items(): + if 'fc' in k: continue + self_state_dict.update({k: v}) + self.load_state_dict(self_state_dict) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, (nn.Linear, nn.Conv2d)): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + + diff --git a/tools/visualization_0416/utils/model_0506/model/vae/vit.py b/tools/visualization_0416/utils/model_0506/model/vae/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..e6446f06d9243e2e80f4e841013dd06787d8fc0a --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/vae/vit.py @@ -0,0 +1,172 @@ +import torch +from torch import nn + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +# helpers + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +# classes + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.norm = nn.LayerNorm(dim) + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + x = self.norm(x) + + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), + FeedForward(dim, mlp_dim, dropout = dropout) + ])) + + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + + return self.norm(x) + + +class ViTEncoder(nn.Module): + def __init__(self, *, image_size, patch_size, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): + super().__init__() + image_height, image_width = pair(image_size) + patch_height, patch_width = pair(patch_size) + + assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' + + num_patches = (image_height // patch_height) * (image_width // patch_width) + patch_dim = channels * patch_height * patch_width + assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' + + self.to_patch_embedding = nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), + nn.LayerNorm(patch_dim), + nn.Linear(patch_dim, dim), + nn.LayerNorm(dim), + ) + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) + self.dropout = nn.Dropout(emb_dropout) + + self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) + + self.pool = pool + self.to_latent = nn.Identity() + + def forward(self, img): + x = self.to_patch_embedding(img) + b, n, _ = x.shape + + cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) + x = torch.cat((cls_tokens, x), dim=1) + x += self.pos_embedding[:, :(n + 1)] + x = self.dropout(x) + + x = self.transformer(x) + x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] + + x = self.to_latent(x) + x = x.view(x.size(0), -1, 1, 1) + + return x + +class ViTDecoder(nn.Module): + def __init__(self, *, image_size, patch_size, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): + super().__init__() + + image_height, image_width = pair(image_size) + patch_height, patch_width = pair(patch_size) + + assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' + + self.num_patches = (image_height // patch_height) * (image_width // patch_width) + pixel_values_per_patch = patch_height * patch_width * 3 + + self.decoder_dim = dim + self.mask_token = nn.Parameter(torch.randn(dim)) + self.decoder = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) + self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, dim)) + self.to_pixels = nn.Linear(dim, pixel_values_per_patch) + self.dropout = nn.Dropout(emb_dropout) + + self.token2_image = Rearrange('b (h w) (p1 p2 c) -> b c (h p1) (w p2)', p1 = patch_height, p2 = patch_width, h=image_height//patch_height, w=image_width//patch_width) + + def forward(self, latent): + + batch = latent.size(0) + device = latent.device + latent = latent.view(batch, -1) + + decoder_tokens = torch.zeros(batch, self.num_patches+1, self.decoder_dim, device=device) + decoder_tokens[:, 0] = latent + + mask_tokens = repeat(self.mask_token, 'd -> b n d', b = batch, n = self.num_patches) + decoder_tokens[:, 1:] = mask_tokens + decoder_tokens += self.pos_embedding + decoder_tokens = self.dropout(decoder_tokens) + decoded_tokens = self.decoder(decoder_tokens) + + # splice out the mask tokens and project to pixel values + mask_tokens = decoded_tokens[:, 1:] + pred_pixel_values = self.to_pixels(mask_tokens) + pred_image = self.token2_image(pred_pixel_values) + + return pred_image + + diff --git a/tools/visualization_0416/utils/model_0506/model/vae/vqvae.py b/tools/visualization_0416/utils/model_0506/model/vae/vqvae.py new file mode 100644 index 0000000000000000000000000000000000000000..41090bcfbe6401e4b539bd185437ce9ffeaba9aa --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/vae/vqvae.py @@ -0,0 +1,128 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from model.lightning.base_modules import BaseModule +from omegaconf import DictConfig +from typing import Any, Dict, Tuple +from utils import instantiate +import cv2 +from PIL import Image +import numpy as np + + +class VQAutoEncoder(BaseModule): + """ VQ-VAE model """ + def __init__( + self, + config: DictConfig, + ) -> None: + super().__init__(config) + + self.config = config + + self.l_w_recon = config.loss.l_w_recon + self.l_w_embedding = config.loss.l_w_embedding + self.l_w_commitment = config.loss.l_w_commitment + self.mse_loss = nn.MSELoss() + + def configure_model(self): + config = self.config + self.encoder = instantiate(config.model.encoder) + + self.decoder = instantiate(config.model.decoder) + # self.quantizer = instantiate(config.model.quantizer) + + # VQ Embedding (Vector Quantization) layer + self.vq_embedding = nn.Embedding(config.model.n_embedding, config.model.latent_dim) + self.vq_embedding.weight.data.uniform_(-1.0 / config.model.latent_dim, 1.0 / config.model.latent_dim) # Random initialization + + def configure_optimizers(self) -> Dict[str, Any]: + params_to_update = [p for p in self.parameters() if p.requires_grad] + optimizer = torch.optim.AdamW( + params_to_update, + lr=self.config.optimizer.lr, + weight_decay=self.config.optimizer.weight_decay, + betas=(self.config.optimizer.adam_beta1, self.config.optimizer.adam_beta2), + eps=self.config.optimizer.adam_epsilon, + ) + + return {"optimizer": optimizer} + + def encode(self, image): + ze = self.encoder(image) + + # Vector Quantization + embedding = self.vq_embedding.weight.data + B, C, H, W = ze.shape + K, _ = embedding.shape + embedding_broadcast = embedding.reshape(1, K, C, 1, 1) + ze_broadcast = ze.reshape(B, 1, C, H, W) + distance = torch.sum((embedding_broadcast - ze_broadcast) ** 2, 2) + nearest_neighbor = torch.argmin(distance, 1) + + # Quantized features + zq = self.vq_embedding(nearest_neighbor).permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W] + + return ze, zq + + def decode(self, quantized_fea): + x_hat = self.decoder(quantized_fea) + return x_hat + + def _step(self, batch, return_loss=True): + pixel_values_vid = batch['pixel_values_vid'] # this is a video batch: [B, T, C, H, W] + pixel_values_vid = pixel_values_vid.view(-1, 3, pixel_values_vid.size(-2), pixel_values_vid.size(-1)) # [B, T, C, H, W] -> [B*T, C, H, W] + + # import cv2 + # cv2.imwrite('debug_img.png', 255*pixel_values_vid[-1].permute(1,2,0).cpu().numpy()[:,:,::-1]) + # import pdb; pdb.set_trace() + + # test on single image + # pixel_values_vid = Image.open('debug_img.png') + # pixel_values_vid = np.array(pixel_values_vid) / 255.0 + # pixel_values_vid = torch.from_numpy(pixel_values_vid).float().to(self.device)[None].permute(0, 3, 1, 2) + + # Encoding + hidden_fea, quantized_fea = self.encode(self, pixel_values_vid) + + # Stop gradient + decoder_input = hidden_fea + (quantized_fea - hidden_fea).detach() + + # Decoding + x_hat = self.decode(decoder_input) + + if return_loss: + # Reconstruction Loss + l_reconstruct = self.mse_loss(x_hat, pixel_values_vid) + + # Embedding Loss + l_embedding = self.mse_loss(hidden_fea.detach(), quantized_fea) + + # Commitment Loss + l_commitment = self.mse_loss(hidden_fea, quantized_fea.detach()) + + # Total Loss + total_loss = l_reconstruct + self.l_w_embedding * l_embedding + self.l_w_commitment * l_commitment + + self.log('recon_loss', l_reconstruct, on_step=True, on_epoch=True, prog_bar=True) + self.log('emb_loss', l_embedding, on_step=True, on_epoch=True, prog_bar=True) + self.log('commit_loss', l_commitment, on_step=True, on_epoch=True, prog_bar=True) + + return total_loss + else: + return x_hat, pixel_values_vid + + def training_step(self, batch): + total_loss = self._step(batch) + return total_loss + + def validation_step(self, batch): + total_loss = self._step(batch) + return total_loss + + def forward(self, batch): + x_pred, x_gt = self._step(batch, return_loss=False) + + return x_pred, x_gt + + diff --git a/tools/visualization_0416/utils/model_0506/model/vae/vqvae2d.py b/tools/visualization_0416/utils/model_0506/model/vae/vqvae2d.py new file mode 100644 index 0000000000000000000000000000000000000000..958b6eb861dcfa1d9103ecbe7c3dca0fe40ebc8d --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/vae/vqvae2d.py @@ -0,0 +1,224 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from model.lightning.base_modules import BaseModule +from omegaconf import DictConfig +from typing import Any, Dict, Tuple +from utils import instantiate +import cv2 +from PIL import Image +import numpy as np + + +class VQAutoEncoder(BaseModule): + """ VQ-VAE model """ + def __init__( + self, + config: DictConfig, + ) -> None: + super().__init__(config) + + self.config = config + + self.l_w_recon = config.loss.l_w_recon + self.l_w_embedding = config.loss.l_w_embedding + self.l_w_commitment = config.loss.l_w_commitment + self.mse_loss = nn.MSELoss() + + + def _get_scheduler(self) -> Any: + # this function is for diffusion model + pass + + def configure_model(self): + config = self.config + self.encoder = instantiate(config.model.encoder) + + self.decoder = instantiate(config.model.decoder) + # self.quantizer = instantiate(config.model.quantizer) + + # VQ Embedding (Vector Quantization) layer + self.vq_embedding = nn.Embedding(config.model.n_embedding, config.model.latent_dim) + self.vq_embedding.weight.data.uniform_(-1.0 / config.model.latent_dim, 1.0 / config.model.latent_dim) # Random initialization + + def configure_optimizers(self) -> Dict[str, Any]: + params_to_update = [p for p in self.parameters() if p.requires_grad] + optimizer = torch.optim.AdamW( + params_to_update, + lr=self.config.optimizer.lr, + weight_decay=self.config.optimizer.weight_decay, + betas=(self.config.optimizer.adam_beta1, self.config.optimizer.adam_beta2), + eps=self.config.optimizer.adam_epsilon, + ) + + return {"optimizer": optimizer} + + def encode(self, image): + ze = self.encoder(image) + + # Vector Quantization + embedding = self.vq_embedding.weight.data + B, C, H, W = ze.shape + K, _ = embedding.shape + embedding_broadcast = embedding.reshape(1, K, C, 1, 1) + ze_broadcast = ze.reshape(B, 1, C, H, W) + distance = torch.sum((embedding_broadcast - ze_broadcast) ** 2, 2) + nearest_neighbor = torch.argmin(distance, 1) + + # Quantized features + zq = self.vq_embedding(nearest_neighbor).permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W] + + return ze, zq + + def decode(self, quantized_fea): + x_hat = self.decoder(quantized_fea) + return x_hat + + def _step(self, batch, return_loss=True): + pixel_values_vid = batch['pixel_values_vid'] # this is a video batch: [B, T, C, H, W] + pixel_values_vid = pixel_values_vid.view(-1, 3, pixel_values_vid.size(-2), pixel_values_vid.size(-1)) # [B, T, C, H, W] -> [B*T, C, H, W] + + # import cv2 + # cv2.imwrite('debug_img.png', 255*pixel_values_vid[-1].permute(1,2,0).cpu().numpy()[:,:,::-1]) + # import pdb; pdb.set_trace() + + # test on single image + # pixel_values_vid = Image.open('debug_img.png') + # pixel_values_vid = np.array(pixel_values_vid) / 255.0 + # pixel_values_vid = torch.from_numpy(pixel_values_vid).float().to(self.device)[None].permute(0, 3, 1, 2) + + # Encoding + hidden_fea, quantized_fea = self.encode(self, pixel_values_vid) + + # Stop gradient + decoder_input = hidden_fea + (quantized_fea - hidden_fea).detach() + + # Decoding + x_hat = self.decode(decoder_input) + + if return_loss: + # Reconstruction Loss + l_reconstruct = self.mse_loss(x_hat, pixel_values_vid) + + # Embedding Loss + l_embedding = self.mse_loss(hidden_fea.detach(), quantized_fea) + + # Commitment Loss + l_commitment = self.mse_loss(hidden_fea, quantized_fea.detach()) + + # Total Loss + total_loss = l_reconstruct + self.l_w_embedding * l_embedding + self.l_w_commitment * l_commitment + + self.log('recon_loss', l_reconstruct, on_step=True, on_epoch=True, prog_bar=True) + self.log('emb_loss', l_embedding, on_step=True, on_epoch=True, prog_bar=True) + self.log('commit_loss', l_commitment, on_step=True, on_epoch=True, prog_bar=True) + + return total_loss + else: + return x_hat, pixel_values_vid + + def training_step(self, batch): + total_loss = self._step(batch) + return total_loss + + def validation_step(self, batch): + total_loss = self._step(batch) + return total_loss + + def forward(self, batch): + x_pred, x_gt = self._step(batch, return_loss=False) + + return x_pred, x_gt + + +class ResBlock(nn.Module): + def __init__(self, in_channels, out_channels, mid_channels=None, bn=False): + super(ResBlock, self).__init__() + + if mid_channels is None: + mid_channels = out_channels + + layers = [ + nn.ReLU(), + nn.Conv2d(in_channels, mid_channels, + kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.Conv2d(mid_channels, out_channels, + kernel_size=1, stride=1, padding=0) + ] + if bn: + layers.insert(2, nn.BatchNorm2d(out_channels)) + self.convs = nn.Sequential(*layers) + + def forward(self, x): + return x + self.convs(x) + + +class ResidualBlock(nn.Module): + def __init__(self, dim): + super().__init__() + self.relu = nn.ReLU() + self.conv1 = nn.Conv2d(dim, dim, 3, 1, 1) + self.conv2 = nn.Conv2d(dim, dim, 1) + + def forward(self, x): + tmp = self.relu(x) + tmp = self.conv1(tmp) + tmp = self.relu(tmp) + tmp = self.conv2(tmp) + return x + tmp + + +class Encoder(nn.Module): + def __init__(self, output_channels=512): + super(Encoder, self).__init__() + + self.block = nn.Sequential( + nn.Conv2d(3, output_channels, 4, 2, 1), # Convolutional layer + nn.ReLU(), + nn.Conv2d(output_channels, output_channels, 4, 2, 1), # Another Convolutional layer + nn.ReLU(), + nn.Conv2d(output_channels, output_channels, 4, 2, 1), # Convolutional layer + nn.ReLU(), + nn.Conv2d(output_channels, output_channels, 4, 2, 1), # Another Convolutional layer + nn.ReLU(), + nn.Conv2d(output_channels, output_channels, 4, 2, 1), # Another Convolutional layer + nn.ReLU(), + nn.Conv2d(output_channels, output_channels, 3, 1, 1), # Final Convolutional layer before residuals + ResidualBlock(output_channels), # Residual block 1 + ResidualBlock(output_channels), # Residual block 2 + ) + + def forward(self, x): + x = self.block(x) + return x + + +class Decoder(nn.Module): + def __init__(self, input_dim=512): + super(Decoder, self).__init__() + + self.fea_map_size=16 + + self.block = nn.Sequential( + nn.Conv2d(input_dim, input_dim, 3, 1, 1), # Initial convolution in the decoder + ResidualBlock(input_dim), # Residual block 1 + ResidualBlock(input_dim), # Residual block 2 + nn.ConvTranspose2d(input_dim, input_dim, 4, 2, 1), # Transposed convolution (upsampling) + nn.ReLU(), + nn.ConvTranspose2d(input_dim, input_dim, 4, 2, 1), # Transposed convolution (upsampling) + nn.ReLU(), + nn.ConvTranspose2d(input_dim, input_dim, 4, 2, 1), # Transposed convolution (upsampling) + nn.ReLU(), + nn.ConvTranspose2d(input_dim, input_dim, 4, 2, 1), # Transposed convolution (upsampling) + nn.ReLU(), + nn.ConvTranspose2d(input_dim, 3, 4, 2, 1) # Final transposed convolution (output layer) + ) + + def forward(self, x): + x_hat = self.block(x) + + return x_hat + + + diff --git a/tools/visualization_0416/utils/model_0506/model/vae/vqvae_test.py b/tools/visualization_0416/utils/model_0506/model/vae/vqvae_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8d21c4575bf0e3919cc618ae4a97c4bec03f967e --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/model/vae/vqvae_test.py @@ -0,0 +1,124 @@ +import unittest +import torch +import torch.nn as nn +from omegaconf import OmegaConf +from pytorch_lightning import seed_everything +from model.vae.vqvae import VQAutoEncoder + +class TestVQAutoEncoder(unittest.TestCase): + @classmethod + def setUpClass(cls): + """Set up test fixtures that are shared across all tests.""" + config = { + 'model': { + 'encoder': { + 'module_name': 'model.vae.cnn', + 'class_name': 'Encoder2D', + 'output_channels': 512 + }, + 'decoder': { + 'module_name': 'model.vae.cnn', + 'class_name': 'Decoder2D', + 'input_dim': 512 + }, + 'latent_dim': 512 + }, + 'optimizer': { + 'lr': 1e-4, + 'weight_decay': 0.0, + 'adam_beta1': 0.9, + 'adam_beta2': 0.999, + 'adam_epsilon': 1e-8 + }, + 'loss': { + 'l_w_recon': 1.0, + 'l_w_embedding': 1.0, + 'l_w_recon': 1.0 + } + } + cls.config = OmegaConf.create(config) + seed_everything(42) + cls.model = VQAutoEncoder(cls.config) + cls.model.configure_model() + + def test_model_initialization(self): + """Test that the model and its components are initialized correctly.""" + self.assertIsInstance(self.model, VQAutoEncoder) + self.assertIsInstance(self.model.encoder, nn.Module) + self.assertIsInstance(self.model.decoder, nn.Module) + self.assertTrue(hasattr(self.model, 'quantizer')) + + def test_encode_decode(self): + """Test the encode and decode functions of the model.""" + batch_size = 2 + channels = 3 + height = 512 # Use 512x512 input to match the model architecture + width = 512 + + # Create dummy input + x = torch.randn(batch_size, channels, height, width) + + # Test encode + quant, emb_loss, info = self.model.encode(x) + self.assertEqual(quant.shape, (batch_size, 1, self.model.config.model.latent_dim)) + self.assertIsInstance(emb_loss, torch.Tensor) + self.assertIsInstance(info, tuple) # VectorQuantizer returns a tuple, not a dict + + # Test decode + dec = self.model.decode(quant) + self.assertEqual(dec.shape, (batch_size, channels, height, width)) + + def test_forward(self): + """Test the forward pass of the model.""" + batch_size = 2 + channels = 3 + height = 512 # Use 512x512 input to match the model architecture + width = 512 + + # Create dummy input + x = torch.randn(batch_size, channels, height, width) + + # Test forward pass + dec, emb_loss, info = self.model.forward(x) + + # Check output shapes and types + self.assertEqual(dec.shape, (batch_size, channels, height, width)) + self.assertIsInstance(emb_loss, torch.Tensor) + self.assertIsInstance(info, tuple) # VectorQuantizer returns a tuple, not a dict + + def test_training_step(self): + """Test the training step of the model.""" + batch_size = 2 + channels = 3 + height = 512 # Use 512x512 input to match the model architecture + width = 512 + + # Create dummy batch + batch = { + 'pixel_values_vid': torch.randn(batch_size, channels, height, width) + } + + # Test training step + loss = self.model.training_step(batch) + self.assertIsInstance(loss, torch.Tensor) + self.assertTrue(loss.requires_grad) + + def test_validation_step(self): + """Test the validation step of the model.""" + batch_size = 2 + channels = 3 + height = 512 # Use 512x512 input to match the model architecture + width = 512 + + # Create dummy batch + batch = { + 'pixel_values_vid': torch.randn(batch_size, channels, height, width) + } + + # Test validation step + loss = self.model.validation_step(batch) + self.assertIsInstance(loss, torch.Tensor) + + +if __name__ == '__main__': + unittest.main() diff --git a/tools/visualization_0416/utils/model_0506/utils.py b/tools/visualization_0416/utils/model_0506/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e7be4d12f442aed77bdd1605a928e6f3209f9e35 --- /dev/null +++ b/tools/visualization_0416/utils/model_0506/utils.py @@ -0,0 +1,55 @@ +from importlib import import_module +from omegaconf import OmegaConf +import os +from pathlib import Path +import shutil +from omegaconf import DictConfig +from lightning.pytorch.utilities import rank_zero_info + +def instantiate(config: DictConfig, instantiate_module=True): + """Get arguments from config.""" + module = import_module(config.module_name) + class_ = getattr(module, config.class_name) + if instantiate_module: + init_args = {k: v for k, v in config.items() if k not in ["module_name", "class_name"]} + return class_(**init_args) + else: + return class_ + +def instantiate_motion_gen(module_name, class_name, cfg, hfstyle=False, **init_args): + module = import_module(module_name) + class_ = getattr(module, class_name) + if hfstyle: + config_class = class_.config_class + cfg = config_class(config_obj=cfg) + return class_(cfg, **init_args) + +def save_config_and_codes(config, save_dir): + os.makedirs(save_dir, exist_ok=True) + sanity_check_dir = os.path.join(save_dir, 'sanity_check') + os.makedirs(sanity_check_dir, exist_ok=True) + with open(os.path.join(sanity_check_dir, f'{config.exp_name}.yaml'), 'w') as f: + OmegaConf.save(config, f) + current_dir = Path.cwd() + for py_file in current_dir.rglob('*.py'): + dest_path = Path(sanity_check_dir) / py_file.relative_to(current_dir) + dest_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy(py_file, dest_path) + +def print_model_size(model): + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + rank_zero_info(f"Total parameters: {total_params:,}") + rank_zero_info(f"Trainable parameters: {trainable_params:,}") + rank_zero_info(f"Non-trainable parameters: {(total_params - trainable_params):,}") + +def load_metrics(file_path): + metrics = {} + with open(file_path, "r") as f: + for line in f: + key, value = line.strip().split(": ") + try: + metrics[key] = float(value) # Convert to float if possible + except ValueError: + metrics[key] = value # Keep as string if conversion fails + return metrics \ No newline at end of file