| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import torch |
| from torch import nn |
| import trimesh |
| import math |
| from typing import NewType |
| from pytorch3d.structures import Meshes |
| from pytorch3d.renderer.mesh import rasterize_meshes |
|
|
| Tensor = NewType('Tensor', torch.Tensor) |
|
|
|
|
| def solid_angles(points: Tensor, |
| triangles: Tensor, |
| thresh: float = 1e-8) -> Tensor: |
| ''' Compute solid angle between the input points and triangles |
| Follows the method described in: |
| The Solid Angle of a Plane Triangle |
| A. VAN OOSTEROM AND J. STRACKEE |
| IEEE TRANSACTIONS ON BIOMEDICAL ENGINEERING, |
| VOL. BME-30, NO. 2, FEBRUARY 1983 |
| Parameters |
| ----------- |
| points: BxQx3 |
| Tensor of input query points |
| triangles: BxFx3x3 |
| Target triangles |
| thresh: float |
| float threshold |
| Returns |
| ------- |
| solid_angles: BxQxF |
| A tensor containing the solid angle between all query points |
| and input triangles |
| ''' |
| |
| centered_tris = triangles[:, None] - points[:, :, None, None] |
|
|
| |
| norms = torch.norm(centered_tris, dim=-1) |
|
|
| |
| cross_prod = torch.cross(centered_tris[:, :, :, 1], |
| centered_tris[:, :, :, 2], |
| dim=-1) |
| |
| numerator = (centered_tris[:, :, :, 0] * cross_prod).sum(dim=-1) |
| del cross_prod |
|
|
| dot01 = (centered_tris[:, :, :, 0] * centered_tris[:, :, :, 1]).sum(dim=-1) |
| dot12 = (centered_tris[:, :, :, 1] * centered_tris[:, :, :, 2]).sum(dim=-1) |
| dot02 = (centered_tris[:, :, :, 0] * centered_tris[:, :, :, 2]).sum(dim=-1) |
| del centered_tris |
|
|
| denominator = (norms.prod(dim=-1) + dot01 * norms[:, :, :, 2] + |
| dot02 * norms[:, :, :, 1] + dot12 * norms[:, :, :, 0]) |
| del dot01, dot12, dot02, norms |
|
|
| |
| solid_angle = torch.atan2(numerator, denominator) |
| del numerator, denominator |
|
|
| torch.cuda.empty_cache() |
|
|
| return 2 * solid_angle |
|
|
|
|
| def winding_numbers(points: Tensor, |
| triangles: Tensor, |
| thresh: float = 1e-8) -> Tensor: |
| ''' Uses winding_numbers to compute inside/outside |
| Robust inside-outside segmentation using generalized winding numbers |
| Alec Jacobson, |
| Ladislav Kavan, |
| Olga Sorkine-Hornung |
| Fast Winding Numbers for Soups and Clouds SIGGRAPH 2018 |
| Gavin Barill |
| NEIL G. Dickson |
| Ryan Schmidt |
| David I.W. Levin |
| and Alec Jacobson |
| Parameters |
| ----------- |
| points: BxQx3 |
| Tensor of input query points |
| triangles: BxFx3x3 |
| Target triangles |
| thresh: float |
| float threshold |
| Returns |
| ------- |
| winding_numbers: BxQ |
| A tensor containing the Generalized winding numbers |
| ''' |
| |
| |
| return 1 / (4 * math.pi) * solid_angles(points, triangles, |
| thresh=thresh).sum(dim=-1) |
|
|
|
|
| def batch_contains(verts, faces, points): |
|
|
| B = verts.shape[0] |
| N = points.shape[1] |
|
|
| verts = verts.detach().cpu() |
| faces = faces.detach().cpu() |
| points = points.detach().cpu() |
| contains = torch.zeros(B, N) |
|
|
| for i in range(B): |
| contains[i] = torch.as_tensor( |
| trimesh.Trimesh(verts[i], faces[i]).contains(points[i])) |
|
|
| return 2.0 * (contains - 0.5) |
|
|
|
|
| def dict2obj(d): |
| |
| |
| if not isinstance(d, dict): |
| return d |
|
|
| class C(object): |
| pass |
|
|
| o = C() |
| for k in d: |
| o.__dict__[k] = dict2obj(d[k]) |
| return o |
|
|
|
|
| def face_vertices(vertices, faces): |
| """ |
| :param vertices: [batch size, number of vertices, 3] |
| :param faces: [batch size, number of faces, 3] |
| :return: [batch size, number of faces, 3, 3] |
| """ |
|
|
| bs, nv = vertices.shape[:2] |
| bs, nf = faces.shape[:2] |
| device = vertices.device |
| faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * |
| nv)[:, None, None] |
| vertices = vertices.reshape((bs * nv, vertices.shape[-1])) |
|
|
| return vertices[faces.long()] |
|
|
|
|
| class Pytorch3dRasterizer(nn.Module): |
| """ Borrowed from https://github.com/facebookresearch/pytorch3d |
| Notice: |
| x,y,z are in image space, normalized |
| can only render squared image now |
| """ |
|
|
| def __init__(self, image_size=224): |
| """ |
| use fixed raster_settings for rendering faces |
| """ |
| super().__init__() |
| raster_settings = { |
| 'image_size': image_size, |
| 'blur_radius': 0.0, |
| 'faces_per_pixel': 1, |
| 'bin_size': None, |
| 'max_faces_per_bin': None, |
| 'perspective_correct': True, |
| 'cull_backfaces': True, |
| } |
| raster_settings = dict2obj(raster_settings) |
| self.raster_settings = raster_settings |
|
|
| def forward(self, vertices, faces, attributes=None): |
| fixed_vertices = vertices.clone() |
| fixed_vertices[..., :2] = -fixed_vertices[..., :2] |
| meshes_screen = Meshes(verts=fixed_vertices.float(), |
| faces=faces.long()) |
| raster_settings = self.raster_settings |
| pix_to_face, zbuf, bary_coords, dists = rasterize_meshes( |
| meshes_screen, |
| image_size=raster_settings.image_size, |
| blur_radius=raster_settings.blur_radius, |
| faces_per_pixel=raster_settings.faces_per_pixel, |
| bin_size=raster_settings.bin_size, |
| max_faces_per_bin=raster_settings.max_faces_per_bin, |
| perspective_correct=raster_settings.perspective_correct, |
| ) |
| vismask = (pix_to_face > -1).float() |
| D = attributes.shape[-1] |
| attributes = attributes.clone() |
| attributes = attributes.view(attributes.shape[0] * attributes.shape[1], |
| 3, attributes.shape[-1]) |
| N, H, W, K, _ = bary_coords.shape |
| mask = pix_to_face == -1 |
| pix_to_face = pix_to_face.clone() |
| pix_to_face[mask] = 0 |
| idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D) |
| pixel_face_vals = attributes.gather(0, idx).view(N, H, W, K, 3, D) |
| pixel_vals = (bary_coords[..., None] * pixel_face_vals).sum(dim=-2) |
| pixel_vals[mask] = 0 |
| pixel_vals = pixel_vals[:, :, :, 0].permute(0, 3, 1, 2) |
| pixel_vals = torch.cat( |
| [pixel_vals, vismask[:, :, :, 0][:, None, :, :]], dim=1) |
| return pixel_vals |
|
|