Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,191 Bytes
c8b42eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
"""
Global quantity prediction head implementation
Downstream heads assume inputs of size BCHW (B: batch, C: channels, H: height, W: width)
"""
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from uniception.models.prediction_heads.base import PredictionHeadInput, SummaryTaskOutput
from uniception.models.prediction_heads.pose_head import ResConvBlock
class GlobalHead(nn.Module):
"""
Glboal quantity regression head implementation
"""
def __init__(
self,
patch_size: int,
input_feature_dim: int,
num_resconv_block: int = 2,
output_representation_dim: int = 1,
pretrained_checkpoint_path: str = None,
*args,
**kwargs,
):
"""
Initialize the global head.
Args:
patch_size : int, the patch size of the transformer used to generate the input features
input_feature_dim : int, the input feature dimension
num_resconv_block : int, the number of residual convolution blocks
output_representation_dim : int, the dimension of the output representation
pretrained_checkpoint_path : str, path to pretrained checkpoint (default: None)
"""
super().__init__()
self.patch_size = patch_size
self.input_feature_dim = input_feature_dim
self.num_resconv_block = num_resconv_block
self.output_representation_dim = output_representation_dim
self.pretrained_checkpoint_path = pretrained_checkpoint_path
# Initialize the hidden dimension of the global head based on the patch size
self.output_dim = 4 * (self.patch_size**2)
# Initialize the projection layer for the hidden dimension of the global head
self.proj = nn.Conv2d(
in_channels=self.input_feature_dim,
out_channels=self.output_dim,
kernel_size=1,
stride=1,
padding=0,
bias=True,
)
# Initialize sequential layers of the global head
self.res_conv = nn.ModuleList(
[copy.deepcopy(ResConvBlock(self.output_dim, self.output_dim)) for _ in range(self.num_resconv_block)]
)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.more_mlps = nn.Sequential(
nn.Linear(self.output_dim, self.output_dim),
nn.ReLU(),
nn.Linear(self.output_dim, self.output_dim),
nn.ReLU(),
)
self.fc_output = nn.Linear(self.output_dim, self.output_representation_dim)
# Load the pretrained checkpoint if provided
if self.pretrained_checkpoint_path is not None:
print(f"Loading pretrained global head from {self.pretrained_checkpoint_path}")
ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
print(self.load_state_dict(ckpt["model"]))
def forward(self, feature_input: PredictionHeadInput):
"""
Forward interface for the global quantity prediction head.
The head requires an adapter on the final output.
Args:
feature_input : PredictionHeadInput, the input features
- last_feature : torch.Tensor, the last feature tensor
Returns:
SummaryTaskOutput, the output of the global head
- decoded_channels : torch.Tensor, the decoded channels
"""
# Get the patch-level features from the input
feat = feature_input.last_feature # (B, C, H, W)
# Check the input dimensions
assert (
feat.shape[1] == self.input_feature_dim
), f"Input feature dimension {feat.shape[1]} does not match expected dimension {self.input_feature_dim}"
# Apply the projection layer to the patch-level features
feat = self.proj(feat) # (B, PC, H, W)
# Apply the residual convolution blocks to the projected features
for i in range(self.num_resconv_block):
feat = self.res_conv[i](feat)
# Apply the average pooling layer to the residual convolution output
feat = self.avgpool(feat) # (B, PC, 1, 1)
# Flatten the average pooled features
feat = feat.view(feat.size(0), -1) # (B, PC)
# Apply the more MLPs to the flattened features
feat = self.more_mlps(feat) # (B, PC)
# Apply the final linear layers to the more MLPs output
output_feat = self.fc_output(feat) # (B, self.output_representation_dim)
return SummaryTaskOutput(decoded_channels=output_feat)
if __name__ == "__main__":
# Init an example global head
global_head = GlobalHead(
patch_size=14,
input_feature_dim=1024,
num_resconv_block=2,
output_representation_dim=1,
pretrained_checkpoint_path=None,
)
# Create a dummy input tensor with shape (B, C, H, W)
dummy_input = torch.randn(4, 1024, 14, 14) # Example input
# Run dummy forward pass
output = global_head(PredictionHeadInput(last_feature=dummy_input))
# Check the output shape
assert output.decoded_channels.shape == (4, 1), "Output shape mismatch"
print("Global head test passed!")
|