Spaces:
Running
on
Zero
Running
on
Zero
| # -*- coding: utf-8 -*- | |
| # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is | |
| # holder of all proprietary rights on this computer program. | |
| # You can only use this computer program if you have closed | |
| # a license agreement with MPG or you get the right to use the computer | |
| # program from someone who is authorized to grant you that right. | |
| # Any use of the computer program without a valid license is prohibited and | |
| # liable to prosecution. | |
| # | |
| # Copyright©2023 Max-Planck-Gesellschaft zur Förderung | |
| # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute | |
| # for Intelligent Systems. All rights reserved. | |
| # | |
| # Contact: mica@tue.mpg.de | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as Functional | |
| from models.flame import FLAME | |
| def kaiming_leaky_init(m): | |
| classname = m.__class__.__name__ | |
| if classname.find('Linear') != -1: | |
| torch.nn.init.kaiming_normal_(m.weight, a=0.2, mode='fan_in', nonlinearity='leaky_relu') | |
| class MappingNetwork(nn.Module): | |
| def __init__(self, z_dim, map_hidden_dim, map_output_dim, hidden=2): | |
| super().__init__() | |
| if hidden > 5: | |
| self.skips = [int(hidden / 2)] | |
| else: | |
| self.skips = [] | |
| self.network = nn.ModuleList( | |
| [nn.Linear(z_dim, map_hidden_dim)] + | |
| [nn.Linear(map_hidden_dim, map_hidden_dim) if i not in self.skips else | |
| nn.Linear(map_hidden_dim + z_dim, map_hidden_dim) for i in range(hidden)] | |
| ) | |
| self.output = nn.Linear(map_hidden_dim, map_output_dim) | |
| self.network.apply(kaiming_leaky_init) | |
| with torch.no_grad(): | |
| self.output.weight *= 0.25 | |
| def forward(self, z): | |
| h = z | |
| for i, l in enumerate(self.network): | |
| h = self.network[i](h) | |
| h = Functional.leaky_relu(h, negative_slope=0.2) | |
| if i in self.skips: | |
| h = torch.cat([z, h], 1) | |
| output = self.output(h) | |
| return output | |
| class Generator(nn.Module): | |
| def __init__(self, z_dim, map_hidden_dim, map_output_dim, hidden, model_cfg, device, regress=True): | |
| super().__init__() | |
| self.device = device | |
| self.cfg = model_cfg | |
| self.regress = regress | |
| if self.regress: | |
| self.regressor = MappingNetwork(z_dim, map_hidden_dim, map_output_dim, hidden).to(self.device) | |
| self.generator = FLAME(model_cfg).to(self.device) | |
| def forward(self, arcface): | |
| if self.regress: | |
| shape = self.regressor(arcface) | |
| else: | |
| shape = arcface | |
| prediction, _, _ = self.generator(shape_params=shape) | |
| return prediction, shape | |