jboth commited on
Commit
e342598
·
verified ·
1 Parent(s): 84a3df6

Upload pytorch3d_stub/pytorch3d/renderer/__init__.py with huggingface_hub

Browse files
pytorch3d_stub/pytorch3d/renderer/__init__.py CHANGED
@@ -1,10 +1,10 @@
1
- """pytorch3d.renderer stub – look_at_view_transform only."""
2
  import torch
3
  import math
4
 
 
5
  def look_at_view_transform(dist=1.0, elev=0.0, azim=0.0, degrees=True, eye=None,
6
  at=((0, 0, 0),), up=((0, 1, 0),), device="cpu"):
7
- """Compute R, T for look-at camera transform (pytorch3d convention)."""
8
  if eye is not None:
9
  if not isinstance(eye, torch.Tensor):
10
  eye = torch.tensor(eye, dtype=torch.float32, device=device)
@@ -12,11 +12,10 @@ def look_at_view_transform(dist=1.0, elev=0.0, azim=0.0, degrees=True, eye=None,
12
  eye = eye.unsqueeze(0)
13
  else:
14
  if degrees:
15
- elev_r = math.radians(elev) if isinstance(elev, (int, float)) else torch.deg2rad(torch.tensor(elev, dtype=torch.float32))
16
- azim_r = math.radians(azim) if isinstance(azim, (int, float)) else torch.deg2rad(torch.tensor(azim, dtype=torch.float32))
17
  else:
18
- elev_r = elev
19
- azim_r = azim
20
  if isinstance(elev_r, (int, float)):
21
  x = dist * math.cos(elev_r) * math.sin(azim_r)
22
  y = dist * math.sin(elev_r)
@@ -27,7 +26,6 @@ def look_at_view_transform(dist=1.0, elev=0.0, azim=0.0, degrees=True, eye=None,
27
  y = dist * torch.sin(elev_r)
28
  z = dist * torch.cos(elev_r) * torch.cos(azim_r)
29
  eye = torch.stack([x, y, z], dim=-1).unsqueeze(0).to(device)
30
-
31
  if not isinstance(at, torch.Tensor):
32
  at = torch.tensor(at, dtype=torch.float32, device=device)
33
  if at.dim() == 1:
@@ -36,13 +34,74 @@ def look_at_view_transform(dist=1.0, elev=0.0, azim=0.0, degrees=True, eye=None,
36
  up = torch.tensor(up, dtype=torch.float32, device=device)
37
  if up.dim() == 1:
38
  up = up.unsqueeze(0)
39
-
40
  z_axis = eye - at
41
  z_axis = z_axis / z_axis.norm(dim=-1, keepdim=True).clamp(min=1e-8)
42
  x_axis = torch.cross(up.expand_as(z_axis), z_axis, dim=-1)
43
  x_axis = x_axis / x_axis.norm(dim=-1, keepdim=True).clamp(min=1e-8)
44
  y_axis = torch.cross(z_axis, x_axis, dim=-1)
45
-
46
- R = torch.stack([x_axis, y_axis, z_axis], dim=-1) # (N, 3, 3)
47
- T = -torch.bmm(R.transpose(1, 2), eye.unsqueeze(-1)).squeeze(-1) # (N, 3)
48
  return R, T
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """pytorch3d.renderer stub – minimal classes for import compatibility."""
2
  import torch
3
  import math
4
 
5
+
6
  def look_at_view_transform(dist=1.0, elev=0.0, azim=0.0, degrees=True, eye=None,
7
  at=((0, 0, 0),), up=((0, 1, 0),), device="cpu"):
 
8
  if eye is not None:
9
  if not isinstance(eye, torch.Tensor):
10
  eye = torch.tensor(eye, dtype=torch.float32, device=device)
 
12
  eye = eye.unsqueeze(0)
13
  else:
14
  if degrees:
15
+ elev_r = math.radians(elev) if isinstance(elev, (int, float)) else torch.deg2rad(torch.tensor(float(elev)))
16
+ azim_r = math.radians(azim) if isinstance(azim, (int, float)) else torch.deg2rad(torch.tensor(float(azim)))
17
  else:
18
+ elev_r, azim_r = elev, azim
 
19
  if isinstance(elev_r, (int, float)):
20
  x = dist * math.cos(elev_r) * math.sin(azim_r)
21
  y = dist * math.sin(elev_r)
 
26
  y = dist * torch.sin(elev_r)
27
  z = dist * torch.cos(elev_r) * torch.cos(azim_r)
28
  eye = torch.stack([x, y, z], dim=-1).unsqueeze(0).to(device)
 
29
  if not isinstance(at, torch.Tensor):
30
  at = torch.tensor(at, dtype=torch.float32, device=device)
31
  if at.dim() == 1:
 
34
  up = torch.tensor(up, dtype=torch.float32, device=device)
35
  if up.dim() == 1:
36
  up = up.unsqueeze(0)
 
37
  z_axis = eye - at
38
  z_axis = z_axis / z_axis.norm(dim=-1, keepdim=True).clamp(min=1e-8)
39
  x_axis = torch.cross(up.expand_as(z_axis), z_axis, dim=-1)
40
  x_axis = x_axis / x_axis.norm(dim=-1, keepdim=True).clamp(min=1e-8)
41
  y_axis = torch.cross(z_axis, x_axis, dim=-1)
42
+ R = torch.stack([x_axis, y_axis, z_axis], dim=-1)
43
+ T = -torch.bmm(R.transpose(1, 2), eye.unsqueeze(-1)).squeeze(-1)
 
44
  return R, T
45
+
46
+
47
+ class PerspectiveCameras:
48
+ def __init__(self, focal_length=None, principal_point=None, R=None, T=None,
49
+ image_size=None, device="cpu", in_ndc=True, **kwargs):
50
+ self.device = device
51
+ self.focal_length = focal_length
52
+ self.principal_point = principal_point
53
+ self.R = R if R is not None else torch.eye(3, device=device).unsqueeze(0)
54
+ self.T = T if T is not None else torch.zeros(1, 3, device=device)
55
+ self.image_size = image_size
56
+ self.in_ndc = in_ndc
57
+
58
+ def to(self, device):
59
+ self.device = device
60
+ return self
61
+
62
+
63
+ class RasterizationSettings:
64
+ def __init__(self, image_size=256, blur_radius=0.0, faces_per_pixel=1, **kwargs):
65
+ self.image_size = image_size
66
+ self.blur_radius = blur_radius
67
+ self.faces_per_pixel = faces_per_pixel
68
+
69
+
70
+ class BlendParams:
71
+ def __init__(self, sigma=1e-4, gamma=1e-4, background_color=(0.0, 0.0, 0.0)):
72
+ self.sigma = sigma
73
+ self.gamma = gamma
74
+ self.background_color = background_color
75
+
76
+
77
+ class SoftSilhouetteShader:
78
+ def __init__(self, blend_params=None):
79
+ self.blend_params = blend_params or BlendParams()
80
+
81
+
82
+ class MeshRasterizer(torch.nn.Module):
83
+ def __init__(self, cameras=None, raster_settings=None):
84
+ super().__init__()
85
+ self.cameras = cameras
86
+ self.raster_settings = raster_settings or RasterizationSettings()
87
+
88
+ def forward(self, meshes, **kwargs):
89
+ raise NotImplementedError("pytorch3d.renderer stub: MeshRasterizer.forward not implemented")
90
+
91
+
92
+ class MeshRenderer(torch.nn.Module):
93
+ def __init__(self, rasterizer=None, shader=None):
94
+ super().__init__()
95
+ self.rasterizer = rasterizer
96
+ self.shader = shader
97
+
98
+ def forward(self, meshes, **kwargs):
99
+ raise NotImplementedError("pytorch3d.renderer stub: MeshRenderer.forward not implemented")
100
+
101
+
102
+ class TexturesVertex:
103
+ def __init__(self, verts_features=None):
104
+ self.verts_features_list = verts_features if isinstance(verts_features, list) else [verts_features]
105
+
106
+ def to(self, device):
107
+ return self