import os from typing import Any, Dict, List, Optional, Tuple import torch from huggingface_hub import PyTorchModelHubMixin from torch import nn # Only enable flash attention backend from uniception.models.encoders import ViTEncoderInput, feature_returner_encoder_factory from uniception.models.info_sharing import INFO_SHARING_CLASSES, MultiViewTransformerInput from uniception.models.prediction_heads.adaptors import ( ConfidenceAdaptor, Covariance2DAdaptor, FlowAdaptor, FlowWithConfidenceAdaptor, MaskAdaptor, ) from uniception.models.prediction_heads.base import AdaptorMap, PredictionHeadInput, PredictionHeadLayeredInput from uniception.models.prediction_heads.dpt import DPTFeature, DPTRegressionProcessor from uniception.models.prediction_heads.mlp_feature import MLPFeature from uniception.models.prediction_heads.moge_conv import MoGeConvFeature from uniflowmatch.models.base import ( UFMClassificationRefinementOutput, UFMFlowFieldOutput, UFMMaskFieldOutput, UFMOutputInterface, UniFlowMatchModelsBase, ) from uniflowmatch.models.unet_encoder import UNet from uniflowmatch.models.utils import get_meshgrid_torch CLASSNAME_TO_ADAPTOR_CLASS = { "FlowWithConfidenceAdaptor": FlowWithConfidenceAdaptor, "FlowAdaptor": FlowAdaptor, "MaskAdaptor": MaskAdaptor, "Covariance2DAdaptor": Covariance2DAdaptor, "ConfidenceAdaptor": ConfidenceAdaptor, } # dust3r data structure for reducing passing duplicate images through the encoder def is_symmetrized(gt1, gt2): "Function to check if input pairs are symmetrized, i.e., (a, b) and (b, a) always exist in the input" x = gt1["instance"] y = gt2["instance"] if len(x) == len(y) and len(x) == 1: return False # special case of batchsize 1 ok = True for i in range(0, len(x), 2): ok = ok and (x[i] == y[i + 1]) and (x[i + 1] == y[i]) return ok def interleave(tensor1, tensor2): "Interleave two tensors along the first dimension (used to avoid redundant encoding for symmetrized pairs)" res1 = torch.stack((tensor1, tensor2), dim=1).flatten(0, 1) res2 = torch.stack((tensor2, tensor1), dim=1).flatten(0, 1) return res1, res2 def modify_state_dict(original_state_dict, mappings): """ Modify state dict keys according to replacement mappings Args: original_state_dict: Loaded checkpoint state dict mappings: Dictionary of {old_key_substr: new_key_substr_or_None} Returns: Modified state dictionary with updated keys """ new_state_dict = {} for k, v in original_state_dict.items(): new_key = None skip = False # Check for all possible replacements for replace_key, replace_value in mappings.items(): if replace_key in k: if replace_value is None: skip = True break # Skip this key entirely else: new_key = k.replace(replace_key, replace_value) break # Only apply first matching replacement if skip: continue new_state_dict[new_key if new_key is not None else k] = v return new_state_dict class UniFlowMatch(UniFlowMatchModelsBase, PyTorchModelHubMixin): """ UniFlowMatch model. """ def __init__( self, # Encoder configurations encoder_str: str, encoder_kwargs: Dict[str, Any], # Info sharing & output head structure configurations info_sharing_and_head_structure: str = "dual+single", # only dual+single is supported # Information sharing configurations info_sharing_str: str = "global_attention", info_sharing_kwargs: Dict[str, Any] = {}, # skip-connections between encoder and info-sharing encoder_skip_connection: Optional[List[int]] = None, info_sharing_skip_connection: Optional[List[int]] = None, # Prediction Heads & Adaptors head_type: str = "dpt", feature_head_kwargs: Dict[str, Any] = {}, adaptors_kwargs: Dict[str, Any] = {}, # Load Pretrained Weights pretrained_checkpoint_path: Optional[str] = None, # Inference Settings inference_resolution: Optional[Tuple[int, int]] = (560, 420), # WH *args, **kwargs, ): """ Initialize the UniFlowMarch Model - encoder_str (str): Encoder string - encoder_kwargs (Dict[str, Any]): Encoder configurations - info_sharing_and_head_structure (str): Info sharing and head structure configurations - "dual+single": Dual view info sharing and single view prediction head - info_sharing_str (str): Info sharing method - "global_attention_transformer": Global attention transformer - info_sharing_kwargs (Dict[str, Any]): Info sharing configurations """ UniFlowMatchModelsBase.__init__(self, inference_resolution=inference_resolution, *args, **kwargs) PyTorchModelHubMixin.__init__(self) # assertion on architectures assert info_sharing_and_head_structure == "dual+single", "Only dual+single is supported now" # initialize the skip-connections self.encoder_skip_connection = encoder_skip_connection self.info_sharing_skip_connection = info_sharing_skip_connection # initialize encoder self.encoder: nn.Module = feature_returner_encoder_factory(encoder_str, **encoder_kwargs) # initialize info-sharing module assert head_type != "linear", "Linear head is not supported, because it have major disadvantage to DPTs" self.head_type = head_type self.info_sharing: nn.Module = INFO_SHARING_CLASSES[info_sharing_str][1](**info_sharing_kwargs) self.head1: nn.Module = self._initialize_prediction_heads(head_type, feature_head_kwargs, adaptors_kwargs) # load pretrained weights if pretrained_checkpoint_path is not None: ckpt = torch.load(pretrained_checkpoint_path, map_location="cpu") if "state_dict" in ckpt: # we are loading from training checkpoint directly. model_state_dict = ckpt["state_dict"] model_state_dict = { k[6:]: v for k, v in model_state_dict.items() if k.startswith("model.") } # remove "model." prefix model_state_dict = modify_state_dict( model_state_dict, {"feature_matching_proj": None, "encoder.model.mask_token": None} ) self.load_state_dict(model_state_dict, strict=True) else: model_state_dict = ckpt["model"] load_result = self.load_state_dict(model_state_dict, strict=False) assert len(load_result.missing_keys) == 0, f"Missing keys: {load_result.missing_keys}" @classmethod def from_pretrained_ckpt(cls, pretrained_model_name_or_path, strict=True, **kw): if os.path.isfile(pretrained_model_name_or_path): ckpt = torch.load(pretrained_model_name_or_path, map_location="cpu") # remove base_pretrained_checkpoint_path from the model args if "base_pretrained_checkpoint_path" in ckpt["model_args"]: ckpt["model_args"].pop("base_pretrained_checkpoint_path") # convert old model args into new definition if "img_size" in ckpt["model_args"]: # we are loading from a old benchmark checkpoint print("Converting from a old benchmark checkpoint") model_args = { # Encoder args "encoder_str": ckpt["model_args"]["encoder_str"], "encoder_kwargs": ckpt["model_args"]["encoder_kwargs"], # Info-sharing args "info_sharing_and_head_structure": "dual+single", "info_sharing_str": ckpt["model_args"]["info_sharing_type"], "info_sharing_kwargs": { "name": "info_sharing", "input_embed_dim": ckpt["model_args"]["input_embed_dim"], "num_views": 2, "use_rand_idx_pe_for_non_reference_views": False, "depth": ckpt["model_args"]["num_layers"], "dim": ckpt["model_args"]["transformer_dim"], "num_heads": ckpt["model_args"]["num_heads"], "mlp_ratio": ckpt["model_args"]["mlp_ratio"], "qkv_bias": ckpt["model_args"]["qkv_bias"], "qk_norm": ckpt["model_args"]["qk_norm"], "custom_positional_encoding": ckpt["model_args"]["position_encoding"], "norm_intermediate": ckpt["model_args"]["normalize_intermediate"], "indices": ckpt["model_args"]["returned_intermediate_layers"], }, # flow head args "head_type": "dpt", "feature_head_kwargs": ckpt["model_args"]["feature_head_kwargs"], "adaptors_kwargs": ckpt["model_args"]["adaptors_kwargs"], } if "covocc_feature_head_kwargs" in ckpt["model_args"]: # if the model has a covocc head, we need to convert it to the new format model_args["uncertainty_head_type"] = "dpt" model_args["uncertainty_head_kwargs"] = { "dpt_feature": ckpt["model_args"]["covocc_feature_head_kwargs"]["dpt_feature"], "dpt_processor": ckpt["model_args"]["covocc_feature_head_kwargs"]["dpt_regr_processor"], } model_args["uncertainty_adaptors_kwargs"] = { "flow_cov": ckpt["model_args"]["covocc_adaptors_kwargs"]["flow_cov"] } ckpt["model_args"] = model_args # Update the old weights into the current format ckpt["model"] = modify_state_dict( ckpt["model"], { "covocc_head.dpt_feature": "uncertainty_head.0.0", "covocc_head.dpt_regr_processor": "uncertainty_head.0.1", "covocc_head.dpt_segm_processor": None, "feature_matching_proj": None, "encoder.model.mask_token": None, }, ) # remove the ket "pretrained_backbone_checkpoint_path" from the model args if "pretrained_backbone_checkpoint_path" in ckpt["model_args"]: ckpt["model_args"].pop("pretrained_backbone_checkpoint_path") model = cls(**ckpt["model_args"]) model.load_state_dict(ckpt["model"], strict=strict) return model else: raise ValueError(f"Pretrained model {pretrained_model_name_or_path} not found.") def _initialize_prediction_heads( self, head_type: str, feature_head_kwargs: Dict[str, Any], adaptors_kwargs: Dict[str, Any] ): """ Initialize prediction heads and adaptors Args: - head_type (str): Head type, either "dpt" or "linear" - feature_head_kwargs (Dict[str, Any]): Feature head configurations - adaptors_kwargs (Dict[str, Any]): Adaptors configurations Returns: - nn.Module: output head + adaptors """ feature_processor: nn.Module if head_type == "dpt": feature_processor = nn.Sequential( DPTFeature(**feature_head_kwargs["dpt_feature"]), DPTRegressionProcessor(**feature_head_kwargs["dpt_processor"]), ) elif head_type == "moge_conv": feature_processor = MoGeConvFeature(**feature_head_kwargs) else: raise ValueError(f"Head type {head_type} not supported.") adaptors = self._initialize_adaptors(adaptors_kwargs) return nn.Sequential(feature_processor, AdaptorMap(*adaptors.values())) def _initialize_adaptors(self, adaptors_kwargs: Dict[str, Any]): """ Initialize a dict of adaptors Args: - adaptors_kwargs (Dict[str, Any]): Adaptors configurations Returns: - Dict[str, nn.Module]: dict of adaptors, from adaptor's name to the adaptor """ return { name: CLASSNAME_TO_ADAPTOR_CLASS[configs["class"]](**configs["kwargs"]) for name, configs in adaptors_kwargs.items() } def _encode_image_pairs(self, img1, img2, data_norm_type): "Encode two different batches of images (each batch can have different image shape)" if img1.shape[-2:] == img2.shape[-2:]: encoder_input = ViTEncoderInput(image=torch.cat((img1, img2), dim=0), data_norm_type=data_norm_type) encoder_output = self.encoder(encoder_input) out_list, out2_list = [], [] for encoder_output_ in encoder_output: out, out2 = encoder_output_.features.chunk(2, dim=0) out_list.append(out) out2_list.append(out2) else: raise NotImplementedError("Unequal Image sizes are not supported now") return out_list, out2_list def _encode_symmetrized(self, view1, view2, symmetrized=False): "Encode image pairs accounting for symmetrization, i.e., (a, b) and (b, a) always exist in the input" img1 = view1["img"] img2 = view2["img"] feat1_list, feat2_list = [], [] if symmetrized: # Computing half of forward pass' # modified in conjunction with UFM for not copying the images again. # used to be: feat1, feat2 = self._encode_image_pairs(img1[::2], img2[::2], data_norm_type=view1["data_norm_type"]) # be very carefult with this!!! feat1_list_, feat2_list_ = self._encode_image_pairs( img1[::2], img2[::2], data_norm_type=view1["data_norm_type"] ) for feat1, feat2 in zip(feat1_list_, feat2_list_): feat1, feat2 = interleave(feat1, feat2) feat1_list.append(feat1) feat2_list.append(feat2) else: feat1_list, feat2_list = self._encode_image_pairs(img1, img2, data_norm_type=view1["data_norm_type"]) return feat1_list, feat2_list def forward(self, view1, view2) -> UFMOutputInterface: """ Forward interface of correspondence prediction networks. Args: - view1 (Dict[str, Any]): Input view 1 - img (torch.Tensor): BCHW image tensor normalized according to encoder's data_norm_type - instance (List[int]): List of instance indices, or id of the input image - data_norm_type (str): Data normalization type, see uniception.models.encoders.IMAGE_NORMALIZATION_DICT - view2 (Dict[str, Any]): Input view 2 - (same structure as view1) Returns: - Dict[str, Any]: Output results - flow [Required] (Dict[str, torch.Tensor]): Flow output - [Required] flow_output (torch.Tensor): Flow output tensor, BCHW - [Optional] flow_covariance - [Optional] flow_covariance_inv - [Optional] flow_covariance_log_det - covisibility [Optional] (Dict[str, torch.Tensor]): Covisibility output - [Optional] mask - [Optional] logits """ # Get input shapes _, _, height1, width1 = view1["img"].shape _, _, height2, width2 = view2["img"].shape shape1 = (int(height1), int(width1)) shape2 = (int(height2), int(width2)) # Encode the two images --> Each feat output: BCHW features (batch_size, feature_dim, feature_height, feature_width) feat1_list, feat2_list = self._encode_symmetrized(view1, view2, view1["symmetrized"]) # Pass the features through the info_sharing info_sharing_input = MultiViewTransformerInput(features=[feat1_list[-1], feat2_list[-1]]) final_info_sharing_multi_view_feat, intermediate_info_sharing_multi_view_feat = self.info_sharing( info_sharing_input ) info_sharing_outputs = { "1": [ feat1_list[-1].float().contiguous(), intermediate_info_sharing_multi_view_feat[0].features[0].float().contiguous(), intermediate_info_sharing_multi_view_feat[1].features[0].float().contiguous(), final_info_sharing_multi_view_feat.features[0].float().contiguous(), ], "2": [ feat2_list[-1].float().contiguous(), intermediate_info_sharing_multi_view_feat[0].features[1].float().contiguous(), intermediate_info_sharing_multi_view_feat[1].features[1].float().contiguous(), final_info_sharing_multi_view_feat.features[1].float().contiguous(), ], } result = UFMOutputInterface() # The prediction need precision, so we disable any autocasting here with torch.autocast("cuda", torch.float32): # run the collected info_sharing features through the prediction heads head_output1 = self._downstream_head(1, info_sharing_outputs, shape1) if "flow" in head_output1: # output is flow only result.flow = UFMFlowFieldOutput(flow_output=head_output1["flow"].value) if "flow_cov" in head_output1: result.flow.flow_covariance = head_output1["flow_cov"].covariance result.flow.flow_covariance_inv = head_output1["flow_cov"].inv_covariance result.flow.flow_covariance_log_det = head_output1["flow_cov"].log_det if "non_occluded_mask" in head_output1: result.covisibility = UFMMaskFieldOutput( mask=head_output1["non_occluded_mask"].mask, logits=head_output1["non_occluded_mask"].logits, ) return result def _downstream_head(self, head_num, decout, img_shape): "Run the respective prediction heads" # if self.info_sharing_and_head_structure == "dual+single": head = getattr(self, f"head{head_num}") if self.head_type == "linear": head_input = PredictionHeadInput(last_feature=decout[f"{head_num}"]) elif self.head_type in ["dpt", "moge_conv"]: head_input = PredictionHeadLayeredInput(list_features=decout[f"{head_num}"], target_output_shape=img_shape) return head(head_input) def get_parameter_groups(self) -> Dict[str, torch.nn.ParameterList]: """ Get parameter groups for optimizer. This methods guides the optimizer to apply correct learning rate to different parts of the model. Returns: - Dict[str, torch.nn.ParameterList]: Parameter groups for optimizer """ return { "encoder": torch.nn.ParameterList(self.encoder.parameters()), "info_sharing": torch.nn.ParameterList(self.info_sharing.parameters()), "output_head": torch.nn.ParameterList(self.head1.parameters()), } class UniFlowMatchConfidence(UniFlowMatch, PyTorchModelHubMixin): """ UniFlowMatch model with uncertainty estimation. """ def __init__( self, # Encoder configurations encoder_str: str, encoder_kwargs: Dict[str, Any], # Info sharing & output head structure configurations info_sharing_and_head_structure: str = "dual+single", # only dual+single is supported # Information sharing configurations info_sharing_str: str = "global_attention", info_sharing_kwargs: Dict[str, Any] = {}, # Prediction Heads & Adaptors head_type: str = "dpt", feature_head_kwargs: Dict[str, Any] = {}, adaptors_kwargs: Dict[str, Any] = {}, # Uncertainty Heads & Adaptors detach_uncertainty_head: bool = True, uncertainty_head_type: str = "dpt", uncertainty_head_kwargs: Dict[str, Any] = {}, uncertainty_adaptors_kwargs: Dict[str, Any] = {}, # Load Pretrained Weights pretrained_backbone_checkpoint_path: Optional[str] = None, pretrained_checkpoint_path: Optional[str] = None, # Inference Settings inference_resolution: Optional[Tuple[int, int]] = (560, 420), # WH *args, **kwargs, ): UniFlowMatch.__init__( self, encoder_str=encoder_str, encoder_kwargs=encoder_kwargs, info_sharing_and_head_structure=info_sharing_and_head_structure, info_sharing_str=info_sharing_str, info_sharing_kwargs=info_sharing_kwargs, head_type=head_type, feature_head_kwargs=feature_head_kwargs, adaptors_kwargs=adaptors_kwargs, pretrained_checkpoint_path=pretrained_backbone_checkpoint_path, inference_resolution=inference_resolution, *args, **kwargs, ) PyTorchModelHubMixin.__init__(self) # initialize uncertainty heads assert uncertainty_head_type == "dpt", "Only DPT is supported for uncertainty head now" self.uncertainty_head = self._initialize_prediction_heads( uncertainty_head_type, uncertainty_head_kwargs, uncertainty_adaptors_kwargs ) self.uncertainty_adaptors = self._initialize_adaptors(uncertainty_adaptors_kwargs) assert pretrained_checkpoint_path is None, "Pretrained weights are not supported for now" self.detach_uncertainty_head = detach_uncertainty_head def forward(self, view1, view2) -> UFMOutputInterface: """ Forward interface of correspondence prediction networks. Args: - view1 (Dict[str, Any]): Input view 1 - img (torch.Tensor): BCHW image tensor normalized according to encoder's data_norm_type - instance (List[int]): List of instance indices, or id of the input image - data_norm_type (str): Data normalization type, see uniception.models.encoders.IMAGE_NORMALIZATION_DICT - view2 (Dict[str, Any]): Input view 2 - (same structure as view1) Returns: - Dict[str, Any]: Output results - flow [Required] (Dict[str, torch.Tensor]): Flow output - [Required] flow_output (torch.Tensor): Flow output tensor, BCHW - [Optional] flow_covariance - [Optional] flow_covariance_inv - [Optional] flow_covariance_log_det - covisibility [Optional] (Dict[str, torch.Tensor]): Covisibiltiy output - [Optional] mask - [Optional] logits """ # Get input shapes _, _, height1, width1 = view1["img"].shape _, _, height2, width2 = view2["img"].shape shape1 = (int(height1), int(width1)) shape2 = (int(height2), int(width2)) # Encode the two images --> Each feat output: BCHW features (batch_size, feature_dim, feature_height, feature_width) feat1_list, feat2_list = self._encode_symmetrized(view1, view2, view1["symmetrized"]) # Pass the features through the info_sharing info_sharing_input = MultiViewTransformerInput(features=[feat1_list[-1], feat2_list[-1]]) final_info_sharing_multi_view_feat, intermediate_info_sharing_multi_view_feat = self.info_sharing( info_sharing_input ) info_sharing_outputs = { "1": [ feat1_list[-1].float().contiguous(), intermediate_info_sharing_multi_view_feat[0].features[0].float().contiguous(), intermediate_info_sharing_multi_view_feat[1].features[0].float().contiguous(), final_info_sharing_multi_view_feat.features[0].float().contiguous(), ], "2": [ feat2_list[-1].float().contiguous(), intermediate_info_sharing_multi_view_feat[0].features[1].float().contiguous(), intermediate_info_sharing_multi_view_feat[1].features[1].float().contiguous(), final_info_sharing_multi_view_feat.features[1].float().contiguous(), ], } info_sharing_outputs_detached = { "1": [ feat1_list[-1].detach().float().contiguous(), intermediate_info_sharing_multi_view_feat[0].features[0].detach().float().contiguous(), intermediate_info_sharing_multi_view_feat[1].features[0].detach().float().contiguous(), final_info_sharing_multi_view_feat.features[0].detach().float().contiguous(), ], "2": [ feat2_list[-1].detach().float().contiguous(), intermediate_info_sharing_multi_view_feat[0].features[1].detach().float().contiguous(), intermediate_info_sharing_multi_view_feat[1].features[1].detach().float().contiguous(), final_info_sharing_multi_view_feat.features[1].detach().float().contiguous(), ], } result = UFMOutputInterface() # The prediction need precision, so we disable any autocasting here with torch.autocast("cuda", torch.float32): # run the collected info_sharing features through the prediction heads head_output1 = self._downstream_head(1, info_sharing_outputs, shape1) head_output_uncertainty = self._downstream_head( "uncertainty", info_sharing_outputs_detached if self.detach_uncertainty_head else info_sharing_outputs, shape1, ) result.flow = UFMFlowFieldOutput( flow_output=head_output1["flow"].value, ) if "flow_cov" in head_output_uncertainty: result.flow.flow_covariance = head_output_uncertainty["flow_cov"].covariance result.flow.flow_covariance_inv = head_output_uncertainty["flow_cov"].inv_covariance result.flow.flow_covariance_log_det = head_output_uncertainty["flow_cov"].log_det if "keypoint_confidence" in head_output_uncertainty: result.keypoint_confidence = head_output_uncertainty["keypoint_confidence"].value.squeeze(1) if "non_occluded_mask" in head_output_uncertainty: result.covisibility = UFMMaskFieldOutput( mask=head_output_uncertainty["non_occluded_mask"].mask, logits=head_output_uncertainty["non_occluded_mask"].logits, ) return result def get_parameter_groups(self) -> Dict[str, torch.nn.ParameterList]: """ Get parameter groups for optimizer. This methods guides the optimizer to apply correct learning rate to different parts of the model. Returns: - Dict[str, torch.nn.ParameterList]: Parameter groups for optimizer """ return { "encoder": torch.nn.ParameterList(self.encoder.parameters()), "info_sharing": torch.nn.ParameterList(self.info_sharing.parameters()), "output_head": torch.nn.ParameterList(self.head1.parameters()), "uncertainty_head": torch.nn.ParameterList(self.uncertainty_head.parameters()), } def _downstream_head(self, head_num, decout, img_shape): "Run the respective prediction heads" # if self.info_sharing_and_head_structure == "dual+single": head = getattr(self, f"head{head_num}") if head_num != "uncertainty" else self.uncertainty_head head_num = head_num if head_num != "uncertainty" else 1 # uncertainty head is always from branch 1 if self.head_type == "linear": head_input = PredictionHeadInput(last_feature=decout[f"{head_num}"]) elif self.head_type in ["dpt", "moge_conv"]: head_input = PredictionHeadLayeredInput(list_features=decout[f"{head_num}"], target_output_shape=img_shape) return head(head_input) class UniFlowMatchClassificationRefinement(UniFlowMatch, PyTorchModelHubMixin): """ The variant of UniFlowMatch with local classification for refinement. """ def __init__( self, # Encoder configurations encoder_str: str, encoder_kwargs: Dict[str, Any], # Info sharing & output head structure configurations info_sharing_and_head_structure: str = "dual+single", # only dual+single is supported # Information sharing configurations info_sharing_str: str = "global_attention", info_sharing_kwargs: Dict[str, Any] = {}, # Prediction Heads & Adaptors head_type: str = "dpt", feature_head_kwargs: Dict[str, Any] = {}, adaptors_kwargs: Dict[str, Any] = {}, # Uncertainty Heads & Adaptors detach_uncertainty_head: bool = True, uncertainty_head_type: str = "dpt", uncertainty_head_kwargs: Dict[str, Any] = {}, uncertainty_adaptors_kwargs: Dict[str, Any] = {}, # Classification Heads & Adaptors temperature: float = 4, use_unet_feature: bool = False, classification_head_type: str = "patch_mlp", classification_head_kwargs: Dict[str, Any] = {}, feature_combine_method: str = "conv", # Refinement Range refinement_range: int = 5, # Load Pretrained Weights pretrained_backbone_checkpoint_path: Optional[str] = None, pretrained_checkpoint_path: Optional[str] = None, # Inference Settings inference_resolution: Optional[Tuple[int, int]] = (560, 420), # WH *args, **kwargs, ): UniFlowMatch.__init__( self, encoder_str=encoder_str, encoder_kwargs=encoder_kwargs, info_sharing_and_head_structure=info_sharing_and_head_structure, info_sharing_str=info_sharing_str, info_sharing_kwargs=info_sharing_kwargs, head_type=head_type, feature_head_kwargs=feature_head_kwargs, adaptors_kwargs=adaptors_kwargs, pretrained_checkpoint_path=pretrained_backbone_checkpoint_path, inference_resolution=inference_resolution, *args, **kwargs, ) PyTorchModelHubMixin.__init__(self) # initialize uncertainty heads assert classification_head_type == "patch_mlp", "Only DPT is supported for uncertainty head now" self.classification_head_type = classification_head_type self.classification_head = self._initialize_classification_head(classification_head_kwargs) self.refinement_range = refinement_range self.temperature = temperature assert pretrained_checkpoint_path is None, "Pretrained weights are not supported for now" self.use_unet_feature = use_unet_feature self.feature_combine_method = feature_combine_method # Unet experiment if self.use_unet_feature: self.unet_feature = UNet(in_channels=3, out_channels=16, features=[64, 128, 256, 512]) self.conv1 = nn.Conv2d(32, 32, kernel_size=1, stride=1, padding=0) if self.feature_combine_method == "conv": self.conv2 = nn.Conv2d(32, 16, kernel_size=1, stride=1, padding=0) elif self.feature_combine_method == "modulate": self.conv2 = nn.Conv2d(16, 16, kernel_size=1, stride=1, padding=0) default_attention_bias = torch.zeros(self.refinement_range * self.refinement_range) self.classification_bias = nn.Parameter(default_attention_bias) # initialize uncertainty heads if len(uncertainty_head_kwargs) > 0: assert uncertainty_head_type == "dpt", "Only DPT is supported for uncertainty head now" self.uncertainty_head = self._initialize_prediction_heads( uncertainty_head_type, uncertainty_head_kwargs, uncertainty_adaptors_kwargs ) self.uncertainty_adaptors = self._initialize_adaptors(uncertainty_adaptors_kwargs) assert pretrained_checkpoint_path is None, "Pretrained weights are not supported for now" self.detach_uncertainty_head = detach_uncertainty_head def forward(self, view1, view2) -> UFMOutputInterface: """ Forward interface of correspondence prediction networks. Args: - view1 (Dict[str, Any]): Input view 1 - img (torch.Tensor): BCHW image tensor normalized according to encoder's data_norm_type - instance (List[int]): List of instance indices, or id of the input image - data_norm_type (str): Data normalization type, see uniception.models.encoders.IMAGE_NORMALIZATION_DICT - view2 (Dict[str, Any]): Input view 2 - (same structure as view1) Returns: - Dict[str, Any]: Output results - flow [Required] (Dict[str, torch.Tensor]): Flow output - [Required] flow_output (torch.Tensor): Flow output tensor, BCHW - [Optional] flow_covariance - [Optional] flow_covariance_inv - [Optional] flow_covariance_log_det - covisibility [Optional] (Dict[str, torch.Tensor]): Covisibility output - [Optional] mask - [Optional] logits - classification [Optional]: Probability and targets of the classification head """ # Get input shapes _, _, height1, width1 = view1["img"].shape _, _, height2, width2 = view2["img"].shape shape1 = (int(height1), int(width1)) shape2 = (int(height2), int(width2)) # Encode the two images --> Each feat output: BCHW features (batch_size, feature_dim, feature_height, feature_width) feat1_list, feat2_list = self._encode_symmetrized(view1, view2, view1["symmetrized"]) # Pass the features through the info_sharing info_sharing_input = MultiViewTransformerInput(features=[feat1_list[-1], feat2_list[-1]]) final_info_sharing_multi_view_feat, intermediate_info_sharing_multi_view_feat = self.info_sharing( info_sharing_input ) info_sharing_outputs = { "1": [ feat1_list[-1].float().contiguous(), intermediate_info_sharing_multi_view_feat[0].features[0].float().contiguous(), intermediate_info_sharing_multi_view_feat[1].features[0].float().contiguous(), final_info_sharing_multi_view_feat.features[0].float().contiguous(), ], "2": [ feat2_list[-1].float().contiguous(), intermediate_info_sharing_multi_view_feat[0].features[1].float().contiguous(), intermediate_info_sharing_multi_view_feat[1].features[1].float().contiguous(), final_info_sharing_multi_view_feat.features[1].float().contiguous(), ], } info_sharing_outputs_detached = { "1": [ feat1_list[-1].detach().float().contiguous(), intermediate_info_sharing_multi_view_feat[0].features[0].detach().float().contiguous(), intermediate_info_sharing_multi_view_feat[1].features[0].detach().float().contiguous(), final_info_sharing_multi_view_feat.features[0].detach().float().contiguous(), ], "2": [ feat2_list[-1].detach().float().contiguous(), intermediate_info_sharing_multi_view_feat[0].features[1].detach().float().contiguous(), intermediate_info_sharing_multi_view_feat[1].features[1].detach().float().contiguous(), final_info_sharing_multi_view_feat.features[1].detach().float().contiguous(), ], } # optionally inference for U-Net Features if self.use_unet_feature: unet_feat1 = self.unet_feature(view1["img"]) unet_feat2 = self.unet_feature(view2["img"]) result = UFMOutputInterface() # The prediction need precision, so we disable any autocasting here with torch.autocast("cuda", torch.float32): # run the collected info_sharing features through the prediction heads head_output1 = self._downstream_head(1, info_sharing_outputs, shape1) flow_prediction = head_output1["flow"].value if hasattr(self, "uncertainty_head"): # run the uncertainty head head_output_uncertainty = self._downstream_head( "uncertainty", info_sharing_outputs_detached if self.detach_uncertainty_head else info_sharing_outputs, shape1, ) if "flow_cov" in head_output_uncertainty: result.flow.flow_covariance = head_output_uncertainty["flow_cov"].covariance result.flow.flow_covariance_inv = head_output_uncertainty["flow_cov"].inv_covariance result.flow.flow_covariance_log_det = head_output_uncertainty["flow_cov"].log_det if "keypoint_confidence" in head_output_uncertainty: result.keypoint_confidence = head_output_uncertainty["keypoint_confidence"].value.squeeze(1) if "non_occluded_mask" in head_output_uncertainty: result.covisibility = UFMMaskFieldOutput( mask=head_output_uncertainty["non_occluded_mask"].mask, logits=head_output_uncertainty["non_occluded_mask"].logits, ) # we run the classification head in the autocast environment bacause it is not regression if self.classification_head_type == "patch_mlp": # concatenate the last encoder feature with final info_sharing feature # use the first encoder feature, because it captures more low-level information, which is needed # for refinement of the regressed flow. classification_feat_1 = torch.cat( [feat1_list[0].float().contiguous(), info_sharing_outputs["1"][-1]], dim=1 ) classification_feat_2 = torch.cat( [feat2_list[0].float().contiguous(), info_sharing_outputs["2"][-1]], dim=1 ) classification_input = PredictionHeadInput( torch.cat([classification_feat_1, classification_feat_2], dim=0) ) classification_features = self.classification_head(classification_input).decoded_channels if self.use_unet_feature: if self.feature_combine_method == "conv": combined_features = torch.cat( [classification_features, torch.cat([unet_feat1, unet_feat2], dim=0)], dim=1 ) combined_features = self.conv1(combined_features) combined_features = nn.functional.relu(combined_features) combined_features = self.conv2(combined_features) elif self.feature_combine_method == "modulate": combined_features = classification_features * torch.tanh( torch.cat([unet_feat1, unet_feat2], dim=0) ) combined_features = self.conv2(combined_features) classification_features = combined_features classification_features0, classification_features1 = classification_features.chunk(2, dim=0) # refine the flow prediction with features from the classification head for i in range(1): residual, log_softmax_attention = self.classification_refinement( flow_prediction, classification_features ) flow_prediction = flow_prediction + residual # Fill in the result # WARNING: based on how the residual is computed, flow_prediction will have gradient cancelled by mathematics, # so there will be no supervision to the flow prediction at all. We need to use specialized loss function to # supervise the regression_flow_output. result.flow = UFMFlowFieldOutput( flow_output=flow_prediction, ) result.classification_refinement = UFMClassificationRefinementOutput( regression_flow_output=flow_prediction, residual=residual, log_softmax=log_softmax_attention, feature_map_0=classification_features0, feature_map_1=classification_features1, ) return result # @torch.compile() def classification_refinement(self, flow_prediction, classification_features) -> Dict[str, Any]: """ Use correlation between self feature and features around a local patch of the initial flow prediction to refine the flow prediction. """ classification_features1, classification_features2 = classification_features.chunk(2, dim=0) neighborhood_features, neighborhood_flow_residual = self.obtain_neighborhood_features( flow_estimation=flow_prediction, other_features=classification_features2, local_patch=self.refinement_range ) residual, log_softmax_attention = self.compute_refinement_attention( classification_features1, neighborhood_features, neighborhood_flow_residual ) return residual, log_softmax_attention def compute_refinement_attention(self, classification_features1, neighborhood_features, neighborhood_flow_residual): """ Compute the attention for the refinement, with special processing to fit """ B, C, H, W = classification_features1.shape P = self.refinement_range # reshape Q to B, H, W, 1, 1, C classification_features1 = classification_features1.permute(0, 2, 3, 1).reshape(B * H * W, 1, C) # reshape K to B, H, W, 1, P^2, C assert neighborhood_features.shape[0] == B assert neighborhood_features.shape[1] == H assert neighborhood_features.shape[2] == W assert neighborhood_features.shape[3] == P assert neighborhood_features.shape[4] == P assert neighborhood_features.shape[5] == C neighborhood_features = neighborhood_features.reshape(B * H * W, P * P, C) # reshape V to B, H, W, 1, P^2, 2 neighborhood_flow_residual = neighborhood_flow_residual.reshape(-1, P * P, 2) # compute the attention attention_score = ( torch.matmul(classification_features1, neighborhood_features.permute(0, 2, 1)) / self.temperature ) attention_score = attention_score + self.classification_bias attention = torch.nn.functional.softmax(attention_score, dim=-1) log_softmax_attention = torch.nn.functional.log_softmax(attention_score, dim=-1) # compute the weighted sum residual = torch.matmul(attention, neighborhood_flow_residual) # reshape the residual to B, H, W, 2, then B, 2, H, W residual = residual.reshape(B, H, W, 2).permute(0, 3, 1, 2) return residual, log_softmax_attention.reshape(B, H, W, P, P) def _downstream_head(self, head_num, decout, img_shape): "Run the respective prediction heads" # if self.info_sharing_and_head_structure == "dual+single": head = getattr(self, f"head{head_num}") if head_num != "uncertainty" else self.uncertainty_head head_num = head_num if head_num != "uncertainty" else 1 # uncertainty head is always from branch 1 if self.head_type == "linear": head_input = PredictionHeadInput(last_feature=decout[f"{head_num}"]) elif self.head_type in ["dpt", "moge_conv"]: head_input = PredictionHeadLayeredInput(list_features=decout[f"{head_num}"], target_output_shape=img_shape) return head(head_input) def obtain_neighborhood_features( self, flow_estimation: torch.Tensor, other_features: torch.Tensor, local_patch: int = 5 ) -> Tuple[torch.Tensor, torch.Tensor]: """ Query the other features according to flow estimation. """ assert local_patch % 2 == 1, "local_patch should be odd number" P = local_patch R = (P - 1) // 2 B, C, H, W = other_features.shape device = other_features.device # expected_output = torch.zeros(B, H, W, P, P, C, device=other_features.device, dtype=torch.float32) neighborhood_grid_ij: torch.Tensor i_local, j_local = torch.meshgrid( torch.arange(-R, R + 1, device=device), torch.arange(-R, R + 1, device=device), indexing="ij" ) ij_local = torch.stack((i_local, j_local), dim=0) # 2, P, P tensor # compute the indices of the fetch base_grid_xy = get_meshgrid_torch(W=W, H=H, device=device).permute(2, 0, 1).reshape(1, 2, H, W) target_coordinate_xy_float = flow_estimation + base_grid_xy target_coordinate_xy = target_coordinate_xy_float.view(B, 2, H, W, 1, 1) target_coordinate_ij = target_coordinate_xy[:, [1, 0], ...] # compute the neighborhood grid neighborhood_grid_ij = target_coordinate_ij + ij_local.view(1, 2, 1, 1, P, P) grid_for_sample = neighborhood_grid_ij[:, [1, 0], ...].permute(0, 2, 3, 4, 5, 1).reshape(B, H, W * P * P, 2) grid_for_sample = (grid_for_sample + 0.5) / torch.tensor([W, H], device=device).view(1, 1, 1, 2) grid_for_sample = grid_for_sample * 2 - 1 expected_output = torch.nn.functional.grid_sample( other_features, grid=grid_for_sample, mode="bicubic", padding_mode="zeros", align_corners=False ).view(B, C, H, W, P, P) # transform BCHWPP to BHWPPC expected_output = expected_output.permute(0, 2, 3, 4, 5, 1) neighborhood_grid_xy_residual = ij_local[[1, 0], ...].view(1, 2, 1, 1, P, P).to(device).float() neighborhood_grid_xy_residual = neighborhood_grid_xy_residual.permute(0, 2, 3, 4, 5, 1).float() return expected_output, neighborhood_grid_xy_residual def _initialize_classification_head(self, classification_head_kwargs: Dict[str, Any]): """ Initialize classification head Args: - classification_head_kwargs (Dict[str, Any]): Classification head configurations Returns: - nn.Module: Classification head """ if self.classification_head_type == "patch_mlp": return MLPFeature(**classification_head_kwargs) else: raise ValueError(f"Classification head type {self.classification_head_type} not supported.") def get_parameter_groups(self) -> Dict[str, torch.nn.ParameterList]: """ Get parameter groups for optimizer. This methods guides the optimizer to apply correct learning rate to different parts of the model. Returns: - Dict[str, torch.nn.ParameterList]: Parameter groups for optimizer """ if self.use_unet_feature: params_dict = { "encoder": torch.nn.ParameterList(self.encoder.parameters()), "info_sharing": torch.nn.ParameterList(self.info_sharing.parameters()), "output_head": torch.nn.ParameterList(self.head1.parameters()), "classification_head": torch.nn.ParameterList(self.classification_head.parameters()), "unet_feature": torch.nn.ParameterList( list(self.unet_feature.parameters()) + list(self.conv1.parameters()) + list(self.conv2.parameters()) + [self.classification_bias] ), } else: params_dict = { "encoder": torch.nn.ParameterList(self.encoder.parameters()), "info_sharing": torch.nn.ParameterList(self.info_sharing.parameters()), "output_head": torch.nn.ParameterList(self.head1.parameters()), "classification_head": torch.nn.ParameterList(self.classification_head.parameters()), } if hasattr(self, "uncertainty_head"): params_dict["uncertainty_head"] = torch.nn.ParameterList(self.uncertainty_head.parameters()) return params_dict if __name__ == "__main__": import cv2 import flow_vis import matplotlib.pyplot as plt import numpy as np import torch from uniflowmatch.utils.geometry import get_meshgrid_torch from uniflowmatch.utils.viz import warp_image_with_flow USE_REFINEMENT_MODEL = False if USE_REFINEMENT_MODEL: model = UniFlowMatchClassificationRefinement.from_pretrained("infinity1096/UFM-Refine") else: model = UniFlowMatchConfidence.from_pretrained("infinity1096/UFM-Base") # === Load and Prepare Images === source_path = "examples/image_pairs/fire_academy_0.png" target_path = "examples/image_pairs/fire_academy_1.png" source_image = cv2.imread(source_path) target_image = cv2.imread(target_path) source_image = cv2.cvtColor(source_image, cv2.COLOR_BGR2RGB) target_image = cv2.cvtColor(target_image, cv2.COLOR_BGR2RGB) # === Predict Correspondences === result = model.predict_correspondences_batched( source_image=torch.from_numpy(source_image), target_image=torch.from_numpy(target_image), ) flow_output = result.flow.flow_output[0].cpu().numpy() covisibility = result.covisibility.mask[0].cpu().numpy() # === Visualize Results === fig, axs = plt.subplots(2, 3, figsize=(15, 5)) axs[0, 0].imshow(source_image) axs[0, 0].set_title("Source Image") axs[0, 1].imshow(target_image) axs[0, 1].set_title("Target Image") # Warp the image using flow warped_image = warp_image_with_flow(source_image, None, target_image, flow_output.transpose(1, 2, 0)) warped_image = covisibility[..., None] * warped_image + (1 - covisibility[..., None]) * 255 * np.ones_like( warped_image ) warped_image /= 255.0 axs[0, 2].imshow(warped_image) axs[0, 2].set_title("Warped Image") # Flow visualization flow_vis_image = flow_vis.flow_to_color(flow_output.transpose(1, 2, 0)) axs[1, 0].imshow(flow_vis_image) axs[1, 0].set_title("Flow Output (Valid at covisible region)") # Covisibility mask axs[1, 1].imshow(covisibility > 0.5, cmap="gray", vmin=0, vmax=1) axs[1, 1].set_title("Covisibility Mask (Thresholded by 0.5)") heatmap = axs[1, 2].imshow(covisibility, cmap="gray", vmin=0, vmax=1) axs[1, 2].set_title("Covisibility Mask") plt.colorbar(heatmap, ax=axs[1, 2]) plt.tight_layout() plt.savefig("ufm_output.png") plt.show() print("Saved ufm_output.png")