|
|
import torch |
|
|
from torch import nn, Tensor |
|
|
from einops import rearrange |
|
|
|
|
|
from typing import Tuple, Union, Dict, Optional, List |
|
|
from functools import partial |
|
|
|
|
|
from .cannet import _cannet, _cannet_bn |
|
|
from .csrnet import _csrnet, _csrnet_bn |
|
|
from .vgg import _vgg_encoder_decoder, _vgg_encoder |
|
|
from .vit import _vit, supported_vit_backbones |
|
|
from .timm_models import _timm_model |
|
|
from .timm_models import regular_models as timm_regular_models, heavy_models as timm_heavy_models, light_models as timm_light_models, lighter_models as timm_lighter_models |
|
|
from .hrnet import _hrnet, available_hrnets |
|
|
|
|
|
from ..utils import conv1x1 |
|
|
|
|
|
|
|
|
regular_models = [ |
|
|
"csrnet", "csrnet_bn", |
|
|
"cannet", "cannet_bn", |
|
|
"vgg11", "vgg11_bn", "vgg13", "vgg13_bn", "vgg16", "vgg16_bn", "vgg19", "vgg19_bn", |
|
|
"vgg11_ae", "vgg11_bn_ae", "vgg13_ae", "vgg13_bn_ae", "vgg16_ae", "vgg16_bn_ae", "vgg19_ae", "vgg19_bn_ae", |
|
|
*timm_regular_models, |
|
|
*available_hrnets, |
|
|
] |
|
|
|
|
|
heavy_models = timm_heavy_models |
|
|
|
|
|
light_models = timm_light_models |
|
|
|
|
|
lighter_models = timm_lighter_models |
|
|
|
|
|
transformer_models = supported_vit_backbones |
|
|
|
|
|
supported_models = regular_models + heavy_models + light_models + lighter_models + transformer_models |
|
|
|
|
|
|
|
|
|
|
|
class EBC(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
model_name: str, |
|
|
block_size: int, |
|
|
bins: List[Tuple[float, float]], |
|
|
bin_centers: List[float], |
|
|
zero_inflated: bool = True, |
|
|
num_vpt: Optional[int] = None, |
|
|
vpt_drop: Optional[float] = None, |
|
|
input_size: Optional[int] = None, |
|
|
norm: str = "none", |
|
|
act: str = "none" |
|
|
) -> None: |
|
|
super().__init__() |
|
|
assert model_name in supported_models, f"Model name should be one of {supported_models}, but got {model_name}." |
|
|
self.model_name = model_name |
|
|
|
|
|
if input_size is not None: |
|
|
input_size = (input_size, input_size) if isinstance(input_size, int) else input_size |
|
|
assert len(input_size) == 2 and input_size[0] > 0 and input_size[1] > 0, f"Expected input_size to be a tuple of two positive integers, got {input_size}" |
|
|
self.input_size = input_size |
|
|
|
|
|
assert len(bins) == len(bin_centers), f"Expected bins and bin_centers to have the same length, got {len(bins)} and {len(bin_centers)}" |
|
|
assert len(bins) >= 2, f"Expected at least 2 bins, got {len(bins)}" |
|
|
assert all(len(b) == 2 for b in bins), f"Expected bins to be a list of tuples of length 2, got {bins}" |
|
|
bins = [(float(b[0]), float(b[1])) for b in bins] |
|
|
assert all(bin[0] <= p <= bin[1] for bin, p in zip(bins, bin_centers)), f"Expected bin_centers to be within the range of the corresponding bin, got {bins} and {bin_centers}" |
|
|
|
|
|
self.block_size = block_size |
|
|
self.bins = bins |
|
|
self.register_buffer("bin_centers", torch.tensor(bin_centers, dtype=torch.float32, requires_grad=False).view(1, -1, 1, 1)) |
|
|
|
|
|
self.zero_inflated = zero_inflated |
|
|
self.num_vpt = num_vpt |
|
|
self.vpt_drop = vpt_drop |
|
|
self.input_size = input_size |
|
|
|
|
|
self.norm = norm |
|
|
self.act = act |
|
|
|
|
|
self._build_backbone() |
|
|
self._build_head() |
|
|
|
|
|
def _build_backbone(self) -> None: |
|
|
model_name = self.model_name |
|
|
if model_name == "csrnet": |
|
|
self.backbone = _csrnet(self.block_size, self.norm, self.act) |
|
|
elif model_name == "csrnet_bn": |
|
|
self.backbone = _csrnet_bn(self.block_size, self.norm, self.act) |
|
|
elif model_name == "cannet": |
|
|
self.backbone = _cannet(self.block_size, self.norm, self.act) |
|
|
elif model_name == "cannet_bn": |
|
|
self.backbone = _cannet_bn(self.block_size, self.norm, self.act) |
|
|
elif model_name == "vgg11": |
|
|
self.backbone = _vgg_encoder("vgg11", self.block_size, self.norm, self.act) |
|
|
elif model_name == "vgg11_ae": |
|
|
self.backbone = _vgg_encoder_decoder("vgg11", self.block_size, self.norm, self.act) |
|
|
elif model_name == "vgg11_bn": |
|
|
self.backbone = _vgg_encoder("vgg11_bn", self.block_size, self.norm, self.act) |
|
|
elif model_name == "vgg11_bn_ae": |
|
|
self.backbone = _vgg_encoder_decoder("vgg11_bn", self.block_size, self.norm, self.act) |
|
|
elif model_name == "vgg13": |
|
|
self.backbone = _vgg_encoder("vgg13", self.block_size, self.norm, self.act) |
|
|
elif model_name == "vgg13_ae": |
|
|
self.backbone = _vgg_encoder_decoder("vgg13", self.block_size, self.norm, self.act) |
|
|
elif model_name == "vgg13_bn": |
|
|
self.backbone = _vgg_encoder("vgg13_bn", self.block_size, self.norm, self.act) |
|
|
elif model_name == "vgg13_bn_ae": |
|
|
self.backbone = _vgg_encoder_decoder("vgg13_bn", self.block_size, self.norm, self.act) |
|
|
elif model_name == "vgg16": |
|
|
self.backbone = _vgg_encoder("vgg16", self.block_size, self.norm, self.act) |
|
|
elif model_name == "vgg16_ae": |
|
|
self.backbone = _vgg_encoder_decoder("vgg16", self.block_size, self.norm, self.act) |
|
|
elif model_name == "vgg16_bn": |
|
|
self.backbone = _vgg_encoder("vgg16_bn", self.block_size, self.norm, self.act) |
|
|
elif model_name == "vgg16_bn_ae": |
|
|
self.backbone = _vgg_encoder_decoder("vgg16_bn", self.block_size, self.norm, self.act) |
|
|
elif model_name == "vgg19": |
|
|
self.backbone = _vgg_encoder("vgg19", self.block_size, self.norm, self.act) |
|
|
elif model_name == "vgg19_ae": |
|
|
self.backbone = _vgg_encoder_decoder("vgg19", self.block_size, self.norm, self.act) |
|
|
elif model_name == "vgg19_bn": |
|
|
self.backbone = _vgg_encoder("vgg19_bn", self.block_size, self.norm, self.act) |
|
|
elif model_name == "vgg19_bn_ae": |
|
|
self.backbone = _vgg_encoder_decoder("vgg19_bn", self.block_size, self.norm, self.act) |
|
|
elif model_name in supported_vit_backbones: |
|
|
self.backbone = _vit(model_name, block_size=self.block_size, num_vpt=self.num_vpt, vpt_drop=self.vpt_drop, input_size=self.input_size, norm=self.norm, act=self.act) |
|
|
elif model_name in available_hrnets: |
|
|
self.backbone = _hrnet(model_name, block_size=self.block_size, norm=self.norm, act=self.act) |
|
|
else: |
|
|
self.backbone = _timm_model(model_name, self.block_size, self.norm, self.act) |
|
|
|
|
|
def _build_head(self) -> None: |
|
|
channels = self.backbone.decoder_channels |
|
|
if self.zero_inflated: |
|
|
self.bin_head = conv1x1( |
|
|
in_channels=channels, |
|
|
out_channels=len(self.bins) - 1, |
|
|
) |
|
|
self.pi_head = conv1x1( |
|
|
in_channels=channels, |
|
|
out_channels=2, |
|
|
) |
|
|
else: |
|
|
self.bin_head = conv1x1( |
|
|
in_channels=channels, |
|
|
out_channels=len(self.bins), |
|
|
) |
|
|
|
|
|
def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: |
|
|
x = self.backbone(x) |
|
|
|
|
|
if self.zero_inflated: |
|
|
logit_pi_maps = self.pi_head(x) |
|
|
logit_maps = self.bin_head(x) |
|
|
lambda_maps = (logit_maps.softmax(dim=1) * self.bin_centers[:, 1:]).sum(dim=1, keepdim=True) |
|
|
|
|
|
|
|
|
den_maps = logit_pi_maps.softmax(dim=1)[:, 1:] * lambda_maps |
|
|
|
|
|
if self.training: |
|
|
return logit_pi_maps, logit_maps, lambda_maps, den_maps |
|
|
else: |
|
|
return den_maps |
|
|
|
|
|
else: |
|
|
logit_maps = self.bin_head(x) |
|
|
den_maps = (logit_maps.softmax(dim=1) * self.bin_centers).sum(dim=1, keepdim=True) |
|
|
|
|
|
if self.training: |
|
|
return logit_maps, den_maps |
|
|
else: |
|
|
return den_maps |
|
|
|
|
|
|
|
|
def _ebc( |
|
|
model_name: str, |
|
|
block_size: int, |
|
|
bins: List[Tuple[float, float]], |
|
|
bin_centers: List[float], |
|
|
zero_inflated: bool = True, |
|
|
num_vpt: Optional[int] = None, |
|
|
vpt_drop: Optional[float] = None, |
|
|
input_size: Optional[int] = None, |
|
|
norm: str = "none", |
|
|
act: str = "none" |
|
|
) -> EBC: |
|
|
return EBC( |
|
|
model_name=model_name, |
|
|
block_size=block_size, |
|
|
bins=bins, |
|
|
bin_centers=bin_centers, |
|
|
zero_inflated=zero_inflated, |
|
|
num_vpt=num_vpt, |
|
|
vpt_drop=vpt_drop, |
|
|
input_size=input_size, |
|
|
norm=norm, |
|
|
act=act |
|
|
) |
|
|
|