| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from __future__ import absolute_import |
| from __future__ import print_function |
| from __future__ import division |
|
|
| import sys |
|
|
| import torch |
| import torch.nn as nn |
| import torch.autograd as autograd |
| |
|
|
| import mesh_mesh_intersection |
| import mesh_mesh_intersect_cuda |
|
|
|
|
| class MeshMeshIntersectionFunction(autograd.Function): |
|
|
| @staticmethod |
| @torch.no_grad() |
| def forward(ctx, query_triangles, target_triangles, print_timings=False, |
| max_collisions=32, |
| *args, **kwargs): |
| outputs = mesh_mesh_intersect_cuda.mesh_to_mesh_forward( |
| query_triangles, target_triangles, print_timings=print_timings, |
| max_collisions=max_collisions) |
| |
| collision_faces, collision_bcs = outputs |
| return collision_faces, collision_bcs |
|
|
| @staticmethod |
| def backward(ctx, grad_output, *args, **kwargs): |
| raise NotImplementedError |
|
|
|
|
| class MeshMeshIntersection(nn.Module): |
|
|
| def __init__(self, max_collisions=32): |
| super(MeshMeshIntersection, self).__init__() |
| self.max_collisions = max_collisions |
| |
|
|
| def forward(self, query_triangles, target_triangles, |
| print_timings=False): |
| return MeshMeshIntersectionFunction.apply( |
| query_triangles, target_triangles, print_timings, |
| self.max_collisions) |
|
|