Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import torch | |
| from huggingface_hub import PyTorchModelHubMixin | |
| from torch import nn | |
| # Only enable flash attention backend | |
| from uniception.models.encoders import ViTEncoderInput, feature_returner_encoder_factory | |
| from uniception.models.info_sharing import INFO_SHARING_CLASSES, MultiViewTransformerInput | |
| from uniception.models.prediction_heads.adaptors import ( | |
| ConfidenceAdaptor, | |
| Covariance2DAdaptor, | |
| FlowAdaptor, | |
| FlowWithConfidenceAdaptor, | |
| MaskAdaptor, | |
| ) | |
| from uniception.models.prediction_heads.base import AdaptorMap, PredictionHeadInput, PredictionHeadLayeredInput | |
| from uniception.models.prediction_heads.dpt import DPTFeature, DPTRegressionProcessor | |
| from uniception.models.prediction_heads.mlp_feature import MLPFeature | |
| from uniception.models.prediction_heads.moge_conv import MoGeConvFeature | |
| from uniflowmatch.models.base import ( | |
| UFMClassificationRefinementOutput, | |
| UFMFlowFieldOutput, | |
| UFMMaskFieldOutput, | |
| UFMOutputInterface, | |
| UniFlowMatchModelsBase, | |
| ) | |
| from uniflowmatch.models.unet_encoder import UNet | |
| from uniflowmatch.models.utils import get_meshgrid_torch | |
| CLASSNAME_TO_ADAPTOR_CLASS = { | |
| "FlowWithConfidenceAdaptor": FlowWithConfidenceAdaptor, | |
| "FlowAdaptor": FlowAdaptor, | |
| "MaskAdaptor": MaskAdaptor, | |
| "Covariance2DAdaptor": Covariance2DAdaptor, | |
| "ConfidenceAdaptor": ConfidenceAdaptor, | |
| } | |
| # dust3r data structure for reducing passing duplicate images through the encoder | |
| def is_symmetrized(gt1, gt2): | |
| "Function to check if input pairs are symmetrized, i.e., (a, b) and (b, a) always exist in the input" | |
| x = gt1["instance"] | |
| y = gt2["instance"] | |
| if len(x) == len(y) and len(x) == 1: | |
| return False # special case of batchsize 1 | |
| ok = True | |
| for i in range(0, len(x), 2): | |
| ok = ok and (x[i] == y[i + 1]) and (x[i + 1] == y[i]) | |
| return ok | |
| def interleave(tensor1, tensor2): | |
| "Interleave two tensors along the first dimension (used to avoid redundant encoding for symmetrized pairs)" | |
| res1 = torch.stack((tensor1, tensor2), dim=1).flatten(0, 1) | |
| res2 = torch.stack((tensor2, tensor1), dim=1).flatten(0, 1) | |
| return res1, res2 | |
| def modify_state_dict(original_state_dict, mappings): | |
| """ | |
| Modify state dict keys according to replacement mappings | |
| Args: | |
| original_state_dict: Loaded checkpoint state dict | |
| mappings: Dictionary of {old_key_substr: new_key_substr_or_None} | |
| Returns: | |
| Modified state dictionary with updated keys | |
| """ | |
| new_state_dict = {} | |
| for k, v in original_state_dict.items(): | |
| new_key = None | |
| skip = False | |
| # Check for all possible replacements | |
| for replace_key, replace_value in mappings.items(): | |
| if replace_key in k: | |
| if replace_value is None: | |
| skip = True | |
| break # Skip this key entirely | |
| else: | |
| new_key = k.replace(replace_key, replace_value) | |
| break # Only apply first matching replacement | |
| if skip: | |
| continue | |
| new_state_dict[new_key if new_key is not None else k] = v | |
| return new_state_dict | |
| class UniFlowMatch(UniFlowMatchModelsBase, PyTorchModelHubMixin): | |
| """ | |
| UniFlowMatch model. | |
| """ | |
| def __init__( | |
| self, | |
| # Encoder configurations | |
| encoder_str: str, | |
| encoder_kwargs: Dict[str, Any], | |
| # Info sharing & output head structure configurations | |
| info_sharing_and_head_structure: str = "dual+single", # only dual+single is supported | |
| # Information sharing configurations | |
| info_sharing_str: str = "global_attention", | |
| info_sharing_kwargs: Dict[str, Any] = {}, | |
| # skip-connections between encoder and info-sharing | |
| encoder_skip_connection: Optional[List[int]] = None, | |
| info_sharing_skip_connection: Optional[List[int]] = None, | |
| # Prediction Heads & Adaptors | |
| head_type: str = "dpt", | |
| feature_head_kwargs: Dict[str, Any] = {}, | |
| adaptors_kwargs: Dict[str, Any] = {}, | |
| # Load Pretrained Weights | |
| pretrained_checkpoint_path: Optional[str] = None, | |
| # Inference Settings | |
| inference_resolution: Optional[Tuple[int, int]] = (560, 420), # WH | |
| *args, | |
| **kwargs, | |
| ): | |
| """ | |
| Initialize the UniFlowMarch Model | |
| - encoder_str (str): Encoder string | |
| - encoder_kwargs (Dict[str, Any]): Encoder configurations | |
| - info_sharing_and_head_structure (str): Info sharing and head structure configurations | |
| - "dual+single": Dual view info sharing and single view prediction head | |
| - info_sharing_str (str): Info sharing method | |
| - "global_attention_transformer": Global attention transformer | |
| - info_sharing_kwargs (Dict[str, Any]): Info sharing configurations | |
| """ | |
| UniFlowMatchModelsBase.__init__(self, inference_resolution=inference_resolution, *args, **kwargs) | |
| PyTorchModelHubMixin.__init__(self) | |
| # assertion on architectures | |
| assert info_sharing_and_head_structure == "dual+single", "Only dual+single is supported now" | |
| # initialize the skip-connections | |
| self.encoder_skip_connection = encoder_skip_connection | |
| self.info_sharing_skip_connection = info_sharing_skip_connection | |
| # initialize encoder | |
| self.encoder: nn.Module = feature_returner_encoder_factory(encoder_str, **encoder_kwargs) | |
| # initialize info-sharing module | |
| assert head_type != "linear", "Linear head is not supported, because it have major disadvantage to DPTs" | |
| self.head_type = head_type | |
| self.info_sharing: nn.Module = INFO_SHARING_CLASSES[info_sharing_str][1](**info_sharing_kwargs) | |
| self.head1: nn.Module = self._initialize_prediction_heads(head_type, feature_head_kwargs, adaptors_kwargs) | |
| # load pretrained weights | |
| if pretrained_checkpoint_path is not None: | |
| ckpt = torch.load(pretrained_checkpoint_path, map_location="cpu") | |
| if "state_dict" in ckpt: | |
| # we are loading from training checkpoint directly. | |
| model_state_dict = ckpt["state_dict"] | |
| model_state_dict = { | |
| k[6:]: v for k, v in model_state_dict.items() if k.startswith("model.") | |
| } # remove "model." prefix | |
| model_state_dict = modify_state_dict( | |
| model_state_dict, {"feature_matching_proj": None, "encoder.model.mask_token": None} | |
| ) | |
| self.load_state_dict(model_state_dict, strict=True) | |
| else: | |
| model_state_dict = ckpt["model"] | |
| load_result = self.load_state_dict(model_state_dict, strict=False) | |
| assert len(load_result.missing_keys) == 0, f"Missing keys: {load_result.missing_keys}" | |
| def from_pretrained_ckpt(cls, pretrained_model_name_or_path, strict=True, **kw): | |
| if os.path.isfile(pretrained_model_name_or_path): | |
| ckpt = torch.load(pretrained_model_name_or_path, map_location="cpu") | |
| # remove base_pretrained_checkpoint_path from the model args | |
| if "base_pretrained_checkpoint_path" in ckpt["model_args"]: | |
| ckpt["model_args"].pop("base_pretrained_checkpoint_path") | |
| # convert old model args into new definition | |
| if "img_size" in ckpt["model_args"]: | |
| # we are loading from a old benchmark checkpoint | |
| print("Converting from a old benchmark checkpoint") | |
| model_args = { | |
| # Encoder args | |
| "encoder_str": ckpt["model_args"]["encoder_str"], | |
| "encoder_kwargs": ckpt["model_args"]["encoder_kwargs"], | |
| # Info-sharing args | |
| "info_sharing_and_head_structure": "dual+single", | |
| "info_sharing_str": ckpt["model_args"]["info_sharing_type"], | |
| "info_sharing_kwargs": { | |
| "name": "info_sharing", | |
| "input_embed_dim": ckpt["model_args"]["input_embed_dim"], | |
| "num_views": 2, | |
| "use_rand_idx_pe_for_non_reference_views": False, | |
| "depth": ckpt["model_args"]["num_layers"], | |
| "dim": ckpt["model_args"]["transformer_dim"], | |
| "num_heads": ckpt["model_args"]["num_heads"], | |
| "mlp_ratio": ckpt["model_args"]["mlp_ratio"], | |
| "qkv_bias": ckpt["model_args"]["qkv_bias"], | |
| "qk_norm": ckpt["model_args"]["qk_norm"], | |
| "custom_positional_encoding": ckpt["model_args"]["position_encoding"], | |
| "norm_intermediate": ckpt["model_args"]["normalize_intermediate"], | |
| "indices": ckpt["model_args"]["returned_intermediate_layers"], | |
| }, | |
| # flow head args | |
| "head_type": "dpt", | |
| "feature_head_kwargs": ckpt["model_args"]["feature_head_kwargs"], | |
| "adaptors_kwargs": ckpt["model_args"]["adaptors_kwargs"], | |
| } | |
| if "covocc_feature_head_kwargs" in ckpt["model_args"]: | |
| # if the model has a covocc head, we need to convert it to the new format | |
| model_args["uncertainty_head_type"] = "dpt" | |
| model_args["uncertainty_head_kwargs"] = { | |
| "dpt_feature": ckpt["model_args"]["covocc_feature_head_kwargs"]["dpt_feature"], | |
| "dpt_processor": ckpt["model_args"]["covocc_feature_head_kwargs"]["dpt_regr_processor"], | |
| } | |
| model_args["uncertainty_adaptors_kwargs"] = { | |
| "flow_cov": ckpt["model_args"]["covocc_adaptors_kwargs"]["flow_cov"] | |
| } | |
| ckpt["model_args"] = model_args | |
| # Update the old weights into the current format | |
| ckpt["model"] = modify_state_dict( | |
| ckpt["model"], | |
| { | |
| "covocc_head.dpt_feature": "uncertainty_head.0.0", | |
| "covocc_head.dpt_regr_processor": "uncertainty_head.0.1", | |
| "covocc_head.dpt_segm_processor": None, | |
| "feature_matching_proj": None, | |
| "encoder.model.mask_token": None, | |
| }, | |
| ) | |
| # remove the ket "pretrained_backbone_checkpoint_path" from the model args | |
| if "pretrained_backbone_checkpoint_path" in ckpt["model_args"]: | |
| ckpt["model_args"].pop("pretrained_backbone_checkpoint_path") | |
| model = cls(**ckpt["model_args"]) | |
| model.load_state_dict(ckpt["model"], strict=strict) | |
| return model | |
| else: | |
| raise ValueError(f"Pretrained model {pretrained_model_name_or_path} not found.") | |
| def _initialize_prediction_heads( | |
| self, head_type: str, feature_head_kwargs: Dict[str, Any], adaptors_kwargs: Dict[str, Any] | |
| ): | |
| """ | |
| Initialize prediction heads and adaptors | |
| Args: | |
| - head_type (str): Head type, either "dpt" or "linear" | |
| - feature_head_kwargs (Dict[str, Any]): Feature head configurations | |
| - adaptors_kwargs (Dict[str, Any]): Adaptors configurations | |
| Returns: | |
| - nn.Module: output head + adaptors | |
| """ | |
| feature_processor: nn.Module | |
| if head_type == "dpt": | |
| feature_processor = nn.Sequential( | |
| DPTFeature(**feature_head_kwargs["dpt_feature"]), | |
| DPTRegressionProcessor(**feature_head_kwargs["dpt_processor"]), | |
| ) | |
| elif head_type == "moge_conv": | |
| feature_processor = MoGeConvFeature(**feature_head_kwargs) | |
| else: | |
| raise ValueError(f"Head type {head_type} not supported.") | |
| adaptors = self._initialize_adaptors(adaptors_kwargs) | |
| return nn.Sequential(feature_processor, AdaptorMap(*adaptors.values())) | |
| def _initialize_adaptors(self, adaptors_kwargs: Dict[str, Any]): | |
| """ | |
| Initialize a dict of adaptors | |
| Args: | |
| - adaptors_kwargs (Dict[str, Any]): Adaptors configurations | |
| Returns: | |
| - Dict[str, nn.Module]: dict of adaptors, from adaptor's name to the adaptor | |
| """ | |
| return { | |
| name: CLASSNAME_TO_ADAPTOR_CLASS[configs["class"]](**configs["kwargs"]) | |
| for name, configs in adaptors_kwargs.items() | |
| } | |
| def _encode_image_pairs(self, img1, img2, data_norm_type): | |
| "Encode two different batches of images (each batch can have different image shape)" | |
| if img1.shape[-2:] == img2.shape[-2:]: | |
| encoder_input = ViTEncoderInput(image=torch.cat((img1, img2), dim=0), data_norm_type=data_norm_type) | |
| encoder_output = self.encoder(encoder_input) | |
| out_list, out2_list = [], [] | |
| for encoder_output_ in encoder_output: | |
| out, out2 = encoder_output_.features.chunk(2, dim=0) | |
| out_list.append(out) | |
| out2_list.append(out2) | |
| else: | |
| raise NotImplementedError("Unequal Image sizes are not supported now") | |
| return out_list, out2_list | |
| def _encode_symmetrized(self, view1, view2, symmetrized=False): | |
| "Encode image pairs accounting for symmetrization, i.e., (a, b) and (b, a) always exist in the input" | |
| img1 = view1["img"] | |
| img2 = view2["img"] | |
| feat1_list, feat2_list = [], [] | |
| if symmetrized: | |
| # Computing half of forward pass' | |
| # modified in conjunction with UFM for not copying the images again. | |
| # used to be: feat1, feat2 = self._encode_image_pairs(img1[::2], img2[::2], data_norm_type=view1["data_norm_type"]) | |
| # be very carefult with this!!! | |
| feat1_list_, feat2_list_ = self._encode_image_pairs( | |
| img1[::2], img2[::2], data_norm_type=view1["data_norm_type"] | |
| ) | |
| for feat1, feat2 in zip(feat1_list_, feat2_list_): | |
| feat1, feat2 = interleave(feat1, feat2) | |
| feat1_list.append(feat1) | |
| feat2_list.append(feat2) | |
| else: | |
| feat1_list, feat2_list = self._encode_image_pairs(img1, img2, data_norm_type=view1["data_norm_type"]) | |
| return feat1_list, feat2_list | |
| def forward(self, view1, view2) -> UFMOutputInterface: | |
| """ | |
| Forward interface of correspondence prediction networks. | |
| Args: | |
| - view1 (Dict[str, Any]): Input view 1 | |
| - img (torch.Tensor): BCHW image tensor normalized according to encoder's data_norm_type | |
| - instance (List[int]): List of instance indices, or id of the input image | |
| - data_norm_type (str): Data normalization type, see uniception.models.encoders.IMAGE_NORMALIZATION_DICT | |
| - view2 (Dict[str, Any]): Input view 2 | |
| - (same structure as view1) | |
| Returns: | |
| - Dict[str, Any]: Output results | |
| - flow [Required] (Dict[str, torch.Tensor]): Flow output | |
| - [Required] flow_output (torch.Tensor): Flow output tensor, BCHW | |
| - [Optional] flow_covariance | |
| - [Optional] flow_covariance_inv | |
| - [Optional] flow_covariance_log_det | |
| - covisibility [Optional] (Dict[str, torch.Tensor]): Covisibility output | |
| - [Optional] mask | |
| - [Optional] logits | |
| """ | |
| # Get input shapes | |
| _, _, height1, width1 = view1["img"].shape | |
| _, _, height2, width2 = view2["img"].shape | |
| shape1 = (int(height1), int(width1)) | |
| shape2 = (int(height2), int(width2)) | |
| # Encode the two images --> Each feat output: BCHW features (batch_size, feature_dim, feature_height, feature_width) | |
| feat1_list, feat2_list = self._encode_symmetrized(view1, view2, view1["symmetrized"]) | |
| # Pass the features through the info_sharing | |
| info_sharing_input = MultiViewTransformerInput(features=[feat1_list[-1], feat2_list[-1]]) | |
| final_info_sharing_multi_view_feat, intermediate_info_sharing_multi_view_feat = self.info_sharing( | |
| info_sharing_input | |
| ) | |
| info_sharing_outputs = { | |
| "1": [ | |
| feat1_list[-1].float().contiguous(), | |
| intermediate_info_sharing_multi_view_feat[0].features[0].float().contiguous(), | |
| intermediate_info_sharing_multi_view_feat[1].features[0].float().contiguous(), | |
| final_info_sharing_multi_view_feat.features[0].float().contiguous(), | |
| ], | |
| "2": [ | |
| feat2_list[-1].float().contiguous(), | |
| intermediate_info_sharing_multi_view_feat[0].features[1].float().contiguous(), | |
| intermediate_info_sharing_multi_view_feat[1].features[1].float().contiguous(), | |
| final_info_sharing_multi_view_feat.features[1].float().contiguous(), | |
| ], | |
| } | |
| result = UFMOutputInterface() | |
| # The prediction need precision, so we disable any autocasting here | |
| with torch.autocast("cuda", torch.float32): | |
| # run the collected info_sharing features through the prediction heads | |
| head_output1 = self._downstream_head(1, info_sharing_outputs, shape1) | |
| if "flow" in head_output1: | |
| # output is flow only | |
| result.flow = UFMFlowFieldOutput(flow_output=head_output1["flow"].value) | |
| if "flow_cov" in head_output1: | |
| result.flow.flow_covariance = head_output1["flow_cov"].covariance | |
| result.flow.flow_covariance_inv = head_output1["flow_cov"].inv_covariance | |
| result.flow.flow_covariance_log_det = head_output1["flow_cov"].log_det | |
| if "non_occluded_mask" in head_output1: | |
| result.covisibility = UFMMaskFieldOutput( | |
| mask=head_output1["non_occluded_mask"].mask, | |
| logits=head_output1["non_occluded_mask"].logits, | |
| ) | |
| return result | |
| def _downstream_head(self, head_num, decout, img_shape): | |
| "Run the respective prediction heads" | |
| # if self.info_sharing_and_head_structure == "dual+single": | |
| head = getattr(self, f"head{head_num}") | |
| if self.head_type == "linear": | |
| head_input = PredictionHeadInput(last_feature=decout[f"{head_num}"]) | |
| elif self.head_type in ["dpt", "moge_conv"]: | |
| head_input = PredictionHeadLayeredInput(list_features=decout[f"{head_num}"], target_output_shape=img_shape) | |
| return head(head_input) | |
| def get_parameter_groups(self) -> Dict[str, torch.nn.ParameterList]: | |
| """ | |
| Get parameter groups for optimizer. This methods guides the optimizer | |
| to apply correct learning rate to different parts of the model. | |
| Returns: | |
| - Dict[str, torch.nn.ParameterList]: Parameter groups for optimizer | |
| """ | |
| return { | |
| "encoder": torch.nn.ParameterList(self.encoder.parameters()), | |
| "info_sharing": torch.nn.ParameterList(self.info_sharing.parameters()), | |
| "output_head": torch.nn.ParameterList(self.head1.parameters()), | |
| } | |
| class UniFlowMatchConfidence(UniFlowMatch, PyTorchModelHubMixin): | |
| """ | |
| UniFlowMatch model with uncertainty estimation. | |
| """ | |
| def __init__( | |
| self, | |
| # Encoder configurations | |
| encoder_str: str, | |
| encoder_kwargs: Dict[str, Any], | |
| # Info sharing & output head structure configurations | |
| info_sharing_and_head_structure: str = "dual+single", # only dual+single is supported | |
| # Information sharing configurations | |
| info_sharing_str: str = "global_attention", | |
| info_sharing_kwargs: Dict[str, Any] = {}, | |
| # Prediction Heads & Adaptors | |
| head_type: str = "dpt", | |
| feature_head_kwargs: Dict[str, Any] = {}, | |
| adaptors_kwargs: Dict[str, Any] = {}, | |
| # Uncertainty Heads & Adaptors | |
| detach_uncertainty_head: bool = True, | |
| uncertainty_head_type: str = "dpt", | |
| uncertainty_head_kwargs: Dict[str, Any] = {}, | |
| uncertainty_adaptors_kwargs: Dict[str, Any] = {}, | |
| # Load Pretrained Weights | |
| pretrained_backbone_checkpoint_path: Optional[str] = None, | |
| pretrained_checkpoint_path: Optional[str] = None, | |
| # Inference Settings | |
| inference_resolution: Optional[Tuple[int, int]] = (560, 420), # WH | |
| *args, | |
| **kwargs, | |
| ): | |
| UniFlowMatch.__init__( | |
| self, | |
| encoder_str=encoder_str, | |
| encoder_kwargs=encoder_kwargs, | |
| info_sharing_and_head_structure=info_sharing_and_head_structure, | |
| info_sharing_str=info_sharing_str, | |
| info_sharing_kwargs=info_sharing_kwargs, | |
| head_type=head_type, | |
| feature_head_kwargs=feature_head_kwargs, | |
| adaptors_kwargs=adaptors_kwargs, | |
| pretrained_checkpoint_path=pretrained_backbone_checkpoint_path, | |
| inference_resolution=inference_resolution, | |
| *args, | |
| **kwargs, | |
| ) | |
| PyTorchModelHubMixin.__init__(self) | |
| # initialize uncertainty heads | |
| assert uncertainty_head_type == "dpt", "Only DPT is supported for uncertainty head now" | |
| self.uncertainty_head = self._initialize_prediction_heads( | |
| uncertainty_head_type, uncertainty_head_kwargs, uncertainty_adaptors_kwargs | |
| ) | |
| self.uncertainty_adaptors = self._initialize_adaptors(uncertainty_adaptors_kwargs) | |
| assert pretrained_checkpoint_path is None, "Pretrained weights are not supported for now" | |
| self.detach_uncertainty_head = detach_uncertainty_head | |
| def forward(self, view1, view2) -> UFMOutputInterface: | |
| """ | |
| Forward interface of correspondence prediction networks. | |
| Args: | |
| - view1 (Dict[str, Any]): Input view 1 | |
| - img (torch.Tensor): BCHW image tensor normalized according to encoder's data_norm_type | |
| - instance (List[int]): List of instance indices, or id of the input image | |
| - data_norm_type (str): Data normalization type, see uniception.models.encoders.IMAGE_NORMALIZATION_DICT | |
| - view2 (Dict[str, Any]): Input view 2 | |
| - (same structure as view1) | |
| Returns: | |
| - Dict[str, Any]: Output results | |
| - flow [Required] (Dict[str, torch.Tensor]): Flow output | |
| - [Required] flow_output (torch.Tensor): Flow output tensor, BCHW | |
| - [Optional] flow_covariance | |
| - [Optional] flow_covariance_inv | |
| - [Optional] flow_covariance_log_det | |
| - covisibility [Optional] (Dict[str, torch.Tensor]): Covisibiltiy output | |
| - [Optional] mask | |
| - [Optional] logits | |
| """ | |
| # Get input shapes | |
| _, _, height1, width1 = view1["img"].shape | |
| _, _, height2, width2 = view2["img"].shape | |
| shape1 = (int(height1), int(width1)) | |
| shape2 = (int(height2), int(width2)) | |
| # Encode the two images --> Each feat output: BCHW features (batch_size, feature_dim, feature_height, feature_width) | |
| feat1_list, feat2_list = self._encode_symmetrized(view1, view2, view1["symmetrized"]) | |
| # Pass the features through the info_sharing | |
| info_sharing_input = MultiViewTransformerInput(features=[feat1_list[-1], feat2_list[-1]]) | |
| final_info_sharing_multi_view_feat, intermediate_info_sharing_multi_view_feat = self.info_sharing( | |
| info_sharing_input | |
| ) | |
| info_sharing_outputs = { | |
| "1": [ | |
| feat1_list[-1].float().contiguous(), | |
| intermediate_info_sharing_multi_view_feat[0].features[0].float().contiguous(), | |
| intermediate_info_sharing_multi_view_feat[1].features[0].float().contiguous(), | |
| final_info_sharing_multi_view_feat.features[0].float().contiguous(), | |
| ], | |
| "2": [ | |
| feat2_list[-1].float().contiguous(), | |
| intermediate_info_sharing_multi_view_feat[0].features[1].float().contiguous(), | |
| intermediate_info_sharing_multi_view_feat[1].features[1].float().contiguous(), | |
| final_info_sharing_multi_view_feat.features[1].float().contiguous(), | |
| ], | |
| } | |
| info_sharing_outputs_detached = { | |
| "1": [ | |
| feat1_list[-1].detach().float().contiguous(), | |
| intermediate_info_sharing_multi_view_feat[0].features[0].detach().float().contiguous(), | |
| intermediate_info_sharing_multi_view_feat[1].features[0].detach().float().contiguous(), | |
| final_info_sharing_multi_view_feat.features[0].detach().float().contiguous(), | |
| ], | |
| "2": [ | |
| feat2_list[-1].detach().float().contiguous(), | |
| intermediate_info_sharing_multi_view_feat[0].features[1].detach().float().contiguous(), | |
| intermediate_info_sharing_multi_view_feat[1].features[1].detach().float().contiguous(), | |
| final_info_sharing_multi_view_feat.features[1].detach().float().contiguous(), | |
| ], | |
| } | |
| result = UFMOutputInterface() | |
| # The prediction need precision, so we disable any autocasting here | |
| with torch.autocast("cuda", torch.float32): | |
| # run the collected info_sharing features through the prediction heads | |
| head_output1 = self._downstream_head(1, info_sharing_outputs, shape1) | |
| head_output_uncertainty = self._downstream_head( | |
| "uncertainty", | |
| info_sharing_outputs_detached if self.detach_uncertainty_head else info_sharing_outputs, | |
| shape1, | |
| ) | |
| result.flow = UFMFlowFieldOutput( | |
| flow_output=head_output1["flow"].value, | |
| ) | |
| if "flow_cov" in head_output_uncertainty: | |
| result.flow.flow_covariance = head_output_uncertainty["flow_cov"].covariance | |
| result.flow.flow_covariance_inv = head_output_uncertainty["flow_cov"].inv_covariance | |
| result.flow.flow_covariance_log_det = head_output_uncertainty["flow_cov"].log_det | |
| if "keypoint_confidence" in head_output_uncertainty: | |
| result.keypoint_confidence = head_output_uncertainty["keypoint_confidence"].value.squeeze(1) | |
| if "non_occluded_mask" in head_output_uncertainty: | |
| result.covisibility = UFMMaskFieldOutput( | |
| mask=head_output_uncertainty["non_occluded_mask"].mask, | |
| logits=head_output_uncertainty["non_occluded_mask"].logits, | |
| ) | |
| return result | |
| def get_parameter_groups(self) -> Dict[str, torch.nn.ParameterList]: | |
| """ | |
| Get parameter groups for optimizer. This methods guides the optimizer | |
| to apply correct learning rate to different parts of the model. | |
| Returns: | |
| - Dict[str, torch.nn.ParameterList]: Parameter groups for optimizer | |
| """ | |
| return { | |
| "encoder": torch.nn.ParameterList(self.encoder.parameters()), | |
| "info_sharing": torch.nn.ParameterList(self.info_sharing.parameters()), | |
| "output_head": torch.nn.ParameterList(self.head1.parameters()), | |
| "uncertainty_head": torch.nn.ParameterList(self.uncertainty_head.parameters()), | |
| } | |
| def _downstream_head(self, head_num, decout, img_shape): | |
| "Run the respective prediction heads" | |
| # if self.info_sharing_and_head_structure == "dual+single": | |
| head = getattr(self, f"head{head_num}") if head_num != "uncertainty" else self.uncertainty_head | |
| head_num = head_num if head_num != "uncertainty" else 1 # uncertainty head is always from branch 1 | |
| if self.head_type == "linear": | |
| head_input = PredictionHeadInput(last_feature=decout[f"{head_num}"]) | |
| elif self.head_type in ["dpt", "moge_conv"]: | |
| head_input = PredictionHeadLayeredInput(list_features=decout[f"{head_num}"], target_output_shape=img_shape) | |
| return head(head_input) | |
| class UniFlowMatchClassificationRefinement(UniFlowMatch, PyTorchModelHubMixin): | |
| """ | |
| The variant of UniFlowMatch with local classification for refinement. | |
| """ | |
| def __init__( | |
| self, | |
| # Encoder configurations | |
| encoder_str: str, | |
| encoder_kwargs: Dict[str, Any], | |
| # Info sharing & output head structure configurations | |
| info_sharing_and_head_structure: str = "dual+single", # only dual+single is supported | |
| # Information sharing configurations | |
| info_sharing_str: str = "global_attention", | |
| info_sharing_kwargs: Dict[str, Any] = {}, | |
| # Prediction Heads & Adaptors | |
| head_type: str = "dpt", | |
| feature_head_kwargs: Dict[str, Any] = {}, | |
| adaptors_kwargs: Dict[str, Any] = {}, | |
| # Uncertainty Heads & Adaptors | |
| detach_uncertainty_head: bool = True, | |
| uncertainty_head_type: str = "dpt", | |
| uncertainty_head_kwargs: Dict[str, Any] = {}, | |
| uncertainty_adaptors_kwargs: Dict[str, Any] = {}, | |
| # Classification Heads & Adaptors | |
| temperature: float = 4, | |
| use_unet_feature: bool = False, | |
| classification_head_type: str = "patch_mlp", | |
| classification_head_kwargs: Dict[str, Any] = {}, | |
| feature_combine_method: str = "conv", | |
| # Refinement Range | |
| refinement_range: int = 5, | |
| # Load Pretrained Weights | |
| pretrained_backbone_checkpoint_path: Optional[str] = None, | |
| pretrained_checkpoint_path: Optional[str] = None, | |
| # Inference Settings | |
| inference_resolution: Optional[Tuple[int, int]] = (560, 420), # WH | |
| *args, | |
| **kwargs, | |
| ): | |
| UniFlowMatch.__init__( | |
| self, | |
| encoder_str=encoder_str, | |
| encoder_kwargs=encoder_kwargs, | |
| info_sharing_and_head_structure=info_sharing_and_head_structure, | |
| info_sharing_str=info_sharing_str, | |
| info_sharing_kwargs=info_sharing_kwargs, | |
| head_type=head_type, | |
| feature_head_kwargs=feature_head_kwargs, | |
| adaptors_kwargs=adaptors_kwargs, | |
| pretrained_checkpoint_path=pretrained_backbone_checkpoint_path, | |
| inference_resolution=inference_resolution, | |
| *args, | |
| **kwargs, | |
| ) | |
| PyTorchModelHubMixin.__init__(self) | |
| # initialize uncertainty heads | |
| assert classification_head_type == "patch_mlp", "Only DPT is supported for uncertainty head now" | |
| self.classification_head_type = classification_head_type | |
| self.classification_head = self._initialize_classification_head(classification_head_kwargs) | |
| self.refinement_range = refinement_range | |
| self.temperature = temperature | |
| assert pretrained_checkpoint_path is None, "Pretrained weights are not supported for now" | |
| self.use_unet_feature = use_unet_feature | |
| self.feature_combine_method = feature_combine_method | |
| # Unet experiment | |
| if self.use_unet_feature: | |
| self.unet_feature = UNet(in_channels=3, out_channels=16, features=[64, 128, 256, 512]) | |
| self.conv1 = nn.Conv2d(32, 32, kernel_size=1, stride=1, padding=0) | |
| if self.feature_combine_method == "conv": | |
| self.conv2 = nn.Conv2d(32, 16, kernel_size=1, stride=1, padding=0) | |
| elif self.feature_combine_method == "modulate": | |
| self.conv2 = nn.Conv2d(16, 16, kernel_size=1, stride=1, padding=0) | |
| default_attention_bias = torch.zeros(self.refinement_range * self.refinement_range) | |
| self.classification_bias = nn.Parameter(default_attention_bias) | |
| # initialize uncertainty heads | |
| if len(uncertainty_head_kwargs) > 0: | |
| assert uncertainty_head_type == "dpt", "Only DPT is supported for uncertainty head now" | |
| self.uncertainty_head = self._initialize_prediction_heads( | |
| uncertainty_head_type, uncertainty_head_kwargs, uncertainty_adaptors_kwargs | |
| ) | |
| self.uncertainty_adaptors = self._initialize_adaptors(uncertainty_adaptors_kwargs) | |
| assert pretrained_checkpoint_path is None, "Pretrained weights are not supported for now" | |
| self.detach_uncertainty_head = detach_uncertainty_head | |
| def forward(self, view1, view2) -> UFMOutputInterface: | |
| """ | |
| Forward interface of correspondence prediction networks. | |
| Args: | |
| - view1 (Dict[str, Any]): Input view 1 | |
| - img (torch.Tensor): BCHW image tensor normalized according to encoder's data_norm_type | |
| - instance (List[int]): List of instance indices, or id of the input image | |
| - data_norm_type (str): Data normalization type, see uniception.models.encoders.IMAGE_NORMALIZATION_DICT | |
| - view2 (Dict[str, Any]): Input view 2 | |
| - (same structure as view1) | |
| Returns: | |
| - Dict[str, Any]: Output results | |
| - flow [Required] (Dict[str, torch.Tensor]): Flow output | |
| - [Required] flow_output (torch.Tensor): Flow output tensor, BCHW | |
| - [Optional] flow_covariance | |
| - [Optional] flow_covariance_inv | |
| - [Optional] flow_covariance_log_det | |
| - covisibility [Optional] (Dict[str, torch.Tensor]): Covisibility output | |
| - [Optional] mask | |
| - [Optional] logits | |
| - classification [Optional]: Probability and targets of the classification head | |
| """ | |
| # Get input shapes | |
| _, _, height1, width1 = view1["img"].shape | |
| _, _, height2, width2 = view2["img"].shape | |
| shape1 = (int(height1), int(width1)) | |
| shape2 = (int(height2), int(width2)) | |
| # Encode the two images --> Each feat output: BCHW features (batch_size, feature_dim, feature_height, feature_width) | |
| feat1_list, feat2_list = self._encode_symmetrized(view1, view2, view1["symmetrized"]) | |
| # Pass the features through the info_sharing | |
| info_sharing_input = MultiViewTransformerInput(features=[feat1_list[-1], feat2_list[-1]]) | |
| final_info_sharing_multi_view_feat, intermediate_info_sharing_multi_view_feat = self.info_sharing( | |
| info_sharing_input | |
| ) | |
| info_sharing_outputs = { | |
| "1": [ | |
| feat1_list[-1].float().contiguous(), | |
| intermediate_info_sharing_multi_view_feat[0].features[0].float().contiguous(), | |
| intermediate_info_sharing_multi_view_feat[1].features[0].float().contiguous(), | |
| final_info_sharing_multi_view_feat.features[0].float().contiguous(), | |
| ], | |
| "2": [ | |
| feat2_list[-1].float().contiguous(), | |
| intermediate_info_sharing_multi_view_feat[0].features[1].float().contiguous(), | |
| intermediate_info_sharing_multi_view_feat[1].features[1].float().contiguous(), | |
| final_info_sharing_multi_view_feat.features[1].float().contiguous(), | |
| ], | |
| } | |
| info_sharing_outputs_detached = { | |
| "1": [ | |
| feat1_list[-1].detach().float().contiguous(), | |
| intermediate_info_sharing_multi_view_feat[0].features[0].detach().float().contiguous(), | |
| intermediate_info_sharing_multi_view_feat[1].features[0].detach().float().contiguous(), | |
| final_info_sharing_multi_view_feat.features[0].detach().float().contiguous(), | |
| ], | |
| "2": [ | |
| feat2_list[-1].detach().float().contiguous(), | |
| intermediate_info_sharing_multi_view_feat[0].features[1].detach().float().contiguous(), | |
| intermediate_info_sharing_multi_view_feat[1].features[1].detach().float().contiguous(), | |
| final_info_sharing_multi_view_feat.features[1].detach().float().contiguous(), | |
| ], | |
| } | |
| # optionally inference for U-Net Features | |
| if self.use_unet_feature: | |
| unet_feat1 = self.unet_feature(view1["img"]) | |
| unet_feat2 = self.unet_feature(view2["img"]) | |
| result = UFMOutputInterface() | |
| # The prediction need precision, so we disable any autocasting here | |
| with torch.autocast("cuda", torch.float32): | |
| # run the collected info_sharing features through the prediction heads | |
| head_output1 = self._downstream_head(1, info_sharing_outputs, shape1) | |
| flow_prediction = head_output1["flow"].value | |
| if hasattr(self, "uncertainty_head"): | |
| # run the uncertainty head | |
| head_output_uncertainty = self._downstream_head( | |
| "uncertainty", | |
| info_sharing_outputs_detached if self.detach_uncertainty_head else info_sharing_outputs, | |
| shape1, | |
| ) | |
| if "flow_cov" in head_output_uncertainty: | |
| result.flow.flow_covariance = head_output_uncertainty["flow_cov"].covariance | |
| result.flow.flow_covariance_inv = head_output_uncertainty["flow_cov"].inv_covariance | |
| result.flow.flow_covariance_log_det = head_output_uncertainty["flow_cov"].log_det | |
| if "keypoint_confidence" in head_output_uncertainty: | |
| result.keypoint_confidence = head_output_uncertainty["keypoint_confidence"].value.squeeze(1) | |
| if "non_occluded_mask" in head_output_uncertainty: | |
| result.covisibility = UFMMaskFieldOutput( | |
| mask=head_output_uncertainty["non_occluded_mask"].mask, | |
| logits=head_output_uncertainty["non_occluded_mask"].logits, | |
| ) | |
| # we run the classification head in the autocast environment bacause it is not regression | |
| if self.classification_head_type == "patch_mlp": | |
| # concatenate the last encoder feature with final info_sharing feature | |
| # use the first encoder feature, because it captures more low-level information, which is needed | |
| # for refinement of the regressed flow. | |
| classification_feat_1 = torch.cat( | |
| [feat1_list[0].float().contiguous(), info_sharing_outputs["1"][-1]], dim=1 | |
| ) | |
| classification_feat_2 = torch.cat( | |
| [feat2_list[0].float().contiguous(), info_sharing_outputs["2"][-1]], dim=1 | |
| ) | |
| classification_input = PredictionHeadInput( | |
| torch.cat([classification_feat_1, classification_feat_2], dim=0) | |
| ) | |
| classification_features = self.classification_head(classification_input).decoded_channels | |
| if self.use_unet_feature: | |
| if self.feature_combine_method == "conv": | |
| combined_features = torch.cat( | |
| [classification_features, torch.cat([unet_feat1, unet_feat2], dim=0)], dim=1 | |
| ) | |
| combined_features = self.conv1(combined_features) | |
| combined_features = nn.functional.relu(combined_features) | |
| combined_features = self.conv2(combined_features) | |
| elif self.feature_combine_method == "modulate": | |
| combined_features = classification_features * torch.tanh( | |
| torch.cat([unet_feat1, unet_feat2], dim=0) | |
| ) | |
| combined_features = self.conv2(combined_features) | |
| classification_features = combined_features | |
| classification_features0, classification_features1 = classification_features.chunk(2, dim=0) | |
| # refine the flow prediction with features from the classification head | |
| for i in range(1): | |
| residual, log_softmax_attention = self.classification_refinement( | |
| flow_prediction, classification_features | |
| ) | |
| flow_prediction = flow_prediction + residual | |
| # Fill in the result | |
| # WARNING: based on how the residual is computed, flow_prediction will have gradient cancelled by mathematics, | |
| # so there will be no supervision to the flow prediction at all. We need to use specialized loss function to | |
| # supervise the regression_flow_output. | |
| result.flow = UFMFlowFieldOutput( | |
| flow_output=flow_prediction, | |
| ) | |
| result.classification_refinement = UFMClassificationRefinementOutput( | |
| regression_flow_output=flow_prediction, | |
| residual=residual, | |
| log_softmax=log_softmax_attention, | |
| feature_map_0=classification_features0, | |
| feature_map_1=classification_features1, | |
| ) | |
| return result | |
| # @torch.compile() | |
| def classification_refinement(self, flow_prediction, classification_features) -> Dict[str, Any]: | |
| """ | |
| Use correlation between self feature and features around a local patch of the initial flow prediction | |
| to refine the flow prediction. | |
| """ | |
| classification_features1, classification_features2 = classification_features.chunk(2, dim=0) | |
| neighborhood_features, neighborhood_flow_residual = self.obtain_neighborhood_features( | |
| flow_estimation=flow_prediction, other_features=classification_features2, local_patch=self.refinement_range | |
| ) | |
| residual, log_softmax_attention = self.compute_refinement_attention( | |
| classification_features1, neighborhood_features, neighborhood_flow_residual | |
| ) | |
| return residual, log_softmax_attention | |
| def compute_refinement_attention(self, classification_features1, neighborhood_features, neighborhood_flow_residual): | |
| """ | |
| Compute the attention for the refinement, with special processing | |
| to fit | |
| """ | |
| B, C, H, W = classification_features1.shape | |
| P = self.refinement_range | |
| # reshape Q to B, H, W, 1, 1, C | |
| classification_features1 = classification_features1.permute(0, 2, 3, 1).reshape(B * H * W, 1, C) | |
| # reshape K to B, H, W, 1, P^2, C | |
| assert neighborhood_features.shape[0] == B | |
| assert neighborhood_features.shape[1] == H | |
| assert neighborhood_features.shape[2] == W | |
| assert neighborhood_features.shape[3] == P | |
| assert neighborhood_features.shape[4] == P | |
| assert neighborhood_features.shape[5] == C | |
| neighborhood_features = neighborhood_features.reshape(B * H * W, P * P, C) | |
| # reshape V to B, H, W, 1, P^2, 2 | |
| neighborhood_flow_residual = neighborhood_flow_residual.reshape(-1, P * P, 2) | |
| # compute the attention | |
| attention_score = ( | |
| torch.matmul(classification_features1, neighborhood_features.permute(0, 2, 1)) / self.temperature | |
| ) | |
| attention_score = attention_score + self.classification_bias | |
| attention = torch.nn.functional.softmax(attention_score, dim=-1) | |
| log_softmax_attention = torch.nn.functional.log_softmax(attention_score, dim=-1) | |
| # compute the weighted sum | |
| residual = torch.matmul(attention, neighborhood_flow_residual) | |
| # reshape the residual to B, H, W, 2, then B, 2, H, W | |
| residual = residual.reshape(B, H, W, 2).permute(0, 3, 1, 2) | |
| return residual, log_softmax_attention.reshape(B, H, W, P, P) | |
| def _downstream_head(self, head_num, decout, img_shape): | |
| "Run the respective prediction heads" | |
| # if self.info_sharing_and_head_structure == "dual+single": | |
| head = getattr(self, f"head{head_num}") if head_num != "uncertainty" else self.uncertainty_head | |
| head_num = head_num if head_num != "uncertainty" else 1 # uncertainty head is always from branch 1 | |
| if self.head_type == "linear": | |
| head_input = PredictionHeadInput(last_feature=decout[f"{head_num}"]) | |
| elif self.head_type in ["dpt", "moge_conv"]: | |
| head_input = PredictionHeadLayeredInput(list_features=decout[f"{head_num}"], target_output_shape=img_shape) | |
| return head(head_input) | |
| def obtain_neighborhood_features( | |
| self, flow_estimation: torch.Tensor, other_features: torch.Tensor, local_patch: int = 5 | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Query the other features according to flow estimation. | |
| """ | |
| assert local_patch % 2 == 1, "local_patch should be odd number" | |
| P = local_patch | |
| R = (P - 1) // 2 | |
| B, C, H, W = other_features.shape | |
| device = other_features.device | |
| # expected_output = torch.zeros(B, H, W, P, P, C, device=other_features.device, dtype=torch.float32) | |
| neighborhood_grid_ij: torch.Tensor | |
| i_local, j_local = torch.meshgrid( | |
| torch.arange(-R, R + 1, device=device), torch.arange(-R, R + 1, device=device), indexing="ij" | |
| ) | |
| ij_local = torch.stack((i_local, j_local), dim=0) # 2, P, P tensor | |
| # compute the indices of the fetch | |
| base_grid_xy = get_meshgrid_torch(W=W, H=H, device=device).permute(2, 0, 1).reshape(1, 2, H, W) | |
| target_coordinate_xy_float = flow_estimation + base_grid_xy | |
| target_coordinate_xy = target_coordinate_xy_float.view(B, 2, H, W, 1, 1) | |
| target_coordinate_ij = target_coordinate_xy[:, [1, 0], ...] | |
| # compute the neighborhood grid | |
| neighborhood_grid_ij = target_coordinate_ij + ij_local.view(1, 2, 1, 1, P, P) | |
| grid_for_sample = neighborhood_grid_ij[:, [1, 0], ...].permute(0, 2, 3, 4, 5, 1).reshape(B, H, W * P * P, 2) | |
| grid_for_sample = (grid_for_sample + 0.5) / torch.tensor([W, H], device=device).view(1, 1, 1, 2) | |
| grid_for_sample = grid_for_sample * 2 - 1 | |
| expected_output = torch.nn.functional.grid_sample( | |
| other_features, grid=grid_for_sample, mode="bicubic", padding_mode="zeros", align_corners=False | |
| ).view(B, C, H, W, P, P) | |
| # transform BCHWPP to BHWPPC | |
| expected_output = expected_output.permute(0, 2, 3, 4, 5, 1) | |
| neighborhood_grid_xy_residual = ij_local[[1, 0], ...].view(1, 2, 1, 1, P, P).to(device).float() | |
| neighborhood_grid_xy_residual = neighborhood_grid_xy_residual.permute(0, 2, 3, 4, 5, 1).float() | |
| return expected_output, neighborhood_grid_xy_residual | |
| def _initialize_classification_head(self, classification_head_kwargs: Dict[str, Any]): | |
| """ | |
| Initialize classification head | |
| Args: | |
| - classification_head_kwargs (Dict[str, Any]): Classification head configurations | |
| Returns: | |
| - nn.Module: Classification head | |
| """ | |
| if self.classification_head_type == "patch_mlp": | |
| return MLPFeature(**classification_head_kwargs) | |
| else: | |
| raise ValueError(f"Classification head type {self.classification_head_type} not supported.") | |
| def get_parameter_groups(self) -> Dict[str, torch.nn.ParameterList]: | |
| """ | |
| Get parameter groups for optimizer. This methods guides the optimizer | |
| to apply correct learning rate to different parts of the model. | |
| Returns: | |
| - Dict[str, torch.nn.ParameterList]: Parameter groups for optimizer | |
| """ | |
| if self.use_unet_feature: | |
| params_dict = { | |
| "encoder": torch.nn.ParameterList(self.encoder.parameters()), | |
| "info_sharing": torch.nn.ParameterList(self.info_sharing.parameters()), | |
| "output_head": torch.nn.ParameterList(self.head1.parameters()), | |
| "classification_head": torch.nn.ParameterList(self.classification_head.parameters()), | |
| "unet_feature": torch.nn.ParameterList( | |
| list(self.unet_feature.parameters()) | |
| + list(self.conv1.parameters()) | |
| + list(self.conv2.parameters()) | |
| + [self.classification_bias] | |
| ), | |
| } | |
| else: | |
| params_dict = { | |
| "encoder": torch.nn.ParameterList(self.encoder.parameters()), | |
| "info_sharing": torch.nn.ParameterList(self.info_sharing.parameters()), | |
| "output_head": torch.nn.ParameterList(self.head1.parameters()), | |
| "classification_head": torch.nn.ParameterList(self.classification_head.parameters()), | |
| } | |
| if hasattr(self, "uncertainty_head"): | |
| params_dict["uncertainty_head"] = torch.nn.ParameterList(self.uncertainty_head.parameters()) | |
| return params_dict | |
| if __name__ == "__main__": | |
| import cv2 | |
| import flow_vis | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| from uniflowmatch.utils.geometry import get_meshgrid_torch | |
| from uniflowmatch.utils.viz import warp_image_with_flow | |
| USE_REFINEMENT_MODEL = False | |
| if USE_REFINEMENT_MODEL: | |
| model = UniFlowMatchClassificationRefinement.from_pretrained("infinity1096/UFM-Refine") | |
| else: | |
| model = UniFlowMatchConfidence.from_pretrained("infinity1096/UFM-Base") | |
| # === Load and Prepare Images === | |
| source_path = "examples/image_pairs/fire_academy_0.png" | |
| target_path = "examples/image_pairs/fire_academy_1.png" | |
| source_image = cv2.imread(source_path) | |
| target_image = cv2.imread(target_path) | |
| source_image = cv2.cvtColor(source_image, cv2.COLOR_BGR2RGB) | |
| target_image = cv2.cvtColor(target_image, cv2.COLOR_BGR2RGB) | |
| # === Predict Correspondences === | |
| result = model.predict_correspondences_batched( | |
| source_image=torch.from_numpy(source_image), | |
| target_image=torch.from_numpy(target_image), | |
| ) | |
| flow_output = result.flow.flow_output[0].cpu().numpy() | |
| covisibility = result.covisibility.mask[0].cpu().numpy() | |
| # === Visualize Results === | |
| fig, axs = plt.subplots(2, 3, figsize=(15, 5)) | |
| axs[0, 0].imshow(source_image) | |
| axs[0, 0].set_title("Source Image") | |
| axs[0, 1].imshow(target_image) | |
| axs[0, 1].set_title("Target Image") | |
| # Warp the image using flow | |
| warped_image = warp_image_with_flow(source_image, None, target_image, flow_output.transpose(1, 2, 0)) | |
| warped_image = covisibility[..., None] * warped_image + (1 - covisibility[..., None]) * 255 * np.ones_like( | |
| warped_image | |
| ) | |
| warped_image /= 255.0 | |
| axs[0, 2].imshow(warped_image) | |
| axs[0, 2].set_title("Warped Image") | |
| # Flow visualization | |
| flow_vis_image = flow_vis.flow_to_color(flow_output.transpose(1, 2, 0)) | |
| axs[1, 0].imshow(flow_vis_image) | |
| axs[1, 0].set_title("Flow Output (Valid at covisible region)") | |
| # Covisibility mask | |
| axs[1, 1].imshow(covisibility > 0.5, cmap="gray", vmin=0, vmax=1) | |
| axs[1, 1].set_title("Covisibility Mask (Thresholded by 0.5)") | |
| heatmap = axs[1, 2].imshow(covisibility, cmap="gray", vmin=0, vmax=1) | |
| axs[1, 2].set_title("Covisibility Mask") | |
| plt.colorbar(heatmap, ax=axs[1, 2]) | |
| plt.tight_layout() | |
| plt.savefig("ufm_output.png") | |
| plt.show() | |
| print("Saved ufm_output.png") |