Spaces:
Paused
Paused
Upload pytorch3d_stub/pytorch3d/transforms/__init__.py with huggingface_hub
Browse files
pytorch3d_stub/pytorch3d/transforms/__init__.py
CHANGED
|
@@ -128,3 +128,46 @@ class Scale(Transform3d):
|
|
| 128 |
M[:, 1, 1] = s[:, 1]
|
| 129 |
M[:, 2, 2] = s[:, 2]
|
| 130 |
super().__init__(dtype=dtype, device=device, matrix=M)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
M[:, 1, 1] = s[:, 1]
|
| 129 |
M[:, 2, 2] = s[:, 2]
|
| 130 |
super().__init__(dtype=dtype, device=device, matrix=M)
|
| 131 |
+
|
| 132 |
+
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
|
| 133 |
+
"""Convert rotation matrix to quaternion (w, x, y, z)."""
|
| 134 |
+
if matrix.dim() == 2:
|
| 135 |
+
matrix = matrix.unsqueeze(0)
|
| 136 |
+
batch_dim = matrix.shape[:-2]
|
| 137 |
+
m = matrix.reshape(-1, 3, 3)
|
| 138 |
+
|
| 139 |
+
trace = m[:, 0, 0] + m[:, 1, 1] + m[:, 2, 2]
|
| 140 |
+
q = torch.zeros(m.shape[0], 4, dtype=m.dtype, device=m.device)
|
| 141 |
+
|
| 142 |
+
s = torch.sqrt((trace + 1.0).clamp(min=1e-10)) * 2 # s = 4*w
|
| 143 |
+
q[:, 0] = 0.25 * s
|
| 144 |
+
q[:, 1] = (m[:, 2, 1] - m[:, 1, 2]) / s.clamp(min=1e-10)
|
| 145 |
+
q[:, 2] = (m[:, 0, 2] - m[:, 2, 0]) / s.clamp(min=1e-10)
|
| 146 |
+
q[:, 3] = (m[:, 1, 0] - m[:, 0, 1]) / s.clamp(min=1e-10)
|
| 147 |
+
|
| 148 |
+
# Handle degenerate cases
|
| 149 |
+
mask1 = (m[:, 0, 0] > m[:, 1, 1]) & (m[:, 0, 0] > m[:, 2, 2]) & (trace <= 0)
|
| 150 |
+
if mask1.any():
|
| 151 |
+
s1 = torch.sqrt((1.0 + m[mask1, 0, 0] - m[mask1, 1, 1] - m[mask1, 2, 2]).clamp(min=1e-10)) * 2
|
| 152 |
+
q[mask1, 0] = (m[mask1, 2, 1] - m[mask1, 1, 2]) / s1.clamp(min=1e-10)
|
| 153 |
+
q[mask1, 1] = 0.25 * s1
|
| 154 |
+
q[mask1, 2] = (m[mask1, 0, 1] + m[mask1, 1, 0]) / s1.clamp(min=1e-10)
|
| 155 |
+
q[mask1, 3] = (m[mask1, 0, 2] + m[mask1, 2, 0]) / s1.clamp(min=1e-10)
|
| 156 |
+
|
| 157 |
+
mask2 = (m[:, 1, 1] > m[:, 2, 2]) & ~mask1 & (trace <= 0)
|
| 158 |
+
if mask2.any():
|
| 159 |
+
s2 = torch.sqrt((1.0 + m[mask2, 1, 1] - m[mask2, 0, 0] - m[mask2, 2, 2]).clamp(min=1e-10)) * 2
|
| 160 |
+
q[mask2, 0] = (m[mask2, 0, 2] - m[mask2, 2, 0]) / s2.clamp(min=1e-10)
|
| 161 |
+
q[mask2, 1] = (m[mask2, 0, 1] + m[mask2, 1, 0]) / s2.clamp(min=1e-10)
|
| 162 |
+
q[mask2, 2] = 0.25 * s2
|
| 163 |
+
q[mask2, 3] = (m[mask2, 1, 2] + m[mask2, 2, 1]) / s2.clamp(min=1e-10)
|
| 164 |
+
|
| 165 |
+
mask3 = ~mask1 & ~mask2 & (trace <= 0)
|
| 166 |
+
if mask3.any():
|
| 167 |
+
s3 = torch.sqrt((1.0 + m[mask3, 2, 2] - m[mask3, 0, 0] - m[mask3, 1, 1]).clamp(min=1e-10)) * 2
|
| 168 |
+
q[mask3, 0] = (m[mask3, 1, 0] - m[mask3, 0, 1]) / s3.clamp(min=1e-10)
|
| 169 |
+
q[mask3, 1] = (m[mask3, 0, 2] + m[mask3, 2, 0]) / s3.clamp(min=1e-10)
|
| 170 |
+
q[mask3, 2] = (m[mask3, 1, 2] + m[mask3, 2, 1]) / s3.clamp(min=1e-10)
|
| 171 |
+
q[mask3, 3] = 0.25 * s3
|
| 172 |
+
|
| 173 |
+
return q.reshape(batch_dim + (4,))
|