Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # Standard library imports | |
| import base64 | |
| import io | |
| import logging | |
| import math | |
| import pickle | |
| import warnings | |
| from collections import defaultdict | |
| from dataclasses import field, dataclass | |
| from typing import Any, Dict, List, Optional, Tuple, Union | |
| # Third-party library imports | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from PIL import Image | |
| from pytorch3d.renderer.cameras import CamerasBase | |
| from pytorch3d.transforms import ( | |
| se3_exp_map, | |
| se3_log_map, | |
| Transform3d, | |
| so3_relative_angle, | |
| ) | |
| from util.camera_transform import pose_encoding_to_camera | |
| import models | |
| from hydra.utils import instantiate | |
| from pytorch3d.renderer.cameras import PerspectiveCameras | |
| logger = logging.getLogger(__name__) | |
| class PoseDiffusionModel(nn.Module): | |
| def __init__( | |
| self, | |
| pose_encoding_type: str, | |
| IMAGE_FEATURE_EXTRACTOR: Dict, | |
| DIFFUSER: Dict, | |
| DENOISER: Dict, | |
| ): | |
| """Initializes a PoseDiffusion model. | |
| Args: | |
| pose_encoding_type (str): | |
| Defines the encoding type for extrinsics and intrinsics | |
| Currently, only `"absT_quaR_logFL"` is supported - | |
| a concatenation of the translation vector, | |
| rotation quaternion, and logarithm of focal length. | |
| image_feature_extractor_cfg (Dict): | |
| Configuration for the image feature extractor. | |
| diffuser_cfg (Dict): | |
| Configuration for the diffuser. | |
| denoiser_cfg (Dict): | |
| Configuration for the denoiser. | |
| """ | |
| super().__init__() | |
| self.pose_encoding_type = pose_encoding_type | |
| self.image_feature_extractor = instantiate( | |
| IMAGE_FEATURE_EXTRACTOR, _recursive_=False | |
| ) | |
| self.diffuser = instantiate(DIFFUSER, _recursive_=False) | |
| denoiser = instantiate(DENOISER, _recursive_=False) | |
| self.diffuser.model = denoiser | |
| self.target_dim = denoiser.target_dim | |
| def forward( | |
| self, | |
| image: torch.Tensor, | |
| gt_cameras: Optional[CamerasBase] = None, | |
| sequence_name: Optional[List[str]] = None, | |
| cond_fn=None, | |
| cond_start_step=0, | |
| ): | |
| """ | |
| Forward pass of the PoseDiffusionModel. | |
| Args: | |
| image (torch.Tensor): | |
| Input image tensor, Bx3xHxW. | |
| gt_cameras (Optional[CamerasBase], optional): | |
| Camera object. Defaults to None. | |
| sequence_name (Optional[List[str]], optional): | |
| List of sequence names. Defaults to None. | |
| cond_fn ([type], optional): | |
| Conditional function. Wrapper for GGS or other functions. | |
| cond_start_step (int, optional): | |
| The sampling step to start using conditional function. | |
| Returns: | |
| PerspectiveCameras: PyTorch3D camera object. | |
| """ | |
| z = self.image_feature_extractor(image) | |
| z = z.unsqueeze(0) | |
| B, N, _ = z.shape | |
| target_shape = [B, N, self.target_dim] | |
| # sampling | |
| pose_encoding, pose_encoding_diffusion_samples = self.diffuser.sample( | |
| shape=target_shape, | |
| z=z, | |
| cond_fn=cond_fn, | |
| cond_start_step=cond_start_step, | |
| ) | |
| # convert the encoded representation to PyTorch3D cameras | |
| pred_cameras = pose_encoding_to_camera( | |
| pose_encoding, pose_encoding_type=self.pose_encoding_type | |
| ) | |
| return pred_cameras | |