Spaces:
Paused
Paused
Upload pytorch3d_stub/pytorch3d/transforms/__init__.py with huggingface_hub
Browse files
pytorch3d_stub/pytorch3d/transforms/__init__.py
CHANGED
|
@@ -1,173 +1,110 @@
|
|
| 1 |
-
"""pytorch3d.transforms stub
|
| 2 |
import torch
|
| 3 |
-
import
|
| 4 |
|
| 5 |
-
def quaternion_multiply(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
|
| 6 |
-
w1, x1, y1, z1 = q1.unbind(-1)
|
| 7 |
-
w2, x2, y2, z2 = q2.unbind(-1)
|
| 8 |
-
return torch.stack([
|
| 9 |
-
w1*w2 - x1*x2 - y1*y2 - z1*z2,
|
| 10 |
-
w1*x2 + x1*w2 + y1*z2 - z1*y2,
|
| 11 |
-
w1*y2 - x1*z2 + y1*w2 + z1*x2,
|
| 12 |
-
w1*z2 + x1*y2 - y1*x2 + z1*w2,
|
| 13 |
-
], dim=-1)
|
| 14 |
|
| 15 |
-
def
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
|
|
|
| 21 |
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
| 22 |
-
|
| 23 |
1 - two_s * (j*j + k*k), two_s * (i*j - k*r), two_s * (i*k + j*r),
|
| 24 |
two_s * (i*j + k*r), 1 - two_s * (i*i + k*k), two_s * (j*k - i*r),
|
| 25 |
two_s * (i*k - j*r), two_s * (j*k + i*r), 1 - two_s * (i*i + j*j),
|
| 26 |
-
],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
-
def
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
| 39 |
|
| 40 |
|
| 41 |
class Transform3d:
|
| 42 |
def __init__(self, dtype=torch.float32, device="cpu", matrix=None):
|
|
|
|
|
|
|
| 43 |
if matrix is not None:
|
| 44 |
-
self._matrix = matrix.to(
|
| 45 |
else:
|
| 46 |
self._matrix = torch.eye(4, dtype=dtype, device=device).unsqueeze(0)
|
| 47 |
-
|
| 48 |
-
def get_matrix(self) -> torch.Tensor:
|
| 49 |
return self._matrix
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
def
|
| 56 |
-
if
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
return self._compose(S)
|
| 66 |
-
|
| 67 |
-
def rotate(self, R) -> "Transform3d":
|
| 68 |
-
if isinstance(R, torch.Tensor):
|
| 69 |
-
if R.dim() == 2:
|
| 70 |
-
R = R.unsqueeze(0)
|
| 71 |
-
M = torch.eye(4, dtype=self._matrix.dtype, device=self._matrix.device).unsqueeze(0).expand(R.shape[0], -1, -1).clone()
|
| 72 |
-
M[:, :3, :3] = R
|
| 73 |
-
return self._compose(M)
|
| 74 |
-
return self
|
| 75 |
-
|
| 76 |
-
def translate(self, t) -> "Transform3d":
|
| 77 |
-
if isinstance(t, torch.Tensor):
|
| 78 |
-
if t.dim() == 1:
|
| 79 |
-
t = t.unsqueeze(0)
|
| 80 |
-
M = torch.eye(4, dtype=self._matrix.dtype, device=self._matrix.device).unsqueeze(0).expand(t.shape[0], -1, -1).clone()
|
| 81 |
-
M[:, 3, :3] = t
|
| 82 |
-
return self._compose(M)
|
| 83 |
-
return self
|
| 84 |
-
|
| 85 |
-
def compose(self, other: "Transform3d") -> "Transform3d":
|
| 86 |
-
return self._compose(other._matrix)
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
class Rotate(Transform3d):
|
| 90 |
-
def __init__(self, R, dtype=torch.float32, device="cpu"):
|
| 91 |
-
if isinstance(R, torch.Tensor):
|
| 92 |
-
if R.dim() == 2:
|
| 93 |
-
R = R.unsqueeze(0)
|
| 94 |
-
M = torch.eye(4, dtype=dtype, device=device).unsqueeze(0).expand(R.shape[0], -1, -1).clone()
|
| 95 |
-
M[:, :3, :3] = R.to(dtype=dtype, device=device)
|
| 96 |
-
super().__init__(dtype=dtype, device=device, matrix=M)
|
| 97 |
else:
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
else:
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
s = s.unsqueeze(0)
|
| 124 |
-
elif s.dim() == 1 and s.shape[0] != 3:
|
| 125 |
-
s = s.unsqueeze(-1).expand(-1, 3)
|
| 126 |
-
M = torch.eye(4, dtype=dtype, device=device).unsqueeze(0).expand(s.shape[0], -1, -1).clone()
|
| 127 |
-
M[:, 0, 0] = s[:, 0]
|
| 128 |
-
M[:, 1, 1] = s[:, 1]
|
| 129 |
-
M[:, 2, 2] = s[:, 2]
|
| 130 |
-
super().__init__(dtype=dtype, device=device, matrix=M)
|
| 131 |
-
|
| 132 |
-
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
|
| 133 |
-
"""Convert rotation matrix to quaternion (w, x, y, z)."""
|
| 134 |
-
if matrix.dim() == 2:
|
| 135 |
-
matrix = matrix.unsqueeze(0)
|
| 136 |
-
batch_dim = matrix.shape[:-2]
|
| 137 |
-
m = matrix.reshape(-1, 3, 3)
|
| 138 |
-
|
| 139 |
-
trace = m[:, 0, 0] + m[:, 1, 1] + m[:, 2, 2]
|
| 140 |
-
q = torch.zeros(m.shape[0], 4, dtype=m.dtype, device=m.device)
|
| 141 |
-
|
| 142 |
-
s = torch.sqrt((trace + 1.0).clamp(min=1e-10)) * 2 # s = 4*w
|
| 143 |
-
q[:, 0] = 0.25 * s
|
| 144 |
-
q[:, 1] = (m[:, 2, 1] - m[:, 1, 2]) / s.clamp(min=1e-10)
|
| 145 |
-
q[:, 2] = (m[:, 0, 2] - m[:, 2, 0]) / s.clamp(min=1e-10)
|
| 146 |
-
q[:, 3] = (m[:, 1, 0] - m[:, 0, 1]) / s.clamp(min=1e-10)
|
| 147 |
-
|
| 148 |
-
# Handle degenerate cases
|
| 149 |
-
mask1 = (m[:, 0, 0] > m[:, 1, 1]) & (m[:, 0, 0] > m[:, 2, 2]) & (trace <= 0)
|
| 150 |
-
if mask1.any():
|
| 151 |
-
s1 = torch.sqrt((1.0 + m[mask1, 0, 0] - m[mask1, 1, 1] - m[mask1, 2, 2]).clamp(min=1e-10)) * 2
|
| 152 |
-
q[mask1, 0] = (m[mask1, 2, 1] - m[mask1, 1, 2]) / s1.clamp(min=1e-10)
|
| 153 |
-
q[mask1, 1] = 0.25 * s1
|
| 154 |
-
q[mask1, 2] = (m[mask1, 0, 1] + m[mask1, 1, 0]) / s1.clamp(min=1e-10)
|
| 155 |
-
q[mask1, 3] = (m[mask1, 0, 2] + m[mask1, 2, 0]) / s1.clamp(min=1e-10)
|
| 156 |
-
|
| 157 |
-
mask2 = (m[:, 1, 1] > m[:, 2, 2]) & ~mask1 & (trace <= 0)
|
| 158 |
-
if mask2.any():
|
| 159 |
-
s2 = torch.sqrt((1.0 + m[mask2, 1, 1] - m[mask2, 0, 0] - m[mask2, 2, 2]).clamp(min=1e-10)) * 2
|
| 160 |
-
q[mask2, 0] = (m[mask2, 0, 2] - m[mask2, 2, 0]) / s2.clamp(min=1e-10)
|
| 161 |
-
q[mask2, 1] = (m[mask2, 0, 1] + m[mask2, 1, 0]) / s2.clamp(min=1e-10)
|
| 162 |
-
q[mask2, 2] = 0.25 * s2
|
| 163 |
-
q[mask2, 3] = (m[mask2, 1, 2] + m[mask2, 2, 1]) / s2.clamp(min=1e-10)
|
| 164 |
-
|
| 165 |
-
mask3 = ~mask1 & ~mask2 & (trace <= 0)
|
| 166 |
-
if mask3.any():
|
| 167 |
-
s3 = torch.sqrt((1.0 + m[mask3, 2, 2] - m[mask3, 0, 0] - m[mask3, 1, 1]).clamp(min=1e-10)) * 2
|
| 168 |
-
q[mask3, 0] = (m[mask3, 1, 0] - m[mask3, 0, 1]) / s3.clamp(min=1e-10)
|
| 169 |
-
q[mask3, 1] = (m[mask3, 0, 2] + m[mask3, 2, 0]) / s3.clamp(min=1e-10)
|
| 170 |
-
q[mask3, 2] = (m[mask3, 1, 2] + m[mask3, 2, 1]) / s3.clamp(min=1e-10)
|
| 171 |
-
q[mask3, 3] = 0.25 * s3
|
| 172 |
-
|
| 173 |
-
return q.reshape(batch_dim + (4,))
|
|
|
|
| 1 |
+
"""pytorch3d.transforms stub with catch-all."""
|
| 2 |
import torch
|
| 3 |
+
import warnings
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
+
def __getattr__(name):
|
| 7 |
+
if name.startswith("_"):
|
| 8 |
+
raise AttributeError(name)
|
| 9 |
+
warnings.warn(f"pytorch3d.transforms stub: {name} not implemented", stacklevel=2)
|
| 10 |
+
def _dummy(*a, **kw): return None
|
| 11 |
+
_dummy.__name__ = name
|
| 12 |
+
return _dummy
|
| 13 |
|
| 14 |
+
|
| 15 |
+
def quaternion_to_matrix(quaternions):
|
| 16 |
+
r, i, j, k = quaternions[..., 0], quaternions[..., 1], quaternions[..., 2], quaternions[..., 3]
|
| 17 |
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
| 18 |
+
o = torch.stack([
|
| 19 |
1 - two_s * (j*j + k*k), two_s * (i*j - k*r), two_s * (i*k + j*r),
|
| 20 |
two_s * (i*j + k*r), 1 - two_s * (i*i + k*k), two_s * (j*k - i*r),
|
| 21 |
two_s * (i*k - j*r), two_s * (j*k + i*r), 1 - two_s * (i*i + j*j),
|
| 22 |
+
], -1)
|
| 23 |
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def matrix_to_quaternion(matrix):
|
| 27 |
+
if matrix.shape[-2:] != (3, 3):
|
| 28 |
+
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}")
|
| 29 |
+
batch_dim = matrix.shape[:-2]
|
| 30 |
+
m00, m01, m02 = matrix[..., 0, 0], matrix[..., 0, 1], matrix[..., 0, 2]
|
| 31 |
+
m10, m11, m12 = matrix[..., 1, 0], matrix[..., 1, 1], matrix[..., 1, 2]
|
| 32 |
+
m20, m21, m22 = matrix[..., 2, 0], matrix[..., 2, 1], matrix[..., 2, 2]
|
| 33 |
+
trace = m00 + m11 + m22
|
| 34 |
+
qw = torch.sqrt(torch.clamp(trace + 1, min=1e-8)) / 2
|
| 35 |
+
qx = (m21 - m12) / (4 * qw + 1e-8)
|
| 36 |
+
qy = (m02 - m20) / (4 * qw + 1e-8)
|
| 37 |
+
qz = (m10 - m01) / (4 * qw + 1e-8)
|
| 38 |
+
return torch.stack([qw, qx, qy, qz], dim=-1)
|
| 39 |
+
|
| 40 |
|
| 41 |
+
def quaternion_multiply(a, b):
|
| 42 |
+
aw, ax, ay, az = a[..., 0], a[..., 1], a[..., 2], a[..., 3]
|
| 43 |
+
bw, bx, by, bz = b[..., 0], b[..., 1], b[..., 2], b[..., 3]
|
| 44 |
+
ow = aw*bw - ax*bx - ay*by - az*bz
|
| 45 |
+
ox = aw*bx + ax*bw + ay*bz - az*by
|
| 46 |
+
oy = aw*by - ax*bz + ay*bw + az*bx
|
| 47 |
+
oz = aw*bz + ax*by - ay*bx + az*bw
|
| 48 |
+
return torch.stack([ow, ox, oy, oz], dim=-1)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def quaternion_invert(quaternion):
|
| 52 |
+
scaling = torch.tensor([1, -1, -1, -1], dtype=quaternion.dtype, device=quaternion.device)
|
| 53 |
+
return quaternion * scaling
|
| 54 |
|
| 55 |
|
| 56 |
class Transform3d:
|
| 57 |
def __init__(self, dtype=torch.float32, device="cpu", matrix=None):
|
| 58 |
+
self.device = device
|
| 59 |
+
self.dtype = dtype
|
| 60 |
if matrix is not None:
|
| 61 |
+
self._matrix = matrix.to(device=device, dtype=dtype)
|
| 62 |
else:
|
| 63 |
self._matrix = torch.eye(4, dtype=dtype, device=device).unsqueeze(0)
|
| 64 |
+
def get_matrix(self):
|
|
|
|
| 65 |
return self._matrix
|
| 66 |
+
def compose(self, *others):
|
| 67 |
+
m = self._matrix
|
| 68 |
+
for o in others:
|
| 69 |
+
m = m @ o.get_matrix()
|
| 70 |
+
return Transform3d(matrix=m, device=self.device, dtype=self.dtype)
|
| 71 |
+
def transform_points(self, points):
|
| 72 |
+
if points.dim() == 2:
|
| 73 |
+
points = points.unsqueeze(0)
|
| 74 |
+
ones = torch.ones(*points.shape[:-1], 1, dtype=points.dtype, device=points.device)
|
| 75 |
+
pts4 = torch.cat([points, ones], dim=-1)
|
| 76 |
+
out = torch.bmm(pts4, self._matrix.expand(pts4.shape[0], -1, -1).transpose(-2, -1))
|
| 77 |
+
return out[..., :3]
|
| 78 |
+
def translate(self, x, y=None, z=None):
|
| 79 |
+
if isinstance(x, torch.Tensor) and x.dim() >= 1 and x.shape[-1] == 3:
|
| 80 |
+
t = x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
else:
|
| 82 |
+
t = torch.tensor([[x, y, z]], dtype=self.dtype, device=self.device)
|
| 83 |
+
if t.dim() == 1: t = t.unsqueeze(0)
|
| 84 |
+
T = torch.eye(4, dtype=self.dtype, device=self.device).unsqueeze(0).expand(t.shape[0], -1, -1).clone()
|
| 85 |
+
T[:, :3, 3] = t
|
| 86 |
+
new_m = self._matrix @ T
|
| 87 |
+
return Transform3d(matrix=new_m, device=self.device, dtype=self.dtype)
|
| 88 |
+
def scale(self, x, y=None, z=None):
|
| 89 |
+
if isinstance(x, torch.Tensor):
|
| 90 |
+
if x.dim() == 0:
|
| 91 |
+
s = x.expand(3)
|
| 92 |
+
elif x.shape[-1] == 3:
|
| 93 |
+
s = x.squeeze()
|
| 94 |
+
else:
|
| 95 |
+
s = x.expand(3)
|
| 96 |
else:
|
| 97 |
+
if y is None: y = x
|
| 98 |
+
if z is None: z = x
|
| 99 |
+
s = torch.tensor([x, y, z], dtype=self.dtype, device=self.device)
|
| 100 |
+
if s.dim() == 1: s = s.unsqueeze(0)
|
| 101 |
+
S = torch.eye(4, dtype=self.dtype, device=self.device).unsqueeze(0).expand(s.shape[0], -1, -1).clone()
|
| 102 |
+
S[:, 0, 0] = s[:, 0]; S[:, 1, 1] = s[:, 1]; S[:, 2, 2] = s[:, 2]
|
| 103 |
+
new_m = self._matrix @ S
|
| 104 |
+
return Transform3d(matrix=new_m, device=self.device, dtype=self.dtype)
|
| 105 |
+
def rotate(self, R):
|
| 106 |
+
if R.dim() == 2: R = R.unsqueeze(0)
|
| 107 |
+
T = torch.eye(4, dtype=self.dtype, device=self.device).unsqueeze(0).expand(R.shape[0], -1, -1).clone()
|
| 108 |
+
T[:, :3, :3] = R
|
| 109 |
+
new_m = self._matrix @ T
|
| 110 |
+
return Transform3d(matrix=new_m, device=self.device, dtype=self.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|