Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,662 Bytes
c8b42eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
"""
Linear head with MLP implementation
Downstream heads assume inputs of size BCHW (B: batch, C: channels, H: height, W: width)
"""
from typing import Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from uniception.models.prediction_heads.base import PixelTaskOutput, PredictionHeadInput
from uniception.models.utils.transformer_blocks import Mlp
class MLPFeature(nn.Module):
"""
This class implements a linear mapping from the low resolution patch features
to pixel-wise features with an additional intermediate MLP layer.
"""
def __init__(
self,
input_feature_dim: Union[int, str],
patch_size: int,
output_dim: int,
mlp_ratio: int = 4,
act_layer=nn.GELU,
bias=True,
drop=0.0,
pretrained_checkpoint_path: str = None,
*args,
**kwargs,
):
"""
Initialize the linear feature mapping.
Args:
input_feature_dim : int, the input feature dimension
output_dim : int, the output feature dimension
patch_size : int, the patch size
pretrained_checkpoint_path : str, path to pretrained checkpoint (default: None)
"""
super().__init__(*args, **kwargs)
if isinstance(input_feature_dim, str):
input_feature_dim = eval(input_feature_dim)
self.input_feature_dim = input_feature_dim
self.output_dim = output_dim
self.patch_size = patch_size
self.pretrained_checkpoint_path = pretrained_checkpoint_path
self.mlp = Mlp(
in_features=self.input_feature_dim,
hidden_features=int(mlp_ratio * self.input_feature_dim),
act_layer=act_layer,
drop=drop,
bias=bias,
)
self.linear = nn.Conv2d(
in_channels=self.input_feature_dim,
out_channels=self.output_dim * (self.patch_size**2),
kernel_size=1,
stride=1,
padding=0,
bias=True,
)
if self.pretrained_checkpoint_path is not None:
print(f"Loading pretrained linear dense feature head from {self.pretrained_checkpoint_path}")
ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
print(self.load_state_dict(ckpt["model"]))
def forward(self, feature_input: PredictionHeadInput):
"""
Forward interface for the linear feature mapping.
Args:
feature_input : PredictionHeadInput, the input features
- last_feature : torch.Tensor, the last feature tensor
Returns:
PixelTaskOutput, the output of the linear feature mapping
- decoded_channels : torch.Tensor, the decoded channels
"""
x = feature_input.last_feature
assert (
x.shape[1] == self.input_feature_dim
), f"Input feature dimension mismatch: {x.shape[1]} != {self.input_feature_dim}"
x = self.mlp(x.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous()
x = self.linear(x)
x = F.pixel_shuffle(x, self.patch_size)
return PixelTaskOutput(decoded_channels=x)
if __name__ == "__main__":
# Init an example linear feature head
linear_prediction_head = MLPFeature(
input_feature_dim=768, mlp_ratio=4, act_layer=nn.GELU, output_dim=4, patch_size=16
)
# Create a dummy input tensor with shape (B, C, H, W)
dummy_input = torch.randn(1, 768, 14, 14) # Example input
# Run dummy forward pass
output = linear_prediction_head(PredictionHeadInput(last_feature=dummy_input))
|