Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import render_util | |
| import geo_transform | |
| import numpy as np | |
| def compute_tri_normal(geometry, tris): | |
| geometry = geometry.permute(0, 2, 1) | |
| tri_1 = tris[:, 0] | |
| tri_2 = tris[:, 1] | |
| tri_3 = tris[:, 2] | |
| vert_1 = torch.index_select(geometry, 2, tri_1) | |
| vert_2 = torch.index_select(geometry, 2, tri_2) | |
| vert_3 = torch.index_select(geometry, 2, tri_3) | |
| nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 1) | |
| normal = nn.functional.normalize(nnorm).permute(0, 2, 1) | |
| return normal | |
| class Compute_normal_base(torch.autograd.Function): | |
| def forward(ctx, normal): | |
| (normal_b,) = render_util.normal_base_forward(normal) | |
| ctx.save_for_backward(normal) | |
| return normal_b | |
| def backward(ctx, grad_normal_b): | |
| (normal,) = ctx.saved_tensors | |
| (grad_normal,) = render_util.normal_base_backward(grad_normal_b, normal) | |
| return grad_normal | |
| class Normal_Base(torch.nn.Module): | |
| def __init__(self): | |
| super(Normal_Base, self).__init__() | |
| def forward(self, normal): | |
| return Compute_normal_base.apply(normal) | |
| def preprocess_render(geometry, euler, trans, cam, tris, vert_tris, ori_img): | |
| point_num = geometry.shape[1] | |
| rott_geo = geo_transform.euler_trans_geo(geometry, euler, trans) | |
| proj_geo = geo_transform.proj_geo(rott_geo, cam) | |
| rot_tri_normal = compute_tri_normal(rott_geo, tris) | |
| rot_vert_normal = torch.index_select(rot_tri_normal, 1, vert_tris) | |
| is_visible = -torch.bmm( | |
| rot_vert_normal.reshape(-1, 1, 3), | |
| nn.functional.normalize(rott_geo.reshape(-1, 3, 1)), | |
| ).reshape(-1, point_num) | |
| is_visible[is_visible < 0.01] = -1 | |
| pixel_valid = torch.zeros( | |
| (ori_img.shape[0], ori_img.shape[1] * ori_img.shape[2]), | |
| dtype=torch.float32, | |
| device=ori_img.device, | |
| ) | |
| return rott_geo, proj_geo, rot_tri_normal, is_visible, pixel_valid | |
| class Render_Face(torch.autograd.Function): | |
| def forward( | |
| ctx, proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid | |
| ): | |
| batch_size, h, w, _ = ori_img.shape | |
| ori_img = ori_img.view(batch_size, -1, 3) | |
| ori_size = torch.cat( | |
| ( | |
| torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device) | |
| * h, | |
| torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device) | |
| * w, | |
| ), | |
| dim=1, | |
| ).view(-1) | |
| tri_index, tri_coord, render, real = render_util.render_face_forward( | |
| proj_geo, ori_img, ori_size, texture, nbl, is_visible, tri_inds, pixel_valid | |
| ) | |
| ctx.save_for_backward( | |
| ori_img, ori_size, proj_geo, texture, nbl, tri_inds, tri_index, tri_coord | |
| ) | |
| return render, real | |
| def backward(ctx, grad_render, grad_real): | |
| ( | |
| ori_img, | |
| ori_size, | |
| proj_geo, | |
| texture, | |
| nbl, | |
| tri_inds, | |
| tri_index, | |
| tri_coord, | |
| ) = ctx.saved_tensors | |
| grad_proj_geo, grad_texture, grad_nbl = render_util.render_face_backward( | |
| grad_render, | |
| grad_real, | |
| ori_img, | |
| ori_size, | |
| proj_geo, | |
| texture, | |
| nbl, | |
| tri_inds, | |
| tri_index, | |
| tri_coord, | |
| ) | |
| return grad_proj_geo, grad_texture, grad_nbl, None, None, None, None | |
| class Render_RGB(nn.Module): | |
| def __init__(self): | |
| super(Render_RGB, self).__init__() | |
| def forward( | |
| self, proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid | |
| ): | |
| return Render_Face.apply( | |
| proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid | |
| ) | |
| def cal_land(proj_geo, is_visible, lands_info, land_num): | |
| (land_index,) = render_util.update_contour(lands_info, is_visible, land_num) | |
| proj_land = torch.index_select(proj_geo.reshape(-1, 3), 0, land_index)[ | |
| :, :2 | |
| ].reshape(-1, land_num, 2) | |
| return proj_land | |
| class Render_Land(nn.Module): | |
| def __init__(self): | |
| super(Render_Land, self).__init__() | |
| lands_info = np.loadtxt("../data/3DMM/lands_info.txt", dtype=np.int32) | |
| self.lands_info = torch.as_tensor(lands_info).cuda() | |
| tris = np.loadtxt("../data/3DMM/tris.txt", dtype=np.int64) | |
| self.tris = torch.as_tensor(tris).cuda() - 1 | |
| vert_tris = np.loadtxt("../data/3DMM/vert_tris.txt", dtype=np.int64) | |
| self.vert_tris = torch.as_tensor(vert_tris).cuda() | |
| self.normal_baser = Normal_Base().cuda() | |
| self.renderer = Render_RGB().cuda() | |
| def render_mesh(self, geometry, euler, trans, cam, ori_img, light): | |
| batch_size, h, w, _ = ori_img.shape | |
| ori_img = ori_img.view(batch_size, -1, 3) | |
| ori_size = torch.cat( | |
| ( | |
| torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device) | |
| * h, | |
| torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device) | |
| * w, | |
| ), | |
| dim=1, | |
| ).view(-1) | |
| rott_geo, proj_geo, rot_tri_normal, _, _ = preprocess_render( | |
| geometry, euler, trans, cam, self.tris, self.vert_tris, ori_img | |
| ) | |
| tri_nb = self.normal_baser(rot_tri_normal.contiguous()) | |
| nbl = torch.bmm( | |
| tri_nb, (light.reshape(-1, 9, 3))[:, :, 0].unsqueeze(-1).repeat(1, 1, 3) | |
| ) | |
| texture = torch.ones_like(geometry) * 200 | |
| (render,) = render_util.render_mesh( | |
| proj_geo, ori_img, ori_size, texture, nbl, self.tris | |
| ) | |
| return render.view(batch_size, h, w, 3).byte() | |
| def cal_loss_rgb(self, geometry, euler, trans, cam, ori_img, light, texture, lands): | |
| rott_geo, proj_geo, rot_tri_normal, is_visible, pixel_valid = preprocess_render( | |
| geometry, euler, trans, cam, self.tris, self.vert_tris, ori_img | |
| ) | |
| tri_nb = self.normal_baser(rot_tri_normal.contiguous()) | |
| nbl = torch.bmm(tri_nb, light.reshape(-1, 9, 3)) | |
| render, real = self.renderer( | |
| proj_geo, texture, nbl, ori_img, is_visible, self.tris, pixel_valid | |
| ) | |
| proj_land = cal_land(proj_geo, is_visible, self.lands_info, lands.shape[1]) | |
| col_minus = torch.norm((render - real).reshape(-1, 3), dim=1).reshape( | |
| ori_img.shape[0], -1 | |
| ) | |
| col_dis = torch.mean(col_minus * pixel_valid) / ( | |
| torch.mean(pixel_valid) + 0.00001 | |
| ) | |
| land_dists = torch.norm((proj_land - lands).reshape(-1, 2), dim=1).reshape( | |
| ori_img.shape[0], -1 | |
| ) | |
| lan_dis = torch.mean(land_dists) | |
| return col_dis, lan_dis | |