| |
| |
| |
| |
| |
|
|
| |
|
|
|
|
| |
| |
|
|
| import warnings |
| from logging import Logger |
| from typing import Any, Dict, Optional, Tuple |
|
|
| import torch |
| import tqdm |
| from pytorch3d.common.compat import prod |
|
|
| from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle |
|
|
| from pytorch3d.implicitron.tools import image_utils |
|
|
| from pytorch3d.implicitron.tools.utils import cat_dataclass |
|
|
|
|
| def preprocess_input( |
| image_rgb: Optional[torch.Tensor], |
| fg_probability: Optional[torch.Tensor], |
| depth_map: Optional[torch.Tensor], |
| mask_images: bool, |
| mask_depths: bool, |
| mask_threshold: float, |
| bg_color: Tuple[float, float, float], |
| ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: |
| """ |
| Helper function to preprocess the input images and optional depth maps |
| to apply masking if required. |
| |
| Args: |
| image_rgb: A tensor of shape `(B, 3, H, W)` containing a batch of rgb images |
| corresponding to the source viewpoints from which features will be extracted |
| fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch |
| of foreground masks with values in [0, 1]. |
| depth_map: A tensor of shape `(B, 1, H, W)` containing a batch of depth maps. |
| mask_images: Whether or not to mask the RGB image background given the |
| foreground mask (the `fg_probability` argument of `GenericModel.forward`) |
| mask_depths: Whether or not to mask the depth image background given the |
| foreground mask (the `fg_probability` argument of `GenericModel.forward`) |
| mask_threshold: If greater than 0.0, the foreground mask is |
| thresholded by this value before being applied to the RGB/Depth images |
| bg_color: RGB values for setting the background color of input image |
| if mask_images=True. Defaults to (0.0, 0.0, 0.0). Each renderer has its own |
| way to determine the background color of its output, unrelated to this. |
| |
| Returns: |
| Modified image_rgb, fg_mask, depth_map |
| """ |
| if image_rgb is not None and image_rgb.ndim == 3: |
| |
| |
| |
| |
| |
| raise ValueError( |
| "Model received unbatched inputs. " |
| + "Perhaps they came from a FrameData which had not been collated." |
| ) |
|
|
| fg_mask = fg_probability |
| if fg_mask is not None and mask_threshold > 0.0: |
| |
| warnings.warn("Thresholding masks!") |
| fg_mask = (fg_mask >= mask_threshold).type_as(fg_mask) |
|
|
| if mask_images and fg_mask is not None and image_rgb is not None: |
| |
| warnings.warn("Masking images!") |
| image_rgb = image_utils.mask_background( |
| image_rgb, fg_mask, dim_color=1, bg_color=torch.tensor(bg_color) |
| ) |
|
|
| if mask_depths and fg_mask is not None and depth_map is not None: |
| |
| assert ( |
| mask_threshold > 0.0 |
| ), "Depths should be masked only with thresholded masks" |
| warnings.warn("Masking depths!") |
| depth_map = depth_map * fg_mask |
|
|
| return image_rgb, fg_mask, depth_map |
|
|
|
|
| def log_loss_weights(loss_weights: Dict[str, float], logger: Logger) -> None: |
| """ |
| Print a table of the loss weights. |
| """ |
| loss_weights_message = ( |
| "-------\nloss_weights:\n" |
| + "\n".join(f"{k:40s}: {w:1.2e}" for k, w in loss_weights.items()) |
| + "-------" |
| ) |
| logger.info(loss_weights_message) |
|
|
|
|
| def weighted_sum_losses( |
| preds: Dict[str, torch.Tensor], loss_weights: Dict[str, float] |
| ) -> Optional[torch.Tensor]: |
| """ |
| A helper function to compute the overall loss as the dot product |
| of individual loss functions with the corresponding weights. |
| """ |
| losses_weighted = [ |
| preds[k] * float(w) |
| for k, w in loss_weights.items() |
| if (k in preds and w != 0.0) |
| ] |
| if len(losses_weighted) == 0: |
| warnings.warn("No main objective found.") |
| return None |
| loss = sum(losses_weighted) |
| assert torch.is_tensor(loss) |
| return loss |
|
|
|
|
| def apply_chunked(func, chunk_generator, tensor_collator): |
| """ |
| Helper function to apply a function on a sequence of |
| chunked inputs yielded by a generator and collate |
| the result. |
| """ |
| processed_chunks = [ |
| func(*chunk_args, **chunk_kwargs) |
| for chunk_args, chunk_kwargs in chunk_generator |
| ] |
|
|
| return cat_dataclass(processed_chunks, tensor_collator) |
|
|
|
|
| def chunk_generator( |
| chunk_size: int, |
| ray_bundle: ImplicitronRayBundle, |
| chunked_inputs: Dict[str, torch.Tensor], |
| tqdm_trigger_threshold: int, |
| *args, |
| **kwargs, |
| ): |
| """ |
| Helper function which yields chunks of rays from the |
| input ray_bundle, to be used when the number of rays is |
| large and will not fit in memory for rendering. |
| """ |
| ( |
| batch_size, |
| *spatial_dim, |
| n_pts_per_ray, |
| ) = ray_bundle.lengths.shape |
| if n_pts_per_ray > 0 and chunk_size % n_pts_per_ray != 0: |
| raise ValueError( |
| f"chunk_size_grid ({chunk_size}) should be divisible " |
| f"by n_pts_per_ray ({n_pts_per_ray})" |
| ) |
|
|
| n_rays = prod(spatial_dim) |
| |
| n_chunks = -(-n_rays * max(n_pts_per_ray, 1) // chunk_size) |
| chunk_size_in_rays = -(-n_rays // n_chunks) |
|
|
| iter = range(0, n_rays, chunk_size_in_rays) |
| if len(iter) >= tqdm_trigger_threshold: |
| iter = tqdm.tqdm(iter) |
|
|
| def _safe_slice( |
| tensor: Optional[torch.Tensor], start_idx: int, end_idx: int |
| ) -> Any: |
| return tensor[start_idx:end_idx] if tensor is not None else None |
|
|
| for start_idx in iter: |
| end_idx = min(start_idx + chunk_size_in_rays, n_rays) |
| bins = ( |
| None |
| if ray_bundle.bins is None |
| else ray_bundle.bins.reshape(batch_size, n_rays, n_pts_per_ray + 1)[ |
| :, start_idx:end_idx |
| ] |
| ) |
| pixel_radii_2d = ( |
| None |
| if ray_bundle.pixel_radii_2d is None |
| else ray_bundle.pixel_radii_2d.reshape(batch_size, -1, 1)[ |
| :, start_idx:end_idx |
| ] |
| ) |
| ray_bundle_chunk = ImplicitronRayBundle( |
| origins=ray_bundle.origins.reshape(batch_size, -1, 3)[:, start_idx:end_idx], |
| directions=ray_bundle.directions.reshape(batch_size, -1, 3)[ |
| :, start_idx:end_idx |
| ], |
| lengths=ray_bundle.lengths.reshape(batch_size, n_rays, n_pts_per_ray)[ |
| :, start_idx:end_idx |
| ], |
| xys=ray_bundle.xys.reshape(batch_size, -1, 2)[:, start_idx:end_idx], |
| bins=bins, |
| pixel_radii_2d=pixel_radii_2d, |
| camera_ids=_safe_slice(ray_bundle.camera_ids, start_idx, end_idx), |
| camera_counts=_safe_slice(ray_bundle.camera_counts, start_idx, end_idx), |
| ) |
| extra_args = kwargs.copy() |
| for k, v in chunked_inputs.items(): |
| extra_args[k] = v.flatten(2)[:, :, start_idx:end_idx] |
| yield [ray_bundle_chunk, *args], extra_args |
|
|