infinity1096
initial commit
c8b42eb
raw
history blame
13.4 kB
"""
Base class of the UniFlowMatch training system.
"""
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
import torch
@dataclass
class UFMFlowFieldOutput:
"""
Output interface of the flow field prediction network.
"""
flow_output: torch.Tensor
flow_covariance: Optional[torch.Tensor] = None
flow_covariance_inv: Optional[torch.Tensor] = None
flow_covariance_log_det: Optional[torch.Tensor] = None
@dataclass
class UFMMaskFieldOutput:
"""
Output interface of the mask prediction network.
"""
mask: torch.Tensor
logits: torch.Tensor
@dataclass
class UFMClassificationRefinementOutput:
"""
Output interface of the classification refinement network.
"""
# the flow output of the regression step, with shape [B, 2, H, W].
# it is the initial flow output, which is used to get the first local feature maps for the residual.
regression_flow_output: torch.Tensor
# residual is the output of the refinement step, with shape [B, 2, H, W].
# it is added to the initial flow output to get the final flow output.
residual: torch.Tensor
# log_softmax is
# the logarithm of
# the softmax of
# similarity of the pixel's feature
# to that of its neighborhood of the flow prediction
# in the other image.
# it have shape [B, H, W, P, P], the similarity of pixel at [b, h, w] to its neighborhood [P, P] centered at regression_flow_output[b, h, w]
log_softmax: torch.Tensor
feature_map_0: torch.Tensor
feature_map_1: torch.Tensor
@dataclass
class UFMOutputInterface:
"""
Output interface of the UniFlowMatch training system.
"""
flow: Optional[UFMFlowFieldOutput] = None
# Refinement output (for training and visualization)
classification_refinement: Optional[UFMClassificationRefinementOutput] = None
# auxiliary ouputs
covisibility: Optional[UFMMaskFieldOutput] = None
from uniception.models.encoders.image_normalizations import IMAGE_NORMALIZATION_DICT
from uniflowmatch.utils.flow_resizing import (
AutomaticShapeSelection,
ResizeToFixedManipulation,
unmap_predicted_channels,
unmap_predicted_flow,
)
class UniFlowMatchModelsBase(torch.nn.Module):
def __init__(self, inference_resolution: Optional[Union[List[Tuple[int, int]], Tuple[int, int]]] = None):
super().__init__()
if inference_resolution is None:
inference_resolution = [(560, 420)]
if isinstance(inference_resolution[0], int): # Handle the case for single resolution
inference_resolution = [inference_resolution]
self.inference_resolution = inference_resolution
self.image_scaler = AutomaticShapeSelection(
*[ResizeToFixedManipulation((resolution[1], resolution[0])) for resolution in inference_resolution],
strategy="closest_aspect", # will inference on the trained aspect ratio that is closest to the input image 1
)
def forward(self, view1, view2) -> UFMOutputInterface:
"""
Forward interface of correspondence prediction networks.
Args:
- view1 (Dict[str, Any]): Input view 1
- img (torch.Tensor): BCHW image tensor normalized according to encoder's data_norm_type
- instance (List[int]): List of instance indices, or id of the input image
- data_norm_type (str): Data normalization type, see uniception.models.encoders.IMAGE_NORMALIZATION_DICT
- view2 (Dict[str, Any]): Input view 2
- (same structure as view1)
Returns:
- Dict[str, Any]: Output results
- flow [Required] (Dict[str, torch.Tensor]): Flow output
- [Required] flow_output (torch.Tensor): Flow output tensor, BCHW
- [Optional] flow_covariance
- [Optional] flow_covariance_inv
- [Optional] flow_covariance_log_det
- occlusion [Optional] (Dict[str, torch.Tensor]): Occlusion output
- [Optional] mask
- [Optional] logits
"""
raise NotImplementedError("Implement this method in derived classes")
def get_parameter_groups(self) -> Dict[str, torch.nn.ParameterList]:
"""
Get parameter groups for optimizer. This methods guides the optimizer
to apply correct learning rate to different parts of the model.
Returns:
- Dict[str, torch.nn.ParameterList]: Parameter groups for optimizer
"""
raise NotImplementedError("Implement this method in derived classes")
def predict_correspondences_batched(
self,
source_image: torch.Tensor,
target_image: torch.Tensor,
data_norm_type: Optional[str] = None,
) -> UFMOutputInterface:
"""
Predict correspondences between source and target images.
This method generates random correspondences for demonstration purposes.
Args:
source_image (torch.Tensor): Tensor of shape BCHW/BHWC/CHW/HWC, dtype of uint8 or float32 The source image tensor.
target_image (torch.Tensor): Tensor of shape BCHW/BHWC/CHW/HWC, dtype of uint8 or float32 The target image tensor.
Returns:
UFMOutputInterface:
- flow
- flow_output (torch.Tensor): Tensor of shape (B, 2, H, W) representing the flow output in the original image space.
- occlusion
- mask (torch.Tensor): Tensor of shape (B, H, W) representing the covisibility in range [0, 1]. 1 = fully covisible, 0 = fully occluded or out of range.
"""
assert isinstance(source_image, torch.Tensor) and isinstance(
target_image, torch.Tensor
), "source_image and target_image must be torch.Tensors"
assert source_image.dim() in [3, 4], "source_image must have dimensions 3 or 4"
assert target_image.dim() in [3, 4], "target_image must have dimensions 3 or 4"
batched = source_image.dim() == 4
if not batched:
# add batch dimension
source_image = source_image.unsqueeze(0)
target_image = target_image.unsqueeze(0)
# check the channel
if source_image.shape[1] == 3 and target_image.shape[1] == 3:
pass # do nothing because the image is in BCHW format
elif source_image.shape[-1] == 3 and target_image.shape[-1] == 3:
# convert to BCHW
source_image = source_image.permute(0, 3, 1, 2)
target_image = target_image.permute(0, 3, 1, 2)
else:
raise ValueError("source_image and target_image must have 3 channels in either BCHW or BHWC format")
required_data_norm_type = self.encoder.data_norm_type
image_device = source_image.device
if source_image.dtype == torch.float32:
assert data_norm_type is not None, "data_norm_type must be provided for float32 images"
assert (
data_norm_type in IMAGE_NORMALIZATION_DICT
), f"data_norm_type must be one of {list(IMAGE_NORMALIZATION_DICT.keys())}"
if data_norm_type != required_data_norm_type:
# apply transformation to the correct from the old normalization
prev_mean = (
IMAGE_NORMALIZATION_DICT[data_norm_type].mean.view(1, 3, 1, 1).to(image_device, non_blocking=True)
)
prev_std = (
IMAGE_NORMALIZATION_DICT[data_norm_type].std.view(1, 3, 1, 1).to(image_device, non_blocking=True)
)
mean = (
IMAGE_NORMALIZATION_DICT[required_data_norm_type]
.mean.view(1, 3, 1, 1)
.to(image_device, non_blocking=True)
)
std = (
IMAGE_NORMALIZATION_DICT[required_data_norm_type]
.std.view(1, 3, 1, 1)
.to(image_device, non_blocking=True)
)
source_image = source_image * (prev_std / std) + (prev_mean - mean) / std
target_image = target_image * (prev_std / std) + (prev_mean - mean) / std
elif source_image.dtype == torch.uint8:
# convert into float32 and apply normalization
mean = (
IMAGE_NORMALIZATION_DICT[required_data_norm_type]
.mean.view(1, 3, 1, 1)
.to(image_device, non_blocking=True)
)
std = (
IMAGE_NORMALIZATION_DICT[required_data_norm_type]
.std.view(1, 3, 1, 1)
.to(image_device, non_blocking=True)
)
source_image = (source_image.float() / 255.0 - mean) / std
target_image = (target_image.float() / 255.0 - mean) / std
else:
raise ValueError("source_image and target_image must be of type torch.float32 or torch.uint8")
# Now all the inputs are normalized according to the model's encoder and organized in BCHW format
return self._predict_correspondences_batched(source_image, target_image)
def _predict_correspondences_batched(
self,
source_image: torch.Tensor,
target_image: torch.Tensor,
) -> UFMOutputInterface:
assert isinstance(source_image, torch.Tensor), "source_image must be a torch.Tensor"
assert isinstance(target_image, torch.Tensor), "target_image must be a torch.Tensor"
assert source_image.dim() == 4, "source_image must be of shape (B, 3, H, W)"
assert target_image.dim() == 4, "target_image must be of shape (B, 3, H, W)"
assert source_image.shape[1] == 3, "source_image must be of shape (B, 3, H, W)"
assert target_image.shape[1] == 3, "target_image must be of shape (B, 3, H, W)"
assert source_image.dtype == torch.float32, "source_image must be of dtype torch.float32"
assert target_image.dtype == torch.float32, "target_image must be of dtype torch.float32"
source_shape_hw = source_image.shape[2:]
target_shape_hw = target_image.shape[2:]
# Scale images to one of the model's trained resolution.
(
scaled_img0, # The scaled source image
scaled_img1, # The scaled target image
img0_region_source, # Where in the source image is captured in the scaled image
img1_region_source, # Where in the target image is captured in the scaled image
img0_region_representation, # Region in the source image is captured in this region in the scaled image
img1_region_representation, # same as above, but for the target image
) = self.image_scaler(source_image.permute(0, 2, 3, 1), target_image.permute(0, 2, 3, 1))
scaled_img0 = scaled_img0.permute(0, 3, 1, 2)
scaled_img1 = scaled_img1.permute(0, 3, 1, 2)
# Run a forward pass
view1 = {"img": scaled_img0, "symmetrized": False, "data_norm_type": self.encoder.data_norm_type}
view2 = {"img": scaled_img1, "symmetrized": False, "data_norm_type": self.encoder.data_norm_type}
with torch.no_grad():
with torch.autocast("cuda", torch.bfloat16):
result = self(view1, view2)
rescaled_ufm_result = UFMOutputInterface()
# rescale flow
flow_output = result.flow.flow_output
flow_unmapped, flow_unmap_validity = unmap_predicted_flow(
flow_output,
img0_region_representation,
img1_region_representation,
img0_region_source,
img1_region_source,
source_shape_hw,
target_shape_hw,
)
rescaled_ufm_result.flow = UFMFlowFieldOutput(
flow_output=flow_unmapped,
)
# rescale covariance if it exists
if result.flow.flow_covariance is not None:
flow_covariance = result.flow.flow_covariance
flow_covariance_unmapped, _ = unmap_predicted_channels(
flow_covariance,
img0_region_representation,
img0_region_source,
source_shape_hw,
)
# scale covariance in the correct way
w_pred = scaled_img0.shape[3]
h_pred = scaled_img0.shape[2]
w_final = source_shape_hw[1]
h_final = source_shape_hw[0]
w_ratio, h_ratio = w_final / w_pred, h_final / h_pred
flow_covariance_unmapped *= (
torch.tensor([w_ratio**2, h_ratio**2, w_ratio * h_ratio])
.view(1, 3, 1, 1)
.to(flow_covariance_unmapped.device)
)
rescaled_ufm_result.flow.flow_covariance = flow_covariance_unmapped
# rescale occlusion if it exists
if result.covisibility is not None:
occlusion_mask = result.covisibility.mask
covisibility_unmapped, _ = unmap_predicted_channels(
occlusion_mask,
img0_region_representation,
img0_region_source,
source_shape_hw,
)
covisibility_unmapped = covisibility_unmapped.squeeze(1)
rescaled_ufm_result.covisibility = UFMMaskFieldOutput(mask=covisibility_unmapped, logits=None)
return rescaled_ufm_result