""" Linear head implementation Downstream heads assume inputs of size BCHW (B: batch, C: channels, H: height, W: width); The linear head implementation is based on DUSt3R and CroCoV2 References: https://github.com/naver/dust3r """ import torch import torch.nn as nn import torch.nn.functional as F from uniception.models.prediction_heads.base import PixelTaskOutput, PredictionHeadInput class LinearFeature(nn.Module): """ This class implements a linear mapping from the low resolution patch features to pixel-wise features. """ def __init__( self, input_feature_dim: int, output_dim: int, patch_size: int, pretrained_checkpoint_path: str = None, *args, **kwargs, ): """ Initialize the linear feature mapping. Args: input_feature_dim : int, the input feature dimension output_dim : int, the output feature dimension patch_size : int, the patch size pretrained_checkpoint_path : str, path to pretrained checkpoint (default: None) """ super().__init__(*args, **kwargs) self.input_feature_dim = input_feature_dim self.output_dim = output_dim self.patch_size = patch_size self.pretrained_checkpoint_path = pretrained_checkpoint_path self.linear = nn.Conv2d( in_channels=self.input_feature_dim, out_channels=self.output_dim * (self.patch_size**2), kernel_size=1, stride=1, padding=0, bias=True, ) if self.pretrained_checkpoint_path is not None: print(f"Loading pretrained linear dense feature head from {self.pretrained_checkpoint_path}") ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False) print(self.load_state_dict(ckpt["model"])) def forward(self, feature_input: PredictionHeadInput): """ Forward interface for the linear feature mapping. Args: feature_input : PredictionHeadInput, the input features - last_feature : torch.Tensor, the last feature tensor Returns: PixelTaskOutput, the output of the linear feature mapping - decoded_channels : torch.Tensor, the decoded channels """ x = feature_input.last_feature assert ( x.shape[1] == self.input_feature_dim ), f"Input feature dimension mismatch: {x.shape[1]} != {self.input_feature_dim}" x = self.linear(x) x = F.pixel_shuffle(x, self.patch_size) return PixelTaskOutput(decoded_channels=x) if __name__ == "__main__": # Init an example linear feature head linear_prediction_head = LinearFeature(input_feature_dim=768, output_dim=4, patch_size=16) # Create a dummy input tensor with shape (B, C, H, W) dummy_input = torch.randn(1, 768, 14, 14) # Example input # Run dummy forward pass output = linear_prediction_head(PredictionHeadInput(last_feature=dummy_input))