Spaces:
Build error
Build error
| # part of the code from | |
| # https://github.com/benjiebob/SMALify/blob/master/smal_fitter/p3d_renderer.py | |
| import torch | |
| import torch.nn.functional as F | |
| from scipy.io import loadmat | |
| import numpy as np | |
| # import config | |
| import pytorch3d | |
| from pytorch3d.structures import Meshes | |
| from pytorch3d.renderer import ( | |
| PerspectiveCameras, look_at_view_transform, look_at_rotation, | |
| RasterizationSettings, MeshRenderer, MeshRasterizer, BlendParams, | |
| PointLights, HardPhongShader, SoftSilhouetteShader, Materials, Textures, | |
| DirectionalLights | |
| ) | |
| from pytorch3d.renderer import TexturesVertex, SoftPhongShader | |
| from pytorch3d.io import load_objs_as_meshes | |
| MESH_COLOR_0 = [0, 172, 223] | |
| MESH_COLOR_1 = [172, 223, 0] | |
| ''' | |
| Explanation of the shift between projection results from opendr and pytorch3d: | |
| (0, 0, ?) will be projected to 127.5 (pytorch3d) instead of 128 (opendr) | |
| imagine you have an image of size 4: | |
| middle of the first pixel is 0 | |
| middle of the last pixel is 3 | |
| => middle of the imgae would be 1.5 and not 2! | |
| so in order to go from pytorch3d predictions to opendr we would calculate: p_odr = p_p3d * (128/127.5) | |
| To reproject points (p3d) by hand according to this pytorch3d renderer we would do the following steps: | |
| 1.) build camera matrix | |
| K = np.array([[flength, 0, c_x], | |
| [0, flength, c_y], | |
| [0, 0, 1]], np.float) | |
| 2.) we don't need to add extrinsics, as the mesh comes with translation (which is | |
| added within smal_pytorch). all 3d points are already in the camera coordinate system. | |
| -> projection reduces to p2d_proj = K*p3d | |
| 3.) convert to pytorch3d conventions (0 in the middle of the first pixel) | |
| p2d_proj_pytorch3d = p2d_proj / image_size * (image_size-1.) | |
| renderer.py - project_points_p3d: shows an example of what is described above, but | |
| same focal length for the whole batch | |
| ''' | |
| class SilhRenderer(torch.nn.Module): | |
| def __init__(self, image_size, adapt_R_wldo=False): | |
| super(SilhRenderer, self).__init__() | |
| # see: https://pytorch3d.org/files/fit_textured_mesh.py, line 315 | |
| # adapt_R=True is True for all my experiments | |
| # image_size: one number, integer | |
| # ----- | |
| # set mesh color | |
| self.register_buffer('mesh_color_0', torch.FloatTensor(MESH_COLOR_0)) | |
| self.register_buffer('mesh_color_1', torch.FloatTensor(MESH_COLOR_1)) | |
| # prepare extrinsics, which in our case don't change | |
| R = torch.Tensor(np.eye(3)).float()[None, :, :] | |
| T = torch.Tensor(np.zeros((1, 3))).float() | |
| if adapt_R_wldo: | |
| R[0, 0, 0] = -1 | |
| else: # used for all my own experiments | |
| R[0, 0, 0] = -1 | |
| R[0, 1, 1] = -1 | |
| self.register_buffer('R', R) | |
| self.register_buffer('T', T) | |
| # prepare that part of the intrinsics which does not change either | |
| # principal_point_prep = torch.Tensor([self.image_size / 2., self.image_size / 2.]).float()[None, :].float().to(device) | |
| # image_size_prep = torch.Tensor([self.image_size, self.image_size]).float()[None, :].float().to(device) | |
| self.img_size_scalar = image_size | |
| self.register_buffer('image_size', torch.Tensor([image_size, image_size]).float()[None, :].float()) | |
| self.register_buffer('principal_point', torch.Tensor([image_size / 2., image_size / 2.]).float()[None, :].float()) | |
| # Rasterization settings for differentiable rendering, where the blur_radius | |
| # initialization is based on Liu et al, 'Soft Rasterizer: A Differentiable | |
| # Renderer for Image-based 3D Reasoning', ICCV 2019 | |
| self.blend_params = BlendParams(sigma=1e-4, gamma=1e-4) | |
| self.raster_settings_soft = RasterizationSettings( | |
| image_size=image_size, # 128 | |
| blur_radius=np.log(1. / 1e-4 - 1.)*self.blend_params.sigma, | |
| faces_per_pixel=100) #50, | |
| # Renderer for Image-based 3D Reasoning', body part segmentation | |
| self.blend_params_parts = BlendParams(sigma=2*1e-4, gamma=1e-4) | |
| self.raster_settings_soft_parts = RasterizationSettings( | |
| image_size=image_size, # 128 | |
| blur_radius=np.log(1. / 1e-4 - 1.)*self.blend_params_parts.sigma, | |
| faces_per_pixel=60) #50, | |
| # settings for visualization renderer | |
| self.raster_settings_vis = RasterizationSettings( | |
| image_size=image_size, | |
| blur_radius=0.0, | |
| faces_per_pixel=1) | |
| def _get_cam(self, focal_lengths): | |
| device = focal_lengths.device | |
| bs = focal_lengths.shape[0] | |
| if pytorch3d.__version__ == '0.2.5': | |
| cameras = PerspectiveCameras(device=device, | |
| focal_length=focal_lengths.repeat((1, 2)), | |
| principal_point=self.principal_point.repeat((bs, 1)), | |
| R=self.R.repeat((bs, 1, 1)), T=self.T.repeat((bs, 1)), | |
| image_size=self.image_size.repeat((bs, 1))) | |
| elif pytorch3d.__version__ == '0.6.1': | |
| cameras = PerspectiveCameras(device=device, in_ndc=False, | |
| focal_length=focal_lengths.repeat((1, 2)), | |
| principal_point=self.principal_point.repeat((bs, 1)), | |
| R=self.R.repeat((bs, 1, 1)), T=self.T.repeat((bs, 1)), | |
| image_size=self.image_size.repeat((bs, 1))) | |
| else: | |
| print('this part depends on the version of pytorch3d, code was developed with 0.2.5') | |
| raise ValueError | |
| return cameras | |
| def _get_visualization_from_mesh(self, mesh, cameras, lights=None): | |
| # color renderer for visualization | |
| with torch.no_grad(): | |
| device = mesh.device | |
| # renderer for visualization | |
| if lights is None: | |
| lights = PointLights(device=device, location=[[0.0, 0.0, 3.0]]) | |
| vis_renderer = MeshRenderer( | |
| rasterizer=MeshRasterizer( | |
| cameras=cameras, | |
| raster_settings=self.raster_settings_vis), | |
| shader=HardPhongShader( | |
| device=device, | |
| cameras=cameras, | |
| lights=lights)) | |
| # render image: | |
| visualization = vis_renderer(mesh).permute(0, 3, 1, 2)[:, :3, :, :] | |
| return visualization | |
| def calculate_vertex_visibility(self, vertices, faces, focal_lengths, soft=False): | |
| tex = torch.ones_like(vertices) * self.mesh_color_0 # (1, V, 3) | |
| textures = Textures(verts_rgb=tex) | |
| mesh = Meshes(verts=vertices, faces=faces, textures=textures) | |
| cameras = self._get_cam(focal_lengths) | |
| # NEW: use the rasterizer to check vertex visibility | |
| # see: https://github.com/facebookresearch/pytorch3d/issues/126 | |
| # Get a rasterizer | |
| if soft: | |
| rasterizer = MeshRasterizer(cameras=cameras, | |
| raster_settings=self.raster_settings_soft) | |
| else: | |
| rasterizer = MeshRasterizer(cameras=cameras, | |
| raster_settings=self.raster_settings_vis) | |
| # Get the output from rasterization | |
| fragments = rasterizer(mesh) | |
| # pix_to_face is of shape (N, H, W, 1) | |
| pix_to_face = fragments.pix_to_face | |
| # (F, 3) where F is the total number of faces across all the meshes in the batch | |
| packed_faces = mesh.faces_packed() | |
| # (V, 3) where V is the total number of verts across all the meshes in the batch | |
| packed_verts = mesh.verts_packed() | |
| vertex_visibility_map = torch.zeros(packed_verts.shape[0]) # (V,) | |
| # Indices of unique visible faces | |
| visible_faces = pix_to_face.unique() # [0] # (num_visible_faces ) | |
| # Get Indices of unique visible verts using the vertex indices in the faces | |
| visible_verts_idx = packed_faces[visible_faces] # (num_visible_faces, 3) | |
| unique_visible_verts_idx = torch.unique(visible_verts_idx) # (num_visible_verts, ) | |
| # Update visibility indicator to 1 for all visible vertices | |
| vertex_visibility_map[unique_visible_verts_idx] = 1.0 | |
| # since all meshes have the same amount of vertices, we can reshape the result | |
| bs = vertices.shape[0] | |
| vertex_visibility_map_resh = vertex_visibility_map.reshape((bs, -1)) | |
| return pix_to_face, vertex_visibility_map_resh | |
| def get_torch_meshes(self, vertices, faces, color=0): | |
| # create pytorch mesh | |
| if color == 0: | |
| mesh_color = self.mesh_color_0 | |
| else: | |
| mesh_color = self.mesh_color_1 | |
| tex = torch.ones_like(vertices) * mesh_color # (1, V, 3) | |
| textures = Textures(verts_rgb=tex) | |
| mesh = Meshes(verts=vertices, faces=faces, textures=textures) | |
| return mesh | |
| def get_visualization_nograd(self, vertices, faces, focal_lengths, color=0): | |
| # vertices: torch.Size([bs, 3889, 3]) | |
| # faces: torch.Size([bs, 7774, 3]), int | |
| # focal_lengths: torch.Size([bs, 1]) | |
| device = vertices.device | |
| # create cameras | |
| cameras = self._get_cam(focal_lengths) | |
| # create pytorch mesh | |
| if color == 0: | |
| mesh_color = self.mesh_color_0 # blue | |
| elif color == 1: | |
| mesh_color = self.mesh_color_1 | |
| elif color == 2: | |
| MESH_COLOR_2 = [240, 250, 240] # white | |
| mesh_color = torch.FloatTensor(MESH_COLOR_2).to(device) | |
| elif color == 3: | |
| # MESH_COLOR_3 = [223, 0, 172] # pink | |
| # MESH_COLOR_3 = [245, 245, 220] # beige | |
| MESH_COLOR_3 = [166, 173, 164] | |
| mesh_color = torch.FloatTensor(MESH_COLOR_3).to(device) | |
| else: | |
| MESH_COLOR_2 = [240, 250, 240] | |
| mesh_color = torch.FloatTensor(MESH_COLOR_2).to(device) | |
| tex = torch.ones_like(vertices) * mesh_color # (1, V, 3) | |
| textures = Textures(verts_rgb=tex) | |
| mesh = Meshes(verts=vertices, faces=faces, textures=textures) | |
| # render mesh (no gradients) | |
| # lights = PointLights(device=device, location=[[0.0, 0.0, 3.0]]) | |
| # lights = PointLights(device=device, location=[[2.0, 2.0, -2.0]]) | |
| lights = DirectionalLights(device=device, direction=[[0.0, -5.0, -10.0]]) | |
| visualization = self._get_visualization_from_mesh(mesh, cameras, lights=lights) | |
| return visualization | |
| def project_points(self, points, focal_lengths=None, cameras=None): | |
| # points: torch.Size([bs, n_points, 3]) | |
| # either focal_lengths or cameras is needed: | |
| # focal_lenghts: torch.Size([bs, 1]) | |
| # cameras: pytorch camera, for example PerspectiveCameras() | |
| bs = points.shape[0] | |
| device = points.device | |
| screen_size = self.image_size.repeat((bs, 1)) | |
| if cameras is None: | |
| cameras = self._get_cam(focal_lengths) | |
| if pytorch3d.__version__ == '0.2.5': | |
| proj_points_orig = cameras.transform_points_screen(points, screen_size)[:, :, [1, 0]] # used in the original virtuel environment (for cvpr BARC submission) | |
| elif pytorch3d.__version__ == '0.6.1': | |
| proj_points_orig = cameras.transform_points_screen(points)[:, :, [1, 0]] | |
| else: | |
| print('this part depends on the version of pytorch3d, code was developed with 0.2.5') | |
| raise ValueError | |
| # flip, otherwise the 1st and 2nd row are exchanged compared to the ground truth | |
| proj_points = torch.flip(proj_points_orig, [2]) | |
| # --- project points 'manually' | |
| # j_proj = project_points_p3d(image_size, focal_length, points, device) | |
| return proj_points | |
| def forward(self, vertices, points, faces, focal_lengths, color=None): | |
| # vertices: torch.Size([bs, 3889, 3]) | |
| # points: torch.Size([bs, n_points, 3]) (or None) | |
| # faces: torch.Size([bs, 7774, 3]), int | |
| # focal_lengths: torch.Size([bs, 1]) | |
| # color: if None we don't render a visualization, else it should | |
| # either be 0 or 1 | |
| # ---> important: results are around 0.5 pixels off compared to chumpy! | |
| # have a look at renderer.py for an explanation | |
| # create cameras | |
| cameras = self._get_cam(focal_lengths) | |
| # create pytorch mesh | |
| if color is None or color == 0: | |
| mesh_color = self.mesh_color_0 | |
| else: | |
| mesh_color = self.mesh_color_1 | |
| tex = torch.ones_like(vertices) * mesh_color # (1, V, 3) | |
| textures = Textures(verts_rgb=tex) | |
| mesh = Meshes(verts=vertices, faces=faces, textures=textures) | |
| # silhouette renderer | |
| renderer_silh = MeshRenderer( | |
| rasterizer=MeshRasterizer( | |
| cameras=cameras, | |
| raster_settings=self.raster_settings_soft), | |
| shader=SoftSilhouetteShader(blend_params=self.blend_params)) | |
| # project silhouette | |
| silh_images = renderer_silh(mesh)[..., -1].unsqueeze(1) | |
| # project points | |
| if points is None: | |
| proj_points = None | |
| else: | |
| proj_points = self.project_points(points=points, cameras=cameras) | |
| if color is not None: | |
| # color renderer for visualization (no gradients) | |
| visualization = self._get_visualization_from_mesh(mesh, cameras) | |
| return silh_images, proj_points, visualization | |
| else: | |
| return silh_images, proj_points | |