File size: 5,696 Bytes
e321b92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from trimesh.graph import face_adjacency
import torch 
import torch.nn as nn

class PrismRegularizationLoss(nn.Module):
    """
    Calculate the loss based on the PriMo energy, as described in the paper:
    PriMo: Coupled Prisms for Intuitive Surface Modeling
    """
    def __init__(self, primo_h):
        super().__init__()
        self.h = primo_h

        # compute coefficient for the energy
        indices = torch.tensor([(i, j) for i in range(2) for j in range(2)])
        indices_A = indices.repeat_interleave(4, dim=0)
        indices_B = indices.repeat(4, 1)
        self.coeff = (torch.ones(1) * 2).pow(((indices_A - indices_B).abs() * -1).sum(dim=1))[None, :]

    def forward(self, transformed_prism, rotations, verts, faces, normals):
        # transformed_prism is (n_faces, 3, 3)
        # verts and faces are from the template (shape 2)
        # * for now assumes there is only one batch
        # todo add batch support
        bs = 1
        verts = verts.reshape(-1, 3)
        normals = normals.reshape(-1, 3)
        faces = faces

        # get the area of each face
        face_areas = self.get_face_areas(verts, faces)  # (n_faces,)

        # get list of edges and the faces that share each edge
        face_ids, edges = face_adjacency(faces.cpu().numpy(), return_edges=True)  # (n_edges, 2), (n_edges, 2)
        face_ids, edges = torch.from_numpy(face_ids).to(verts.device), torch.from_numpy(edges).to(verts.device)

        # normals and rotations of the faces that share each edge
        normals1, normals2 = normals[edges[:, 0]], normals[edges[:, 1]]  # (n_edges, 3), normals are per vertex
        rotations1, rotations2 = rotations[face_ids[:, 0]], rotations[face_ids[:, 1]]  # (n_edges, 3, 3), rotations are per face

        # computed normals from the transformed prism
        # normals = self.compute_normals(transformed_prism)

        # compute the loss
        face_id1, face_id2 = face_ids[:, 0], face_ids[:, 1]  # (n_edges,)
        faces_to_verts = self.get_verts_id_face(faces, edges, face_ids)  # (n_edges, 4)
        verts1_p1, verts2_p1 = transformed_prism[face_id1, faces_to_verts[:, 0]], transformed_prism[face_id1, faces_to_verts[:, 1]]  # (n_edges, 3)
        verts1_p2, verts2_p2 = transformed_prism[face_id2, faces_to_verts[:, 2]], transformed_prism[face_id2, faces_to_verts[:, 3]]  # (n_edges, 3)

        # get the normals per vertex
        # normals1, normals2 = normals[face_id1], normals[face_id2]  # (n_edges, 3)  # normals per face (NOT USED)
        prism1_n1, prism1_n2 = (normals1[:, None] @ rotations1).squeeze(1), (normals2[:, None] @ rotations1).squeeze(1)  # todo check if this is correct
        prism2_n1, prism2_n2 = (normals1[:, None] @ rotations2).squeeze(1), (normals2[:, None] @ rotations2).squeeze(1)

        # get the coordinates of the face of the prism
        # prism1 (1 -> 2)
        f_p1_00, f_p1_01 = verts1_p1 + prism1_n1 * self.h, verts2_p1 + prism1_n2 * self.h  # (n_edges, 3)
        f_p1_10, f_p1_11 = verts1_p1 - prism1_n1 * self.h, verts2_p1 - prism1_n2 * self.h  # (n_edges, 3)
        # prism2 (2 -> 1)
        f_p2_00, f_p2_01 = verts1_p2 + prism2_n1 * self.h, verts2_p2 + prism2_n2 * self.h  # (n_edges, 3)
        f_p2_10, f_p2_11 = verts1_p2 - prism2_n1 * self.h, verts2_p2 - prism2_n2 * self.h  # (n_edges, 3)

        # compute the energy
        A, B = torch.stack((f_p1_00, f_p1_01, f_p1_10, f_p1_11), dim=1), torch.stack((f_p2_00, f_p2_01, f_p2_10, f_p2_11), dim=1)  # (n_edges, 4, 3)
        energy = self.compute_energy(A - B, A - B)  # (n_edges,)

        # compute weight
        area1, area2 = face_areas[face_id1], face_areas[face_id2]  # (n_edges,)
        weight = torch.norm(verts[edges[:, 0]] - verts[edges[:, 1]], dim=1).square() / (area1 + area2)  # (n_edges,)
        # weight = torch.ones_like(weight).to(weight.device)  # todo remove
        energy = energy * weight  # (n_edges,)

        loss = energy.sum() / bs  # todo when batch enabled, need to divide by batch size
        return loss

    def compute_energy(self, A, B):
        """
        Computes the formula sum_{i,j,k,l=0}^{1} a_{ij}b_{kl} 2^{-|i - k| - |j - l|}.
        Assumes that A and B are tensors of size bs x 4 x 3, where bs is the batch size.
        """
        self.coeff = self.coeff.to(A.device)

        A_repeated = A.repeat_interleave(4, dim=1)
        B_repeated = B.repeat(1, 4, 1)

        energy = (A_repeated * B_repeated).sum(dim=-1)
        energy = (energy * self.coeff).sum(dim=1)
        energy = energy / 9

        return energy

    def get_face_areas(self, verts, faces):
        # get the area of each face
        v1, v2, v3 = verts[faces[:, 0]], verts[faces[:, 1]], verts[faces[:, 2]]
        area = 0.5 * torch.cross(v2 - v1, v3 - v1, dim=-1).norm(dim=1)

        return area

    def get_verts_id_face(self, F, E, Q):
        e = E.shape[0]
        Z = torch.zeros((e, 4), dtype=torch.long)

        v1 = F[:, 0][Q[:, 0]]
        v2 = F[:, 1][Q[:, 0]]
        v3 = F[:, 2][Q[:, 0]]
        v4 = F[:, 0][Q[:, 1]]
        v5 = F[:, 1][Q[:, 1]]
        v6 = F[:, 2][Q[:, 1]]

        idx1 = torch.where(v1 == E[:, 0], 0, torch.where(v2 == E[:, 0], 1, torch.where(v3 == E[:, 0], 2, -1)))
        idx2 = torch.where(v1 == E[:, 1], 0, torch.where(v2 == E[:, 1], 1, torch.where(v3 == E[:, 1], 2, -1)))
        idx3 = torch.where(v4 == E[:, 0], 0, torch.where(v5 == E[:, 0], 1, torch.where(v6 == E[:, 0], 2, -1)))
        idx4 = torch.where(v4 == E[:, 1], 0, torch.where(v5 == E[:, 1], 1, torch.where(v6 == E[:, 1], 2, -1)))

        Z[:, 0:2] = torch.stack((idx1, idx2), dim=1)
        Z[:, 2:4] = torch.stack((idx3, idx4), dim=1)
        Z = Z.to(F.device)

        return Z