jboth's picture
Upload pytorch3d_stub/pytorch3d/structures/__init__.py with huggingface_hub
8fc4142 verified
"""pytorch3d.structures stub with catch-all."""
import torch
import warnings
def __getattr__(name):
if name.startswith("__") and name.endswith("__"):
raise AttributeError(name)
warnings.warn(f"pytorch3d.structures stub: {name} not implemented", stacklevel=2)
class _Dummy:
def __init__(self, *a, **kw): pass
def to(self, *a, **kw): return self
_Dummy.__name__ = _Dummy.__qualname__ = name
return _Dummy
class Meshes:
def __init__(self, verts=None, faces=None, textures=None):
if isinstance(verts, list):
self._verts_list = [v if isinstance(v, torch.Tensor) else torch.tensor(v, dtype=torch.float32) for v in verts]
elif isinstance(verts, torch.Tensor):
self._verts_list = [verts]
else:
self._verts_list = []
if isinstance(faces, list):
self._faces_list = [f if isinstance(f, torch.Tensor) else torch.tensor(f, dtype=torch.long) for f in faces]
elif isinstance(faces, torch.Tensor):
self._faces_list = [faces]
else:
self._faces_list = []
self.textures = textures
def verts_list(self): return self._verts_list
def faces_list(self): return self._faces_list
def verts_packed(self):
return torch.cat(self._verts_list, dim=0) if self._verts_list else torch.zeros(0, 3)
def faces_packed(self):
return torch.cat(self._faces_list, dim=0) if self._faces_list else torch.zeros(0, 3, dtype=torch.long)
def num_verts_per_mesh(self):
return torch.tensor([v.shape[0] for v in self._verts_list])
def __len__(self): return len(self._verts_list)
class Pointclouds:
def __init__(self, points=None, features=None, normals=None):
self.points_list = points if isinstance(points, list) else ([points] if points is not None else [])
self.features_list = features if isinstance(features, list) else ([features] if features is not None else [])
self.normals_list = normals if isinstance(normals, list) else ([normals] if normals is not None else [])
def to(self, device): return self