jboth commited on
Commit
e92bb9f
·
verified ·
1 Parent(s): 8d991db

Upload pytorch3d_stub/pytorch3d/transforms/__init__.py with huggingface_hub

Browse files
pytorch3d_stub/pytorch3d/transforms/__init__.py CHANGED
@@ -1,173 +1,110 @@
1
- """pytorch3d.transforms stub pure PyTorch implementations."""
2
  import torch
3
- import math
4
 
5
- def quaternion_multiply(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
6
- w1, x1, y1, z1 = q1.unbind(-1)
7
- w2, x2, y2, z2 = q2.unbind(-1)
8
- return torch.stack([
9
- w1*w2 - x1*x2 - y1*y2 - z1*z2,
10
- w1*x2 + x1*w2 + y1*z2 - z1*y2,
11
- w1*y2 - x1*z2 + y1*w2 + z1*x2,
12
- w1*z2 + x1*y2 - y1*x2 + z1*w2,
13
- ], dim=-1)
14
 
15
- def quaternion_invert(q: torch.Tensor) -> torch.Tensor:
16
- scaling = torch.tensor([1, -1, -1, -1], device=q.device, dtype=q.dtype)
17
- return q * scaling / (q * q).sum(dim=-1, keepdim=True).clamp(min=1e-10)
 
 
 
 
18
 
19
- def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
20
- r, i, j, k = quaternions.unbind(-1)
 
21
  two_s = 2.0 / (quaternions * quaternions).sum(-1)
22
- return torch.stack([
23
  1 - two_s * (j*j + k*k), two_s * (i*j - k*r), two_s * (i*k + j*r),
24
  two_s * (i*j + k*r), 1 - two_s * (i*i + k*k), two_s * (j*k - i*r),
25
  two_s * (i*k - j*r), two_s * (j*k + i*r), 1 - two_s * (i*i + j*j),
26
- ], dim=-1).reshape(quaternions.shape[:-1] + (3, 3))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
29
- angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
30
- half_angles = 0.5 * angles
31
- eps = 1e-6
32
- small = (angles.abs() < eps).squeeze(-1)
33
- sin_half = torch.sin(half_angles)
34
- cos_half = torch.cos(half_angles)
35
- k = torch.where(angles > eps, sin_half / angles, 0.5 * torch.ones_like(angles))
36
- w = cos_half.squeeze(-1)
37
- xyz = axis_angle * k
38
- return torch.cat([w.unsqueeze(-1), xyz], dim=-1)
 
 
39
 
40
 
41
  class Transform3d:
42
  def __init__(self, dtype=torch.float32, device="cpu", matrix=None):
 
 
43
  if matrix is not None:
44
- self._matrix = matrix.to(dtype=dtype, device=device)
45
  else:
46
  self._matrix = torch.eye(4, dtype=dtype, device=device).unsqueeze(0)
47
-
48
- def get_matrix(self) -> torch.Tensor:
49
  return self._matrix
50
-
51
- def _compose(self, other_matrix: torch.Tensor) -> "Transform3d":
52
- new_matrix = torch.bmm(self._matrix, other_matrix)
53
- return Transform3d(matrix=new_matrix)
54
-
55
- def scale(self, s) -> "Transform3d":
56
- if isinstance(s, (int, float)):
57
- s = torch.tensor([[s, s, s]], dtype=self._matrix.dtype, device=self._matrix.device)
58
- elif isinstance(s, torch.Tensor):
59
- if s.dim() == 1:
60
- s = s.unsqueeze(0)
61
- S = torch.eye(4, dtype=self._matrix.dtype, device=self._matrix.device).unsqueeze(0).expand(s.shape[0], -1, -1).clone()
62
- S[:, 0, 0] = s[:, 0]
63
- S[:, 1, 1] = s[:, 1]
64
- S[:, 2, 2] = s[:, 2]
65
- return self._compose(S)
66
-
67
- def rotate(self, R) -> "Transform3d":
68
- if isinstance(R, torch.Tensor):
69
- if R.dim() == 2:
70
- R = R.unsqueeze(0)
71
- M = torch.eye(4, dtype=self._matrix.dtype, device=self._matrix.device).unsqueeze(0).expand(R.shape[0], -1, -1).clone()
72
- M[:, :3, :3] = R
73
- return self._compose(M)
74
- return self
75
-
76
- def translate(self, t) -> "Transform3d":
77
- if isinstance(t, torch.Tensor):
78
- if t.dim() == 1:
79
- t = t.unsqueeze(0)
80
- M = torch.eye(4, dtype=self._matrix.dtype, device=self._matrix.device).unsqueeze(0).expand(t.shape[0], -1, -1).clone()
81
- M[:, 3, :3] = t
82
- return self._compose(M)
83
- return self
84
-
85
- def compose(self, other: "Transform3d") -> "Transform3d":
86
- return self._compose(other._matrix)
87
-
88
-
89
- class Rotate(Transform3d):
90
- def __init__(self, R, dtype=torch.float32, device="cpu"):
91
- if isinstance(R, torch.Tensor):
92
- if R.dim() == 2:
93
- R = R.unsqueeze(0)
94
- M = torch.eye(4, dtype=dtype, device=device).unsqueeze(0).expand(R.shape[0], -1, -1).clone()
95
- M[:, :3, :3] = R.to(dtype=dtype, device=device)
96
- super().__init__(dtype=dtype, device=device, matrix=M)
97
  else:
98
- super().__init__(dtype=dtype, device=device)
99
-
100
- class Translate(Transform3d):
101
- def __init__(self, *args, dtype=torch.float32, device="cpu"):
102
- if len(args) == 1 and isinstance(args[0], torch.Tensor):
103
- t = args[0]
104
- if t.dim() == 1:
105
- t = t.unsqueeze(0)
106
- elif len(args) == 3:
107
- t = torch.tensor([[args[0], args[1], args[2]]], dtype=dtype, device=device)
 
 
 
 
108
  else:
109
- super().__init__(dtype=dtype, device=device)
110
- return
111
- M = torch.eye(4, dtype=dtype, device=device).unsqueeze(0).expand(t.shape[0], -1, -1).clone()
112
- M[:, 3, :3] = t.to(dtype=dtype, device=device)
113
- super().__init__(dtype=dtype, device=device, matrix=M)
114
-
115
- class Scale(Transform3d):
116
- def __init__(self, s, dtype=torch.float32, device="cpu"):
117
- if isinstance(s, (int, float)):
118
- s = torch.tensor([[s, s, s]], dtype=dtype, device=device)
119
- elif isinstance(s, torch.Tensor):
120
- if s.dim() == 0:
121
- s = s.reshape(1, 1).expand(1, 3)
122
- elif s.dim() == 1 and s.shape[0] == 3:
123
- s = s.unsqueeze(0)
124
- elif s.dim() == 1 and s.shape[0] != 3:
125
- s = s.unsqueeze(-1).expand(-1, 3)
126
- M = torch.eye(4, dtype=dtype, device=device).unsqueeze(0).expand(s.shape[0], -1, -1).clone()
127
- M[:, 0, 0] = s[:, 0]
128
- M[:, 1, 1] = s[:, 1]
129
- M[:, 2, 2] = s[:, 2]
130
- super().__init__(dtype=dtype, device=device, matrix=M)
131
-
132
- def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
133
- """Convert rotation matrix to quaternion (w, x, y, z)."""
134
- if matrix.dim() == 2:
135
- matrix = matrix.unsqueeze(0)
136
- batch_dim = matrix.shape[:-2]
137
- m = matrix.reshape(-1, 3, 3)
138
-
139
- trace = m[:, 0, 0] + m[:, 1, 1] + m[:, 2, 2]
140
- q = torch.zeros(m.shape[0], 4, dtype=m.dtype, device=m.device)
141
-
142
- s = torch.sqrt((trace + 1.0).clamp(min=1e-10)) * 2 # s = 4*w
143
- q[:, 0] = 0.25 * s
144
- q[:, 1] = (m[:, 2, 1] - m[:, 1, 2]) / s.clamp(min=1e-10)
145
- q[:, 2] = (m[:, 0, 2] - m[:, 2, 0]) / s.clamp(min=1e-10)
146
- q[:, 3] = (m[:, 1, 0] - m[:, 0, 1]) / s.clamp(min=1e-10)
147
-
148
- # Handle degenerate cases
149
- mask1 = (m[:, 0, 0] > m[:, 1, 1]) & (m[:, 0, 0] > m[:, 2, 2]) & (trace <= 0)
150
- if mask1.any():
151
- s1 = torch.sqrt((1.0 + m[mask1, 0, 0] - m[mask1, 1, 1] - m[mask1, 2, 2]).clamp(min=1e-10)) * 2
152
- q[mask1, 0] = (m[mask1, 2, 1] - m[mask1, 1, 2]) / s1.clamp(min=1e-10)
153
- q[mask1, 1] = 0.25 * s1
154
- q[mask1, 2] = (m[mask1, 0, 1] + m[mask1, 1, 0]) / s1.clamp(min=1e-10)
155
- q[mask1, 3] = (m[mask1, 0, 2] + m[mask1, 2, 0]) / s1.clamp(min=1e-10)
156
-
157
- mask2 = (m[:, 1, 1] > m[:, 2, 2]) & ~mask1 & (trace <= 0)
158
- if mask2.any():
159
- s2 = torch.sqrt((1.0 + m[mask2, 1, 1] - m[mask2, 0, 0] - m[mask2, 2, 2]).clamp(min=1e-10)) * 2
160
- q[mask2, 0] = (m[mask2, 0, 2] - m[mask2, 2, 0]) / s2.clamp(min=1e-10)
161
- q[mask2, 1] = (m[mask2, 0, 1] + m[mask2, 1, 0]) / s2.clamp(min=1e-10)
162
- q[mask2, 2] = 0.25 * s2
163
- q[mask2, 3] = (m[mask2, 1, 2] + m[mask2, 2, 1]) / s2.clamp(min=1e-10)
164
-
165
- mask3 = ~mask1 & ~mask2 & (trace <= 0)
166
- if mask3.any():
167
- s3 = torch.sqrt((1.0 + m[mask3, 2, 2] - m[mask3, 0, 0] - m[mask3, 1, 1]).clamp(min=1e-10)) * 2
168
- q[mask3, 0] = (m[mask3, 1, 0] - m[mask3, 0, 1]) / s3.clamp(min=1e-10)
169
- q[mask3, 1] = (m[mask3, 0, 2] + m[mask3, 2, 0]) / s3.clamp(min=1e-10)
170
- q[mask3, 2] = (m[mask3, 1, 2] + m[mask3, 2, 1]) / s3.clamp(min=1e-10)
171
- q[mask3, 3] = 0.25 * s3
172
-
173
- return q.reshape(batch_dim + (4,))
 
1
+ """pytorch3d.transforms stub with catch-all."""
2
  import torch
3
+ import warnings
4
 
 
 
 
 
 
 
 
 
 
5
 
6
+ def __getattr__(name):
7
+ if name.startswith("_"):
8
+ raise AttributeError(name)
9
+ warnings.warn(f"pytorch3d.transforms stub: {name} not implemented", stacklevel=2)
10
+ def _dummy(*a, **kw): return None
11
+ _dummy.__name__ = name
12
+ return _dummy
13
 
14
+
15
+ def quaternion_to_matrix(quaternions):
16
+ r, i, j, k = quaternions[..., 0], quaternions[..., 1], quaternions[..., 2], quaternions[..., 3]
17
  two_s = 2.0 / (quaternions * quaternions).sum(-1)
18
+ o = torch.stack([
19
  1 - two_s * (j*j + k*k), two_s * (i*j - k*r), two_s * (i*k + j*r),
20
  two_s * (i*j + k*r), 1 - two_s * (i*i + k*k), two_s * (j*k - i*r),
21
  two_s * (i*k - j*r), two_s * (j*k + i*r), 1 - two_s * (i*i + j*j),
22
+ ], -1)
23
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
24
+
25
+
26
+ def matrix_to_quaternion(matrix):
27
+ if matrix.shape[-2:] != (3, 3):
28
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}")
29
+ batch_dim = matrix.shape[:-2]
30
+ m00, m01, m02 = matrix[..., 0, 0], matrix[..., 0, 1], matrix[..., 0, 2]
31
+ m10, m11, m12 = matrix[..., 1, 0], matrix[..., 1, 1], matrix[..., 1, 2]
32
+ m20, m21, m22 = matrix[..., 2, 0], matrix[..., 2, 1], matrix[..., 2, 2]
33
+ trace = m00 + m11 + m22
34
+ qw = torch.sqrt(torch.clamp(trace + 1, min=1e-8)) / 2
35
+ qx = (m21 - m12) / (4 * qw + 1e-8)
36
+ qy = (m02 - m20) / (4 * qw + 1e-8)
37
+ qz = (m10 - m01) / (4 * qw + 1e-8)
38
+ return torch.stack([qw, qx, qy, qz], dim=-1)
39
+
40
 
41
+ def quaternion_multiply(a, b):
42
+ aw, ax, ay, az = a[..., 0], a[..., 1], a[..., 2], a[..., 3]
43
+ bw, bx, by, bz = b[..., 0], b[..., 1], b[..., 2], b[..., 3]
44
+ ow = aw*bw - ax*bx - ay*by - az*bz
45
+ ox = aw*bx + ax*bw + ay*bz - az*by
46
+ oy = aw*by - ax*bz + ay*bw + az*bx
47
+ oz = aw*bz + ax*by - ay*bx + az*bw
48
+ return torch.stack([ow, ox, oy, oz], dim=-1)
49
+
50
+
51
+ def quaternion_invert(quaternion):
52
+ scaling = torch.tensor([1, -1, -1, -1], dtype=quaternion.dtype, device=quaternion.device)
53
+ return quaternion * scaling
54
 
55
 
56
  class Transform3d:
57
  def __init__(self, dtype=torch.float32, device="cpu", matrix=None):
58
+ self.device = device
59
+ self.dtype = dtype
60
  if matrix is not None:
61
+ self._matrix = matrix.to(device=device, dtype=dtype)
62
  else:
63
  self._matrix = torch.eye(4, dtype=dtype, device=device).unsqueeze(0)
64
+ def get_matrix(self):
 
65
  return self._matrix
66
+ def compose(self, *others):
67
+ m = self._matrix
68
+ for o in others:
69
+ m = m @ o.get_matrix()
70
+ return Transform3d(matrix=m, device=self.device, dtype=self.dtype)
71
+ def transform_points(self, points):
72
+ if points.dim() == 2:
73
+ points = points.unsqueeze(0)
74
+ ones = torch.ones(*points.shape[:-1], 1, dtype=points.dtype, device=points.device)
75
+ pts4 = torch.cat([points, ones], dim=-1)
76
+ out = torch.bmm(pts4, self._matrix.expand(pts4.shape[0], -1, -1).transpose(-2, -1))
77
+ return out[..., :3]
78
+ def translate(self, x, y=None, z=None):
79
+ if isinstance(x, torch.Tensor) and x.dim() >= 1 and x.shape[-1] == 3:
80
+ t = x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  else:
82
+ t = torch.tensor([[x, y, z]], dtype=self.dtype, device=self.device)
83
+ if t.dim() == 1: t = t.unsqueeze(0)
84
+ T = torch.eye(4, dtype=self.dtype, device=self.device).unsqueeze(0).expand(t.shape[0], -1, -1).clone()
85
+ T[:, :3, 3] = t
86
+ new_m = self._matrix @ T
87
+ return Transform3d(matrix=new_m, device=self.device, dtype=self.dtype)
88
+ def scale(self, x, y=None, z=None):
89
+ if isinstance(x, torch.Tensor):
90
+ if x.dim() == 0:
91
+ s = x.expand(3)
92
+ elif x.shape[-1] == 3:
93
+ s = x.squeeze()
94
+ else:
95
+ s = x.expand(3)
96
  else:
97
+ if y is None: y = x
98
+ if z is None: z = x
99
+ s = torch.tensor([x, y, z], dtype=self.dtype, device=self.device)
100
+ if s.dim() == 1: s = s.unsqueeze(0)
101
+ S = torch.eye(4, dtype=self.dtype, device=self.device).unsqueeze(0).expand(s.shape[0], -1, -1).clone()
102
+ S[:, 0, 0] = s[:, 0]; S[:, 1, 1] = s[:, 1]; S[:, 2, 2] = s[:, 2]
103
+ new_m = self._matrix @ S
104
+ return Transform3d(matrix=new_m, device=self.device, dtype=self.dtype)
105
+ def rotate(self, R):
106
+ if R.dim() == 2: R = R.unsqueeze(0)
107
+ T = torch.eye(4, dtype=self.dtype, device=self.device).unsqueeze(0).expand(R.shape[0], -1, -1).clone()
108
+ T[:, :3, :3] = R
109
+ new_m = self._matrix @ T
110
+ return Transform3d(matrix=new_m, device=self.device, dtype=self.dtype)