| | import torch |
| | import torch.nn as nn |
| | import numpy as np |
| | import os |
| | from pytorch3d.structures import Meshes |
| | from pytorch3d.renderer import ( |
| | look_at_view_transform, |
| | PerspectiveCameras, |
| | FoVPerspectiveCameras, |
| | PointLights, |
| | DirectionalLights, |
| | Materials, |
| | RasterizationSettings, |
| | MeshRenderer, |
| | MeshRasterizer, |
| | SoftPhongShader, |
| | TexturesUV, |
| | TexturesVertex, |
| | blending, |
| | ) |
| |
|
| | from pytorch3d.ops import interpolate_face_attributes |
| |
|
| | from pytorch3d.renderer.blending import ( |
| | BlendParams, |
| | hard_rgb_blend, |
| | sigmoid_alpha_blend, |
| | softmax_rgb_blend, |
| | ) |
| |
|
| |
|
| | class SoftSimpleShader(nn.Module): |
| | """ |
| | Per pixel lighting - the lighting model is applied using the interpolated |
| | coordinates and normals for each pixel. The blending function returns the |
| | soft aggregated color using all the faces per pixel. |
| | |
| | To use the default values, simply initialize the shader with the desired |
| | device e.g. |
| | |
| | """ |
| |
|
| | def __init__( |
| | self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None |
| | ): |
| | super().__init__() |
| | self.lights = lights if lights is not None else PointLights(device=device) |
| | self.materials = ( |
| | materials if materials is not None else Materials(device=device) |
| | ) |
| | self.cameras = cameras |
| | self.blend_params = blend_params if blend_params is not None else BlendParams() |
| |
|
| | def to(self, device): |
| | |
| | self.cameras = self.cameras.to(device) |
| | self.materials = self.materials.to(device) |
| | self.lights = self.lights.to(device) |
| | return self |
| |
|
| | def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: |
| |
|
| | texels = meshes.sample_textures(fragments) |
| | blend_params = kwargs.get("blend_params", self.blend_params) |
| |
|
| | cameras = kwargs.get("cameras", self.cameras) |
| | if cameras is None: |
| | msg = "Cameras must be specified either at initialization \ |
| | or in the forward pass of SoftPhongShader" |
| | raise ValueError(msg) |
| | znear = kwargs.get("znear", getattr(cameras, "znear", 1.0)) |
| | zfar = kwargs.get("zfar", getattr(cameras, "zfar", 100.0)) |
| | images = softmax_rgb_blend( |
| | texels, fragments, blend_params, znear=znear, zfar=zfar |
| | ) |
| | return images |
| |
|
| |
|
| | class Render_3DMM(nn.Module): |
| | def __init__( |
| | self, |
| | focal=1015, |
| | img_h=500, |
| | img_w=500, |
| | batch_size=1, |
| | device=torch.device("cuda:0"), |
| | ): |
| | super(Render_3DMM, self).__init__() |
| |
|
| | self.focal = focal |
| | self.img_h = img_h |
| | self.img_w = img_w |
| | self.device = device |
| | self.renderer = self.get_render(batch_size) |
| |
|
| | dir_path = os.path.dirname(os.path.realpath(__file__)) |
| | topo_info = np.load( |
| | os.path.join(dir_path, "3DMM", "topology_info.npy"), allow_pickle=True |
| | ).item() |
| | self.tris = torch.as_tensor(topo_info["tris"]).to(self.device) |
| | self.vert_tris = torch.as_tensor(topo_info["vert_tris"]).to(self.device) |
| |
|
| | def compute_normal(self, geometry): |
| | vert_1 = torch.index_select(geometry, 1, self.tris[:, 0]) |
| | vert_2 = torch.index_select(geometry, 1, self.tris[:, 1]) |
| | vert_3 = torch.index_select(geometry, 1, self.tris[:, 2]) |
| | nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 2) |
| | tri_normal = nn.functional.normalize(nnorm, dim=2) |
| | v_norm = tri_normal[:, self.vert_tris, :].sum(2) |
| | vert_normal = v_norm / v_norm.norm(dim=2).unsqueeze(2) |
| | return vert_normal |
| |
|
| | def get_render(self, batch_size=1): |
| | half_s = self.img_w * 0.5 |
| | R, T = look_at_view_transform(10, 0, 0) |
| | R = R.repeat(batch_size, 1, 1) |
| | T = torch.zeros((batch_size, 3), dtype=torch.float32).to(self.device) |
| |
|
| | cameras = FoVPerspectiveCameras( |
| | device=self.device, |
| | R=R, |
| | T=T, |
| | znear=0.01, |
| | zfar=20, |
| | fov=2 * np.arctan(self.img_w // 2 / self.focal) * 180.0 / np.pi, |
| | ) |
| | lights = PointLights( |
| | device=self.device, |
| | location=[[0.0, 0.0, 1e5]], |
| | ambient_color=[[1, 1, 1]], |
| | specular_color=[[0.0, 0.0, 0.0]], |
| | diffuse_color=[[0.0, 0.0, 0.0]], |
| | ) |
| | sigma = 1e-4 |
| | raster_settings = RasterizationSettings( |
| | image_size=(self.img_h, self.img_w), |
| | blur_radius=np.log(1.0 / 1e-4 - 1.0) * sigma / 18.0, |
| | faces_per_pixel=2, |
| | perspective_correct=False, |
| | ) |
| | blend_params = blending.BlendParams(background_color=[0, 0, 0]) |
| | renderer = MeshRenderer( |
| | rasterizer=MeshRasterizer(raster_settings=raster_settings, cameras=cameras), |
| | shader=SoftSimpleShader( |
| | lights=lights, blend_params=blend_params, cameras=cameras |
| | ), |
| | ) |
| | return renderer.to(self.device) |
| |
|
| | @staticmethod |
| | def Illumination_layer(face_texture, norm, gamma): |
| |
|
| | n_b, num_vertex, _ = face_texture.size() |
| | n_v_full = n_b * num_vertex |
| | gamma = gamma.view(-1, 3, 9).clone() |
| | gamma[:, :, 0] += 0.8 |
| |
|
| | gamma = gamma.permute(0, 2, 1) |
| |
|
| | a0 = np.pi |
| | a1 = 2 * np.pi / np.sqrt(3.0) |
| | a2 = 2 * np.pi / np.sqrt(8.0) |
| | c0 = 1 / np.sqrt(4 * np.pi) |
| | c1 = np.sqrt(3.0) / np.sqrt(4 * np.pi) |
| | c2 = 3 * np.sqrt(5.0) / np.sqrt(12 * np.pi) |
| | d0 = 0.5 / np.sqrt(3.0) |
| |
|
| | Y0 = torch.ones(n_v_full).to(gamma.device).float() * a0 * c0 |
| | norm = norm.view(-1, 3) |
| | nx, ny, nz = norm[:, 0], norm[:, 1], norm[:, 2] |
| | arrH = [] |
| |
|
| | arrH.append(Y0) |
| | arrH.append(-a1 * c1 * ny) |
| | arrH.append(a1 * c1 * nz) |
| | arrH.append(-a1 * c1 * nx) |
| | arrH.append(a2 * c2 * nx * ny) |
| | arrH.append(-a2 * c2 * ny * nz) |
| | arrH.append(a2 * c2 * d0 * (3 * nz.pow(2) - 1)) |
| | arrH.append(-a2 * c2 * nx * nz) |
| | arrH.append(a2 * c2 * 0.5 * (nx.pow(2) - ny.pow(2))) |
| |
|
| | H = torch.stack(arrH, 1) |
| | Y = H.view(n_b, num_vertex, 9) |
| | lighting = Y.bmm(gamma) |
| |
|
| | face_color = face_texture * lighting |
| | return face_color |
| |
|
| | def forward(self, rott_geometry, texture, diffuse_sh): |
| | face_normal = self.compute_normal(rott_geometry) |
| | face_color = self.Illumination_layer(texture, face_normal, diffuse_sh) |
| | face_color = TexturesVertex(face_color) |
| | mesh = Meshes( |
| | rott_geometry, |
| | self.tris.float().repeat(rott_geometry.shape[0], 1, 1), |
| | face_color, |
| | ) |
| | rendered_img = self.renderer(mesh) |
| | rendered_img = torch.clamp(rendered_img, 0, 255) |
| |
|
| | return rendered_img |
| |
|