File size: 9,104 Bytes
c02d17f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 | """ timm model adapter
Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
"""
import logging
from collections import OrderedDict
import torch
import torch.nn as nn
from torchvision.ops import roi_align
import torch.nn.functional as F
try:
import timm
from timm.models.layers import Mlp, to_2tuple
try:
# old timm imports < 0.8.1
from timm.models.layers.attention_pool2d import RotAttentionPool2d
from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
except ImportError:
# new timm imports >= 0.8.1
from timm.layers import RotAttentionPool2d
from timm.layers import AttentionPool2d as AbsAttentionPool2d
except ImportError:
timm = None
from .utils import freeze_batch_norm_2d
class TimmModel(nn.Module):
""" timm model adapter
"""
def __init__(
self,
model_name,
embed_dim,
image_size=224,
pool='avg',
proj='linear',
proj_bias=False,
drop=0.,
drop_path=None,
patch_drop=None,
pretrained=False,
):
super().__init__()
if timm is None:
raise RuntimeError("Please `pip install timm` to use timm models.")
self.image_size = to_2tuple(image_size)
# setup kwargs that may not be common across all models
timm_kwargs = {}
if drop_path is not None:
timm_kwargs['drop_path_rate'] = drop_path
if patch_drop is not None:
timm_kwargs['patch_drop_rate'] = patch_drop
custom_pool = pool in ('abs_attn', 'rot_attn')
if not proj and not custom_pool:
# use network classifier head as projection if no proj specified and no custom pooling used
self.trunk = timm.create_model(
model_name,
num_classes=embed_dim,
global_pool=pool,
pretrained=pretrained,
**timm_kwargs,
)
prev_chs = embed_dim
else:
self.trunk = timm.create_model(
model_name,
pretrained=pretrained,
**timm_kwargs,
)
feat_size = self.trunk.default_cfg.get('pool_size', None)
feature_ndim = 1 if not feat_size else 2
if custom_pool:
assert feature_ndim == 2
# if attn pooling used, remove both classifier and default pool
self.trunk.reset_classifier(0, global_pool='')
else:
# reset global pool if pool config set, otherwise leave as network default
reset_kwargs = dict(global_pool=pool) if pool else {}
self.trunk.reset_classifier(0, **reset_kwargs)
prev_chs = self.trunk.num_features
head_layers = OrderedDict()
# Add custom pooling to head
if pool == 'abs_attn':
head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
prev_chs = embed_dim
elif pool == 'rot_attn':
head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
prev_chs = embed_dim
# NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
if proj == 'linear':
head_layers['drop'] = nn.Dropout(drop)
head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
elif proj == 'mlp':
head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias))
else:
assert not proj, f'Unknown projection type {proj}.'
self.head = nn.Sequential(head_layers)
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
""" lock modules
Args:
unlocked_groups (int): leave last n layer groups unlocked (default: 0)
"""
if not unlocked_groups:
# lock full model
for param in self.trunk.parameters():
param.requires_grad = False
if freeze_bn_stats:
freeze_batch_norm_2d(self.trunk)
else:
# NOTE: partial freeze requires latest timm (master) branch and is subject to change
try:
# FIXME import here until API stable and in an official release
from timm.models.helpers import group_parameters, group_modules
except ImportError:
raise RuntimeError(
'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
matcher = self.trunk.group_matcher()
gparams = group_parameters(self.trunk, matcher)
max_layer_id = max(gparams.keys())
max_layer_id = max_layer_id - unlocked_groups
for group_idx in range(max_layer_id + 1):
group = gparams[group_idx]
for param in group:
self.trunk.get_parameter(param).requires_grad = False
if freeze_bn_stats:
gmodules = group_modules(self.trunk, matcher, reverse=True)
gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
freeze_batch_norm_2d(self.trunk, gmodules)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
try:
self.trunk.set_grad_checkpointing(enable)
except Exception as e:
logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')
def forward(self, x):
x = self.trunk(x)
x = self.head(x)
return x
@staticmethod
def _denormalize_boxes(normed_boxes, x):
h, w = x.shape[-2:]
denormed_boxes = []
for boxes in normed_boxes:
new_boxes = boxes.clone() # FIXME: do not change the value in normed_boxes!
new_boxes[:, [0, 2]] *= w
new_boxes[:, [1, 3]] *= h
denormed_boxes.append(new_boxes)
return denormed_boxes
def _extract_roi_features_v1(self, x, normed_boxes, **kwargs):
h, w = x.shape[-2:]
x = self.trunk.forward_features(x)
h_f, w_f = x.shape[-2:]
tar_h = (self.image_size[0] * h_f) // h
tar_w = (self.image_size[1] * w_f) // w
x = roi_align(x, self._denormalize_boxes(normed_boxes, x), (tar_h, tar_w),
1.0, -1, True)
x = self.trunk.forward_head(x)
x = self.head(x)
return x
def encode_dense(self, x, **kwargs):
x = self.trunk.forward_features(x)
x = self.dense_trunk_head(x)
x = self.head(x)
x = x.permute(0, 3, 1, 2)
return x
def dense_trunk_head(self, x):
x = self.trunk.head.norm(x)
x = x.permute(0, 2, 3, 1)
x = self.trunk.head.drop(x)
# x = x.permute(0, 3, 1, 2)
return x
def mask_pool(self, image, masks):
feature_map = self.encode_dense(image)
feature_map = F.normalize(feature_map, dim=1) # remember to normalize!
feature_map = feature_map.flatten(-2, -1) # bs, c, h*w
num_masks_per_image = [len(masks_per_image) for masks_per_image in masks]
masks = torch.cat(masks).float().flatten(-2, -1) # bs, h*w
feature_map = torch.repeat_interleave(
feature_map, torch.tensor(num_masks_per_image, device=feature_map.device), dim=0)
features = (feature_map * masks[:, None]).sum(-1) / (masks.sum(1, keepdim=True) + 1e-12)
return features
def extract_roi_features(self, x, normed_boxes, extract_type='v1'):
assert extract_type == "v1"
if extract_type == 'v1':
return self._extract_roi_features_v1(x, normed_boxes)
else:
assert extract_type == 'v2'
return self._extract_roi_features_v2(x, normed_boxes)
def _extract_roi_features_v2(self, x, normed_boxes, **kwargs):
x = self.encode_dense(x)
x = F.normalize(x, dim=1) # remember to normalize!
roi_feats = roi_align(x, self._denormalize_boxes(normed_boxes, x), (1, 1),
1.0, -1, True)[..., 0, 0]
return roi_feats
def encode_rois_and_image(self, x, normed_boxes, **kwargs):
h, w = x.shape[-2:]
x = self.trunk.forward_features(x)
h_f, w_f = x.shape[-2:]
tar_h = (self.image_size[0] * h_f) // h
tar_w = (self.image_size[1] * w_f) // w
x_image = x
x_rois = roi_align(x, self._denormalize_boxes(normed_boxes, x), (tar_h, tar_w),
1.0, -1, True)
x_rois = self.trunk.forward_head(x_rois)
x_rois = self.head(x_rois)
x_rois = F.normalize(x_rois, dim=-1)
x_image = self.trunk.forward_head(x_image)
x_image = self.head(x_image)
x_image = F.normalize(x_image, dim=-1)
return x_rois, x_image
|