PEAR / models /modules /renderer /head_renderer.py
BestWJH's picture
Upload 205 files
2c68f56 verified
import torch
import pickle
import os.path as osp
from pytorch3d.io import load_obj
from .base_renderer import BaseMeshRenderer
from ...utils.helper import face_vertices
from ...utils.graphics import GS_BaseMeshRenderer
class Renderer(BaseMeshRenderer):
''' visualizer
'''
def __init__(self, assets_dir, image_size=512, device='cuda', focal_length=12):
super().__init__(assets_dir, image_size, device, focal_length=focal_length)
obj_filename = osp.join(assets_dir, 'head_template.obj')
self.focal_length=focal_length
verts, faces, aux = load_obj(obj_filename)
uvcoords = aux.verts_uvs[None, ...] # (N, V, 2)
uvfaces = faces.textures_idx[None, ...] # (N, F, 3)
faces = faces.verts_idx[None,...]
# shape colors, for rendering shape overlay
colors = torch.tensor([180, 180, 180])[None, None, :].repeat(1, faces.max()+1, 1).float()/255.
flame_masks = pickle.load(
open(osp.join(assets_dir, 'FLAME_masks/FLAME_masks.pkl'), 'rb'),
encoding='latin1')
self.flame_masks = flame_masks
self.register_buffer('faces', faces)
face_colors = face_vertices(colors, faces)
self.register_buffer('face_colors', face_colors)
self.register_buffer('raw_uvcoords', uvcoords)
# uv coords
uvcoords = torch.cat([uvcoords, uvcoords[:,:,0:1]*0.+1.], -1) #[bz, ntv, 3]
uvcoords = uvcoords*2 - 1; uvcoords[...,1] = -uvcoords[...,1]
face_uvcoords = face_vertices(uvcoords, uvfaces)
self.register_buffer('uvcoords', uvcoords)
self.register_buffer('uvfaces', uvfaces)
self.register_buffer('face_uvcoords', face_uvcoords)
def forward(self, vertices, faces=None, landmarks={}, cameras=None, transform_matrix=None, focal_length=None, is_weak_cam=False, ret_image=True):
if faces is None:
faces = self.faces.squeeze(0)
return super().forward(vertices, faces, landmarks, cameras, transform_matrix, focal_length, is_weak_cam, ret_image)
class Renderer2(GS_BaseMeshRenderer):
def __init__(self, assets_dir, image_size=512, device='cuda', focal_length=24):
super().__init__( image_size, focal_length=focal_length)
obj_filename = osp.join(assets_dir, 'head_template.obj')
self.focal_length=focal_length
verts, faces, aux = load_obj(obj_filename)
uvcoords = aux.verts_uvs[None, ...] # (N, V, 2)
uvfaces = faces.textures_idx[None, ...] # (N, F, 3)
faces = faces.verts_idx[None,...]
# shape colors, for rendering shape overlay
colors = torch.tensor([180, 180, 180])[None, None, :].repeat(1, faces.max()+1, 1).float()/255.
flame_masks = pickle.load(
open(osp.join(assets_dir, 'FLAME_masks/FLAME_masks.pkl'), 'rb'),
encoding='latin1')
self.flame_masks = flame_masks
self.register_buffer('faces', faces)
face_colors = face_vertices(colors, faces)
self.register_buffer('face_colors', face_colors)
self.register_buffer('raw_uvcoords', uvcoords)
# uv coords
uvcoords = torch.cat([uvcoords, uvcoords[:,:,0:1]*0.+1.], -1) #[bz, ntv, 3]
uvcoords = uvcoords*2 - 1; uvcoords[...,1] = -uvcoords[...,1]
face_uvcoords = face_vertices(uvcoords, uvfaces)
self.register_buffer('uvcoords', uvcoords)
self.register_buffer('uvfaces', uvfaces)
self.register_buffer('face_uvcoords', face_uvcoords)
def forward(self, vertices, faces=None, landmarks={}, cameras=None, transform_matrix=None, focal_length=None, is_weak_cam=False, ret_image=True):
if faces is None:
faces = self.faces.squeeze(0)
return super().forward(vertices, faces, landmarks, cameras, transform_matrix, focal_length, ret_image)