Spaces:
Runtime error
Runtime error
| """ | |
| Copyright (c) Meta Platforms, Inc. and affiliates. | |
| All rights reserved. | |
| This source code is licensed under the license found in the | |
| LICENSE file in the root directory of this source tree. | |
| """ | |
| from typing import Dict, Tuple | |
| import numpy as np | |
| import torch as th | |
| import torch.nn as nn | |
| import visualize.ca_body.nn.layers as la | |
| from attrdict import AttrDict | |
| class FaceDecoderFrontal(nn.Module): | |
| def __init__( | |
| self, | |
| assets: AttrDict, | |
| n_latent: int = 256, | |
| n_vert_out: int = 3 * 7306, | |
| tex_out_shp: Tuple[int, int] = (1024, 1024), | |
| tex_roi: Tuple[Tuple[int, int], Tuple[int, int]] = ((0, 0), (1024, 1024)), | |
| ) -> None: | |
| super().__init__() | |
| self.n_latent = n_latent | |
| self.n_vert_out = n_vert_out | |
| self.tex_roi = tex_roi | |
| self.tex_roi_shp: Tuple[int, int] = tuple( | |
| [int(i) for i in np.diff(np.array(tex_roi), axis=0).squeeze()] | |
| ) | |
| self.tex_out_shp = tex_out_shp | |
| self.encmod = nn.Sequential( | |
| la.LinearWN(n_latent, 256), nn.LeakyReLU(0.2, inplace=True) | |
| ) | |
| self.geommod = nn.Sequential(la.LinearWN(256, n_vert_out)) | |
| self.viewmod = nn.Sequential(la.LinearWN(3, 8), nn.LeakyReLU(0.2, inplace=True)) | |
| self.texmod2 = nn.Sequential( | |
| la.LinearWN(256 + 8, 256 * 4 * 4), nn.LeakyReLU(0.2, inplace=True) | |
| ) | |
| self.texmod = nn.Sequential( | |
| la.ConvTranspose2dWNUB(256, 256, 8, 8, 4, 2, 1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| la.ConvTranspose2dWNUB(256, 128, 16, 16, 4, 2, 1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| la.ConvTranspose2dWNUB(128, 128, 32, 32, 4, 2, 1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| la.ConvTranspose2dWNUB(128, 64, 64, 64, 4, 2, 1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| la.ConvTranspose2dWNUB(64, 64, 128, 128, 4, 2, 1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| la.ConvTranspose2dWNUB(64, 32, 256, 256, 4, 2, 1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| la.ConvTranspose2dWNUB(32, 8, 512, 512, 4, 2, 1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| la.ConvTranspose2dWNUB(8, 3, 1024, 1024, 4, 2, 1), | |
| ) | |
| self.bias = nn.Parameter(th.zeros(3, self.tex_roi_shp[0], self.tex_roi_shp[1])) | |
| self.bias.data.zero_() | |
| self.register_buffer( | |
| "frontal_view", th.as_tensor(assets.face_frontal_view, dtype=th.float32) | |
| ) | |
| self.apply(lambda x: la.glorot(x, 0.2)) | |
| la.glorot(self.texmod[-1], 1.0) | |
| def forward(self, face_embs: th.Tensor) -> Dict[str, th.Tensor]: | |
| B = face_embs.shape[0] | |
| view = self.frontal_view[np.newaxis].expand(B, -1) | |
| encout = self.encmod(face_embs) | |
| geomout = self.geommod(encout) | |
| viewout = self.viewmod(view) | |
| encview = th.cat([encout, viewout], dim=1) | |
| texout = self.texmod(self.texmod2(encview).view(-1, 256, 4, 4)) | |
| out = {"face_geom": geomout.view(geomout.shape[0], -1, 3)} | |
| out["face_tex_raw"] = texout | |
| texout = texout + self.bias[None] | |
| out["face_tex"] = 255 * (texout + 0.5) | |
| return out | |