infinity1096
initial commit
c8b42eb
"""
Adaptors for the UniCeption Prediction Heads.
"""
from functools import lru_cache
from math import isfinite
from typing import List, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from uniception.models.prediction_heads import (
AdaptorInput,
AdaptorOutput,
Covariance2DAdaptorOutput,
MaskAdaptorOutput,
RegressionAdaptorOutput,
RegressionWithConfidenceAdaptorOutput,
RegressionWithConfidenceAndMaskAdaptorOutput,
RegressionWithMaskAdaptorOutput,
UniCeptionAdaptorBase,
)
class FlowAdaptor(UniCeptionAdaptorBase):
def __init__(
self,
name: str,
flow_mean: Union[Tuple[float, float], List[float]],
flow_std: Union[Tuple[float, float], List[float]],
base_shape: Tuple[int, int],
scale_strategy: str,
output_normalized_coordinate: bool = False,
*args,
**kwargs,
):
"""
Adaptor for the Flow head in UniCeption.
Args:
name (str): Name of the adaptor.
flow_mean (torch.Tensor): (2,) Mean of the flow.
flow_std (torch.Tensor): (2,) Standard deviation of the flow.
base_shape (Tuple[int, int]): Base shape of the flow mean and std.
scale_strategy (str): Strategy for scaling the flow, either
- none: No scaling, network will be unnormalized with the given mean and std for all input shapes
- scale_width: scale the output for "none" by actual width divided by base width for both X and Y
- scale_height: scale the output for "none" by actual height divided by base height for both X and Y
- scale_both: scale the output for "none" by actual dimension / base dimension individually for X and Y
output_normalized_coordinate (bool): If True, will subtract the (X, Y) coordinate of the output pixel from input x after it is being scaled to pixel coordinates.
In other words, the network will predict the pixel position that the source pixel will land on the target image, rather than the flow.
"""
super().__init__(name, required_channels=2, *args, **kwargs)
self.name: str = name
flow_mean = list(flow_mean)
flow_std = list(flow_std)
# Handle the case where flow_mean and flow_std are passed as tuples
if isinstance(flow_mean, tuple) or isinstance(flow_mean, list):
flow_mean = torch.tensor(flow_mean, dtype=torch.float32)
assert flow_mean.shape == (2,), f"Flow mean must be a 2D tensor, got {flow_mean.shape}"
if isinstance(flow_std, tuple) or isinstance(flow_std, list):
flow_std = torch.tensor(flow_std, dtype=torch.float32)
assert flow_std.shape == (2,), f"Flow std must be a 2D tensor, got {flow_std.shape}"
self.register_buffer("flow_mean", flow_mean.view(1, 2, 1, 1))
self.register_buffer("flow_std", flow_std.view(1, 2, 1, 1))
self.base_shape = list(base_shape)
self.scale_strategy = scale_strategy
self.output_normalized_coordinate = output_normalized_coordinate
def forward(self, adaptor_input: AdaptorInput):
"""
Forward pass for the FlowAdaptor.
Args:
adaptor_input (AdaptorInput): Input to the adaptor. (B x C x H x W)
Returns:
AdaptorOutput: Output of the adaptor.
"""
x = adaptor_input.adaptor_feature
# Check the number of channels to avoid passing BHWC features
_, C, _, _ = x.shape
assert C == 2, f"FlowAdaptor requires BCHW format with 2 channels, got {C} channels"
output_shape = adaptor_input.output_shape_hw
if not self.output_normalized_coordinate:
x_scale, y_scale = self._get_xy_scale(output_shape)
# Scale the flow by stored mean, std and scaling factors
flow_mean = self.flow_mean * torch.tensor([x_scale, y_scale], dtype=torch.float32, device=x.device).view(
1, 2, 1, 1
)
flow_std = self.flow_std * torch.tensor([x_scale, y_scale], dtype=torch.float32, device=x.device).view(
1, 2, 1, 1
)
# Unnormalize the flow
x = x * flow_std + flow_mean
else:
# Optionally subtract the coordinate bias
wh_normalizer = torch.tensor(
adaptor_input.output_shape_hw[::-1], dtype=torch.float32, device=x.device
).view(1, 2, 1, 1)
x = 0.5 * (x + 1) * wh_normalizer + 0.5
coords = self._get_coordinate_bias(output_shape, x.device)
x = x - coords
return RegressionAdaptorOutput(value=x)
def _get_xy_scale(self, output_shape: Tuple[int, int]):
"""
Get the scaling factor for the X and Y dimensions.
Args:
output_shape (Tuple[int, int]): HW Shape of the output.
Returns:
Tuple[float, float]: Scaling factors for X and Y dimensions.
"""
if self.scale_strategy == "none":
return 1.0, 1.0
elif self.scale_strategy == "scale_width":
return output_shape[1] / self.base_shape[1], output_shape[1] / self.base_shape[1]
elif self.scale_strategy == "scale_height":
return output_shape[0] / self.base_shape[0], output_shape[0] / self.base_shape[0]
elif self.scale_strategy == "scale_both":
return output_shape[1] / self.base_shape[1], output_shape[0] / self.base_shape[0]
else:
raise ValueError(f"Invalid scaling strategy: {self.scale_strategy}")
@lru_cache(maxsize=10)
def _get_coordinate_bias(self, output_shape: Tuple[int, int], device: str):
"""
Get the (X, Y) coordinate image for the given output shape.
Args:
output_shape (Tuple[int, int]): HW Shape of the output.
device: device to store the tensor on
Returns:
torch.Tensor: (2, H, W) tensor with X and Y coordinates, at device. This coordinate value will
include 0.5 px offset - i.e. the center of the top-left pixel is (0.5, 0.5).
"""
H, W = output_shape
coords = torch.stack(
torch.meshgrid(
torch.arange(0, W, device=device, dtype=torch.float32) + 0.5,
torch.arange(0, H, device=device, dtype=torch.float32) + 0.5,
indexing="xy",
),
dim=0,
)
return coords
class ScaleAdaptor(UniCeptionAdaptorBase):
def __init__(self, name: str, mode: str, vmin: float = 0, vmax: float = np.inf, *args, **kwargs):
"""
Adaptor for scale prediction in UniCeption.
Args:
name (str): Name of the adaptor.
mode (str): Mode of the scale prediction, either "linear", "square" or "exp". Scales the predicted scaling factor accordingly.
vmin (float): Minimum value of the scale prediction after scaling.
vmax (float): Maximum value of the scale prediction after scaling.
"""
super().__init__(name, required_channels=1, *args, **kwargs)
self.mode = mode
self.vmin = vmin
self.vmax = vmax
self.no_bounds = (vmin == -float("inf")) and (vmax == float("inf"))
def forward(self, adaptor_input: AdaptorInput):
"""
Forward pass for the ScaleAdaptor.
Args:
adaptor_input (AdaptorInput): Input to the adaptor. (B x 1 x ...)
Returns:
AdaptorOutput: Output of the adaptor.
"""
predicted_scale_factor = adaptor_input.adaptor_feature
output_scale_factor = None
if self.mode == "linear":
output_scale_factor = predicted_scale_factor
elif self.mode == "square":
output_scale_factor = predicted_scale_factor.square()
elif self.mode == "exp":
output_scale_factor = torch.exp(predicted_scale_factor)
if not self.no_bounds:
output_scale_factor = output_scale_factor.clip(self.vmin, self.vmax)
return AdaptorOutput(value=output_scale_factor)
class DepthAdaptor(UniCeptionAdaptorBase):
def __init__(self, name: str, mode: str, vmin: float = 0, vmax: float = np.inf, *args, **kwargs):
"""
Adaptor for the Depth head in UniCeption.
Args:
name (str): Name of the adaptor.
mode (str): Mode of the depth, either "linear", "square" or "exp". Scales the depth accordingly.
vmin (float): Minimum value of the depth after scaling.
vmax (float): Maximum value of the depth after scaling.
"""
super().__init__(name, required_channels=1, *args, **kwargs)
self.mode = mode
self.vmin = vmin
self.vmax = vmax
self.no_bounds = (vmin == -float("inf")) and (vmax == float("inf"))
def forward(self, adaptor_input: AdaptorInput):
"""
Forward pass for the DepthAdaptor.
Args:
adaptor_input (AdaptorInput): Input to the adaptor. (B x C x H x W)
Returns:
AdaptorOutput: Output of the adaptor.
"""
x = adaptor_input.adaptor_feature
output_depth = None
if self.mode == "linear":
output_depth = x
elif self.mode == "square":
output_depth = x**2
elif self.mode == "exp":
output_depth = torch.exp(x)
else:
raise ValueError(f"Invalid mode: {self.mode}")
if not self.no_bounds:
output_depth = output_depth.clip(self.vmin, self.vmax)
return RegressionAdaptorOutput(value=output_depth)
class PointMapAdaptor(UniCeptionAdaptorBase):
def __init__(self, name: str, mode: str, vmin: float = -np.inf, vmax: float = np.inf, *args, **kwargs):
"""
Adaptor for the PointMap head in UniCeption.
Args:
name (str): Name of the adaptor.
mode (str): Mode of the point map, either "linear", "square" or "exp". Scales the distance of the points to the world origin accordingly.
vmin (float): Minimum value of the point map after scaling.
vmax (float): Maximum value of the point map after scaling.
"""
super().__init__(name, required_channels=3, *args, **kwargs)
self.mode = mode
self.vmin = vmin
self.vmax = vmax
self.no_bounds = (vmin == -float("inf")) and (vmax == float("inf"))
def forward(self, adaptor_input: AdaptorInput):
"""
Forward pass for the PointMapAdaptor.
Args:
adaptor_input (AdaptorInput): Input to the adaptor. (B x C x H x W)
Returns:
AdaptorOutput: Output of the adaptor.
"""
xyz = adaptor_input.adaptor_feature
output_xyz = None
if self.mode != "linear":
if self.mode == "square":
# Compute distance to world origin
d = xyz.norm(dim=1, keepdim=True)
output_xyz = xyz / d.clip(min=1e-8)
# Scale the distance to world origin based on mode
output_xyz = output_xyz * d.square()
elif self.mode == "exp":
# Compute distance to world origin
d = xyz.norm(dim=1, keepdim=True)
output_xyz = xyz / d.clip(min=1e-8)
# Scale the distance to world origin based on mode
output_xyz = output_xyz * torch.expm1(d)
elif self.mode == "z_exp":
xy, z = xyz.split([2, 1], dim=1)
z = torch.exp(z)
output_xyz = torch.cat([xy * z, z], dim=1)
else:
raise ValueError(f"Invalid mode: {self.mode}")
else:
output_xyz = xyz
if not self.no_bounds:
output_xyz = output_xyz.clip(self.vmin, self.vmax)
return RegressionAdaptorOutput(value=output_xyz)
class RayOriginsAdaptor(UniCeptionAdaptorBase):
def __init__(self, name: str, mode: str, vmin: float = -np.inf, vmax: float = np.inf, *args, **kwargs):
"""
Adaptor for the RayOrigins head in UniCeption.
Args:
name (str): Name of the adaptor.
mode (str): Mode of the ray origins, either "linear", "square" or "exp". Scales the distance of the ray origins to the world origin accordingly.
vmin (float): Minimum value of the ray origins after scaling.
vmax (float): Maximum value of the ray origins after scaling.
"""
super().__init__(name, required_channels=3, *args, **kwargs)
self.mode = mode
self.vmin = vmin
self.vmax = vmax
self.no_bounds = (vmin == -float("inf")) and (vmax == float("inf"))
def forward(self, adaptor_input: AdaptorInput):
"""
Forward pass for the RayOriginsAdaptor.
Args:
adaptor_input (AdaptorInput): Input to the adaptor. (B x C x H x W)
Returns:
AdaptorOutput: Output of the adaptor.
"""
ray_origins = adaptor_input.adaptor_feature
output_ray_origins = None
if self.mode != "linear":
# Compute distance to world origin
d = ray_origins.norm(dim=1, keepdim=True)
output_ray_origins = ray_origins / d.clip(min=1e-8)
# Scale the distance to world origin based on mode
if self.mode == "square":
output_ray_origins = output_ray_origins * d.square()
elif self.mode == "exp":
output_ray_origins = output_ray_origins * torch.expm1(d)
else:
raise ValueError(f"Invalid mode: {self.mode}")
else:
output_ray_origins = ray_origins
if not self.no_bounds:
output_ray_origins = output_ray_origins.clip(self.vmin, self.vmax)
return RegressionAdaptorOutput(value=output_ray_origins)
class RayDirectionsAdaptor(UniCeptionAdaptorBase):
def __init__(
self,
name: str,
mode: str,
normalize_to_unit_sphere: bool,
normalize_to_unit_image_plane: bool,
vmin: float = -np.inf,
vmax: float = np.inf,
clamp_min_of_z_dir: bool = False,
z_dir_min: float = 1,
*args,
**kwargs,
):
"""
Adaptor for the RayDirections head in UniCeption.
Args:
name (str): Name of the adaptor.
mode (str): Mode of the ray directions. Scales the directions accordingly. Currently only supports "linear".
normalize_to_unit_sphere (bool): If True, will normalize the ray directions to unit vectors.
normalize_to_unit_image_plane (bool): If True, will normalize the ray directions so that the z component is 1.
vmin (float): Minimum value of the ray directions after scaling & before any sort of normalization. (default: -inf)
vmax (float): Maximum value of the ray directions after scaling & before any sort of normalization. (default: inf)
clamp_min_of_z_dir (bool): If True, will clamp the z component of the ray directions before normalization. (default: False)
z_dir_min (float): If clamp_min_of_z_dir is True, this minimum value is used for clamping. (default: 1)
"""
super().__init__(name, required_channels=3, *args, **kwargs)
self.mode = mode
self.normalize_to_unit_sphere = normalize_to_unit_sphere
self.normalize_to_unit_image_plane = normalize_to_unit_image_plane
self.vmin = vmin
self.vmax = vmax
self.clamp_min_of_z_dir = clamp_min_of_z_dir
self.z_dir_min = z_dir_min
self.no_bounds = (vmin == -float("inf")) and (vmax == float("inf"))
def forward(self, adaptor_input: AdaptorInput):
"""
Forward pass for the RayDirectionsAdaptor.
Args:
adaptor_input (AdaptorInput): Input to the adaptor. (B x C x H x W)
Returns:
AdaptorOutput: Output of the adaptor.
"""
ray_directions = adaptor_input.adaptor_feature
if self.mode == "linear":
output_ray_directions = ray_directions
else:
raise ValueError(f"Invalid mode: {self.mode}")
if not self.no_bounds:
output_ray_directions = output_ray_directions.clip(self.vmin, self.vmax)
if self.clamp_min_of_z_dir:
# Clamp the z component of ray directions
output_ray_directions_xy = output_ray_directions[:, :2]
clamped_output_ray_directions_z = torch.clamp(output_ray_directions[:, 2:3], min=self.z_dir_min)
output_ray_directions = torch.cat((output_ray_directions_xy, clamped_output_ray_directions_z), dim=1)
if self.normalize_to_unit_sphere:
# Normalize the ray directions to unit vectors
output_ray_dirs_norm = output_ray_directions.norm(dim=1, keepdim=True).clip(min=1e-8)
output_ray_directions = output_ray_directions / output_ray_dirs_norm
elif self.normalize_to_unit_image_plane:
# Normalize the ray directions so that the z component is 1
output_ray_directions_z = output_ray_directions[:, 2:3]
output_ray_directions = output_ray_directions / output_ray_directions_z
return RegressionAdaptorOutput(value=output_ray_directions)
class RayDirectionsPlusDepthAdaptor(UniCeptionAdaptorBase):
def __init__(
self,
name: str,
# Ray directions adaptor
ray_directions_mode: str,
ray_directions_normalize_to_unit_sphere: bool,
ray_directions_normalize_to_unit_image_plane: bool,
ray_directions_vmin: float,
ray_directions_vmax: float,
ray_directions_clamp_min_of_z_dir: bool,
ray_directions_z_dir_min: float,
# Depth adaptor
depth_mode: str,
depth_vmin: float,
depth_vmax: float,
*args,
**kwargs,
):
"""
Adaptor for the RayDirections + Depth head in UniCeption.
"""
super().__init__(name, required_channels=4, *args, **kwargs)
self.ray_directions_adaptor = RayDirectionsAdaptor(
name,
ray_directions_mode,
ray_directions_normalize_to_unit_sphere,
ray_directions_normalize_to_unit_image_plane,
ray_directions_vmin,
ray_directions_vmax,
ray_directions_clamp_min_of_z_dir,
ray_directions_z_dir_min,
)
self.depth_adaptor = DepthAdaptor(name, depth_mode, depth_vmin, depth_vmax)
def forward(self, adaptor_input: AdaptorInput):
"""
Forward pass for the RayMapPlusDepthAdaptor.
Args:
adaptor_input (AdaptorInput): Input to the adaptor. (B x C x H x W)
Returns:
AdaptorOutput: Output of the adaptor.
"""
ray_directions, ray_depths = torch.split(adaptor_input.adaptor_feature, [3, 1], dim=1)
ray_directions_adaptor_input = AdaptorInput(
adaptor_feature=ray_directions, output_shape_hw=adaptor_input.output_shape_hw
)
depth_adaptor_input = AdaptorInput(adaptor_feature=ray_depths, output_shape_hw=adaptor_input.output_shape_hw)
output_ray_directions = self.ray_directions_adaptor(ray_directions_adaptor_input)
output_depth = self.depth_adaptor(depth_adaptor_input)
output = torch.cat([output_ray_directions.value, output_depth.value], dim=1)
return RegressionAdaptorOutput(value=output)
class CamTranslationAdaptor(UniCeptionAdaptorBase):
def __init__(self, name: str, mode: str, vmin: float = -np.inf, vmax: float = np.inf, *args, **kwargs):
"""
Adaptor for the Camera Translation or Pose head in UniCeption.
Args:
name (str): Name of the adaptor.
mode (str): Mode of the camera translation, either "linear", "square" or "exp". Scales the distance of the camera to the world origin accordingly.
vmin (float): Minimum value of the camera translation after scaling.
vmax (float): Maximum value of the camera translation after scaling.
"""
super().__init__(name, required_channels=3, *args, **kwargs)
self.mode = mode
self.vmin = vmin
self.vmax = vmax
self.no_bounds = (vmin == -float("inf")) and (vmax == float("inf"))
def forward(self, adaptor_input: AdaptorInput):
"""
Forward pass for the CamTranslationAdaptor.
Args:
adaptor_input (AdaptorInput): Input to the adaptor. (B x C ...)
Returns:
AdaptorOutput: Output of the adaptor.
"""
cam_trans = adaptor_input.adaptor_feature
output_cam_trans = None
if self.mode != "linear":
# Compute distance to world origin
d = cam_trans.norm(dim=1, keepdim=True)
output_cam_trans = cam_trans / d.clip(min=1e-8)
# Scale the distance to world origin based on mode
if self.mode == "square":
output_cam_trans = output_cam_trans * d.square()
elif self.mode == "exp":
output_cam_trans = output_cam_trans * torch.expm1(d)
else:
raise ValueError(f"Invalid mode: {self.mode}")
else:
output_cam_trans = cam_trans
if not self.no_bounds:
output_cam_trans = output_cam_trans.clip(self.vmin, self.vmax)
return AdaptorOutput(value=output_cam_trans)
class QuaternionsAdaptor(UniCeptionAdaptorBase):
def __init__(
self, name: str, mode: str, normalize: bool, vmin: float = -np.inf, vmax: float = np.inf, *args, **kwargs
):
"""
Adaptor for the Quaternions or Pose head in UniCeption.
Notation of the quaternions: (x, y, z, w)
Args:
name (str): Name of the adaptor.
mode (str): Mode of the quaternions. Scales the quaternions accordingly before normalization. Currently only supports "linear".
normalize (bool): If True, will normalize the quaternions to unit quaternions.
vmin (float): Minimum value of the quaternions after scaling & before normalization to unit quaternions if required.
vmax (float): Maximum value of the quaternions after scaling & before normalization to unit quaternions if required.
"""
super().__init__(name, required_channels=4, *args, **kwargs)
self.mode = mode
self.normalize = normalize
self.vmin = vmin
self.vmax = vmax
self.no_bounds = (vmin == -float("inf")) and (vmax == float("inf"))
def forward(self, adaptor_input: AdaptorInput):
"""
Forward pass for the QuaternionsAdaptor.
Args:
adaptor_input (AdaptorInput): Input to the adaptor. (B x C ...)
Returns:
AdaptorOutput: Output of the adaptor.
"""
quaternions = adaptor_input.adaptor_feature
if self.mode == "linear":
output_quaternions = quaternions
else:
raise ValueError(f"Invalid mode: {self.mode}")
if not self.no_bounds:
output_quaternions = output_quaternions.clip(self.vmin, self.vmax)
if self.normalize:
# Normalize the quaternions to unit quaternions
output_quats_norm = output_quaternions.norm(dim=1, keepdim=True).clip(min=1e-8)
output_quaternions = output_quaternions / output_quats_norm
return AdaptorOutput(value=output_quaternions)
class CamTranslationPlusQuatsAdaptor(UniCeptionAdaptorBase):
def __init__(
self,
name: str,
# Cam translation adaptor
cam_trans_mode: str,
cam_trans_vmin: float,
cam_trans_vmax: float,
# Quaternions adaptor
quaternions_mode: str,
quaternions_normalize: bool,
quaternions_vmin: float,
quaternions_vmax: float,
*args,
**kwargs,
):
"""
Adaptor for the Camera Translation + Quaternions head in UniCeption.
"""
super().__init__(name, required_channels=7, *args, **kwargs)
self.cam_trans_adaptor = CamTranslationAdaptor(name, cam_trans_mode, cam_trans_vmin, cam_trans_vmax)
self.quaternions_adaptor = QuaternionsAdaptor(
name, quaternions_mode, quaternions_normalize, quaternions_vmin, quaternions_vmax
)
def forward(self, adaptor_input: AdaptorInput):
"""
Forward pass for the CamTranslationPlusQuatsAdaptor.
Args:
adaptor_input (AdaptorInput): Input to the adaptor. (B x C ...)
Returns:
AdaptorOutput: Output of the adaptor.
"""
cam_trans, quaternions = torch.split(adaptor_input.adaptor_feature, [3, 4], dim=1)
cam_trans_adaptor_input = AdaptorInput(adaptor_feature=cam_trans, output_shape_hw=adaptor_input.output_shape_hw)
quaternions_adaptor_input = AdaptorInput(
adaptor_feature=quaternions, output_shape_hw=adaptor_input.output_shape_hw
)
output_cam_trans = self.cam_trans_adaptor(cam_trans_adaptor_input)
output_quaternions = self.quaternions_adaptor(quaternions_adaptor_input)
output = torch.cat([output_cam_trans.value, output_quaternions.value], dim=1)
return AdaptorOutput(value=output)
class RayMapAdaptor(UniCeptionAdaptorBase):
def __init__(
self,
name: str,
# Ray origins adaptor
ray_origins_mode: str,
ray_origins_vmin: float,
ray_origins_vmax: float,
# Ray directions adaptor
ray_directions_mode: str,
ray_directions_normalize_to_unit_sphere: bool,
ray_directions_normalize_to_unit_image_plane: bool,
ray_directions_vmin: float,
ray_directions_vmax: float,
ray_directions_clamp_min_of_z_dir: bool,
ray_directions_z_dir_min: float,
*args,
**kwargs,
):
"""
Adaptor for the RayMap (RayOrigins + RayDirections) head in UniCeption.
"""
super().__init__(name, required_channels=6, *args, **kwargs)
self.ray_origins_adaptor = RayOriginsAdaptor(name, ray_origins_mode, ray_origins_vmin, ray_origins_vmax)
self.ray_directions_adaptor = RayDirectionsAdaptor(
name,
ray_directions_mode,
ray_directions_normalize_to_unit_sphere,
ray_directions_normalize_to_unit_image_plane,
ray_directions_vmin,
ray_directions_vmax,
ray_directions_clamp_min_of_z_dir,
ray_directions_z_dir_min,
)
def forward(self, adaptor_input: AdaptorInput):
"""
Forward pass for the RayMapAdaptor.
Args:
adaptor_input (AdaptorInput): Input to the adaptor. (B x C x H x W)
Returns:
AdaptorOutput: Output of the adaptor.
"""
ray_origins, ray_directions = torch.split(adaptor_input.adaptor_feature, 3, dim=1)
ray_origins_adaptor_input = AdaptorInput(
adaptor_feature=ray_origins, output_shape_hw=adaptor_input.output_shape_hw
)
ray_directions_adaptor_input = AdaptorInput(
adaptor_feature=ray_directions, output_shape_hw=adaptor_input.output_shape_hw
)
output_ray_origins = self.ray_origins_adaptor(ray_origins_adaptor_input)
output_ray_directions = self.ray_directions_adaptor(ray_directions_adaptor_input)
output_rays = torch.cat([output_ray_origins.value, output_ray_directions.value], dim=1)
return RegressionAdaptorOutput(value=output_rays)
class RayMapPlusDepthAdaptor(UniCeptionAdaptorBase):
def __init__(
self,
name: str,
# Ray origins adaptor
ray_origins_mode: str,
ray_origins_vmin: float,
ray_origins_vmax: float,
# Ray directions adaptor
ray_directions_mode: str,
ray_directions_normalize_to_unit_sphere: bool,
ray_directions_normalize_to_unit_image_plane: bool,
ray_directions_vmin: float,
ray_directions_vmax: float,
ray_directions_clamp_min_of_z_dir: bool,
ray_directions_z_dir_min: float,
# Depth adaptor
depth_mode: str,
depth_vmin: float,
depth_vmax: float,
*args,
**kwargs,
):
"""
Adaptor for the RayMap (RayOrigins + RayDirections) + Depth head in UniCeption.
"""
super().__init__(name, required_channels=7, *args, **kwargs)
self.ray_origins_adaptor = RayOriginsAdaptor(name, ray_origins_mode, ray_origins_vmin, ray_origins_vmax)
self.ray_directions_adaptor = RayDirectionsAdaptor(
name,
ray_directions_mode,
ray_directions_normalize_to_unit_sphere,
ray_directions_normalize_to_unit_image_plane,
ray_directions_vmin,
ray_directions_vmax,
ray_directions_clamp_min_of_z_dir,
ray_directions_z_dir_min,
)
self.depth_adaptor = DepthAdaptor(name, depth_mode, depth_vmin, depth_vmax)
def forward(self, adaptor_input: AdaptorInput):
"""
Forward pass for the RayMapPlusDepthAdaptor.
Args:
adaptor_input (AdaptorInput): Input to the adaptor. (B x C x H x W)
Returns:
AdaptorOutput: Output of the adaptor.
"""
ray_origins, ray_directions, ray_depths = torch.split(adaptor_input.adaptor_feature, [3, 3, 1], dim=1)
ray_origins_adaptor_input = AdaptorInput(
adaptor_feature=ray_origins, output_shape_hw=adaptor_input.output_shape_hw
)
ray_directions_adaptor_input = AdaptorInput(
adaptor_feature=ray_directions, output_shape_hw=adaptor_input.output_shape_hw
)
depth_adaptor_input = AdaptorInput(adaptor_feature=ray_depths, output_shape_hw=adaptor_input.output_shape_hw)
output_ray_origins = self.ray_origins_adaptor(ray_origins_adaptor_input)
output_ray_directions = self.ray_directions_adaptor(ray_directions_adaptor_input)
output_depth = self.depth_adaptor(depth_adaptor_input)
output = torch.cat([output_ray_origins.value, output_ray_directions.value, output_depth.value], dim=1)
return RegressionAdaptorOutput(value=output)
class RayMapPlusDepthPlusQuatsAdaptor(UniCeptionAdaptorBase):
def __init__(
self,
name: str,
# Ray origins adaptor
ray_origins_mode: str,
ray_origins_vmin: float,
ray_origins_vmax: float,
# Ray directions adaptor
ray_directions_mode: str,
ray_directions_normalize_to_unit_sphere: bool,
ray_directions_normalize_to_unit_image_plane: bool,
ray_directions_vmin: float,
ray_directions_vmax: float,
ray_directions_clamp_min_of_z_dir: bool,
ray_directions_z_dir_min: float,
# Depth adaptor
depth_mode: str,
depth_vmin: float,
depth_vmax: float,
# Quaternions adaptor
quaternions_mode: str,
quaternions_normalize: bool,
quaternions_vmin: float,
quaternions_vmax: float,
*args,
**kwargs,
):
"""
Adaptor for the RayMap (RayOrigins + RayDirections) + Depth + Quaternions head in UniCeption.
"""
super().__init__(name, required_channels=11, *args, **kwargs)
self.ray_origins_adaptor = RayOriginsAdaptor(name, ray_origins_mode, ray_origins_vmin, ray_origins_vmax)
self.ray_directions_adaptor = RayDirectionsAdaptor(
name,
ray_directions_mode,
ray_directions_normalize_to_unit_sphere,
ray_directions_normalize_to_unit_image_plane,
ray_directions_vmin,
ray_directions_vmax,
ray_directions_clamp_min_of_z_dir,
ray_directions_z_dir_min,
)
self.depth_adaptor = DepthAdaptor(name, depth_mode, depth_vmin, depth_vmax)
self.quaternions_adaptor = QuaternionsAdaptor(
name, quaternions_mode, quaternions_normalize, quaternions_vmin, quaternions_vmax
)
def forward(self, adaptor_input: AdaptorInput):
"""
Forward pass for the RayMapPlusDepthPlusQuatsAdaptor.
Args:
adaptor_input (AdaptorInput): Input to the adaptor. (B x C x H x W)
Returns:
AdaptorOutput: Output of the adaptor.
"""
ray_origins, ray_directions, ray_depths, ray_quaternions = torch.split(
adaptor_input.adaptor_feature, [3, 3, 1, 4], dim=1
)
ray_origins_adaptor_input = AdaptorInput(
adaptor_feature=ray_origins, output_shape_hw=adaptor_input.output_shape_hw
)
ray_directions_adaptor_input = AdaptorInput(
adaptor_feature=ray_directions, output_shape_hw=adaptor_input.output_shape_hw
)
depth_adaptor_input = AdaptorInput(adaptor_feature=ray_depths, output_shape_hw=adaptor_input.output_shape_hw)
quaternions_adaptor_input = AdaptorInput(
adaptor_feature=ray_quaternions, output_shape_hw=adaptor_input.output_shape_hw
)
output_ray_origins = self.ray_origins_adaptor(ray_origins_adaptor_input)
output_ray_directions = self.ray_directions_adaptor(ray_directions_adaptor_input)
output_ray_depths = self.depth_adaptor(depth_adaptor_input)
output_ray_quaternions = self.quaternions_adaptor(quaternions_adaptor_input)
output = torch.cat(
[
output_ray_origins.value,
output_ray_directions.value,
output_ray_depths.value,
output_ray_quaternions.value,
],
dim=1,
)
return RegressionAdaptorOutput(value=output)
class ConfidenceAdaptor(UniCeptionAdaptorBase):
def __init__(
self,
name: str,
confidence_type: str,
vmin: float,
vmax: float,
*args,
**kwargs,
):
"""
Adaptor for the Confidence head in UniCeption.
Args:
name (str): Name of the adaptor.
confidence_type (str): Type of the confidence, either
- exp: Exponential confidence
- sigmoid: Sigmoid confidence
vmin (float): Minimum value of the confidence.
vmax (float): Maximum value of the confidence.
"""
super().__init__(name, required_channels=1, *args, **kwargs)
self.confidence_type = confidence_type
self.vmin = vmin
self.vmax = vmax
assert vmin < vmax, "vmin must be less than vmax"
if confidence_type == "sigmoid":
assert isfinite(vmin) and isfinite(vmax), "vmin and vmax must be finite for sigmoid confidence"
assert vmin >= 0
def forward(self, adaptor_input: AdaptorInput):
"""
Forward pass for the ConfidenceAdaptor.
Args:
adaptor_input (AdaptorInput): Input to the adaptor. (B x C x H x W)
Returns:
AdaptorOutput: Output of the adaptor.
"""
x = adaptor_input.adaptor_feature
if self.confidence_type == "exp":
confidence = self.vmin + x.exp().clip(max=self.vmax - self.vmin)
return RegressionAdaptorOutput(value=confidence)
elif self.confidence_type == "sigmoid":
confidence = torch.sigmoid(x)
confidence = confidence * (self.vmax - self.vmin) + self.vmin
return RegressionAdaptorOutput(value=confidence)
elif self.confidence_type == "softmax":
B, C, H, W = x.shape
confidence = torch.nn.functional.softmax(x.reshape(B, C, -1), dim=-1).reshape(B, C, H, W) * (H * W)
return RegressionAdaptorOutput(value=confidence)
class Covariance2DAdaptor(UniCeptionAdaptorBase):
def __init__(
self,
name: str,
parametrization: str = "exp_tanh",
*args,
**kwargs,
):
"""
Adaptor for the Covariance2D head in UniCeption.
"""
super().__init__(name, required_channels=3, *args, **kwargs)
self.parametrization = parametrization
def forward(self, adaptor_input: AdaptorInput):
x = adaptor_input.adaptor_feature
if self.parametrization == "exp_tanh":
c1, c2, s = torch.split(x, 1, dim=1)
diag_exponent = (c1 + c2) / 2
tanh_s = s.tanh()
cov = torch.cat([c1.exp(), c2.exp(), tanh_s * torch.exp(diag_exponent)], dim=1)
log_det = c1 + c2 + torch.log(1 - torch.square(tanh_s) + 1e-8)
inv_coeff = 1 / (1 - torch.square(tanh_s) + 1e-8)
inv_cov = inv_coeff * torch.cat(
[torch.exp(-c1), torch.exp(-c2), -tanh_s * torch.exp(-diag_exponent)], dim=1
)
else:
raise ValueError(f"Invalid parametrization: {self.parametrization}")
return Covariance2DAdaptorOutput(covariance=cov, log_det=log_det, inv_covariance=inv_cov)
class MaskAdaptor(UniCeptionAdaptorBase):
def __init__(
self,
name: str,
*args,
**kwargs,
):
"""
Adaptor for the Mask head in UniCeption.
"""
super().__init__(name, required_channels=1, *args, **kwargs)
def forward(self, adaptor_input: AdaptorInput):
x = adaptor_input.adaptor_feature
mask = torch.sigmoid(x)
return MaskAdaptorOutput(logits=x, mask=mask)
class ValueWithConfidenceAdaptor(UniCeptionAdaptorBase):
def __init__(
self,
name: str,
value_adaptor: UniCeptionAdaptorBase,
confidence_adaptor: UniCeptionAdaptorBase,
*args,
**kwargs,
):
"""
Adaptor for the Value with Confidence head in UniCeption.
Args:
name (str): Name of the adaptor.
value_adaptor (UniCeptionAdaptorBase): Adaptor for the value.
confidence_adaptor (UniCeptionAdaptorBase): Adaptor for the confidence.
"""
super().__init__(
name,
required_channels=value_adaptor.required_channels + confidence_adaptor.required_channels,
*args,
**kwargs,
)
self.value_adaptor = value_adaptor
self.confidence_adaptor = confidence_adaptor
def forward(self, adaptor_input: AdaptorInput):
value_input, confidence_input = torch.split(
adaptor_input.adaptor_feature,
[self.value_adaptor.required_channels, self.confidence_adaptor.required_channels],
dim=1,
)
value_adaptor_input = AdaptorInput(adaptor_feature=value_input, output_shape_hw=adaptor_input.output_shape_hw)
confidence_adaptor_input = AdaptorInput(
adaptor_feature=confidence_input, output_shape_hw=adaptor_input.output_shape_hw
)
value_output = self.value_adaptor(value_adaptor_input)
confidence_output = self.confidence_adaptor(confidence_adaptor_input)
return RegressionWithConfidenceAdaptorOutput(value=value_output.value, confidence=confidence_output.value)
class FlowWithConfidenceAdaptor(ValueWithConfidenceAdaptor):
def __init__(
self,
name: str,
# Flow adaptor
flow_mean: torch.Tensor,
flow_std: torch.Tensor,
base_shape: Tuple[int, int],
scale_strategy: str,
output_normalized_coordinate: bool,
# Confidence adaptor
confidence_type: str,
vmin: float,
vmax: float,
*args,
**kwargs,
):
"""
Adaptor for the Flow with Confidence head in UniCeption.
"""
flow_adaptor = FlowAdaptor(
name=f"{name}",
flow_mean=flow_mean,
flow_std=flow_std,
base_shape=base_shape,
scale_strategy=scale_strategy,
output_normalized_coordinate=output_normalized_coordinate,
)
confidence_adaptor = ConfidenceAdaptor(
name=f"{name}_confidence", confidence_type=confidence_type, vmin=vmin, vmax=vmax
)
super().__init__(name, value_adaptor=flow_adaptor, confidence_adaptor=confidence_adaptor, *args, **kwargs)
class PointMapWithConfidenceAdaptor(ValueWithConfidenceAdaptor):
def __init__(
self,
name: str,
# Pointmap adaptor
pointmap_mode: str,
pointmap_vmin: float,
pointmap_vmax: float,
# Confidence adaptor
confidence_type: str,
confidence_vmin: float,
confidence_vmax: float,
*args,
**kwargs,
):
"""
Adaptor for the PointMap with Confidence head in UniCeption.
"""
pointmap_adaptor = PointMapAdaptor(name=f"{name}", mode=pointmap_mode, vmin=pointmap_vmin, vmax=pointmap_vmax)
confidence_adaptor = ConfidenceAdaptor(
name=f"{name}_confidence", confidence_type=confidence_type, vmin=confidence_vmin, vmax=confidence_vmax
)
super().__init__(name, value_adaptor=pointmap_adaptor, confidence_adaptor=confidence_adaptor, *args, **kwargs)
class RayDirectionsPlusDepthwithConfidenceAdaptor(ValueWithConfidenceAdaptor):
def __init__(
self,
name: str,
# Ray directions adaptor
ray_directions_mode: str,
ray_directions_normalize_to_unit_sphere: bool,
ray_directions_normalize_to_unit_image_plane: bool,
ray_directions_vmin: float,
ray_directions_vmax: float,
ray_directions_clamp_min_of_z_dir: bool,
ray_directions_z_dir_min: float,
# Depth adaptor
depth_mode: str,
depth_vmin: float,
depth_vmax: float,
# Confidence adaptor
confidence_type: str,
confidence_vmin: float,
confidence_vmax: float,
*args,
**kwargs,
):
"""
Adaptor for the RayDirections + Depth with Confidence head in UniCeption.
"""
ray_directions_plus_depth_adaptor = RayDirectionsPlusDepthAdaptor(
name=f"{name}",
ray_directions_mode=ray_directions_mode,
ray_directions_normalize_to_unit_sphere=ray_directions_normalize_to_unit_sphere,
ray_directions_normalize_to_unit_image_plane=ray_directions_normalize_to_unit_image_plane,
ray_directions_vmin=ray_directions_vmin,
ray_directions_vmax=ray_directions_vmax,
ray_directions_clamp_min_of_z_dir=ray_directions_clamp_min_of_z_dir,
ray_directions_z_dir_min=ray_directions_z_dir_min,
depth_mode=depth_mode,
depth_vmin=depth_vmin,
depth_vmax=depth_vmax,
)
confidence_adaptor = ConfidenceAdaptor(
name=f"{name}_confidence", confidence_type=confidence_type, vmin=confidence_vmin, vmax=confidence_vmax
)
super().__init__(
name,
value_adaptor=ray_directions_plus_depth_adaptor,
confidence_adaptor=confidence_adaptor,
*args,
**kwargs,
)
class RayMapPlusDepthwithConfidenceAdaptor(ValueWithConfidenceAdaptor):
def __init__(
self,
name: str,
# RayMap adaptor
ray_origins_mode: str,
ray_origins_vmin: float,
ray_origins_vmax: float,
ray_directions_mode: str,
ray_directions_normalize_to_unit_sphere: bool,
ray_directions_normalize_to_unit_image_plane: bool,
ray_directions_vmin: float,
ray_directions_vmax: float,
ray_directions_clamp_min_of_z_dir: bool,
ray_directions_z_dir_min: float,
# Depth adaptor
depth_mode: str,
depth_vmin: float,
depth_vmax: float,
# Confidence adaptor
confidence_type: str,
confidence_vmin: float,
confidence_vmax: float,
*args,
**kwargs,
):
"""
Adaptor for the RayMap (RayOrigins + RayDirections) + Depth with Confidence head in UniCeption.
"""
raymap_plus_depth_adaptor = RayMapPlusDepthAdaptor(
name=f"{name}",
ray_origins_mode=ray_origins_mode,
ray_origins_vmin=ray_origins_vmin,
ray_origins_vmax=ray_origins_vmax,
ray_directions_mode=ray_directions_mode,
ray_directions_normalize_to_unit_sphere=ray_directions_normalize_to_unit_sphere,
ray_directions_normalize_to_unit_image_plane=ray_directions_normalize_to_unit_image_plane,
ray_directions_vmin=ray_directions_vmin,
ray_directions_vmax=ray_directions_vmax,
ray_directions_clamp_min_of_z_dir=ray_directions_clamp_min_of_z_dir,
ray_directions_z_dir_min=ray_directions_z_dir_min,
depth_mode=depth_mode,
depth_vmin=depth_vmin,
depth_vmax=depth_vmax,
)
confidence_adaptor = ConfidenceAdaptor(
name=f"{name}_confidence", confidence_type=confidence_type, vmin=confidence_vmin, vmax=confidence_vmax
)
super().__init__(
name, value_adaptor=raymap_plus_depth_adaptor, confidence_adaptor=confidence_adaptor, *args, **kwargs
)
class RayMapPlusDepthPlusQuatswithConfidenceAdaptor(ValueWithConfidenceAdaptor):
def __init__(
self,
name: str,
# RayMap adaptor
ray_origins_mode: str,
ray_origins_vmin: float,
ray_origins_vmax: float,
ray_directions_mode: str,
ray_directions_normalize_to_unit_sphere: bool,
ray_directions_normalize_to_unit_image_plane: bool,
ray_directions_vmin: float,
ray_directions_vmax: float,
ray_directions_clamp_min_of_z_dir: bool,
ray_directions_z_dir_min: float,
# Depth adaptor
depth_mode: str,
depth_vmin: float,
depth_vmax: float,
# Quaternions adaptor
quaternions_mode: str,
quaternions_normalize: bool,
quaternions_vmin: float,
quaternions_vmax: float,
# Confidence adaptor
confidence_type: str,
confidence_vmin: float,
confidence_vmax: float,
*args,
**kwargs,
):
"""
Adaptor for the RayMap (RayOrigins + RayDirections) + Depth + Quaternions with Confidence head in UniCeption.
"""
raymap_plus_depth_plus_quats_adaptor = RayMapPlusDepthPlusQuatsAdaptor(
name=f"{name}",
ray_origins_mode=ray_origins_mode,
ray_origins_vmin=ray_origins_vmin,
ray_origins_vmax=ray_origins_vmax,
ray_directions_mode=ray_directions_mode,
ray_directions_normalize_to_unit_sphere=ray_directions_normalize_to_unit_sphere,
ray_directions_normalize_to_unit_image_plane=ray_directions_normalize_to_unit_image_plane,
ray_directions_vmin=ray_directions_vmin,
ray_directions_vmax=ray_directions_vmax,
ray_directions_clamp_min_of_z_dir=ray_directions_clamp_min_of_z_dir,
ray_directions_z_dir_min=ray_directions_z_dir_min,
depth_mode=depth_mode,
depth_vmin=depth_vmin,
depth_vmax=depth_vmax,
quaternions_mode=quaternions_mode,
quaternions_normalize=quaternions_normalize,
quaternions_vmin=quaternions_vmin,
quaternions_vmax=quaternions_vmax,
)
confidence_adaptor = ConfidenceAdaptor(
name=f"{name}_confidence", confidence_type=confidence_type, vmin=confidence_vmin, vmax=confidence_vmax
)
super().__init__(
name,
value_adaptor=raymap_plus_depth_plus_quats_adaptor,
confidence_adaptor=confidence_adaptor,
*args,
**kwargs,
)
class ValueWithMaskAdaptor(UniCeptionAdaptorBase):
def __init__(
self,
name: str,
value_adaptor: UniCeptionAdaptorBase,
mask_adaptor: UniCeptionAdaptorBase,
*args,
**kwargs,
):
"""
Adaptor for the Value with Mask head in UniCeption.
Args:
name (str): Name of the adaptor.
value_adaptor (UniCeptionAdaptorBase): Adaptor for the value.
mask_adaptor (UniCeptionAdaptorBase): Adaptor for the mask.
"""
super().__init__(
name,
required_channels=value_adaptor.required_channels + mask_adaptor.required_channels,
*args,
**kwargs,
)
self.value_adaptor = value_adaptor
self.mask_adaptor = mask_adaptor
def forward(self, adaptor_input: AdaptorInput):
value_input, mask_input = torch.split(
adaptor_input.adaptor_feature,
[self.value_adaptor.required_channels, self.mask_adaptor.required_channels],
dim=1,
)
value_adaptor_input = AdaptorInput(adaptor_feature=value_input, output_shape_hw=adaptor_input.output_shape_hw)
mask_adaptor_input = AdaptorInput(adaptor_feature=mask_input, output_shape_hw=adaptor_input.output_shape_hw)
value_output = self.value_adaptor(value_adaptor_input)
mask_output = self.mask_adaptor(mask_adaptor_input)
return RegressionWithMaskAdaptorOutput(
value=value_output.value, mask=mask_output.mask, logits=mask_output.logits
)
class PointMapWithMaskAdaptor(ValueWithMaskAdaptor):
def __init__(
self,
name: str,
# Pointmap adaptor
pointmap_mode: str,
pointmap_vmin: float,
pointmap_vmax: float,
*args,
**kwargs,
):
"""
Adaptor for the PointMap with Confidence head in UniCeption.
"""
pointmap_adaptor = PointMapAdaptor(name=f"{name}", mode=pointmap_mode, vmin=pointmap_vmin, vmax=pointmap_vmax)
mask_adaptor = MaskAdaptor(name=f"{name}_mask")
super().__init__(name, value_adaptor=pointmap_adaptor, mask_adaptor=mask_adaptor, *args, **kwargs)
class RayDirectionsPlusDepthwithMaskAdaptor(ValueWithMaskAdaptor):
def __init__(
self,
name: str,
# Ray directions adaptor
ray_directions_mode: str,
ray_directions_normalize_to_unit_sphere: bool,
ray_directions_normalize_to_unit_image_plane: bool,
ray_directions_vmin: float,
ray_directions_vmax: float,
ray_directions_clamp_min_of_z_dir: bool,
ray_directions_z_dir_min: float,
# Depth adaptor
depth_mode: str,
depth_vmin: float,
depth_vmax: float,
*args,
**kwargs,
):
"""
Adaptor for the RayDirections + Depth with Mask head in UniCeption.
"""
ray_directions_plus_depth_adaptor = RayDirectionsPlusDepthAdaptor(
name=f"{name}",
ray_directions_mode=ray_directions_mode,
ray_directions_normalize_to_unit_sphere=ray_directions_normalize_to_unit_sphere,
ray_directions_normalize_to_unit_image_plane=ray_directions_normalize_to_unit_image_plane,
ray_directions_vmin=ray_directions_vmin,
ray_directions_vmax=ray_directions_vmax,
ray_directions_clamp_min_of_z_dir=ray_directions_clamp_min_of_z_dir,
ray_directions_z_dir_min=ray_directions_z_dir_min,
depth_mode=depth_mode,
depth_vmin=depth_vmin,
depth_vmax=depth_vmax,
)
mask_adaptor = MaskAdaptor(name=f"{name}_mask")
super().__init__(
name, value_adaptor=ray_directions_plus_depth_adaptor, mask_adaptor=mask_adaptor, *args, **kwargs
)
class RayMapPlusDepthwithMaskAdaptor(ValueWithMaskAdaptor):
def __init__(
self,
name: str,
# RayMap adaptor
ray_origins_mode: str,
ray_origins_vmin: float,
ray_origins_vmax: float,
ray_directions_mode: str,
ray_directions_normalize_to_unit_sphere: bool,
ray_directions_normalize_to_unit_image_plane: bool,
ray_directions_vmin: float,
ray_directions_vmax: float,
ray_directions_clamp_min_of_z_dir: bool,
ray_directions_z_dir_min: float,
# Depth adaptor
depth_mode: str,
depth_vmin: float,
depth_vmax: float,
*args,
**kwargs,
):
"""
Adaptor for the RayMap (RayOrigins + RayDirections) + Depth with Mask head in UniCeption.
"""
raymap_plus_depth_adaptor = RayMapPlusDepthAdaptor(
name=f"{name}",
ray_origins_mode=ray_origins_mode,
ray_origins_vmin=ray_origins_vmin,
ray_origins_vmax=ray_origins_vmax,
ray_directions_mode=ray_directions_mode,
ray_directions_normalize_to_unit_sphere=ray_directions_normalize_to_unit_sphere,
ray_directions_normalize_to_unit_image_plane=ray_directions_normalize_to_unit_image_plane,
ray_directions_vmin=ray_directions_vmin,
ray_directions_vmax=ray_directions_vmax,
ray_directions_clamp_min_of_z_dir=ray_directions_clamp_min_of_z_dir,
ray_directions_z_dir_min=ray_directions_z_dir_min,
depth_mode=depth_mode,
depth_vmin=depth_vmin,
depth_vmax=depth_vmax,
)
mask_adaptor = MaskAdaptor(name=f"{name}_mask")
super().__init__(name, value_adaptor=raymap_plus_depth_adaptor, mask_adaptor=mask_adaptor, *args, **kwargs)
class RayMapPlusDepthPlusQuatswithMaskAdaptor(ValueWithMaskAdaptor):
def __init__(
self,
name: str,
# RayMap adaptor
ray_origins_mode: str,
ray_origins_vmin: float,
ray_origins_vmax: float,
ray_directions_mode: str,
ray_directions_normalize_to_unit_sphere: bool,
ray_directions_normalize_to_unit_image_plane: bool,
ray_directions_vmin: float,
ray_directions_vmax: float,
ray_directions_clamp_min_of_z_dir: bool,
ray_directions_z_dir_min: float,
# Depth adaptor
depth_mode: str,
depth_vmin: float,
depth_vmax: float,
# Quaternions adaptor
quaternions_mode: str,
quaternions_normalize: bool,
quaternions_vmin: float,
quaternions_vmax: float,
*args,
**kwargs,
):
"""
Adaptor for the RayMap (RayOrigins + RayDirections) + Depth + Quaternions with Mask head in UniCeption.
"""
raymap_plus_depth_plus_quats_adaptor = RayMapPlusDepthPlusQuatsAdaptor(
name=f"{name}",
ray_origins_mode=ray_origins_mode,
ray_origins_vmin=ray_origins_vmin,
ray_origins_vmax=ray_origins_vmax,
ray_directions_mode=ray_directions_mode,
ray_directions_normalize_to_unit_sphere=ray_directions_normalize_to_unit_sphere,
ray_directions_normalize_to_unit_image_plane=ray_directions_normalize_to_unit_image_plane,
ray_directions_vmin=ray_directions_vmin,
ray_directions_vmax=ray_directions_vmax,
ray_directions_clamp_min_of_z_dir=ray_directions_clamp_min_of_z_dir,
ray_directions_z_dir_min=ray_directions_z_dir_min,
depth_mode=depth_mode,
depth_vmin=depth_vmin,
depth_vmax=depth_vmax,
quaternions_mode=quaternions_mode,
quaternions_normalize=quaternions_normalize,
quaternions_vmin=quaternions_vmin,
quaternions_vmax=quaternions_vmax,
)
mask_adaptor = MaskAdaptor(name=f"{name}_mask")
super().__init__(
name, value_adaptor=raymap_plus_depth_plus_quats_adaptor, mask_adaptor=mask_adaptor, *args, **kwargs
)
class ValueWithConfidenceAndMaskAdaptor(UniCeptionAdaptorBase):
def __init__(
self,
name: str,
value_adaptor: UniCeptionAdaptorBase,
confidence_adaptor: UniCeptionAdaptorBase,
mask_adaptor: UniCeptionAdaptorBase,
*args,
**kwargs,
):
"""
Adaptor for the Value with Confidence & Mask head in UniCeption.
Args:
name (str): Name of the adaptor.
value_adaptor (UniCeptionAdaptorBase): Adaptor for the value.
mask_adaptor (UniCeptionAdaptorBase): Adaptor for the mask.
"""
super().__init__(
name,
required_channels=value_adaptor.required_channels
+ confidence_adaptor.required_channels
+ mask_adaptor.required_channels,
*args,
**kwargs,
)
self.value_adaptor = value_adaptor
self.confidence_adaptor = confidence_adaptor
self.mask_adaptor = mask_adaptor
def forward(self, adaptor_input: AdaptorInput):
value_input, confidence_input, mask_input = torch.split(
adaptor_input.adaptor_feature,
[
self.value_adaptor.required_channels,
self.confidence_adaptor.required_channels,
self.mask_adaptor.required_channels,
],
dim=1,
)
value_adaptor_input = AdaptorInput(adaptor_feature=value_input, output_shape_hw=adaptor_input.output_shape_hw)
confidence_adaptor_input = AdaptorInput(
adaptor_feature=confidence_input, output_shape_hw=adaptor_input.output_shape_hw
)
mask_adaptor_input = AdaptorInput(adaptor_feature=mask_input, output_shape_hw=adaptor_input.output_shape_hw)
value_output = self.value_adaptor(value_adaptor_input)
confidence_output = self.confidence_adaptor(confidence_adaptor_input)
mask_output = self.mask_adaptor(mask_adaptor_input)
return RegressionWithConfidenceAndMaskAdaptorOutput(
value=value_output.value,
confidence=confidence_output.value,
mask=mask_output.mask,
logits=mask_output.logits,
)
class PointMapWithConfidenceAndMaskAdaptor(ValueWithConfidenceAndMaskAdaptor):
def __init__(
self,
name: str,
# PointMap adaptor
pointmap_mode: str,
pointmap_vmin: float,
pointmap_vmax: float,
# Confidence adaptor
confidence_type: str,
confidence_vmin: float,
confidence_vmax: float,
*args,
**kwargs,
):
"""
Adaptor for the PointMap with Confidence & Mask head in UniCeption.
"""
pointmap_adaptor = PointMapAdaptor(name=f"{name}", mode=pointmap_mode, vmin=pointmap_vmin, vmax=pointmap_vmax)
confidence_adaptor = ConfidenceAdaptor(
name=f"{name}_confidence", confidence_type=confidence_type, vmin=confidence_vmin, vmax=confidence_vmax
)
mask_adaptor = MaskAdaptor(name=f"{name}_mask")
super().__init__(
name,
value_adaptor=pointmap_adaptor,
confidence_adaptor=confidence_adaptor,
mask_adaptor=mask_adaptor,
*args,
**kwargs,
)
class RayDirectionsPlusDepthwithConfidenceAndMaskAdaptor(ValueWithConfidenceAndMaskAdaptor):
def __init__(
self,
name: str,
# Ray directions adaptor
ray_directions_mode: str,
ray_directions_normalize_to_unit_sphere: bool,
ray_directions_normalize_to_unit_image_plane: bool,
ray_directions_vmin: float,
ray_directions_vmax: float,
ray_directions_clamp_min_of_z_dir: bool,
ray_directions_z_dir_min: float,
# Depth adaptor
depth_mode: str,
depth_vmin: float,
depth_vmax: float,
# Confidence adaptor
confidence_type: str,
confidence_vmin: float,
confidence_vmax: float,
*args,
**kwargs,
):
"""
Adaptor for the RayDirections + Depth with Confidence & Mask head in UniCeption.
"""
ray_directions_plus_depth_adaptor = RayDirectionsPlusDepthAdaptor(
name=f"{name}",
ray_directions_mode=ray_directions_mode,
ray_directions_normalize_to_unit_sphere=ray_directions_normalize_to_unit_sphere,
ray_directions_normalize_to_unit_image_plane=ray_directions_normalize_to_unit_image_plane,
ray_directions_vmin=ray_directions_vmin,
ray_directions_vmax=ray_directions_vmax,
ray_directions_clamp_min_of_z_dir=ray_directions_clamp_min_of_z_dir,
ray_directions_z_dir_min=ray_directions_z_dir_min,
depth_mode=depth_mode,
depth_vmin=depth_vmin,
depth_vmax=depth_vmax,
)
confidence_adaptor = ConfidenceAdaptor(
name=f"{name}_confidence", confidence_type=confidence_type, vmin=confidence_vmin, vmax=confidence_vmax
)
mask_adaptor = MaskAdaptor(name=f"{name}_mask")
super().__init__(
name,
value_adaptor=ray_directions_plus_depth_adaptor,
confidence_adaptor=confidence_adaptor,
mask_adaptor=mask_adaptor,
*args,
**kwargs,
)
class RayMapPlusDepthwithConfidenceAndMaskAdaptor(ValueWithConfidenceAndMaskAdaptor):
def __init__(
self,
name: str,
# RayMap adaptor
ray_origins_mode: str,
ray_origins_vmin: float,
ray_origins_vmax: float,
ray_directions_mode: str,
ray_directions_normalize_to_unit_sphere: bool,
ray_directions_normalize_to_unit_image_plane: bool,
ray_directions_vmin: float,
ray_directions_vmax: float,
ray_directions_clamp_min_of_z_dir: bool,
ray_directions_z_dir_min: float,
# Depth adaptor
depth_mode: str,
depth_vmin: float,
depth_vmax: float,
# Confidence adaptor
confidence_type: str,
confidence_vmin: float,
confidence_vmax: float,
*args,
**kwargs,
):
"""
Adaptor for the RayMap (RayOrigins + RayDirections) + Depth with Confidence & Mask head in UniCeption.
"""
raymap_plus_depth_adaptor = RayMapPlusDepthAdaptor(
name=f"{name}",
ray_origins_mode=ray_origins_mode,
ray_origins_vmin=ray_origins_vmin,
ray_origins_vmax=ray_origins_vmax,
ray_directions_mode=ray_directions_mode,
ray_directions_normalize_to_unit_sphere=ray_directions_normalize_to_unit_sphere,
ray_directions_normalize_to_unit_image_plane=ray_directions_normalize_to_unit_image_plane,
ray_directions_vmin=ray_directions_vmin,
ray_directions_vmax=ray_directions_vmax,
ray_directions_clamp_min_of_z_dir=ray_directions_clamp_min_of_z_dir,
ray_directions_z_dir_min=ray_directions_z_dir_min,
depth_mode=depth_mode,
depth_vmin=depth_vmin,
depth_vmax=depth_vmax,
)
confidence_adaptor = ConfidenceAdaptor(
name=f"{name}_confidence", confidence_type=confidence_type, vmin=confidence_vmin, vmax=confidence_vmax
)
mask_adaptor = MaskAdaptor(name=f"{name}_mask")
super().__init__(
name,
value_adaptor=raymap_plus_depth_adaptor,
confidence_adaptor=confidence_adaptor,
mask_adaptor=mask_adaptor,
*args,
**kwargs,
)
class RayMapPlusDepthPlusQuatswithConfidenceAndMaskAdaptor(ValueWithConfidenceAndMaskAdaptor):
def __init__(
self,
name: str,
# RayMap adaptor
ray_origins_mode: str,
ray_origins_vmin: float,
ray_origins_vmax: float,
ray_directions_mode: str,
ray_directions_normalize_to_unit_sphere: bool,
ray_directions_normalize_to_unit_image_plane: bool,
ray_directions_vmin: float,
ray_directions_vmax: float,
ray_directions_clamp_min_of_z_dir: bool,
ray_directions_z_dir_min: float,
# Depth adaptor
depth_mode: str,
depth_vmin: float,
depth_vmax: float,
# Quaternions adaptor
quaternions_mode: str,
quaternions_normalize: bool,
quaternions_vmin: float,
quaternions_vmax: float,
# Confidence adaptor
confidence_type: str,
confidence_vmin: float,
confidence_vmax: float,
*args,
**kwargs,
):
"""
Adaptor for the RayMap (RayOrigins + RayDirections) + Depth + Quaternions with Confidence & Mask head in UniCeption.
"""
raymap_plus_depth_plus_quats_adaptor = RayMapPlusDepthPlusQuatsAdaptor(
name=f"{name}",
ray_origins_mode=ray_origins_mode,
ray_origins_vmin=ray_origins_vmin,
ray_origins_vmax=ray_origins_vmax,
ray_directions_mode=ray_directions_mode,
ray_directions_normalize_to_unit_sphere=ray_directions_normalize_to_unit_sphere,
ray_directions_normalize_to_unit_image_plane=ray_directions_normalize_to_unit_image_plane,
ray_directions_vmin=ray_directions_vmin,
ray_directions_vmax=ray_directions_vmax,
ray_directions_clamp_min_of_z_dir=ray_directions_clamp_min_of_z_dir,
ray_directions_z_dir_min=ray_directions_z_dir_min,
depth_mode=depth_mode,
depth_vmin=depth_vmin,
depth_vmax=depth_vmax,
quaternions_mode=quaternions_mode,
quaternions_normalize=quaternions_normalize,
quaternions_vmin=quaternions_vmin,
quaternions_vmax=quaternions_vmax,
)
confidence_adaptor = ConfidenceAdaptor(
name=f"{name}_confidence", confidence_type=confidence_type, vmin=confidence_vmin, vmax=confidence_vmax
)
mask_adaptor = MaskAdaptor(name=f"{name}_mask")
super().__init__(
name,
value_adaptor=raymap_plus_depth_plus_quats_adaptor,
confidence_adaptor=confidence_adaptor,
mask_adaptor=mask_adaptor,
*args,
**kwargs,
)