Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Base class of the UniFlowMatch training system. | |
| """ | |
| from dataclasses import dataclass | |
| from typing import Dict, List, Optional, Tuple, Union | |
| import torch | |
| class UFMFlowFieldOutput: | |
| """ | |
| Output interface of the flow field prediction network. | |
| """ | |
| flow_output: torch.Tensor | |
| flow_covariance: Optional[torch.Tensor] = None | |
| flow_covariance_inv: Optional[torch.Tensor] = None | |
| flow_covariance_log_det: Optional[torch.Tensor] = None | |
| class UFMMaskFieldOutput: | |
| """ | |
| Output interface of the mask prediction network. | |
| """ | |
| mask: torch.Tensor | |
| logits: torch.Tensor | |
| class UFMClassificationRefinementOutput: | |
| """ | |
| Output interface of the classification refinement network. | |
| """ | |
| # the flow output of the regression step, with shape [B, 2, H, W]. | |
| # it is the initial flow output, which is used to get the first local feature maps for the residual. | |
| regression_flow_output: torch.Tensor | |
| # residual is the output of the refinement step, with shape [B, 2, H, W]. | |
| # it is added to the initial flow output to get the final flow output. | |
| residual: torch.Tensor | |
| # log_softmax is | |
| # the logarithm of | |
| # the softmax of | |
| # similarity of the pixel's feature | |
| # to that of its neighborhood of the flow prediction | |
| # in the other image. | |
| # it have shape [B, H, W, P, P], the similarity of pixel at [b, h, w] to its neighborhood [P, P] centered at regression_flow_output[b, h, w] | |
| log_softmax: torch.Tensor | |
| feature_map_0: torch.Tensor | |
| feature_map_1: torch.Tensor | |
| class UFMOutputInterface: | |
| """ | |
| Output interface of the UniFlowMatch training system. | |
| """ | |
| flow: Optional[UFMFlowFieldOutput] = None | |
| # Refinement output (for training and visualization) | |
| classification_refinement: Optional[UFMClassificationRefinementOutput] = None | |
| # auxiliary ouputs | |
| covisibility: Optional[UFMMaskFieldOutput] = None | |
| from uniception.models.encoders.image_normalizations import IMAGE_NORMALIZATION_DICT | |
| from uniflowmatch.utils.flow_resizing import ( | |
| AutomaticShapeSelection, | |
| ResizeToFixedManipulation, | |
| unmap_predicted_channels, | |
| unmap_predicted_flow, | |
| ) | |
| class UniFlowMatchModelsBase(torch.nn.Module): | |
| def __init__(self, inference_resolution: Optional[Union[List[Tuple[int, int]], Tuple[int, int]]] = None): | |
| super().__init__() | |
| if inference_resolution is None: | |
| inference_resolution = [(560, 420)] | |
| if isinstance(inference_resolution[0], int): # Handle the case for single resolution | |
| inference_resolution = [inference_resolution] | |
| self.inference_resolution = inference_resolution | |
| self.image_scaler = AutomaticShapeSelection( | |
| *[ResizeToFixedManipulation((resolution[1], resolution[0])) for resolution in inference_resolution], | |
| strategy="closest_aspect", # will inference on the trained aspect ratio that is closest to the input image 1 | |
| ) | |
| def forward(self, view1, view2) -> UFMOutputInterface: | |
| """ | |
| Forward interface of correspondence prediction networks. | |
| Args: | |
| - view1 (Dict[str, Any]): Input view 1 | |
| - img (torch.Tensor): BCHW image tensor normalized according to encoder's data_norm_type | |
| - instance (List[int]): List of instance indices, or id of the input image | |
| - data_norm_type (str): Data normalization type, see uniception.models.encoders.IMAGE_NORMALIZATION_DICT | |
| - view2 (Dict[str, Any]): Input view 2 | |
| - (same structure as view1) | |
| Returns: | |
| - Dict[str, Any]: Output results | |
| - flow [Required] (Dict[str, torch.Tensor]): Flow output | |
| - [Required] flow_output (torch.Tensor): Flow output tensor, BCHW | |
| - [Optional] flow_covariance | |
| - [Optional] flow_covariance_inv | |
| - [Optional] flow_covariance_log_det | |
| - occlusion [Optional] (Dict[str, torch.Tensor]): Occlusion output | |
| - [Optional] mask | |
| - [Optional] logits | |
| """ | |
| raise NotImplementedError("Implement this method in derived classes") | |
| def get_parameter_groups(self) -> Dict[str, torch.nn.ParameterList]: | |
| """ | |
| Get parameter groups for optimizer. This methods guides the optimizer | |
| to apply correct learning rate to different parts of the model. | |
| Returns: | |
| - Dict[str, torch.nn.ParameterList]: Parameter groups for optimizer | |
| """ | |
| raise NotImplementedError("Implement this method in derived classes") | |
| def predict_correspondences_batched( | |
| self, | |
| source_image: torch.Tensor, | |
| target_image: torch.Tensor, | |
| data_norm_type: Optional[str] = None, | |
| ) -> UFMOutputInterface: | |
| """ | |
| Predict correspondences between source and target images. | |
| This method generates random correspondences for demonstration purposes. | |
| Args: | |
| source_image (torch.Tensor): Tensor of shape BCHW/BHWC/CHW/HWC, dtype of uint8 or float32 The source image tensor. | |
| target_image (torch.Tensor): Tensor of shape BCHW/BHWC/CHW/HWC, dtype of uint8 or float32 The target image tensor. | |
| Returns: | |
| UFMOutputInterface: | |
| - flow | |
| - flow_output (torch.Tensor): Tensor of shape (B, 2, H, W) representing the flow output in the original image space. | |
| - occlusion | |
| - mask (torch.Tensor): Tensor of shape (B, H, W) representing the covisibility in range [0, 1]. 1 = fully covisible, 0 = fully occluded or out of range. | |
| """ | |
| assert isinstance(source_image, torch.Tensor) and isinstance( | |
| target_image, torch.Tensor | |
| ), "source_image and target_image must be torch.Tensors" | |
| assert source_image.dim() in [3, 4], "source_image must have dimensions 3 or 4" | |
| assert target_image.dim() in [3, 4], "target_image must have dimensions 3 or 4" | |
| batched = source_image.dim() == 4 | |
| if not batched: | |
| # add batch dimension | |
| source_image = source_image.unsqueeze(0) | |
| target_image = target_image.unsqueeze(0) | |
| # check the channel | |
| if source_image.shape[1] == 3 and target_image.shape[1] == 3: | |
| pass # do nothing because the image is in BCHW format | |
| elif source_image.shape[-1] == 3 and target_image.shape[-1] == 3: | |
| # convert to BCHW | |
| source_image = source_image.permute(0, 3, 1, 2) | |
| target_image = target_image.permute(0, 3, 1, 2) | |
| else: | |
| raise ValueError("source_image and target_image must have 3 channels in either BCHW or BHWC format") | |
| required_data_norm_type = self.encoder.data_norm_type | |
| image_device = source_image.device | |
| if source_image.dtype == torch.float32: | |
| assert data_norm_type is not None, "data_norm_type must be provided for float32 images" | |
| assert ( | |
| data_norm_type in IMAGE_NORMALIZATION_DICT | |
| ), f"data_norm_type must be one of {list(IMAGE_NORMALIZATION_DICT.keys())}" | |
| if data_norm_type != required_data_norm_type: | |
| # apply transformation to the correct from the old normalization | |
| prev_mean = ( | |
| IMAGE_NORMALIZATION_DICT[data_norm_type].mean.view(1, 3, 1, 1).to(image_device, non_blocking=True) | |
| ) | |
| prev_std = ( | |
| IMAGE_NORMALIZATION_DICT[data_norm_type].std.view(1, 3, 1, 1).to(image_device, non_blocking=True) | |
| ) | |
| mean = ( | |
| IMAGE_NORMALIZATION_DICT[required_data_norm_type] | |
| .mean.view(1, 3, 1, 1) | |
| .to(image_device, non_blocking=True) | |
| ) | |
| std = ( | |
| IMAGE_NORMALIZATION_DICT[required_data_norm_type] | |
| .std.view(1, 3, 1, 1) | |
| .to(image_device, non_blocking=True) | |
| ) | |
| source_image = source_image * (prev_std / std) + (prev_mean - mean) / std | |
| target_image = target_image * (prev_std / std) + (prev_mean - mean) / std | |
| elif source_image.dtype == torch.uint8: | |
| # convert into float32 and apply normalization | |
| mean = ( | |
| IMAGE_NORMALIZATION_DICT[required_data_norm_type] | |
| .mean.view(1, 3, 1, 1) | |
| .to(image_device, non_blocking=True) | |
| ) | |
| std = ( | |
| IMAGE_NORMALIZATION_DICT[required_data_norm_type] | |
| .std.view(1, 3, 1, 1) | |
| .to(image_device, non_blocking=True) | |
| ) | |
| source_image = (source_image.float() / 255.0 - mean) / std | |
| target_image = (target_image.float() / 255.0 - mean) / std | |
| else: | |
| raise ValueError("source_image and target_image must be of type torch.float32 or torch.uint8") | |
| # Now all the inputs are normalized according to the model's encoder and organized in BCHW format | |
| return self._predict_correspondences_batched(source_image, target_image) | |
| def _predict_correspondences_batched( | |
| self, | |
| source_image: torch.Tensor, | |
| target_image: torch.Tensor, | |
| ) -> UFMOutputInterface: | |
| assert isinstance(source_image, torch.Tensor), "source_image must be a torch.Tensor" | |
| assert isinstance(target_image, torch.Tensor), "target_image must be a torch.Tensor" | |
| assert source_image.dim() == 4, "source_image must be of shape (B, 3, H, W)" | |
| assert target_image.dim() == 4, "target_image must be of shape (B, 3, H, W)" | |
| assert source_image.shape[1] == 3, "source_image must be of shape (B, 3, H, W)" | |
| assert target_image.shape[1] == 3, "target_image must be of shape (B, 3, H, W)" | |
| assert source_image.dtype == torch.float32, "source_image must be of dtype torch.float32" | |
| assert target_image.dtype == torch.float32, "target_image must be of dtype torch.float32" | |
| source_shape_hw = source_image.shape[2:] | |
| target_shape_hw = target_image.shape[2:] | |
| # Scale images to one of the model's trained resolution. | |
| ( | |
| scaled_img0, # The scaled source image | |
| scaled_img1, # The scaled target image | |
| img0_region_source, # Where in the source image is captured in the scaled image | |
| img1_region_source, # Where in the target image is captured in the scaled image | |
| img0_region_representation, # Region in the source image is captured in this region in the scaled image | |
| img1_region_representation, # same as above, but for the target image | |
| ) = self.image_scaler(source_image.permute(0, 2, 3, 1), target_image.permute(0, 2, 3, 1)) | |
| scaled_img0 = scaled_img0.permute(0, 3, 1, 2) | |
| scaled_img1 = scaled_img1.permute(0, 3, 1, 2) | |
| # Run a forward pass | |
| view1 = {"img": scaled_img0, "symmetrized": False, "data_norm_type": self.encoder.data_norm_type} | |
| view2 = {"img": scaled_img1, "symmetrized": False, "data_norm_type": self.encoder.data_norm_type} | |
| with torch.no_grad(): | |
| with torch.autocast("cuda", torch.bfloat16): | |
| result = self(view1, view2) | |
| rescaled_ufm_result = UFMOutputInterface() | |
| # rescale flow | |
| flow_output = result.flow.flow_output | |
| flow_unmapped, flow_unmap_validity = unmap_predicted_flow( | |
| flow_output, | |
| img0_region_representation, | |
| img1_region_representation, | |
| img0_region_source, | |
| img1_region_source, | |
| source_shape_hw, | |
| target_shape_hw, | |
| ) | |
| rescaled_ufm_result.flow = UFMFlowFieldOutput( | |
| flow_output=flow_unmapped, | |
| ) | |
| # rescale covariance if it exists | |
| if result.flow.flow_covariance is not None: | |
| flow_covariance = result.flow.flow_covariance | |
| flow_covariance_unmapped, _ = unmap_predicted_channels( | |
| flow_covariance, | |
| img0_region_representation, | |
| img0_region_source, | |
| source_shape_hw, | |
| ) | |
| # scale covariance in the correct way | |
| w_pred = scaled_img0.shape[3] | |
| h_pred = scaled_img0.shape[2] | |
| w_final = source_shape_hw[1] | |
| h_final = source_shape_hw[0] | |
| w_ratio, h_ratio = w_final / w_pred, h_final / h_pred | |
| flow_covariance_unmapped *= ( | |
| torch.tensor([w_ratio**2, h_ratio**2, w_ratio * h_ratio]) | |
| .view(1, 3, 1, 1) | |
| .to(flow_covariance_unmapped.device) | |
| ) | |
| rescaled_ufm_result.flow.flow_covariance = flow_covariance_unmapped | |
| # rescale occlusion if it exists | |
| if result.covisibility is not None: | |
| occlusion_mask = result.covisibility.mask | |
| covisibility_unmapped, _ = unmap_predicted_channels( | |
| occlusion_mask, | |
| img0_region_representation, | |
| img0_region_source, | |
| source_shape_hw, | |
| ) | |
| covisibility_unmapped = covisibility_unmapped.squeeze(1) | |
| rescaled_ufm_result.covisibility = UFMMaskFieldOutput(mask=covisibility_unmapped, logits=None) | |
| return rescaled_ufm_result | |