infinity1096
initial commit
c8b42eb
"""
Linear head implementation
Downstream heads assume inputs of size BCHW (B: batch, C: channels, H: height, W: width);
The linear head implementation is based on DUSt3R and CroCoV2
References: https://github.com/naver/dust3r
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from uniception.models.prediction_heads.base import PixelTaskOutput, PredictionHeadInput
class LinearFeature(nn.Module):
"""
This class implements a linear mapping from the low resolution patch features
to pixel-wise features.
"""
def __init__(
self,
input_feature_dim: int,
output_dim: int,
patch_size: int,
pretrained_checkpoint_path: str = None,
*args,
**kwargs,
):
"""
Initialize the linear feature mapping.
Args:
input_feature_dim : int, the input feature dimension
output_dim : int, the output feature dimension
patch_size : int, the patch size
pretrained_checkpoint_path : str, path to pretrained checkpoint (default: None)
"""
super().__init__(*args, **kwargs)
self.input_feature_dim = input_feature_dim
self.output_dim = output_dim
self.patch_size = patch_size
self.pretrained_checkpoint_path = pretrained_checkpoint_path
self.linear = nn.Conv2d(
in_channels=self.input_feature_dim,
out_channels=self.output_dim * (self.patch_size**2),
kernel_size=1,
stride=1,
padding=0,
bias=True,
)
if self.pretrained_checkpoint_path is not None:
print(f"Loading pretrained linear dense feature head from {self.pretrained_checkpoint_path}")
ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
print(self.load_state_dict(ckpt["model"]))
def forward(self, feature_input: PredictionHeadInput):
"""
Forward interface for the linear feature mapping.
Args:
feature_input : PredictionHeadInput, the input features
- last_feature : torch.Tensor, the last feature tensor
Returns:
PixelTaskOutput, the output of the linear feature mapping
- decoded_channels : torch.Tensor, the decoded channels
"""
x = feature_input.last_feature
assert (
x.shape[1] == self.input_feature_dim
), f"Input feature dimension mismatch: {x.shape[1]} != {self.input_feature_dim}"
x = self.linear(x)
x = F.pixel_shuffle(x, self.patch_size)
return PixelTaskOutput(decoded_channels=x)
if __name__ == "__main__":
# Init an example linear feature head
linear_prediction_head = LinearFeature(input_feature_dim=768, output_dim=4, patch_size=16)
# Create a dummy input tensor with shape (B, C, H, W)
dummy_input = torch.randn(1, 768, 14, 14) # Example input
# Run dummy forward pass
output = linear_prediction_head(PredictionHeadInput(last_feature=dummy_input))