Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,064 Bytes
c8b42eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
"""
Linear head implementation
Downstream heads assume inputs of size BCHW (B: batch, C: channels, H: height, W: width);
The linear head implementation is based on DUSt3R and CroCoV2
References: https://github.com/naver/dust3r
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from uniception.models.prediction_heads.base import PixelTaskOutput, PredictionHeadInput
class LinearFeature(nn.Module):
"""
This class implements a linear mapping from the low resolution patch features
to pixel-wise features.
"""
def __init__(
self,
input_feature_dim: int,
output_dim: int,
patch_size: int,
pretrained_checkpoint_path: str = None,
*args,
**kwargs,
):
"""
Initialize the linear feature mapping.
Args:
input_feature_dim : int, the input feature dimension
output_dim : int, the output feature dimension
patch_size : int, the patch size
pretrained_checkpoint_path : str, path to pretrained checkpoint (default: None)
"""
super().__init__(*args, **kwargs)
self.input_feature_dim = input_feature_dim
self.output_dim = output_dim
self.patch_size = patch_size
self.pretrained_checkpoint_path = pretrained_checkpoint_path
self.linear = nn.Conv2d(
in_channels=self.input_feature_dim,
out_channels=self.output_dim * (self.patch_size**2),
kernel_size=1,
stride=1,
padding=0,
bias=True,
)
if self.pretrained_checkpoint_path is not None:
print(f"Loading pretrained linear dense feature head from {self.pretrained_checkpoint_path}")
ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
print(self.load_state_dict(ckpt["model"]))
def forward(self, feature_input: PredictionHeadInput):
"""
Forward interface for the linear feature mapping.
Args:
feature_input : PredictionHeadInput, the input features
- last_feature : torch.Tensor, the last feature tensor
Returns:
PixelTaskOutput, the output of the linear feature mapping
- decoded_channels : torch.Tensor, the decoded channels
"""
x = feature_input.last_feature
assert (
x.shape[1] == self.input_feature_dim
), f"Input feature dimension mismatch: {x.shape[1]} != {self.input_feature_dim}"
x = self.linear(x)
x = F.pixel_shuffle(x, self.patch_size)
return PixelTaskOutput(decoded_channels=x)
if __name__ == "__main__":
# Init an example linear feature head
linear_prediction_head = LinearFeature(input_feature_dim=768, output_dim=4, patch_size=16)
# Create a dummy input tensor with shape (B, C, H, W)
dummy_input = torch.randn(1, 768, 14, 14) # Example input
# Run dummy forward pass
output = linear_prediction_head(PredictionHeadInput(last_feature=dummy_input))
|