jboth commited on
Commit
8707d60
·
verified ·
1 Parent(s): ccebd17

Upload pytorch3d_stub/pytorch3d/transforms/__init__.py with huggingface_hub

Browse files
pytorch3d_stub/pytorch3d/transforms/__init__.py CHANGED
@@ -53,6 +53,31 @@ def quaternion_invert(quaternion):
53
  return quaternion * scaling
54
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  class Transform3d:
57
  def __init__(self, dtype=torch.float32, device="cpu", matrix=None):
58
  if matrix is not None:
@@ -164,3 +189,56 @@ class Transform3d:
164
  if m.dim() == 2:
165
  m = m.unsqueeze(0)
166
  return Transform3d(matrix=m, device=str(self._matrix.device), dtype=self._matrix.dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  return quaternion * scaling
54
 
55
 
56
+ def axis_angle_to_quaternion(axis_angle):
57
+ """Convert rotations given as axis/angle to quaternions.
58
+ Args:
59
+ axis_angle: Rotations given as a vector in axis angle form,
60
+ as a tensor of shape (..., 3), where the magnitude is
61
+ the angle turned anticlockwise in radians around the
62
+ vector's direction.
63
+ Returns:
64
+ quaternions with real parts first, as tensor of shape (..., 4).
65
+ """
66
+ angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
67
+ half_angles = angles * 0.5
68
+ eps = 1e-6
69
+ small_angles = angles.abs() < eps
70
+ sin_half = torch.where(small_angles, 0.5 * torch.ones_like(angles), torch.sin(half_angles) / angles)
71
+ cos_half = torch.where(small_angles, torch.ones_like(angles), torch.cos(half_angles))
72
+ quaternions = torch.cat([cos_half, axis_angle * sin_half], dim=-1)
73
+ return quaternions
74
+
75
+
76
+ def axis_angle_to_matrix(axis_angle):
77
+ """Convert rotations given as axis/angle to rotation matrices."""
78
+ return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
79
+
80
+
81
  class Transform3d:
82
  def __init__(self, dtype=torch.float32, device="cpu", matrix=None):
83
  if matrix is not None:
 
189
  if m.dim() == 2:
190
  m = m.unsqueeze(0)
191
  return Transform3d(matrix=m, device=str(self._matrix.device), dtype=self._matrix.dtype)
192
+
193
+
194
+ class Rotate(Transform3d):
195
+ """Transform3d initialized with a rotation matrix."""
196
+ def __init__(self, R=None, dtype=torch.float32, device="cpu"):
197
+ super().__init__(dtype=dtype, device=device)
198
+ if R is not None:
199
+ if isinstance(R, torch.Tensor):
200
+ R = R.to(device=device, dtype=dtype)
201
+ if R.dim() == 2:
202
+ R = R.unsqueeze(0)
203
+ T = torch.eye(4, dtype=dtype, device=device).unsqueeze(0).expand(R.shape[0], -1, -1).clone()
204
+ T[:, :3, :3] = R
205
+ self._matrix = T
206
+ self.device = self._matrix.device
207
+
208
+
209
+ class Translate(Transform3d):
210
+ """Transform3d initialized with a translation."""
211
+ def __init__(self, x=0.0, y=0.0, z=0.0, dtype=torch.float32, device="cpu"):
212
+ super().__init__(dtype=dtype, device=device)
213
+ if isinstance(x, torch.Tensor):
214
+ t = x.to(device=device, dtype=dtype)
215
+ if t.dim() == 1:
216
+ t = t.unsqueeze(0)
217
+ else:
218
+ t = torch.tensor([[x, y, z]], dtype=dtype, device=device)
219
+ T = torch.eye(4, dtype=dtype, device=device).unsqueeze(0).expand(t.shape[0], -1, -1).clone()
220
+ T[:, :3, 3] = t
221
+ self._matrix = T
222
+ self.device = self._matrix.device
223
+
224
+
225
+ class Scale(Transform3d):
226
+ """Transform3d initialized with a scale."""
227
+ def __init__(self, x=1.0, y=None, z=None, dtype=torch.float32, device="cpu"):
228
+ super().__init__(dtype=dtype, device=device)
229
+ if isinstance(x, torch.Tensor):
230
+ s = x.to(device=device, dtype=dtype)
231
+ if s.dim() == 0:
232
+ s = s.expand(3).unsqueeze(0)
233
+ elif s.dim() == 1 and s.shape[0] == 3:
234
+ s = s.unsqueeze(0)
235
+ elif s.dim() == 1 and s.shape[0] == 1:
236
+ s = s.expand(3).unsqueeze(0)
237
+ else:
238
+ if y is None: y = x
239
+ if z is None: z = x
240
+ s = torch.tensor([[x, y, z]], dtype=dtype, device=device)
241
+ S = torch.eye(4, dtype=dtype, device=device).unsqueeze(0).expand(s.shape[0], -1, -1).clone()
242
+ S[:, 0, 0] = s[:, 0]; S[:, 1, 1] = s[:, 1]; S[:, 2, 2] = s[:, 2]
243
+ self._matrix = S
244
+ self.device = self._matrix.device