Shengxiao0709's picture
Upload 78 files
8f72b1f verified
from .backbone import Backbone
from .transformer import TransformerEncoder
from .ope import OPEModule
from .positional_encoding import PositionalEncodingsFixed
from .regression_head import DensityMapRegressor
import torch
from torch import nn
from torch.nn import functional as F
class LOCA(nn.Module):
def __init__(
self,
image_size: int,
num_encoder_layers: int,
num_ope_iterative_steps: int,
num_objects: int,
emb_dim: int,
num_heads: int,
kernel_dim: int,
backbone_name: str,
swav_backbone: bool,
train_backbone: bool,
reduction: int,
dropout: float,
layer_norm_eps: float,
mlp_factor: int,
norm_first: bool,
activation: nn.Module,
norm: bool,
zero_shot: bool,
):
super(LOCA, self).__init__()
self.emb_dim = emb_dim
self.num_objects = num_objects
self.reduction = reduction
self.kernel_dim = kernel_dim
self.image_size = image_size
self.zero_shot = zero_shot
self.num_heads = num_heads
self.num_encoder_layers = num_encoder_layers
self.backbone = Backbone(
backbone_name, pretrained=True, dilation=False, reduction=reduction,
swav=swav_backbone, requires_grad=train_backbone
)
self.input_proj = nn.Conv2d(
self.backbone.num_channels, emb_dim, kernel_size=1
)
if num_encoder_layers > 0:
self.encoder = TransformerEncoder(
num_encoder_layers, emb_dim, num_heads, dropout, layer_norm_eps,
mlp_factor, norm_first, activation, norm
)
self.ope = OPEModule(
num_ope_iterative_steps, emb_dim, kernel_dim, num_objects, num_heads,
reduction, layer_norm_eps, mlp_factor, norm_first, activation, norm, zero_shot
)
self.regression_head = DensityMapRegressor(emb_dim, reduction)
self.aux_heads = nn.ModuleList([
DensityMapRegressor(emb_dim, reduction)
for _ in range(num_ope_iterative_steps - 1)
])
self.pos_emb = PositionalEncodingsFixed(emb_dim)
self.attn_norm = nn.LayerNorm(normalized_shape=(64, 64))
self.fuse = nn.Sequential(
nn.Conv2d(324, 256, kernel_size=1, stride=1),
nn.LeakyReLU(),
nn.LayerNorm((64, 64))
)
# self.fuse1 = nn.Sequential(
# nn.Conv2d(322, 256, kernel_size=1, stride=1),
# nn.LeakyReLU(),
# nn.LayerNorm((64, 64))
# )
def forward_before_reg(self, x, bboxes):
num_objects = bboxes.size(1) if not self.zero_shot else self.num_objects
# backbone
backbone_features = self.backbone(x)
# prepare the encoder input
src = self.input_proj(backbone_features)
bs, c, h, w = src.size()
pos_emb = self.pos_emb(bs, h, w, src.device).flatten(2).permute(2, 0, 1)
src = src.flatten(2).permute(2, 0, 1)
# push through the encoder
if self.num_encoder_layers > 0:
image_features = self.encoder(src, pos_emb, src_key_padding_mask=None, src_mask=None)
else:
image_features = src
# prepare OPE input
f_e = image_features.permute(1, 2, 0).reshape(-1, self.emb_dim, h, w)
all_prototypes = self.ope(f_e, pos_emb, bboxes) # [3, 27, 1, 256]
outputs = list()
response_maps_list = []
for i in range(all_prototypes.size(0)):
prototypes = all_prototypes[i, ...].permute(1, 0, 2).reshape(
bs, num_objects, self.kernel_dim, self.kernel_dim, -1
).permute(0, 1, 4, 2, 3).flatten(0, 2)[:, None, ...] # [768, 1, 3, 3]
response_maps = F.conv2d(
torch.cat([f_e for _ in range(num_objects)], dim=1).flatten(0, 1).unsqueeze(0),
prototypes,
bias=None,
padding=self.kernel_dim // 2,
groups=prototypes.size(0)
).view(
bs, num_objects, self.emb_dim, h, w
).max(dim=1)[0]
# # send through regression heads
# if i == all_prototypes.size(0) - 1:
# predicted_dmaps = self.regression_head(response_maps)
# else:
# predicted_dmaps = self.aux_heads[i](response_maps)
# outputs.append(predicted_dmaps)
response_maps_list.append(response_maps)
out = {
# "pred": outputs[-1],
"feature_bf_regression": response_maps_list[-1],
# "aux_pred": outputs[:-1],
"aux_feature_bf_regression": response_maps_list[:-1]
}
return out
def forward_reg(self, response_maps, attn_stack, unet_feature):
attn_stack = self.attn_norm(attn_stack)
attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64]
unet_feature = unet_feature * attn_stack_mean
if unet_feature.shape[1] == 322:
unet_feature = self.fuse1(unet_feature)
else:
unet_feature = self.fuse(unet_feature)
response_maps = response_maps["aux_feature_bf_regression"] + [response_maps["feature_bf_regression"]]
outputs = []
for i in range(len(response_maps)):
response_map = response_maps[i] + unet_feature
if i == len(response_maps) - 1:
predicted_dmaps = self.regression_head(response_map)
else:
predicted_dmaps = self.aux_heads[i](response_map)
outputs.append(predicted_dmaps)
return {"pred": outputs[-1], "aux_pred": outputs[:-1]}
def forward_reg1(self, response_maps, self_attn):
# attn_stack = self.attn_norm(attn_stack)
# attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
# unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64]
# unet_feature = unet_feature * attn_stack_mean
# if unet_feature.shape[1] == 322:
# unet_feature = self.fuse1(unet_feature)
# else:
# unet_feature = self.fuse(unet_feature)
response_maps = response_maps["aux_feature_bf_regression"] + [response_maps["feature_bf_regression"]]
outputs = []
for i in range(len(response_maps)):
response_map = response_maps[i] + self_attn
if i == len(response_maps) - 1:
predicted_dmaps = self.regression_head(response_map)
else:
predicted_dmaps = self.aux_heads[i](response_map)
outputs.append(predicted_dmaps)
return {"pred": outputs[-1], "aux_pred": outputs[:-1]}
def forward_reg_without_unet(self, response_maps, attn_stack):
# attn_stack = self.attn_norm(attn_stack)
attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
response_maps = response_maps["aux_feature_bf_regression"] + [response_maps["feature_bf_regression"]]
outputs = []
for i in range(len(response_maps)):
response_map = response_maps[i] * attn_stack_mean * 0.5 + response_maps[i]
if i == len(response_maps) - 1:
predicted_dmaps = self.regression_head(response_map)
else:
predicted_dmaps = self.aux_heads[i](response_map)
outputs.append(predicted_dmaps)
return {"pred": outputs[-1], "aux_pred": outputs[:-1]}
def build_model(args):
assert args.backbone in ['resnet18', 'resnet50', 'resnet101']
assert args.reduction in [4, 8, 16]
return LOCA(
image_size=args.image_size,
num_encoder_layers=args.num_enc_layers,
num_ope_iterative_steps=args.num_ope_iterative_steps,
num_objects=args.num_objects,
zero_shot=args.zero_shot,
emb_dim=args.emb_dim,
num_heads=args.num_heads,
kernel_dim=args.kernel_dim,
backbone_name=args.backbone,
swav_backbone=args.swav_backbone,
train_backbone=args.backbone_lr > 0,
reduction=args.reduction,
dropout=args.dropout,
layer_norm_eps=1e-5,
mlp_factor=8,
norm_first=args.pre_norm,
activation=nn.GELU,
norm=True,
)