infinity1096
initial commit
c8b42eb
import os
from typing import Any, Dict, List, Optional, Tuple
import torch
from huggingface_hub import PyTorchModelHubMixin
from torch import nn
# Only enable flash attention backend
from uniception.models.encoders import ViTEncoderInput, feature_returner_encoder_factory
from uniception.models.info_sharing import INFO_SHARING_CLASSES, MultiViewTransformerInput
from uniception.models.prediction_heads.adaptors import (
ConfidenceAdaptor,
Covariance2DAdaptor,
FlowAdaptor,
FlowWithConfidenceAdaptor,
MaskAdaptor,
)
from uniception.models.prediction_heads.base import AdaptorMap, PredictionHeadInput, PredictionHeadLayeredInput
from uniception.models.prediction_heads.dpt import DPTFeature, DPTRegressionProcessor
from uniception.models.prediction_heads.mlp_feature import MLPFeature
from uniception.models.prediction_heads.moge_conv import MoGeConvFeature
from uniflowmatch.models.base import (
UFMClassificationRefinementOutput,
UFMFlowFieldOutput,
UFMMaskFieldOutput,
UFMOutputInterface,
UniFlowMatchModelsBase,
)
from uniflowmatch.models.unet_encoder import UNet
from uniflowmatch.models.utils import get_meshgrid_torch
CLASSNAME_TO_ADAPTOR_CLASS = {
"FlowWithConfidenceAdaptor": FlowWithConfidenceAdaptor,
"FlowAdaptor": FlowAdaptor,
"MaskAdaptor": MaskAdaptor,
"Covariance2DAdaptor": Covariance2DAdaptor,
"ConfidenceAdaptor": ConfidenceAdaptor,
}
# dust3r data structure for reducing passing duplicate images through the encoder
def is_symmetrized(gt1, gt2):
"Function to check if input pairs are symmetrized, i.e., (a, b) and (b, a) always exist in the input"
x = gt1["instance"]
y = gt2["instance"]
if len(x) == len(y) and len(x) == 1:
return False # special case of batchsize 1
ok = True
for i in range(0, len(x), 2):
ok = ok and (x[i] == y[i + 1]) and (x[i + 1] == y[i])
return ok
def interleave(tensor1, tensor2):
"Interleave two tensors along the first dimension (used to avoid redundant encoding for symmetrized pairs)"
res1 = torch.stack((tensor1, tensor2), dim=1).flatten(0, 1)
res2 = torch.stack((tensor2, tensor1), dim=1).flatten(0, 1)
return res1, res2
def modify_state_dict(original_state_dict, mappings):
"""
Modify state dict keys according to replacement mappings
Args:
original_state_dict: Loaded checkpoint state dict
mappings: Dictionary of {old_key_substr: new_key_substr_or_None}
Returns:
Modified state dictionary with updated keys
"""
new_state_dict = {}
for k, v in original_state_dict.items():
new_key = None
skip = False
# Check for all possible replacements
for replace_key, replace_value in mappings.items():
if replace_key in k:
if replace_value is None:
skip = True
break # Skip this key entirely
else:
new_key = k.replace(replace_key, replace_value)
break # Only apply first matching replacement
if skip:
continue
new_state_dict[new_key if new_key is not None else k] = v
return new_state_dict
class UniFlowMatch(UniFlowMatchModelsBase, PyTorchModelHubMixin):
"""
UniFlowMatch model.
"""
def __init__(
self,
# Encoder configurations
encoder_str: str,
encoder_kwargs: Dict[str, Any],
# Info sharing & output head structure configurations
info_sharing_and_head_structure: str = "dual+single", # only dual+single is supported
# Information sharing configurations
info_sharing_str: str = "global_attention",
info_sharing_kwargs: Dict[str, Any] = {},
# skip-connections between encoder and info-sharing
encoder_skip_connection: Optional[List[int]] = None,
info_sharing_skip_connection: Optional[List[int]] = None,
# Prediction Heads & Adaptors
head_type: str = "dpt",
feature_head_kwargs: Dict[str, Any] = {},
adaptors_kwargs: Dict[str, Any] = {},
# Load Pretrained Weights
pretrained_checkpoint_path: Optional[str] = None,
# Inference Settings
inference_resolution: Optional[Tuple[int, int]] = (560, 420), # WH
*args,
**kwargs,
):
"""
Initialize the UniFlowMarch Model
- encoder_str (str): Encoder string
- encoder_kwargs (Dict[str, Any]): Encoder configurations
- info_sharing_and_head_structure (str): Info sharing and head structure configurations
- "dual+single": Dual view info sharing and single view prediction head
- info_sharing_str (str): Info sharing method
- "global_attention_transformer": Global attention transformer
- info_sharing_kwargs (Dict[str, Any]): Info sharing configurations
"""
UniFlowMatchModelsBase.__init__(self, inference_resolution=inference_resolution, *args, **kwargs)
PyTorchModelHubMixin.__init__(self)
# assertion on architectures
assert info_sharing_and_head_structure == "dual+single", "Only dual+single is supported now"
# initialize the skip-connections
self.encoder_skip_connection = encoder_skip_connection
self.info_sharing_skip_connection = info_sharing_skip_connection
# initialize encoder
self.encoder: nn.Module = feature_returner_encoder_factory(encoder_str, **encoder_kwargs)
# initialize info-sharing module
assert head_type != "linear", "Linear head is not supported, because it have major disadvantage to DPTs"
self.head_type = head_type
self.info_sharing: nn.Module = INFO_SHARING_CLASSES[info_sharing_str][1](**info_sharing_kwargs)
self.head1: nn.Module = self._initialize_prediction_heads(head_type, feature_head_kwargs, adaptors_kwargs)
# load pretrained weights
if pretrained_checkpoint_path is not None:
ckpt = torch.load(pretrained_checkpoint_path, map_location="cpu")
if "state_dict" in ckpt:
# we are loading from training checkpoint directly.
model_state_dict = ckpt["state_dict"]
model_state_dict = {
k[6:]: v for k, v in model_state_dict.items() if k.startswith("model.")
} # remove "model." prefix
model_state_dict = modify_state_dict(
model_state_dict, {"feature_matching_proj": None, "encoder.model.mask_token": None}
)
self.load_state_dict(model_state_dict, strict=True)
else:
model_state_dict = ckpt["model"]
load_result = self.load_state_dict(model_state_dict, strict=False)
assert len(load_result.missing_keys) == 0, f"Missing keys: {load_result.missing_keys}"
@classmethod
def from_pretrained_ckpt(cls, pretrained_model_name_or_path, strict=True, **kw):
if os.path.isfile(pretrained_model_name_or_path):
ckpt = torch.load(pretrained_model_name_or_path, map_location="cpu")
# remove base_pretrained_checkpoint_path from the model args
if "base_pretrained_checkpoint_path" in ckpt["model_args"]:
ckpt["model_args"].pop("base_pretrained_checkpoint_path")
# convert old model args into new definition
if "img_size" in ckpt["model_args"]:
# we are loading from a old benchmark checkpoint
print("Converting from a old benchmark checkpoint")
model_args = {
# Encoder args
"encoder_str": ckpt["model_args"]["encoder_str"],
"encoder_kwargs": ckpt["model_args"]["encoder_kwargs"],
# Info-sharing args
"info_sharing_and_head_structure": "dual+single",
"info_sharing_str": ckpt["model_args"]["info_sharing_type"],
"info_sharing_kwargs": {
"name": "info_sharing",
"input_embed_dim": ckpt["model_args"]["input_embed_dim"],
"num_views": 2,
"use_rand_idx_pe_for_non_reference_views": False,
"depth": ckpt["model_args"]["num_layers"],
"dim": ckpt["model_args"]["transformer_dim"],
"num_heads": ckpt["model_args"]["num_heads"],
"mlp_ratio": ckpt["model_args"]["mlp_ratio"],
"qkv_bias": ckpt["model_args"]["qkv_bias"],
"qk_norm": ckpt["model_args"]["qk_norm"],
"custom_positional_encoding": ckpt["model_args"]["position_encoding"],
"norm_intermediate": ckpt["model_args"]["normalize_intermediate"],
"indices": ckpt["model_args"]["returned_intermediate_layers"],
},
# flow head args
"head_type": "dpt",
"feature_head_kwargs": ckpt["model_args"]["feature_head_kwargs"],
"adaptors_kwargs": ckpt["model_args"]["adaptors_kwargs"],
}
if "covocc_feature_head_kwargs" in ckpt["model_args"]:
# if the model has a covocc head, we need to convert it to the new format
model_args["uncertainty_head_type"] = "dpt"
model_args["uncertainty_head_kwargs"] = {
"dpt_feature": ckpt["model_args"]["covocc_feature_head_kwargs"]["dpt_feature"],
"dpt_processor": ckpt["model_args"]["covocc_feature_head_kwargs"]["dpt_regr_processor"],
}
model_args["uncertainty_adaptors_kwargs"] = {
"flow_cov": ckpt["model_args"]["covocc_adaptors_kwargs"]["flow_cov"]
}
ckpt["model_args"] = model_args
# Update the old weights into the current format
ckpt["model"] = modify_state_dict(
ckpt["model"],
{
"covocc_head.dpt_feature": "uncertainty_head.0.0",
"covocc_head.dpt_regr_processor": "uncertainty_head.0.1",
"covocc_head.dpt_segm_processor": None,
"feature_matching_proj": None,
"encoder.model.mask_token": None,
},
)
# remove the ket "pretrained_backbone_checkpoint_path" from the model args
if "pretrained_backbone_checkpoint_path" in ckpt["model_args"]:
ckpt["model_args"].pop("pretrained_backbone_checkpoint_path")
model = cls(**ckpt["model_args"])
model.load_state_dict(ckpt["model"], strict=strict)
return model
else:
raise ValueError(f"Pretrained model {pretrained_model_name_or_path} not found.")
def _initialize_prediction_heads(
self, head_type: str, feature_head_kwargs: Dict[str, Any], adaptors_kwargs: Dict[str, Any]
):
"""
Initialize prediction heads and adaptors
Args:
- head_type (str): Head type, either "dpt" or "linear"
- feature_head_kwargs (Dict[str, Any]): Feature head configurations
- adaptors_kwargs (Dict[str, Any]): Adaptors configurations
Returns:
- nn.Module: output head + adaptors
"""
feature_processor: nn.Module
if head_type == "dpt":
feature_processor = nn.Sequential(
DPTFeature(**feature_head_kwargs["dpt_feature"]),
DPTRegressionProcessor(**feature_head_kwargs["dpt_processor"]),
)
elif head_type == "moge_conv":
feature_processor = MoGeConvFeature(**feature_head_kwargs)
else:
raise ValueError(f"Head type {head_type} not supported.")
adaptors = self._initialize_adaptors(adaptors_kwargs)
return nn.Sequential(feature_processor, AdaptorMap(*adaptors.values()))
def _initialize_adaptors(self, adaptors_kwargs: Dict[str, Any]):
"""
Initialize a dict of adaptors
Args:
- adaptors_kwargs (Dict[str, Any]): Adaptors configurations
Returns:
- Dict[str, nn.Module]: dict of adaptors, from adaptor's name to the adaptor
"""
return {
name: CLASSNAME_TO_ADAPTOR_CLASS[configs["class"]](**configs["kwargs"])
for name, configs in adaptors_kwargs.items()
}
def _encode_image_pairs(self, img1, img2, data_norm_type):
"Encode two different batches of images (each batch can have different image shape)"
if img1.shape[-2:] == img2.shape[-2:]:
encoder_input = ViTEncoderInput(image=torch.cat((img1, img2), dim=0), data_norm_type=data_norm_type)
encoder_output = self.encoder(encoder_input)
out_list, out2_list = [], []
for encoder_output_ in encoder_output:
out, out2 = encoder_output_.features.chunk(2, dim=0)
out_list.append(out)
out2_list.append(out2)
else:
raise NotImplementedError("Unequal Image sizes are not supported now")
return out_list, out2_list
def _encode_symmetrized(self, view1, view2, symmetrized=False):
"Encode image pairs accounting for symmetrization, i.e., (a, b) and (b, a) always exist in the input"
img1 = view1["img"]
img2 = view2["img"]
feat1_list, feat2_list = [], []
if symmetrized:
# Computing half of forward pass'
# modified in conjunction with UFM for not copying the images again.
# used to be: feat1, feat2 = self._encode_image_pairs(img1[::2], img2[::2], data_norm_type=view1["data_norm_type"])
# be very carefult with this!!!
feat1_list_, feat2_list_ = self._encode_image_pairs(
img1[::2], img2[::2], data_norm_type=view1["data_norm_type"]
)
for feat1, feat2 in zip(feat1_list_, feat2_list_):
feat1, feat2 = interleave(feat1, feat2)
feat1_list.append(feat1)
feat2_list.append(feat2)
else:
feat1_list, feat2_list = self._encode_image_pairs(img1, img2, data_norm_type=view1["data_norm_type"])
return feat1_list, feat2_list
def forward(self, view1, view2) -> UFMOutputInterface:
"""
Forward interface of correspondence prediction networks.
Args:
- view1 (Dict[str, Any]): Input view 1
- img (torch.Tensor): BCHW image tensor normalized according to encoder's data_norm_type
- instance (List[int]): List of instance indices, or id of the input image
- data_norm_type (str): Data normalization type, see uniception.models.encoders.IMAGE_NORMALIZATION_DICT
- view2 (Dict[str, Any]): Input view 2
- (same structure as view1)
Returns:
- Dict[str, Any]: Output results
- flow [Required] (Dict[str, torch.Tensor]): Flow output
- [Required] flow_output (torch.Tensor): Flow output tensor, BCHW
- [Optional] flow_covariance
- [Optional] flow_covariance_inv
- [Optional] flow_covariance_log_det
- covisibility [Optional] (Dict[str, torch.Tensor]): Covisibility output
- [Optional] mask
- [Optional] logits
"""
# Get input shapes
_, _, height1, width1 = view1["img"].shape
_, _, height2, width2 = view2["img"].shape
shape1 = (int(height1), int(width1))
shape2 = (int(height2), int(width2))
# Encode the two images --> Each feat output: BCHW features (batch_size, feature_dim, feature_height, feature_width)
feat1_list, feat2_list = self._encode_symmetrized(view1, view2, view1["symmetrized"])
# Pass the features through the info_sharing
info_sharing_input = MultiViewTransformerInput(features=[feat1_list[-1], feat2_list[-1]])
final_info_sharing_multi_view_feat, intermediate_info_sharing_multi_view_feat = self.info_sharing(
info_sharing_input
)
info_sharing_outputs = {
"1": [
feat1_list[-1].float().contiguous(),
intermediate_info_sharing_multi_view_feat[0].features[0].float().contiguous(),
intermediate_info_sharing_multi_view_feat[1].features[0].float().contiguous(),
final_info_sharing_multi_view_feat.features[0].float().contiguous(),
],
"2": [
feat2_list[-1].float().contiguous(),
intermediate_info_sharing_multi_view_feat[0].features[1].float().contiguous(),
intermediate_info_sharing_multi_view_feat[1].features[1].float().contiguous(),
final_info_sharing_multi_view_feat.features[1].float().contiguous(),
],
}
result = UFMOutputInterface()
# The prediction need precision, so we disable any autocasting here
with torch.autocast("cuda", torch.float32):
# run the collected info_sharing features through the prediction heads
head_output1 = self._downstream_head(1, info_sharing_outputs, shape1)
if "flow" in head_output1:
# output is flow only
result.flow = UFMFlowFieldOutput(flow_output=head_output1["flow"].value)
if "flow_cov" in head_output1:
result.flow.flow_covariance = head_output1["flow_cov"].covariance
result.flow.flow_covariance_inv = head_output1["flow_cov"].inv_covariance
result.flow.flow_covariance_log_det = head_output1["flow_cov"].log_det
if "non_occluded_mask" in head_output1:
result.covisibility = UFMMaskFieldOutput(
mask=head_output1["non_occluded_mask"].mask,
logits=head_output1["non_occluded_mask"].logits,
)
return result
def _downstream_head(self, head_num, decout, img_shape):
"Run the respective prediction heads"
# if self.info_sharing_and_head_structure == "dual+single":
head = getattr(self, f"head{head_num}")
if self.head_type == "linear":
head_input = PredictionHeadInput(last_feature=decout[f"{head_num}"])
elif self.head_type in ["dpt", "moge_conv"]:
head_input = PredictionHeadLayeredInput(list_features=decout[f"{head_num}"], target_output_shape=img_shape)
return head(head_input)
def get_parameter_groups(self) -> Dict[str, torch.nn.ParameterList]:
"""
Get parameter groups for optimizer. This methods guides the optimizer
to apply correct learning rate to different parts of the model.
Returns:
- Dict[str, torch.nn.ParameterList]: Parameter groups for optimizer
"""
return {
"encoder": torch.nn.ParameterList(self.encoder.parameters()),
"info_sharing": torch.nn.ParameterList(self.info_sharing.parameters()),
"output_head": torch.nn.ParameterList(self.head1.parameters()),
}
class UniFlowMatchConfidence(UniFlowMatch, PyTorchModelHubMixin):
"""
UniFlowMatch model with uncertainty estimation.
"""
def __init__(
self,
# Encoder configurations
encoder_str: str,
encoder_kwargs: Dict[str, Any],
# Info sharing & output head structure configurations
info_sharing_and_head_structure: str = "dual+single", # only dual+single is supported
# Information sharing configurations
info_sharing_str: str = "global_attention",
info_sharing_kwargs: Dict[str, Any] = {},
# Prediction Heads & Adaptors
head_type: str = "dpt",
feature_head_kwargs: Dict[str, Any] = {},
adaptors_kwargs: Dict[str, Any] = {},
# Uncertainty Heads & Adaptors
detach_uncertainty_head: bool = True,
uncertainty_head_type: str = "dpt",
uncertainty_head_kwargs: Dict[str, Any] = {},
uncertainty_adaptors_kwargs: Dict[str, Any] = {},
# Load Pretrained Weights
pretrained_backbone_checkpoint_path: Optional[str] = None,
pretrained_checkpoint_path: Optional[str] = None,
# Inference Settings
inference_resolution: Optional[Tuple[int, int]] = (560, 420), # WH
*args,
**kwargs,
):
UniFlowMatch.__init__(
self,
encoder_str=encoder_str,
encoder_kwargs=encoder_kwargs,
info_sharing_and_head_structure=info_sharing_and_head_structure,
info_sharing_str=info_sharing_str,
info_sharing_kwargs=info_sharing_kwargs,
head_type=head_type,
feature_head_kwargs=feature_head_kwargs,
adaptors_kwargs=adaptors_kwargs,
pretrained_checkpoint_path=pretrained_backbone_checkpoint_path,
inference_resolution=inference_resolution,
*args,
**kwargs,
)
PyTorchModelHubMixin.__init__(self)
# initialize uncertainty heads
assert uncertainty_head_type == "dpt", "Only DPT is supported for uncertainty head now"
self.uncertainty_head = self._initialize_prediction_heads(
uncertainty_head_type, uncertainty_head_kwargs, uncertainty_adaptors_kwargs
)
self.uncertainty_adaptors = self._initialize_adaptors(uncertainty_adaptors_kwargs)
assert pretrained_checkpoint_path is None, "Pretrained weights are not supported for now"
self.detach_uncertainty_head = detach_uncertainty_head
def forward(self, view1, view2) -> UFMOutputInterface:
"""
Forward interface of correspondence prediction networks.
Args:
- view1 (Dict[str, Any]): Input view 1
- img (torch.Tensor): BCHW image tensor normalized according to encoder's data_norm_type
- instance (List[int]): List of instance indices, or id of the input image
- data_norm_type (str): Data normalization type, see uniception.models.encoders.IMAGE_NORMALIZATION_DICT
- view2 (Dict[str, Any]): Input view 2
- (same structure as view1)
Returns:
- Dict[str, Any]: Output results
- flow [Required] (Dict[str, torch.Tensor]): Flow output
- [Required] flow_output (torch.Tensor): Flow output tensor, BCHW
- [Optional] flow_covariance
- [Optional] flow_covariance_inv
- [Optional] flow_covariance_log_det
- covisibility [Optional] (Dict[str, torch.Tensor]): Covisibiltiy output
- [Optional] mask
- [Optional] logits
"""
# Get input shapes
_, _, height1, width1 = view1["img"].shape
_, _, height2, width2 = view2["img"].shape
shape1 = (int(height1), int(width1))
shape2 = (int(height2), int(width2))
# Encode the two images --> Each feat output: BCHW features (batch_size, feature_dim, feature_height, feature_width)
feat1_list, feat2_list = self._encode_symmetrized(view1, view2, view1["symmetrized"])
# Pass the features through the info_sharing
info_sharing_input = MultiViewTransformerInput(features=[feat1_list[-1], feat2_list[-1]])
final_info_sharing_multi_view_feat, intermediate_info_sharing_multi_view_feat = self.info_sharing(
info_sharing_input
)
info_sharing_outputs = {
"1": [
feat1_list[-1].float().contiguous(),
intermediate_info_sharing_multi_view_feat[0].features[0].float().contiguous(),
intermediate_info_sharing_multi_view_feat[1].features[0].float().contiguous(),
final_info_sharing_multi_view_feat.features[0].float().contiguous(),
],
"2": [
feat2_list[-1].float().contiguous(),
intermediate_info_sharing_multi_view_feat[0].features[1].float().contiguous(),
intermediate_info_sharing_multi_view_feat[1].features[1].float().contiguous(),
final_info_sharing_multi_view_feat.features[1].float().contiguous(),
],
}
info_sharing_outputs_detached = {
"1": [
feat1_list[-1].detach().float().contiguous(),
intermediate_info_sharing_multi_view_feat[0].features[0].detach().float().contiguous(),
intermediate_info_sharing_multi_view_feat[1].features[0].detach().float().contiguous(),
final_info_sharing_multi_view_feat.features[0].detach().float().contiguous(),
],
"2": [
feat2_list[-1].detach().float().contiguous(),
intermediate_info_sharing_multi_view_feat[0].features[1].detach().float().contiguous(),
intermediate_info_sharing_multi_view_feat[1].features[1].detach().float().contiguous(),
final_info_sharing_multi_view_feat.features[1].detach().float().contiguous(),
],
}
result = UFMOutputInterface()
# The prediction need precision, so we disable any autocasting here
with torch.autocast("cuda", torch.float32):
# run the collected info_sharing features through the prediction heads
head_output1 = self._downstream_head(1, info_sharing_outputs, shape1)
head_output_uncertainty = self._downstream_head(
"uncertainty",
info_sharing_outputs_detached if self.detach_uncertainty_head else info_sharing_outputs,
shape1,
)
result.flow = UFMFlowFieldOutput(
flow_output=head_output1["flow"].value,
)
if "flow_cov" in head_output_uncertainty:
result.flow.flow_covariance = head_output_uncertainty["flow_cov"].covariance
result.flow.flow_covariance_inv = head_output_uncertainty["flow_cov"].inv_covariance
result.flow.flow_covariance_log_det = head_output_uncertainty["flow_cov"].log_det
if "keypoint_confidence" in head_output_uncertainty:
result.keypoint_confidence = head_output_uncertainty["keypoint_confidence"].value.squeeze(1)
if "non_occluded_mask" in head_output_uncertainty:
result.covisibility = UFMMaskFieldOutput(
mask=head_output_uncertainty["non_occluded_mask"].mask,
logits=head_output_uncertainty["non_occluded_mask"].logits,
)
return result
def get_parameter_groups(self) -> Dict[str, torch.nn.ParameterList]:
"""
Get parameter groups for optimizer. This methods guides the optimizer
to apply correct learning rate to different parts of the model.
Returns:
- Dict[str, torch.nn.ParameterList]: Parameter groups for optimizer
"""
return {
"encoder": torch.nn.ParameterList(self.encoder.parameters()),
"info_sharing": torch.nn.ParameterList(self.info_sharing.parameters()),
"output_head": torch.nn.ParameterList(self.head1.parameters()),
"uncertainty_head": torch.nn.ParameterList(self.uncertainty_head.parameters()),
}
def _downstream_head(self, head_num, decout, img_shape):
"Run the respective prediction heads"
# if self.info_sharing_and_head_structure == "dual+single":
head = getattr(self, f"head{head_num}") if head_num != "uncertainty" else self.uncertainty_head
head_num = head_num if head_num != "uncertainty" else 1 # uncertainty head is always from branch 1
if self.head_type == "linear":
head_input = PredictionHeadInput(last_feature=decout[f"{head_num}"])
elif self.head_type in ["dpt", "moge_conv"]:
head_input = PredictionHeadLayeredInput(list_features=decout[f"{head_num}"], target_output_shape=img_shape)
return head(head_input)
class UniFlowMatchClassificationRefinement(UniFlowMatch, PyTorchModelHubMixin):
"""
The variant of UniFlowMatch with local classification for refinement.
"""
def __init__(
self,
# Encoder configurations
encoder_str: str,
encoder_kwargs: Dict[str, Any],
# Info sharing & output head structure configurations
info_sharing_and_head_structure: str = "dual+single", # only dual+single is supported
# Information sharing configurations
info_sharing_str: str = "global_attention",
info_sharing_kwargs: Dict[str, Any] = {},
# Prediction Heads & Adaptors
head_type: str = "dpt",
feature_head_kwargs: Dict[str, Any] = {},
adaptors_kwargs: Dict[str, Any] = {},
# Uncertainty Heads & Adaptors
detach_uncertainty_head: bool = True,
uncertainty_head_type: str = "dpt",
uncertainty_head_kwargs: Dict[str, Any] = {},
uncertainty_adaptors_kwargs: Dict[str, Any] = {},
# Classification Heads & Adaptors
temperature: float = 4,
use_unet_feature: bool = False,
classification_head_type: str = "patch_mlp",
classification_head_kwargs: Dict[str, Any] = {},
feature_combine_method: str = "conv",
# Refinement Range
refinement_range: int = 5,
# Load Pretrained Weights
pretrained_backbone_checkpoint_path: Optional[str] = None,
pretrained_checkpoint_path: Optional[str] = None,
# Inference Settings
inference_resolution: Optional[Tuple[int, int]] = (560, 420), # WH
*args,
**kwargs,
):
UniFlowMatch.__init__(
self,
encoder_str=encoder_str,
encoder_kwargs=encoder_kwargs,
info_sharing_and_head_structure=info_sharing_and_head_structure,
info_sharing_str=info_sharing_str,
info_sharing_kwargs=info_sharing_kwargs,
head_type=head_type,
feature_head_kwargs=feature_head_kwargs,
adaptors_kwargs=adaptors_kwargs,
pretrained_checkpoint_path=pretrained_backbone_checkpoint_path,
inference_resolution=inference_resolution,
*args,
**kwargs,
)
PyTorchModelHubMixin.__init__(self)
# initialize uncertainty heads
assert classification_head_type == "patch_mlp", "Only DPT is supported for uncertainty head now"
self.classification_head_type = classification_head_type
self.classification_head = self._initialize_classification_head(classification_head_kwargs)
self.refinement_range = refinement_range
self.temperature = temperature
assert pretrained_checkpoint_path is None, "Pretrained weights are not supported for now"
self.use_unet_feature = use_unet_feature
self.feature_combine_method = feature_combine_method
# Unet experiment
if self.use_unet_feature:
self.unet_feature = UNet(in_channels=3, out_channels=16, features=[64, 128, 256, 512])
self.conv1 = nn.Conv2d(32, 32, kernel_size=1, stride=1, padding=0)
if self.feature_combine_method == "conv":
self.conv2 = nn.Conv2d(32, 16, kernel_size=1, stride=1, padding=0)
elif self.feature_combine_method == "modulate":
self.conv2 = nn.Conv2d(16, 16, kernel_size=1, stride=1, padding=0)
default_attention_bias = torch.zeros(self.refinement_range * self.refinement_range)
self.classification_bias = nn.Parameter(default_attention_bias)
# initialize uncertainty heads
if len(uncertainty_head_kwargs) > 0:
assert uncertainty_head_type == "dpt", "Only DPT is supported for uncertainty head now"
self.uncertainty_head = self._initialize_prediction_heads(
uncertainty_head_type, uncertainty_head_kwargs, uncertainty_adaptors_kwargs
)
self.uncertainty_adaptors = self._initialize_adaptors(uncertainty_adaptors_kwargs)
assert pretrained_checkpoint_path is None, "Pretrained weights are not supported for now"
self.detach_uncertainty_head = detach_uncertainty_head
def forward(self, view1, view2) -> UFMOutputInterface:
"""
Forward interface of correspondence prediction networks.
Args:
- view1 (Dict[str, Any]): Input view 1
- img (torch.Tensor): BCHW image tensor normalized according to encoder's data_norm_type
- instance (List[int]): List of instance indices, or id of the input image
- data_norm_type (str): Data normalization type, see uniception.models.encoders.IMAGE_NORMALIZATION_DICT
- view2 (Dict[str, Any]): Input view 2
- (same structure as view1)
Returns:
- Dict[str, Any]: Output results
- flow [Required] (Dict[str, torch.Tensor]): Flow output
- [Required] flow_output (torch.Tensor): Flow output tensor, BCHW
- [Optional] flow_covariance
- [Optional] flow_covariance_inv
- [Optional] flow_covariance_log_det
- covisibility [Optional] (Dict[str, torch.Tensor]): Covisibility output
- [Optional] mask
- [Optional] logits
- classification [Optional]: Probability and targets of the classification head
"""
# Get input shapes
_, _, height1, width1 = view1["img"].shape
_, _, height2, width2 = view2["img"].shape
shape1 = (int(height1), int(width1))
shape2 = (int(height2), int(width2))
# Encode the two images --> Each feat output: BCHW features (batch_size, feature_dim, feature_height, feature_width)
feat1_list, feat2_list = self._encode_symmetrized(view1, view2, view1["symmetrized"])
# Pass the features through the info_sharing
info_sharing_input = MultiViewTransformerInput(features=[feat1_list[-1], feat2_list[-1]])
final_info_sharing_multi_view_feat, intermediate_info_sharing_multi_view_feat = self.info_sharing(
info_sharing_input
)
info_sharing_outputs = {
"1": [
feat1_list[-1].float().contiguous(),
intermediate_info_sharing_multi_view_feat[0].features[0].float().contiguous(),
intermediate_info_sharing_multi_view_feat[1].features[0].float().contiguous(),
final_info_sharing_multi_view_feat.features[0].float().contiguous(),
],
"2": [
feat2_list[-1].float().contiguous(),
intermediate_info_sharing_multi_view_feat[0].features[1].float().contiguous(),
intermediate_info_sharing_multi_view_feat[1].features[1].float().contiguous(),
final_info_sharing_multi_view_feat.features[1].float().contiguous(),
],
}
info_sharing_outputs_detached = {
"1": [
feat1_list[-1].detach().float().contiguous(),
intermediate_info_sharing_multi_view_feat[0].features[0].detach().float().contiguous(),
intermediate_info_sharing_multi_view_feat[1].features[0].detach().float().contiguous(),
final_info_sharing_multi_view_feat.features[0].detach().float().contiguous(),
],
"2": [
feat2_list[-1].detach().float().contiguous(),
intermediate_info_sharing_multi_view_feat[0].features[1].detach().float().contiguous(),
intermediate_info_sharing_multi_view_feat[1].features[1].detach().float().contiguous(),
final_info_sharing_multi_view_feat.features[1].detach().float().contiguous(),
],
}
# optionally inference for U-Net Features
if self.use_unet_feature:
unet_feat1 = self.unet_feature(view1["img"])
unet_feat2 = self.unet_feature(view2["img"])
result = UFMOutputInterface()
# The prediction need precision, so we disable any autocasting here
with torch.autocast("cuda", torch.float32):
# run the collected info_sharing features through the prediction heads
head_output1 = self._downstream_head(1, info_sharing_outputs, shape1)
flow_prediction = head_output1["flow"].value
if hasattr(self, "uncertainty_head"):
# run the uncertainty head
head_output_uncertainty = self._downstream_head(
"uncertainty",
info_sharing_outputs_detached if self.detach_uncertainty_head else info_sharing_outputs,
shape1,
)
if "flow_cov" in head_output_uncertainty:
result.flow.flow_covariance = head_output_uncertainty["flow_cov"].covariance
result.flow.flow_covariance_inv = head_output_uncertainty["flow_cov"].inv_covariance
result.flow.flow_covariance_log_det = head_output_uncertainty["flow_cov"].log_det
if "keypoint_confidence" in head_output_uncertainty:
result.keypoint_confidence = head_output_uncertainty["keypoint_confidence"].value.squeeze(1)
if "non_occluded_mask" in head_output_uncertainty:
result.covisibility = UFMMaskFieldOutput(
mask=head_output_uncertainty["non_occluded_mask"].mask,
logits=head_output_uncertainty["non_occluded_mask"].logits,
)
# we run the classification head in the autocast environment bacause it is not regression
if self.classification_head_type == "patch_mlp":
# concatenate the last encoder feature with final info_sharing feature
# use the first encoder feature, because it captures more low-level information, which is needed
# for refinement of the regressed flow.
classification_feat_1 = torch.cat(
[feat1_list[0].float().contiguous(), info_sharing_outputs["1"][-1]], dim=1
)
classification_feat_2 = torch.cat(
[feat2_list[0].float().contiguous(), info_sharing_outputs["2"][-1]], dim=1
)
classification_input = PredictionHeadInput(
torch.cat([classification_feat_1, classification_feat_2], dim=0)
)
classification_features = self.classification_head(classification_input).decoded_channels
if self.use_unet_feature:
if self.feature_combine_method == "conv":
combined_features = torch.cat(
[classification_features, torch.cat([unet_feat1, unet_feat2], dim=0)], dim=1
)
combined_features = self.conv1(combined_features)
combined_features = nn.functional.relu(combined_features)
combined_features = self.conv2(combined_features)
elif self.feature_combine_method == "modulate":
combined_features = classification_features * torch.tanh(
torch.cat([unet_feat1, unet_feat2], dim=0)
)
combined_features = self.conv2(combined_features)
classification_features = combined_features
classification_features0, classification_features1 = classification_features.chunk(2, dim=0)
# refine the flow prediction with features from the classification head
for i in range(1):
residual, log_softmax_attention = self.classification_refinement(
flow_prediction, classification_features
)
flow_prediction = flow_prediction + residual
# Fill in the result
# WARNING: based on how the residual is computed, flow_prediction will have gradient cancelled by mathematics,
# so there will be no supervision to the flow prediction at all. We need to use specialized loss function to
# supervise the regression_flow_output.
result.flow = UFMFlowFieldOutput(
flow_output=flow_prediction,
)
result.classification_refinement = UFMClassificationRefinementOutput(
regression_flow_output=flow_prediction,
residual=residual,
log_softmax=log_softmax_attention,
feature_map_0=classification_features0,
feature_map_1=classification_features1,
)
return result
# @torch.compile()
def classification_refinement(self, flow_prediction, classification_features) -> Dict[str, Any]:
"""
Use correlation between self feature and features around a local patch of the initial flow prediction
to refine the flow prediction.
"""
classification_features1, classification_features2 = classification_features.chunk(2, dim=0)
neighborhood_features, neighborhood_flow_residual = self.obtain_neighborhood_features(
flow_estimation=flow_prediction, other_features=classification_features2, local_patch=self.refinement_range
)
residual, log_softmax_attention = self.compute_refinement_attention(
classification_features1, neighborhood_features, neighborhood_flow_residual
)
return residual, log_softmax_attention
def compute_refinement_attention(self, classification_features1, neighborhood_features, neighborhood_flow_residual):
"""
Compute the attention for the refinement, with special processing
to fit
"""
B, C, H, W = classification_features1.shape
P = self.refinement_range
# reshape Q to B, H, W, 1, 1, C
classification_features1 = classification_features1.permute(0, 2, 3, 1).reshape(B * H * W, 1, C)
# reshape K to B, H, W, 1, P^2, C
assert neighborhood_features.shape[0] == B
assert neighborhood_features.shape[1] == H
assert neighborhood_features.shape[2] == W
assert neighborhood_features.shape[3] == P
assert neighborhood_features.shape[4] == P
assert neighborhood_features.shape[5] == C
neighborhood_features = neighborhood_features.reshape(B * H * W, P * P, C)
# reshape V to B, H, W, 1, P^2, 2
neighborhood_flow_residual = neighborhood_flow_residual.reshape(-1, P * P, 2)
# compute the attention
attention_score = (
torch.matmul(classification_features1, neighborhood_features.permute(0, 2, 1)) / self.temperature
)
attention_score = attention_score + self.classification_bias
attention = torch.nn.functional.softmax(attention_score, dim=-1)
log_softmax_attention = torch.nn.functional.log_softmax(attention_score, dim=-1)
# compute the weighted sum
residual = torch.matmul(attention, neighborhood_flow_residual)
# reshape the residual to B, H, W, 2, then B, 2, H, W
residual = residual.reshape(B, H, W, 2).permute(0, 3, 1, 2)
return residual, log_softmax_attention.reshape(B, H, W, P, P)
def _downstream_head(self, head_num, decout, img_shape):
"Run the respective prediction heads"
# if self.info_sharing_and_head_structure == "dual+single":
head = getattr(self, f"head{head_num}") if head_num != "uncertainty" else self.uncertainty_head
head_num = head_num if head_num != "uncertainty" else 1 # uncertainty head is always from branch 1
if self.head_type == "linear":
head_input = PredictionHeadInput(last_feature=decout[f"{head_num}"])
elif self.head_type in ["dpt", "moge_conv"]:
head_input = PredictionHeadLayeredInput(list_features=decout[f"{head_num}"], target_output_shape=img_shape)
return head(head_input)
def obtain_neighborhood_features(
self, flow_estimation: torch.Tensor, other_features: torch.Tensor, local_patch: int = 5
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Query the other features according to flow estimation.
"""
assert local_patch % 2 == 1, "local_patch should be odd number"
P = local_patch
R = (P - 1) // 2
B, C, H, W = other_features.shape
device = other_features.device
# expected_output = torch.zeros(B, H, W, P, P, C, device=other_features.device, dtype=torch.float32)
neighborhood_grid_ij: torch.Tensor
i_local, j_local = torch.meshgrid(
torch.arange(-R, R + 1, device=device), torch.arange(-R, R + 1, device=device), indexing="ij"
)
ij_local = torch.stack((i_local, j_local), dim=0) # 2, P, P tensor
# compute the indices of the fetch
base_grid_xy = get_meshgrid_torch(W=W, H=H, device=device).permute(2, 0, 1).reshape(1, 2, H, W)
target_coordinate_xy_float = flow_estimation + base_grid_xy
target_coordinate_xy = target_coordinate_xy_float.view(B, 2, H, W, 1, 1)
target_coordinate_ij = target_coordinate_xy[:, [1, 0], ...]
# compute the neighborhood grid
neighborhood_grid_ij = target_coordinate_ij + ij_local.view(1, 2, 1, 1, P, P)
grid_for_sample = neighborhood_grid_ij[:, [1, 0], ...].permute(0, 2, 3, 4, 5, 1).reshape(B, H, W * P * P, 2)
grid_for_sample = (grid_for_sample + 0.5) / torch.tensor([W, H], device=device).view(1, 1, 1, 2)
grid_for_sample = grid_for_sample * 2 - 1
expected_output = torch.nn.functional.grid_sample(
other_features, grid=grid_for_sample, mode="bicubic", padding_mode="zeros", align_corners=False
).view(B, C, H, W, P, P)
# transform BCHWPP to BHWPPC
expected_output = expected_output.permute(0, 2, 3, 4, 5, 1)
neighborhood_grid_xy_residual = ij_local[[1, 0], ...].view(1, 2, 1, 1, P, P).to(device).float()
neighborhood_grid_xy_residual = neighborhood_grid_xy_residual.permute(0, 2, 3, 4, 5, 1).float()
return expected_output, neighborhood_grid_xy_residual
def _initialize_classification_head(self, classification_head_kwargs: Dict[str, Any]):
"""
Initialize classification head
Args:
- classification_head_kwargs (Dict[str, Any]): Classification head configurations
Returns:
- nn.Module: Classification head
"""
if self.classification_head_type == "patch_mlp":
return MLPFeature(**classification_head_kwargs)
else:
raise ValueError(f"Classification head type {self.classification_head_type} not supported.")
def get_parameter_groups(self) -> Dict[str, torch.nn.ParameterList]:
"""
Get parameter groups for optimizer. This methods guides the optimizer
to apply correct learning rate to different parts of the model.
Returns:
- Dict[str, torch.nn.ParameterList]: Parameter groups for optimizer
"""
if self.use_unet_feature:
params_dict = {
"encoder": torch.nn.ParameterList(self.encoder.parameters()),
"info_sharing": torch.nn.ParameterList(self.info_sharing.parameters()),
"output_head": torch.nn.ParameterList(self.head1.parameters()),
"classification_head": torch.nn.ParameterList(self.classification_head.parameters()),
"unet_feature": torch.nn.ParameterList(
list(self.unet_feature.parameters())
+ list(self.conv1.parameters())
+ list(self.conv2.parameters())
+ [self.classification_bias]
),
}
else:
params_dict = {
"encoder": torch.nn.ParameterList(self.encoder.parameters()),
"info_sharing": torch.nn.ParameterList(self.info_sharing.parameters()),
"output_head": torch.nn.ParameterList(self.head1.parameters()),
"classification_head": torch.nn.ParameterList(self.classification_head.parameters()),
}
if hasattr(self, "uncertainty_head"):
params_dict["uncertainty_head"] = torch.nn.ParameterList(self.uncertainty_head.parameters())
return params_dict
if __name__ == "__main__":
import cv2
import flow_vis
import matplotlib.pyplot as plt
import numpy as np
import torch
from uniflowmatch.utils.geometry import get_meshgrid_torch
from uniflowmatch.utils.viz import warp_image_with_flow
USE_REFINEMENT_MODEL = False
if USE_REFINEMENT_MODEL:
model = UniFlowMatchClassificationRefinement.from_pretrained("infinity1096/UFM-Refine")
else:
model = UniFlowMatchConfidence.from_pretrained("infinity1096/UFM-Base")
# === Load and Prepare Images ===
source_path = "examples/image_pairs/fire_academy_0.png"
target_path = "examples/image_pairs/fire_academy_1.png"
source_image = cv2.imread(source_path)
target_image = cv2.imread(target_path)
source_image = cv2.cvtColor(source_image, cv2.COLOR_BGR2RGB)
target_image = cv2.cvtColor(target_image, cv2.COLOR_BGR2RGB)
# === Predict Correspondences ===
result = model.predict_correspondences_batched(
source_image=torch.from_numpy(source_image),
target_image=torch.from_numpy(target_image),
)
flow_output = result.flow.flow_output[0].cpu().numpy()
covisibility = result.covisibility.mask[0].cpu().numpy()
# === Visualize Results ===
fig, axs = plt.subplots(2, 3, figsize=(15, 5))
axs[0, 0].imshow(source_image)
axs[0, 0].set_title("Source Image")
axs[0, 1].imshow(target_image)
axs[0, 1].set_title("Target Image")
# Warp the image using flow
warped_image = warp_image_with_flow(source_image, None, target_image, flow_output.transpose(1, 2, 0))
warped_image = covisibility[..., None] * warped_image + (1 - covisibility[..., None]) * 255 * np.ones_like(
warped_image
)
warped_image /= 255.0
axs[0, 2].imshow(warped_image)
axs[0, 2].set_title("Warped Image")
# Flow visualization
flow_vis_image = flow_vis.flow_to_color(flow_output.transpose(1, 2, 0))
axs[1, 0].imshow(flow_vis_image)
axs[1, 0].set_title("Flow Output (Valid at covisible region)")
# Covisibility mask
axs[1, 1].imshow(covisibility > 0.5, cmap="gray", vmin=0, vmax=1)
axs[1, 1].set_title("Covisibility Mask (Thresholded by 0.5)")
heatmap = axs[1, 2].imshow(covisibility, cmap="gray", vmin=0, vmax=1)
axs[1, 2].set_title("Covisibility Mask")
plt.colorbar(heatmap, ax=axs[1, 2])
plt.tight_layout()
plt.savefig("ufm_output.png")
plt.show()
print("Saved ufm_output.png")