PEAR / pytorch3d /implicitron /dataset /visualize.py
BestWJH's picture
Upload 455 files
94dc344 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
from typing import cast, Optional, Tuple
import torch
from pytorch3d.implicitron.tools.point_cloud_utils import get_rgbd_point_cloud
from pytorch3d.structures import Pointclouds
from .frame_data import FrameData
from .json_index_dataset import JsonIndexDataset
def get_implicitron_sequence_pointcloud(
dataset: JsonIndexDataset,
sequence_name: Optional[str] = None,
mask_points: bool = True,
max_frames: int = -1,
num_workers: int = 0,
load_dataset_point_cloud: bool = False,
) -> Tuple[Pointclouds, FrameData]:
"""
Make a point cloud by sampling random points from each frame the dataset.
"""
if len(dataset) == 0:
raise ValueError("The dataset is empty.")
if not dataset.load_depths:
raise ValueError("The dataset has to load depths (dataset.load_depths=True).")
if mask_points and not dataset.load_masks:
raise ValueError(
"For mask_points=True, the dataset has to load masks"
+ " (dataset.load_masks=True)."
)
# setup the indices of frames loaded from the dataset db
sequence_entries = list(range(len(dataset)))
if sequence_name is not None:
sequence_entries = [
ei
for ei in sequence_entries
# pyre-ignore[16]
if dataset.frame_annots[ei]["frame_annotation"].sequence_name
== sequence_name
]
if len(sequence_entries) == 0:
raise ValueError(
f'There are no dataset entries for sequence name "{sequence_name}".'
)
# subsample loaded frames if needed
if (max_frames > 0) and (len(sequence_entries) > max_frames):
sequence_entries = [
sequence_entries[i]
for i in torch.randperm(len(sequence_entries))[:max_frames].sort().values
]
# take only the part of the dataset corresponding to the sequence entries
sequence_dataset = torch.utils.data.Subset(dataset, sequence_entries)
# load the required part of the dataset
loader = torch.utils.data.DataLoader(
sequence_dataset,
batch_size=len(sequence_dataset),
shuffle=False,
num_workers=num_workers,
collate_fn=dataset.frame_data_type.collate,
)
frame_data = next(iter(loader)) # there's only one batch
# scene point cloud
if load_dataset_point_cloud:
if not dataset.load_point_clouds:
raise ValueError(
"For load_dataset_point_cloud=True, the dataset has to"
+ " load point clouds (dataset.load_point_clouds=True)."
)
point_cloud = frame_data.sequence_point_cloud
else:
point_cloud = get_rgbd_point_cloud(
frame_data.camera,
frame_data.image_rgb,
frame_data.depth_map,
(
(cast(torch.Tensor, frame_data.fg_probability) > 0.5).float()
if mask_points and frame_data.fg_probability is not None
else None
),
)
return point_cloud, frame_data