infinity1096
initial commit
c8b42eb
"""
Linear head with MLP implementation
Downstream heads assume inputs of size BCHW (B: batch, C: channels, H: height, W: width)
"""
from typing import Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from uniception.models.prediction_heads.base import PixelTaskOutput, PredictionHeadInput
from uniception.models.utils.transformer_blocks import Mlp
class MLPFeature(nn.Module):
"""
This class implements a linear mapping from the low resolution patch features
to pixel-wise features with an additional intermediate MLP layer.
"""
def __init__(
self,
input_feature_dim: Union[int, str],
patch_size: int,
output_dim: int,
mlp_ratio: int = 4,
act_layer=nn.GELU,
bias=True,
drop=0.0,
pretrained_checkpoint_path: str = None,
*args,
**kwargs,
):
"""
Initialize the linear feature mapping.
Args:
input_feature_dim : int, the input feature dimension
output_dim : int, the output feature dimension
patch_size : int, the patch size
pretrained_checkpoint_path : str, path to pretrained checkpoint (default: None)
"""
super().__init__(*args, **kwargs)
if isinstance(input_feature_dim, str):
input_feature_dim = eval(input_feature_dim)
self.input_feature_dim = input_feature_dim
self.output_dim = output_dim
self.patch_size = patch_size
self.pretrained_checkpoint_path = pretrained_checkpoint_path
self.mlp = Mlp(
in_features=self.input_feature_dim,
hidden_features=int(mlp_ratio * self.input_feature_dim),
act_layer=act_layer,
drop=drop,
bias=bias,
)
self.linear = nn.Conv2d(
in_channels=self.input_feature_dim,
out_channels=self.output_dim * (self.patch_size**2),
kernel_size=1,
stride=1,
padding=0,
bias=True,
)
if self.pretrained_checkpoint_path is not None:
print(f"Loading pretrained linear dense feature head from {self.pretrained_checkpoint_path}")
ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
print(self.load_state_dict(ckpt["model"]))
def forward(self, feature_input: PredictionHeadInput):
"""
Forward interface for the linear feature mapping.
Args:
feature_input : PredictionHeadInput, the input features
- last_feature : torch.Tensor, the last feature tensor
Returns:
PixelTaskOutput, the output of the linear feature mapping
- decoded_channels : torch.Tensor, the decoded channels
"""
x = feature_input.last_feature
assert (
x.shape[1] == self.input_feature_dim
), f"Input feature dimension mismatch: {x.shape[1]} != {self.input_feature_dim}"
x = self.mlp(x.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous()
x = self.linear(x)
x = F.pixel_shuffle(x, self.patch_size)
return PixelTaskOutput(decoded_channels=x)
if __name__ == "__main__":
# Init an example linear feature head
linear_prediction_head = MLPFeature(
input_feature_dim=768, mlp_ratio=4, act_layer=nn.GELU, output_dim=4, patch_size=16
)
# Create a dummy input tensor with shape (B, C, H, W)
dummy_input = torch.randn(1, 768, 14, 14) # Example input
# Run dummy forward pass
output = linear_prediction_head(PredictionHeadInput(last_feature=dummy_input))