File size: 2,157 Bytes
0a95064
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# -*- coding: utf-8 -*-

# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: ps-license@tuebingen.mpg.de

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import sys

import torch
import torch.nn as nn
import torch.autograd as autograd
# from loguru import logger

import mesh_mesh_intersection
import mesh_mesh_intersect_cuda


class MeshMeshIntersectionFunction(autograd.Function):

    @staticmethod
    @torch.no_grad()
    def forward(ctx, query_triangles, target_triangles, print_timings=False,
                max_collisions=32,
                *args, **kwargs):
        outputs = mesh_mesh_intersect_cuda.mesh_to_mesh_forward(
            query_triangles, target_triangles, print_timings=print_timings,
            max_collisions=max_collisions)
        #  ctx.save_for_backward(query_triangles, outputs)
        collision_faces, collision_bcs = outputs
        return collision_faces, collision_bcs

    @staticmethod
    def backward(ctx, grad_output, *args, **kwargs):
        raise NotImplementedError


class MeshMeshIntersection(nn.Module):

    def __init__(self, max_collisions=32):
        super(MeshMeshIntersection, self).__init__()
        self.max_collisions = max_collisions
        #  MeshMeshIntersectionFunction.max_collisions = self.max_collisions

    def forward(self, query_triangles, target_triangles,
                print_timings=False):
        return MeshMeshIntersectionFunction.apply(
            query_triangles, target_triangles, print_timings,
            self.max_collisions)