Spaces:
Runtime error
Runtime error
| # | |
| # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual | |
| # property and proprietary rights in and to this software and related documentation. | |
| # Any commercial use, reproduction, disclosure or distribution of this software and | |
| # related documentation without an express license agreement from Toyota Motor Europe NV/SA | |
| # is strictly prohibited. | |
| # | |
| from typing import Literal | |
| import tyro | |
| import numpy as np | |
| from PIL import Image | |
| from pathlib import Path | |
| import torch | |
| import nvdiffrast.torch as dr | |
| from vhap.util.render_uvmap import render_uvmap_vtex | |
| from vhap.model.flame import FlameHead | |
| FLAME_UV_MASK_FOLDER = "asset/flame/uv_masks" | |
| FLAME_UV_MASK_NPZ = "asset/flame/uv_masks.npz" | |
| def main( | |
| use_opengl: bool = False, | |
| device: Literal['cuda', 'cpu'] = 'cuda', | |
| ): | |
| n_shape = 300 | |
| n_expr = 100 | |
| print("Initializing FLAME model") | |
| flame_model = FlameHead(n_shape, n_expr, add_teeth=True) | |
| flame_model = FlameHead( | |
| n_shape, | |
| n_expr, | |
| add_teeth=True, | |
| ).cuda() | |
| faces = flame_model.faces.int().cuda() | |
| verts_uv = flame_model.verts_uvs.cuda() | |
| # verts_uv[:, 1] = 1 - verts_uv[:, 1] | |
| faces_uv = flame_model.textures_idx.int().cuda() | |
| col_idx = faces_uv | |
| # Rasterizer context | |
| glctx = dr.RasterizeGLContext() if use_opengl else dr.RasterizeCudaContext() | |
| h, w = 2048, 2048 | |
| resolution = (h, w) | |
| if not Path(FLAME_UV_MASK_FOLDER).exists(): | |
| Path(FLAME_UV_MASK_FOLDER).mkdir(parents=True) | |
| # alpha_maps = {} | |
| masks = {} | |
| for region, vt_mask in flame_model.mask.vt: | |
| v_color = torch.zeros(verts_uv.shape[0], 1).to(device) # alpha channel | |
| v_color[vt_mask] = 1 | |
| alpha = render_uvmap_vtex(glctx, verts_uv, faces_uv, v_color, col_idx, resolution)[0] | |
| alpha = alpha.flip(0) | |
| # alpha_maps[region] = alpha.cpu().numpy() | |
| mask = (alpha > 0.5) # to avoid overlap between hair and face | |
| mask = mask.squeeze(-1).cpu().numpy() | |
| masks[region] = mask # (h, w) | |
| print(f"Saving uv mask for {region}...") | |
| # rgba = mask.expand(-1, -1, 4) # (h, w, 4) | |
| # rgb = torch.ones_like(mask).expand(-1, -1, 3) # (h, w, 3) | |
| # rgba = torch.cat([rgb, mask], dim=-1).cpu().numpy() # (h, w, 4) | |
| img = mask | |
| img = Image.fromarray((img * 255).astype(np.uint8)) | |
| img.save(Path(FLAME_UV_MASK_FOLDER) / f"{region}.png") | |
| print(f"Saving uv mask into: {FLAME_UV_MASK_NPZ}") | |
| np.savez_compressed(FLAME_UV_MASK_NPZ, **masks) | |
| if __name__ == "__main__": | |
| tyro.cli(main) |