jboth commited on
Commit
a32b1b2
·
verified ·
1 Parent(s): 56fbaea

Upload pytorch3d_stub/pytorch3d/transforms/__init__.py with huggingface_hub

Browse files
pytorch3d_stub/pytorch3d/transforms/__init__.py CHANGED
@@ -128,3 +128,46 @@ class Scale(Transform3d):
128
  M[:, 1, 1] = s[:, 1]
129
  M[:, 2, 2] = s[:, 2]
130
  super().__init__(dtype=dtype, device=device, matrix=M)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  M[:, 1, 1] = s[:, 1]
129
  M[:, 2, 2] = s[:, 2]
130
  super().__init__(dtype=dtype, device=device, matrix=M)
131
+
132
+ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
133
+ """Convert rotation matrix to quaternion (w, x, y, z)."""
134
+ if matrix.dim() == 2:
135
+ matrix = matrix.unsqueeze(0)
136
+ batch_dim = matrix.shape[:-2]
137
+ m = matrix.reshape(-1, 3, 3)
138
+
139
+ trace = m[:, 0, 0] + m[:, 1, 1] + m[:, 2, 2]
140
+ q = torch.zeros(m.shape[0], 4, dtype=m.dtype, device=m.device)
141
+
142
+ s = torch.sqrt((trace + 1.0).clamp(min=1e-10)) * 2 # s = 4*w
143
+ q[:, 0] = 0.25 * s
144
+ q[:, 1] = (m[:, 2, 1] - m[:, 1, 2]) / s.clamp(min=1e-10)
145
+ q[:, 2] = (m[:, 0, 2] - m[:, 2, 0]) / s.clamp(min=1e-10)
146
+ q[:, 3] = (m[:, 1, 0] - m[:, 0, 1]) / s.clamp(min=1e-10)
147
+
148
+ # Handle degenerate cases
149
+ mask1 = (m[:, 0, 0] > m[:, 1, 1]) & (m[:, 0, 0] > m[:, 2, 2]) & (trace <= 0)
150
+ if mask1.any():
151
+ s1 = torch.sqrt((1.0 + m[mask1, 0, 0] - m[mask1, 1, 1] - m[mask1, 2, 2]).clamp(min=1e-10)) * 2
152
+ q[mask1, 0] = (m[mask1, 2, 1] - m[mask1, 1, 2]) / s1.clamp(min=1e-10)
153
+ q[mask1, 1] = 0.25 * s1
154
+ q[mask1, 2] = (m[mask1, 0, 1] + m[mask1, 1, 0]) / s1.clamp(min=1e-10)
155
+ q[mask1, 3] = (m[mask1, 0, 2] + m[mask1, 2, 0]) / s1.clamp(min=1e-10)
156
+
157
+ mask2 = (m[:, 1, 1] > m[:, 2, 2]) & ~mask1 & (trace <= 0)
158
+ if mask2.any():
159
+ s2 = torch.sqrt((1.0 + m[mask2, 1, 1] - m[mask2, 0, 0] - m[mask2, 2, 2]).clamp(min=1e-10)) * 2
160
+ q[mask2, 0] = (m[mask2, 0, 2] - m[mask2, 2, 0]) / s2.clamp(min=1e-10)
161
+ q[mask2, 1] = (m[mask2, 0, 1] + m[mask2, 1, 0]) / s2.clamp(min=1e-10)
162
+ q[mask2, 2] = 0.25 * s2
163
+ q[mask2, 3] = (m[mask2, 1, 2] + m[mask2, 2, 1]) / s2.clamp(min=1e-10)
164
+
165
+ mask3 = ~mask1 & ~mask2 & (trace <= 0)
166
+ if mask3.any():
167
+ s3 = torch.sqrt((1.0 + m[mask3, 2, 2] - m[mask3, 0, 0] - m[mask3, 1, 1]).clamp(min=1e-10)) * 2
168
+ q[mask3, 0] = (m[mask3, 1, 0] - m[mask3, 0, 1]) / s3.clamp(min=1e-10)
169
+ q[mask3, 1] = (m[mask3, 0, 2] + m[mask3, 2, 0]) / s3.clamp(min=1e-10)
170
+ q[mask3, 2] = (m[mask3, 1, 2] + m[mask3, 2, 1]) / s3.clamp(min=1e-10)
171
+ q[mask3, 3] = 0.25 * s3
172
+
173
+ return q.reshape(batch_dim + (4,))