| | import timm |
| | import torch |
| | import torch.nn as nn |
| | import numpy as np |
| |
|
| | from .utils import activations, get_activation, Transpose |
| |
|
| |
|
| | def forward_levit(pretrained, x): |
| | pretrained.model.forward_features(x) |
| |
|
| | layer_1 = pretrained.activations["1"] |
| | layer_2 = pretrained.activations["2"] |
| | layer_3 = pretrained.activations["3"] |
| |
|
| | layer_1 = pretrained.act_postprocess1(layer_1) |
| | layer_2 = pretrained.act_postprocess2(layer_2) |
| | layer_3 = pretrained.act_postprocess3(layer_3) |
| |
|
| | return layer_1, layer_2, layer_3 |
| |
|
| |
|
| | def _make_levit_backbone( |
| | model, |
| | hooks=[3, 11, 21], |
| | patch_grid=[14, 14] |
| | ): |
| | pretrained = nn.Module() |
| |
|
| | pretrained.model = model |
| | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) |
| | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) |
| | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) |
| |
|
| | pretrained.activations = activations |
| |
|
| | patch_grid_size = np.array(patch_grid, dtype=int) |
| |
|
| | pretrained.act_postprocess1 = nn.Sequential( |
| | Transpose(1, 2), |
| | nn.Unflatten(2, torch.Size(patch_grid_size.tolist())) |
| | ) |
| | pretrained.act_postprocess2 = nn.Sequential( |
| | Transpose(1, 2), |
| | nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 2).astype(int)).tolist())) |
| | ) |
| | pretrained.act_postprocess3 = nn.Sequential( |
| | Transpose(1, 2), |
| | nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 4).astype(int)).tolist())) |
| | ) |
| |
|
| | return pretrained |
| |
|
| |
|
| | class ConvTransposeNorm(nn.Sequential): |
| | """ |
| | Modification of |
| | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: ConvNorm |
| | such that ConvTranspose2d is used instead of Conv2d. |
| | """ |
| |
|
| | def __init__( |
| | self, in_chs, out_chs, kernel_size=1, stride=1, pad=0, dilation=1, |
| | groups=1, bn_weight_init=1): |
| | super().__init__() |
| | self.add_module('c', |
| | nn.ConvTranspose2d(in_chs, out_chs, kernel_size, stride, pad, dilation, groups, bias=False)) |
| | self.add_module('bn', nn.BatchNorm2d(out_chs)) |
| |
|
| | nn.init.constant_(self.bn.weight, bn_weight_init) |
| |
|
| | @torch.no_grad() |
| | def fuse(self): |
| | c, bn = self._modules.values() |
| | w = bn.weight / (bn.running_var + bn.eps) ** 0.5 |
| | w = c.weight * w[:, None, None, None] |
| | b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 |
| | m = nn.ConvTranspose2d( |
| | w.size(1), w.size(0), w.shape[2:], stride=self.c.stride, |
| | padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) |
| | m.weight.data.copy_(w) |
| | m.bias.data.copy_(b) |
| | return m |
| |
|
| |
|
| | def stem_b4_transpose(in_chs, out_chs, activation): |
| | """ |
| | Modification of |
| | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: stem_b16 |
| | such that ConvTranspose2d is used instead of Conv2d and stem is also reduced to the half. |
| | """ |
| | return nn.Sequential( |
| | ConvTransposeNorm(in_chs, out_chs, 3, 2, 1), |
| | activation(), |
| | ConvTransposeNorm(out_chs, out_chs // 2, 3, 2, 1), |
| | activation()) |
| |
|
| |
|
| | def _make_pretrained_levit_384(pretrained, hooks=None): |
| | model = timm.create_model("levit_384", pretrained=pretrained) |
| |
|
| | hooks = [3, 11, 21] if hooks == None else hooks |
| | return _make_levit_backbone( |
| | model, |
| | hooks=hooks |
| | ) |
| |
|