Spaces:
Running
on
Zero
Running
on
Zero
| import numpy as np | |
| import torch | |
| from torch import nn | |
| from model_utils.resnet import ResNet, BottleneckBlock | |
| import torch.nn.functional as F | |
| class DummyAggregationNetwork(nn.Module): # for testing, return the input | |
| def __init__(self): | |
| super(DummyAggregationNetwork, self).__init__() | |
| # dummy paprameter | |
| self.dummy = nn.Parameter(torch.ones([])) | |
| def forward(self, batch, pose=None): | |
| return batch * self.dummy | |
| class AggregationNetwork(nn.Module): | |
| """ | |
| Module for aggregating feature maps across time and space. | |
| Design inspired by the Feature Extractor from ODISE (Xu et. al., CVPR 2023). | |
| https://github.com/NVlabs/ODISE/blob/5836c0adfcd8d7fd1f8016ff5604d4a31dd3b145/odise/modeling/backbone/feature_extractor.py | |
| """ | |
| def __init__( | |
| self, | |
| device, | |
| feature_dims=[640,1280,1280,768], | |
| projection_dim=384, | |
| num_norm_groups=32, | |
| save_timestep=[1], | |
| kernel_size = [1,3,1], | |
| contrastive_temp = 10, | |
| feat_map_dropout=0.0, | |
| num_blocks=None, | |
| bottleneck_channels=None | |
| ): | |
| super().__init__() | |
| self.skip_connection = True | |
| self.feat_map_dropout = feat_map_dropout | |
| self.azimuth_embedding = None | |
| self.pos_embedding = None | |
| self.bottleneck_layers = nn.ModuleList() | |
| self.feature_dims = feature_dims | |
| self.num_blocks = num_blocks if num_blocks is not None else 1 | |
| self.bottleneck_channels = bottleneck_channels if bottleneck_channels is not None else projection_dim//4 | |
| # For CLIP symmetric cross entropy loss during training | |
| self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) | |
| self.self_logit_scale = nn.Parameter(torch.ones([]) * np.log(contrastive_temp)) | |
| self.device = device | |
| self.save_timestep = save_timestep | |
| self.mixing_weights_names = [] | |
| for l, feature_dim in enumerate(self.feature_dims): | |
| bottleneck_layer = nn.Sequential( | |
| *ResNet.make_stage( | |
| BottleneckBlock, | |
| num_blocks=self.num_blocks, | |
| in_channels=feature_dim, | |
| bottleneck_channels=self.bottleneck_channels, | |
| out_channels=projection_dim, | |
| norm="GN", | |
| num_norm_groups=num_norm_groups, | |
| kernel_size=kernel_size | |
| ) | |
| ) | |
| self.bottleneck_layers.append(bottleneck_layer) | |
| for t in save_timestep: | |
| # 1-index the layer name following prior work | |
| self.mixing_weights_names.append(f"timestep-{save_timestep}_layer-{l+1}") | |
| self.last_layer = None | |
| self.bottleneck_layers = self.bottleneck_layers.to(device) | |
| mixing_weights = torch.ones(len(self.bottleneck_layers) * len(save_timestep)) | |
| self.mixing_weights = nn.Parameter(mixing_weights.to(device)) | |
| # count number of parameters | |
| num_params = 0 | |
| for param in self.parameters(): | |
| num_params += param.numel() | |
| print(f"AggregationNetwork has {num_params} parameters.") | |
| def load_pretrained_weights(self, pretrained_dict): | |
| custom_dict = self.state_dict() | |
| # Handle size mismatch | |
| if 'mixing_weights' in custom_dict and 'mixing_weights' in pretrained_dict and custom_dict['mixing_weights'].shape != pretrained_dict['mixing_weights'].shape: | |
| # Keep the first four weights from the pretrained model, and randomly initialize the fifth weight | |
| custom_dict['mixing_weights'][:4] = pretrained_dict['mixing_weights'][:4] | |
| custom_dict['mixing_weights'][4] = torch.zeros_like(custom_dict['mixing_weights'][4]) | |
| else: | |
| custom_dict['mixing_weights'][:4] = pretrained_dict['mixing_weights'][:4] | |
| # Load the weights that do match | |
| matching_keys = {k: v for k, v in pretrained_dict.items() if k in custom_dict and k != 'mixing_weights'} | |
| custom_dict.update(matching_keys) | |
| # Now load the updated state_dict | |
| self.load_state_dict(custom_dict, strict=False) | |
| def forward(self, batch, pose=None): | |
| """ | |
| Assumes batch is shape (B, C, H, W) where C is the concatentation of all layer features. | |
| """ | |
| if self.feat_map_dropout > 0 and self.training: | |
| batch = F.dropout(batch, p=self.feat_map_dropout) | |
| output_feature = None | |
| start = 0 | |
| mixing_weights = torch.nn.functional.softmax(self.mixing_weights, dim=0) | |
| if self.pos_embedding is not None: #position embedding | |
| batch = torch.cat((batch, self.pos_embedding), dim=1) | |
| for i in range(len(mixing_weights)): | |
| # Share bottleneck layers across timesteps | |
| bottleneck_layer = self.bottleneck_layers[i % len(self.feature_dims)] | |
| # Chunk the batch according the layer | |
| # Account for looping if there are multiple timesteps | |
| end = start + self.feature_dims[i % len(self.feature_dims)] | |
| feats = batch[:, start:end, :, :] | |
| start = end | |
| # Downsample the number of channels and weight the layer | |
| bottlenecked_feature = bottleneck_layer(feats) | |
| bottlenecked_feature = mixing_weights[i] * bottlenecked_feature | |
| if output_feature is None: | |
| output_feature = bottlenecked_feature | |
| else: | |
| output_feature += bottlenecked_feature | |
| if self.last_layer is not None: | |
| output_feature_after = self.last_layer(output_feature) | |
| if self.skip_connection: | |
| # skip connection | |
| output_feature = output_feature + output_feature_after | |
| return output_feature | |
| def conv1x1(in_planes, out_planes, stride=1): | |
| """1x1 convolution without padding""" | |
| return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) | |
| def conv3x3(in_planes, out_planes, stride=1): | |
| """3x3 convolution with padding""" | |
| return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) | |
| class BasicBlock(nn.Module): | |
| def __init__(self, in_planes, planes, stride=1): | |
| super().__init__() | |
| self.conv1 = conv3x3(in_planes, planes, stride) | |
| self.conv2 = conv3x3(planes, planes) | |
| self.bn1 = nn.BatchNorm2d(planes) | |
| self.bn2 = nn.BatchNorm2d(planes) | |
| self.relu = nn.ReLU(inplace=True) | |
| if stride == 1: | |
| self.downsample = None | |
| else: | |
| self.downsample = nn.Sequential( | |
| conv1x1(in_planes, planes, stride=stride), | |
| nn.BatchNorm2d(planes) | |
| ) | |
| def forward(self, x): | |
| y = x | |
| y = self.relu(self.bn1(self.conv1(y))) | |
| y = self.bn2(self.conv2(y)) | |
| if self.downsample is not None: | |
| x = self.downsample(x) | |
| return self.relu(x+y) | |