| |
| |
| |
| |
| |
|
|
| |
|
|
| from os.path import dirname, join, realpath |
| from typing import Optional, Tuple |
|
|
| import torch |
| from pytorch3d.implicitron.tools.config import registry, run_auto_creation |
| from pytorch3d.io import IO |
| from pytorch3d.renderer import ( |
| AmbientLights, |
| BlendParams, |
| CamerasBase, |
| FoVPerspectiveCameras, |
| HardPhongShader, |
| look_at_view_transform, |
| MeshRasterizer, |
| MeshRendererWithFragments, |
| PointLights, |
| RasterizationSettings, |
| ) |
| from pytorch3d.structures.meshes import Meshes |
|
|
| from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, PathManagerFactory |
| from .single_sequence_dataset import SingleSceneDataset |
| from .utils import DATASET_TYPE_KNOWN |
|
|
|
|
| @registry.register |
| class RenderedMeshDatasetMapProvider(DatasetMapProviderBase): |
| """ |
| A simple single-scene dataset based on PyTorch3D renders of a mesh. |
| Provides `num_views` renders of the mesh as train, with no val |
| and test. The renders are generated from viewpoints sampled at uniformly |
| distributed azimuth intervals. The elevation is kept constant so that the |
| camera's vertical position coincides with the equator. |
| |
| By default, uses Keenan Crane's cow model, and the camera locations are |
| set to make sense for that. |
| |
| Although the rendering used to generate this dataset will use a GPU |
| if one is available, the data it produces is on the CPU just like |
| the data returned by implicitron's other dataset map providers. |
| This is because both datasets and models can be large, so implicitron's |
| training loop expects data on the CPU and only moves |
| what it needs to the device. |
| |
| For a more detailed explanation of this code, please refer to the |
| docs/tutorials/fit_textured_mesh.ipynb notebook. |
| |
| Members: |
| num_views: The number of generated renders. |
| data_file: The folder that contains the mesh file. By default, finds |
| the cow mesh in the same repo as this code. |
| azimuth_range: number of degrees on each side of the start position to |
| take samples |
| distance: distance from camera centres to the origin. |
| resolution: the common height and width of the output images. |
| use_point_light: whether to use a particular point light as opposed |
| to ambient white. |
| gpu_idx: which gpu to use for rendering the mesh. |
| path_manager_factory: (Optional) An object that generates an instance of |
| PathManager that can translate provided file paths. |
| path_manager_factory_class_type: The class type of `path_manager_factory`. |
| """ |
|
|
| num_views: int = 40 |
| data_file: Optional[str] = None |
| azimuth_range: float = 180 |
| distance: float = 2.7 |
| resolution: int = 128 |
| use_point_light: bool = True |
| gpu_idx: Optional[int] = 0 |
| |
| path_manager_factory: PathManagerFactory |
| path_manager_factory_class_type: str = "PathManagerFactory" |
|
|
| def get_dataset_map(self) -> DatasetMap: |
| |
| return DatasetMap(train=self.train_dataset, val=None, test=None) |
|
|
| def get_all_train_cameras(self) -> CamerasBase: |
| |
| return self.poses |
|
|
| def __post_init__(self) -> None: |
| super().__init__() |
| run_auto_creation(self) |
| if torch.cuda.is_available() and self.gpu_idx is not None: |
| device = torch.device(f"cuda:{self.gpu_idx}") |
| else: |
| device = torch.device("cpu") |
| if self.data_file is None: |
| data_file = join( |
| dirname(dirname(dirname(dirname(realpath(__file__))))), |
| "docs", |
| "tutorials", |
| "data", |
| "cow_mesh", |
| "cow.obj", |
| ) |
| else: |
| data_file = self.data_file |
| io = IO(path_manager=self.path_manager_factory.get()) |
| mesh = io.load_mesh(data_file, device=device) |
| poses, images, masks = _generate_cow_renders( |
| num_views=self.num_views, |
| mesh=mesh, |
| azimuth_range=self.azimuth_range, |
| distance=self.distance, |
| resolution=self.resolution, |
| device=device, |
| use_point_light=self.use_point_light, |
| ) |
| |
| self.poses = poses.cpu() |
| |
| self.train_dataset = SingleSceneDataset( |
| object_name="cow", |
| images=list(images.permute(0, 3, 1, 2).cpu()), |
| fg_probabilities=list(masks[:, None].cpu()), |
| poses=[self.poses[i] for i in range(len(poses))], |
| frame_types=[DATASET_TYPE_KNOWN] * len(poses), |
| eval_batches=None, |
| ) |
|
|
|
|
| @torch.no_grad() |
| def _generate_cow_renders( |
| *, |
| num_views: int, |
| mesh: Meshes, |
| azimuth_range: float, |
| distance: float, |
| resolution: int, |
| device: torch.device, |
| use_point_light: bool, |
| ) -> Tuple[CamerasBase, torch.Tensor, torch.Tensor]: |
| """ |
| Returns: |
| cameras: A batch of `num_views` `FoVPerspectiveCameras` from which the |
| images are rendered. |
| images: A tensor of shape `(num_views, height, width, 3)` containing |
| the rendered images. |
| silhouettes: A tensor of shape `(num_views, height, width)` containing |
| the rendered silhouettes. |
| """ |
|
|
| |
|
|
| |
| |
| |
| |
| verts = mesh.verts_packed() |
| N = verts.shape[0] |
| center = verts.mean(0) |
| scale = max((verts - center).abs().max(0)[0]) |
| mesh.offset_verts_(-(center.expand(N, 3))) |
| mesh.scale_verts_((1.0 / float(scale))) |
|
|
| |
| elev = torch.linspace(0, 0, num_views) |
| azim = torch.linspace(-azimuth_range, azimuth_range, num_views) + 180.0 |
|
|
| |
| |
| if use_point_light: |
| lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]]) |
| else: |
| lights = AmbientLights(device=device) |
|
|
| |
| |
| |
| |
| R, T = look_at_view_transform(dist=distance, elev=elev, azim=azim) |
| cameras = FoVPerspectiveCameras(device=device, R=R, T=T) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| raster_settings = RasterizationSettings( |
| image_size=resolution, blur_radius=0.0, faces_per_pixel=1 |
| ) |
|
|
| |
| |
| |
| blend_params = BlendParams(sigma=1e-4, gamma=1e-4, background_color=(0.0, 0.0, 0.0)) |
| rasterizer_type = MeshRasterizer |
| renderer = MeshRendererWithFragments( |
| rasterizer=rasterizer_type(cameras=cameras, raster_settings=raster_settings), |
| shader=HardPhongShader( |
| device=device, cameras=cameras, lights=lights, blend_params=blend_params |
| ), |
| ) |
|
|
| |
| |
| |
| meshes = mesh.extend(num_views) |
|
|
| |
| target_images, fragments = renderer(meshes, cameras=cameras, lights=lights) |
| silhouette_binary = (fragments.pix_to_face[..., 0] >= 0).float() |
|
|
| return cameras, target_images[..., :3], silhouette_binary |
|
|