Spaces:
Sleeping
Sleeping
| from .backbone import Backbone | |
| from .transformer import TransformerEncoder | |
| from .ope import OPEModule | |
| from .positional_encoding import PositionalEncodingsFixed | |
| from .regression_head import DensityMapRegressor | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| class LOCA(nn.Module): | |
| def __init__( | |
| self, | |
| image_size: int, | |
| num_encoder_layers: int, | |
| num_ope_iterative_steps: int, | |
| num_objects: int, | |
| emb_dim: int, | |
| num_heads: int, | |
| kernel_dim: int, | |
| backbone_name: str, | |
| swav_backbone: bool, | |
| train_backbone: bool, | |
| reduction: int, | |
| dropout: float, | |
| layer_norm_eps: float, | |
| mlp_factor: int, | |
| norm_first: bool, | |
| activation: nn.Module, | |
| norm: bool, | |
| zero_shot: bool, | |
| ): | |
| super(LOCA, self).__init__() | |
| self.emb_dim = emb_dim | |
| self.num_objects = num_objects | |
| self.reduction = reduction | |
| self.kernel_dim = kernel_dim | |
| self.image_size = image_size | |
| self.zero_shot = zero_shot | |
| self.num_heads = num_heads | |
| self.num_encoder_layers = num_encoder_layers | |
| self.backbone = Backbone( | |
| backbone_name, pretrained=True, dilation=False, reduction=reduction, | |
| swav=swav_backbone, requires_grad=train_backbone | |
| ) | |
| self.input_proj = nn.Conv2d( | |
| self.backbone.num_channels, emb_dim, kernel_size=1 | |
| ) | |
| if num_encoder_layers > 0: | |
| self.encoder = TransformerEncoder( | |
| num_encoder_layers, emb_dim, num_heads, dropout, layer_norm_eps, | |
| mlp_factor, norm_first, activation, norm | |
| ) | |
| self.ope = OPEModule( | |
| num_ope_iterative_steps, emb_dim, kernel_dim, num_objects, num_heads, | |
| reduction, layer_norm_eps, mlp_factor, norm_first, activation, norm, zero_shot | |
| ) | |
| self.regression_head = DensityMapRegressor(emb_dim, reduction) | |
| self.aux_heads = nn.ModuleList([ | |
| DensityMapRegressor(emb_dim, reduction) | |
| for _ in range(num_ope_iterative_steps - 1) | |
| ]) | |
| self.pos_emb = PositionalEncodingsFixed(emb_dim) | |
| self.attn_norm = nn.LayerNorm(normalized_shape=(64, 64)) | |
| self.fuse = nn.Sequential( | |
| nn.Conv2d(324, 256, kernel_size=1, stride=1), | |
| nn.LeakyReLU(), | |
| nn.LayerNorm((64, 64)) | |
| ) | |
| # self.fuse1 = nn.Sequential( | |
| # nn.Conv2d(322, 256, kernel_size=1, stride=1), | |
| # nn.LeakyReLU(), | |
| # nn.LayerNorm((64, 64)) | |
| # ) | |
| def forward_before_reg(self, x, bboxes): | |
| num_objects = bboxes.size(1) if not self.zero_shot else self.num_objects | |
| # backbone | |
| backbone_features = self.backbone(x) | |
| # prepare the encoder input | |
| src = self.input_proj(backbone_features) | |
| bs, c, h, w = src.size() | |
| pos_emb = self.pos_emb(bs, h, w, src.device).flatten(2).permute(2, 0, 1) | |
| src = src.flatten(2).permute(2, 0, 1) | |
| # push through the encoder | |
| if self.num_encoder_layers > 0: | |
| image_features = self.encoder(src, pos_emb, src_key_padding_mask=None, src_mask=None) | |
| else: | |
| image_features = src | |
| # prepare OPE input | |
| f_e = image_features.permute(1, 2, 0).reshape(-1, self.emb_dim, h, w) | |
| all_prototypes = self.ope(f_e, pos_emb, bboxes) # [3, 27, 1, 256] | |
| outputs = list() | |
| response_maps_list = [] | |
| for i in range(all_prototypes.size(0)): | |
| prototypes = all_prototypes[i, ...].permute(1, 0, 2).reshape( | |
| bs, num_objects, self.kernel_dim, self.kernel_dim, -1 | |
| ).permute(0, 1, 4, 2, 3).flatten(0, 2)[:, None, ...] # [768, 1, 3, 3] | |
| response_maps = F.conv2d( | |
| torch.cat([f_e for _ in range(num_objects)], dim=1).flatten(0, 1).unsqueeze(0), | |
| prototypes, | |
| bias=None, | |
| padding=self.kernel_dim // 2, | |
| groups=prototypes.size(0) | |
| ).view( | |
| bs, num_objects, self.emb_dim, h, w | |
| ).max(dim=1)[0] | |
| # # send through regression heads | |
| # if i == all_prototypes.size(0) - 1: | |
| # predicted_dmaps = self.regression_head(response_maps) | |
| # else: | |
| # predicted_dmaps = self.aux_heads[i](response_maps) | |
| # outputs.append(predicted_dmaps) | |
| response_maps_list.append(response_maps) | |
| out = { | |
| # "pred": outputs[-1], | |
| "feature_bf_regression": response_maps_list[-1], | |
| # "aux_pred": outputs[:-1], | |
| "aux_feature_bf_regression": response_maps_list[:-1] | |
| } | |
| return out | |
| def forward_reg(self, response_maps, attn_stack, unet_feature): | |
| attn_stack = self.attn_norm(attn_stack) | |
| attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True) | |
| unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64] | |
| unet_feature = unet_feature * attn_stack_mean | |
| if unet_feature.shape[1] == 322: | |
| unet_feature = self.fuse1(unet_feature) | |
| else: | |
| unet_feature = self.fuse(unet_feature) | |
| response_maps = response_maps["aux_feature_bf_regression"] + [response_maps["feature_bf_regression"]] | |
| outputs = [] | |
| for i in range(len(response_maps)): | |
| response_map = response_maps[i] + unet_feature | |
| if i == len(response_maps) - 1: | |
| predicted_dmaps = self.regression_head(response_map) | |
| else: | |
| predicted_dmaps = self.aux_heads[i](response_map) | |
| outputs.append(predicted_dmaps) | |
| return {"pred": outputs[-1], "aux_pred": outputs[:-1]} | |
| def forward_reg1(self, response_maps, self_attn): | |
| # attn_stack = self.attn_norm(attn_stack) | |
| # attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True) | |
| # unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64] | |
| # unet_feature = unet_feature * attn_stack_mean | |
| # if unet_feature.shape[1] == 322: | |
| # unet_feature = self.fuse1(unet_feature) | |
| # else: | |
| # unet_feature = self.fuse(unet_feature) | |
| response_maps = response_maps["aux_feature_bf_regression"] + [response_maps["feature_bf_regression"]] | |
| outputs = [] | |
| for i in range(len(response_maps)): | |
| response_map = response_maps[i] + self_attn | |
| if i == len(response_maps) - 1: | |
| predicted_dmaps = self.regression_head(response_map) | |
| else: | |
| predicted_dmaps = self.aux_heads[i](response_map) | |
| outputs.append(predicted_dmaps) | |
| return {"pred": outputs[-1], "aux_pred": outputs[:-1]} | |
| def forward_reg_without_unet(self, response_maps, attn_stack): | |
| # attn_stack = self.attn_norm(attn_stack) | |
| attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True) | |
| response_maps = response_maps["aux_feature_bf_regression"] + [response_maps["feature_bf_regression"]] | |
| outputs = [] | |
| for i in range(len(response_maps)): | |
| response_map = response_maps[i] * attn_stack_mean * 0.5 + response_maps[i] | |
| if i == len(response_maps) - 1: | |
| predicted_dmaps = self.regression_head(response_map) | |
| else: | |
| predicted_dmaps = self.aux_heads[i](response_map) | |
| outputs.append(predicted_dmaps) | |
| return {"pred": outputs[-1], "aux_pred": outputs[:-1]} | |
| def build_model(args): | |
| assert args.backbone in ['resnet18', 'resnet50', 'resnet101'] | |
| assert args.reduction in [4, 8, 16] | |
| return LOCA( | |
| image_size=args.image_size, | |
| num_encoder_layers=args.num_enc_layers, | |
| num_ope_iterative_steps=args.num_ope_iterative_steps, | |
| num_objects=args.num_objects, | |
| zero_shot=args.zero_shot, | |
| emb_dim=args.emb_dim, | |
| num_heads=args.num_heads, | |
| kernel_dim=args.kernel_dim, | |
| backbone_name=args.backbone, | |
| swav_backbone=args.swav_backbone, | |
| train_backbone=args.backbone_lr > 0, | |
| reduction=args.reduction, | |
| dropout=args.dropout, | |
| layer_norm_eps=1e-5, | |
| mlp_factor=8, | |
| norm_first=args.pre_norm, | |
| activation=nn.GELU, | |
| norm=True, | |
| ) | |