| import torch |
| import pickle |
| import os.path as osp |
| from pytorch3d.io import load_obj |
|
|
| from .base_renderer import BaseMeshRenderer |
| from ...utils.helper import face_vertices |
| from ...utils.graphics import GS_BaseMeshRenderer |
|
|
| class Renderer(BaseMeshRenderer): |
| ''' visualizer |
| ''' |
|
|
| def __init__(self, assets_dir, image_size=512, device='cuda', focal_length=12): |
| super().__init__(assets_dir, image_size, device, focal_length=focal_length) |
| obj_filename = osp.join(assets_dir, 'head_template.obj') |
| self.focal_length=focal_length |
| verts, faces, aux = load_obj(obj_filename) |
| uvcoords = aux.verts_uvs[None, ...] |
| uvfaces = faces.textures_idx[None, ...] |
| faces = faces.verts_idx[None,...] |
|
|
| |
| colors = torch.tensor([180, 180, 180])[None, None, :].repeat(1, faces.max()+1, 1).float()/255. |
|
|
| flame_masks = pickle.load( |
| open(osp.join(assets_dir, 'FLAME_masks/FLAME_masks.pkl'), 'rb'), |
| encoding='latin1') |
| self.flame_masks = flame_masks |
|
|
| self.register_buffer('faces', faces) |
|
|
| face_colors = face_vertices(colors, faces) |
| self.register_buffer('face_colors', face_colors) |
| |
| self.register_buffer('raw_uvcoords', uvcoords) |
|
|
| |
| uvcoords = torch.cat([uvcoords, uvcoords[:,:,0:1]*0.+1.], -1) |
| uvcoords = uvcoords*2 - 1; uvcoords[...,1] = -uvcoords[...,1] |
| face_uvcoords = face_vertices(uvcoords, uvfaces) |
| self.register_buffer('uvcoords', uvcoords) |
| self.register_buffer('uvfaces', uvfaces) |
| self.register_buffer('face_uvcoords', face_uvcoords) |
|
|
| def forward(self, vertices, faces=None, landmarks={}, cameras=None, transform_matrix=None, focal_length=None, is_weak_cam=False, ret_image=True): |
| if faces is None: |
| faces = self.faces.squeeze(0) |
| return super().forward(vertices, faces, landmarks, cameras, transform_matrix, focal_length, is_weak_cam, ret_image) |
| |
|
|
| class Renderer2(GS_BaseMeshRenderer): |
| def __init__(self, assets_dir, image_size=512, device='cuda', focal_length=24): |
| super().__init__( image_size, focal_length=focal_length) |
| obj_filename = osp.join(assets_dir, 'head_template.obj') |
| self.focal_length=focal_length |
| verts, faces, aux = load_obj(obj_filename) |
| uvcoords = aux.verts_uvs[None, ...] |
| uvfaces = faces.textures_idx[None, ...] |
| faces = faces.verts_idx[None,...] |
|
|
| |
| colors = torch.tensor([180, 180, 180])[None, None, :].repeat(1, faces.max()+1, 1).float()/255. |
|
|
| flame_masks = pickle.load( |
| open(osp.join(assets_dir, 'FLAME_masks/FLAME_masks.pkl'), 'rb'), |
| encoding='latin1') |
| self.flame_masks = flame_masks |
|
|
| self.register_buffer('faces', faces) |
|
|
| face_colors = face_vertices(colors, faces) |
| self.register_buffer('face_colors', face_colors) |
| |
| self.register_buffer('raw_uvcoords', uvcoords) |
|
|
| |
| uvcoords = torch.cat([uvcoords, uvcoords[:,:,0:1]*0.+1.], -1) |
| uvcoords = uvcoords*2 - 1; uvcoords[...,1] = -uvcoords[...,1] |
| face_uvcoords = face_vertices(uvcoords, uvfaces) |
| self.register_buffer('uvcoords', uvcoords) |
| self.register_buffer('uvfaces', uvfaces) |
| self.register_buffer('face_uvcoords', face_uvcoords) |
| |
| def forward(self, vertices, faces=None, landmarks={}, cameras=None, transform_matrix=None, focal_length=None, is_weak_cam=False, ret_image=True): |
| if faces is None: |
| faces = self.faces.squeeze(0) |
| return super().forward(vertices, faces, landmarks, cameras, transform_matrix, focal_length, ret_image) |