jboth commited on
Commit
0bec954
·
verified ·
1 Parent(s): d9d5913

Upload pytorch3d_stub/pytorch3d/transforms/__init__.py with huggingface_hub

Browse files
pytorch3d_stub/pytorch3d/transforms/__init__.py CHANGED
@@ -1,8 +1,8 @@
1
- """Minimal pytorch3d.transforms stub – quaternion operations only."""
2
  import torch
 
3
 
4
  def quaternion_multiply(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
5
- """Hamilton product of two quaternions (w, x, y, z)."""
6
  w1, x1, y1, z1 = q1.unbind(-1)
7
  w2, x2, y2, z2 = q2.unbind(-1)
8
  return torch.stack([
@@ -13,6 +13,118 @@ def quaternion_multiply(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
13
  ], dim=-1)
14
 
15
  def quaternion_invert(q: torch.Tensor) -> torch.Tensor:
16
- """Invert a quaternion (w, x, y, z) -> (w, -x, -y, -z) / ||q||^2."""
17
  scaling = torch.tensor([1, -1, -1, -1], device=q.device, dtype=q.dtype)
18
  return q * scaling / (q * q).sum(dim=-1, keepdim=True).clamp(min=1e-10)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """pytorch3d.transforms stub – pure PyTorch implementations."""
2
  import torch
3
+ import math
4
 
5
  def quaternion_multiply(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
 
6
  w1, x1, y1, z1 = q1.unbind(-1)
7
  w2, x2, y2, z2 = q2.unbind(-1)
8
  return torch.stack([
 
13
  ], dim=-1)
14
 
15
  def quaternion_invert(q: torch.Tensor) -> torch.Tensor:
 
16
  scaling = torch.tensor([1, -1, -1, -1], device=q.device, dtype=q.dtype)
17
  return q * scaling / (q * q).sum(dim=-1, keepdim=True).clamp(min=1e-10)
18
+
19
+ def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
20
+ r, i, j, k = quaternions.unbind(-1)
21
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
22
+ return torch.stack([
23
+ 1 - two_s * (j*j + k*k), two_s * (i*j - k*r), two_s * (i*k + j*r),
24
+ two_s * (i*j + k*r), 1 - two_s * (i*i + k*k), two_s * (j*k - i*r),
25
+ two_s * (i*k - j*r), two_s * (j*k + i*r), 1 - two_s * (i*i + j*j),
26
+ ], dim=-1).reshape(quaternions.shape[:-1] + (3, 3))
27
+
28
+ def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
29
+ angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
30
+ half_angles = 0.5 * angles
31
+ eps = 1e-6
32
+ small = (angles.abs() < eps).squeeze(-1)
33
+ sin_half = torch.sin(half_angles)
34
+ cos_half = torch.cos(half_angles)
35
+ k = torch.where(angles > eps, sin_half / angles, 0.5 * torch.ones_like(angles))
36
+ w = cos_half.squeeze(-1)
37
+ xyz = axis_angle * k
38
+ return torch.cat([w.unsqueeze(-1), xyz], dim=-1)
39
+
40
+
41
+ class Transform3d:
42
+ def __init__(self, dtype=torch.float32, device="cpu", matrix=None):
43
+ if matrix is not None:
44
+ self._matrix = matrix.to(dtype=dtype, device=device)
45
+ else:
46
+ self._matrix = torch.eye(4, dtype=dtype, device=device).unsqueeze(0)
47
+
48
+ def get_matrix(self) -> torch.Tensor:
49
+ return self._matrix
50
+
51
+ def _compose(self, other_matrix: torch.Tensor) -> "Transform3d":
52
+ new_matrix = torch.bmm(self._matrix, other_matrix)
53
+ return Transform3d(matrix=new_matrix)
54
+
55
+ def scale(self, s) -> "Transform3d":
56
+ if isinstance(s, (int, float)):
57
+ s = torch.tensor([[s, s, s]], dtype=self._matrix.dtype, device=self._matrix.device)
58
+ elif isinstance(s, torch.Tensor):
59
+ if s.dim() == 1:
60
+ s = s.unsqueeze(0)
61
+ S = torch.eye(4, dtype=self._matrix.dtype, device=self._matrix.device).unsqueeze(0).expand(s.shape[0], -1, -1).clone()
62
+ S[:, 0, 0] = s[:, 0]
63
+ S[:, 1, 1] = s[:, 1]
64
+ S[:, 2, 2] = s[:, 2]
65
+ return self._compose(S)
66
+
67
+ def rotate(self, R) -> "Transform3d":
68
+ if isinstance(R, torch.Tensor):
69
+ if R.dim() == 2:
70
+ R = R.unsqueeze(0)
71
+ M = torch.eye(4, dtype=self._matrix.dtype, device=self._matrix.device).unsqueeze(0).expand(R.shape[0], -1, -1).clone()
72
+ M[:, :3, :3] = R
73
+ return self._compose(M)
74
+ return self
75
+
76
+ def translate(self, t) -> "Transform3d":
77
+ if isinstance(t, torch.Tensor):
78
+ if t.dim() == 1:
79
+ t = t.unsqueeze(0)
80
+ M = torch.eye(4, dtype=self._matrix.dtype, device=self._matrix.device).unsqueeze(0).expand(t.shape[0], -1, -1).clone()
81
+ M[:, 3, :3] = t
82
+ return self._compose(M)
83
+ return self
84
+
85
+ def compose(self, other: "Transform3d") -> "Transform3d":
86
+ return self._compose(other._matrix)
87
+
88
+
89
+ class Rotate(Transform3d):
90
+ def __init__(self, R, dtype=torch.float32, device="cpu"):
91
+ if isinstance(R, torch.Tensor):
92
+ if R.dim() == 2:
93
+ R = R.unsqueeze(0)
94
+ M = torch.eye(4, dtype=dtype, device=device).unsqueeze(0).expand(R.shape[0], -1, -1).clone()
95
+ M[:, :3, :3] = R.to(dtype=dtype, device=device)
96
+ super().__init__(dtype=dtype, device=device, matrix=M)
97
+ else:
98
+ super().__init__(dtype=dtype, device=device)
99
+
100
+ class Translate(Transform3d):
101
+ def __init__(self, *args, dtype=torch.float32, device="cpu"):
102
+ if len(args) == 1 and isinstance(args[0], torch.Tensor):
103
+ t = args[0]
104
+ if t.dim() == 1:
105
+ t = t.unsqueeze(0)
106
+ elif len(args) == 3:
107
+ t = torch.tensor([[args[0], args[1], args[2]]], dtype=dtype, device=device)
108
+ else:
109
+ super().__init__(dtype=dtype, device=device)
110
+ return
111
+ M = torch.eye(4, dtype=dtype, device=device).unsqueeze(0).expand(t.shape[0], -1, -1).clone()
112
+ M[:, 3, :3] = t.to(dtype=dtype, device=device)
113
+ super().__init__(dtype=dtype, device=device, matrix=M)
114
+
115
+ class Scale(Transform3d):
116
+ def __init__(self, s, dtype=torch.float32, device="cpu"):
117
+ if isinstance(s, (int, float)):
118
+ s = torch.tensor([[s, s, s]], dtype=dtype, device=device)
119
+ elif isinstance(s, torch.Tensor):
120
+ if s.dim() == 0:
121
+ s = s.reshape(1, 1).expand(1, 3)
122
+ elif s.dim() == 1 and s.shape[0] == 3:
123
+ s = s.unsqueeze(0)
124
+ elif s.dim() == 1 and s.shape[0] != 3:
125
+ s = s.unsqueeze(-1).expand(-1, 3)
126
+ M = torch.eye(4, dtype=dtype, device=device).unsqueeze(0).expand(s.shape[0], -1, -1).clone()
127
+ M[:, 0, 0] = s[:, 0]
128
+ M[:, 1, 1] = s[:, 1]
129
+ M[:, 2, 2] = s[:, 2]
130
+ super().__init__(dtype=dtype, device=device, matrix=M)