DenseVLM / src /open_clip /timm_model.py
lyhisme's picture
Upload 151 files
c02d17f verified
""" timm model adapter
Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
"""
import logging
from collections import OrderedDict
import torch
import torch.nn as nn
from torchvision.ops import roi_align
import torch.nn.functional as F
try:
import timm
from timm.models.layers import Mlp, to_2tuple
try:
# old timm imports < 0.8.1
from timm.models.layers.attention_pool2d import RotAttentionPool2d
from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
except ImportError:
# new timm imports >= 0.8.1
from timm.layers import RotAttentionPool2d
from timm.layers import AttentionPool2d as AbsAttentionPool2d
except ImportError:
timm = None
from .utils import freeze_batch_norm_2d
class TimmModel(nn.Module):
""" timm model adapter
"""
def __init__(
self,
model_name,
embed_dim,
image_size=224,
pool='avg',
proj='linear',
proj_bias=False,
drop=0.,
drop_path=None,
patch_drop=None,
pretrained=False,
):
super().__init__()
if timm is None:
raise RuntimeError("Please `pip install timm` to use timm models.")
self.image_size = to_2tuple(image_size)
# setup kwargs that may not be common across all models
timm_kwargs = {}
if drop_path is not None:
timm_kwargs['drop_path_rate'] = drop_path
if patch_drop is not None:
timm_kwargs['patch_drop_rate'] = patch_drop
custom_pool = pool in ('abs_attn', 'rot_attn')
if not proj and not custom_pool:
# use network classifier head as projection if no proj specified and no custom pooling used
self.trunk = timm.create_model(
model_name,
num_classes=embed_dim,
global_pool=pool,
pretrained=pretrained,
**timm_kwargs,
)
prev_chs = embed_dim
else:
self.trunk = timm.create_model(
model_name,
pretrained=pretrained,
**timm_kwargs,
)
feat_size = self.trunk.default_cfg.get('pool_size', None)
feature_ndim = 1 if not feat_size else 2
if custom_pool:
assert feature_ndim == 2
# if attn pooling used, remove both classifier and default pool
self.trunk.reset_classifier(0, global_pool='')
else:
# reset global pool if pool config set, otherwise leave as network default
reset_kwargs = dict(global_pool=pool) if pool else {}
self.trunk.reset_classifier(0, **reset_kwargs)
prev_chs = self.trunk.num_features
head_layers = OrderedDict()
# Add custom pooling to head
if pool == 'abs_attn':
head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
prev_chs = embed_dim
elif pool == 'rot_attn':
head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
prev_chs = embed_dim
# NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
if proj == 'linear':
head_layers['drop'] = nn.Dropout(drop)
head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
elif proj == 'mlp':
head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias))
else:
assert not proj, f'Unknown projection type {proj}.'
self.head = nn.Sequential(head_layers)
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
""" lock modules
Args:
unlocked_groups (int): leave last n layer groups unlocked (default: 0)
"""
if not unlocked_groups:
# lock full model
for param in self.trunk.parameters():
param.requires_grad = False
if freeze_bn_stats:
freeze_batch_norm_2d(self.trunk)
else:
# NOTE: partial freeze requires latest timm (master) branch and is subject to change
try:
# FIXME import here until API stable and in an official release
from timm.models.helpers import group_parameters, group_modules
except ImportError:
raise RuntimeError(
'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
matcher = self.trunk.group_matcher()
gparams = group_parameters(self.trunk, matcher)
max_layer_id = max(gparams.keys())
max_layer_id = max_layer_id - unlocked_groups
for group_idx in range(max_layer_id + 1):
group = gparams[group_idx]
for param in group:
self.trunk.get_parameter(param).requires_grad = False
if freeze_bn_stats:
gmodules = group_modules(self.trunk, matcher, reverse=True)
gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
freeze_batch_norm_2d(self.trunk, gmodules)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
try:
self.trunk.set_grad_checkpointing(enable)
except Exception as e:
logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')
def forward(self, x):
x = self.trunk(x)
x = self.head(x)
return x
@staticmethod
def _denormalize_boxes(normed_boxes, x):
h, w = x.shape[-2:]
denormed_boxes = []
for boxes in normed_boxes:
new_boxes = boxes.clone() # FIXME: do not change the value in normed_boxes!
new_boxes[:, [0, 2]] *= w
new_boxes[:, [1, 3]] *= h
denormed_boxes.append(new_boxes)
return denormed_boxes
def _extract_roi_features_v1(self, x, normed_boxes, **kwargs):
h, w = x.shape[-2:]
x = self.trunk.forward_features(x)
h_f, w_f = x.shape[-2:]
tar_h = (self.image_size[0] * h_f) // h
tar_w = (self.image_size[1] * w_f) // w
x = roi_align(x, self._denormalize_boxes(normed_boxes, x), (tar_h, tar_w),
1.0, -1, True)
x = self.trunk.forward_head(x)
x = self.head(x)
return x
def encode_dense(self, x, **kwargs):
x = self.trunk.forward_features(x)
x = self.dense_trunk_head(x)
x = self.head(x)
x = x.permute(0, 3, 1, 2)
return x
def dense_trunk_head(self, x):
x = self.trunk.head.norm(x)
x = x.permute(0, 2, 3, 1)
x = self.trunk.head.drop(x)
# x = x.permute(0, 3, 1, 2)
return x
def mask_pool(self, image, masks):
feature_map = self.encode_dense(image)
feature_map = F.normalize(feature_map, dim=1) # remember to normalize!
feature_map = feature_map.flatten(-2, -1) # bs, c, h*w
num_masks_per_image = [len(masks_per_image) for masks_per_image in masks]
masks = torch.cat(masks).float().flatten(-2, -1) # bs, h*w
feature_map = torch.repeat_interleave(
feature_map, torch.tensor(num_masks_per_image, device=feature_map.device), dim=0)
features = (feature_map * masks[:, None]).sum(-1) / (masks.sum(1, keepdim=True) + 1e-12)
return features
def extract_roi_features(self, x, normed_boxes, extract_type='v1'):
assert extract_type == "v1"
if extract_type == 'v1':
return self._extract_roi_features_v1(x, normed_boxes)
else:
assert extract_type == 'v2'
return self._extract_roi_features_v2(x, normed_boxes)
def _extract_roi_features_v2(self, x, normed_boxes, **kwargs):
x = self.encode_dense(x)
x = F.normalize(x, dim=1) # remember to normalize!
roi_feats = roi_align(x, self._denormalize_boxes(normed_boxes, x), (1, 1),
1.0, -1, True)[..., 0, 0]
return roi_feats
def encode_rois_and_image(self, x, normed_boxes, **kwargs):
h, w = x.shape[-2:]
x = self.trunk.forward_features(x)
h_f, w_f = x.shape[-2:]
tar_h = (self.image_size[0] * h_f) // h
tar_w = (self.image_size[1] * w_f) // w
x_image = x
x_rois = roi_align(x, self._denormalize_boxes(normed_boxes, x), (tar_h, tar_w),
1.0, -1, True)
x_rois = self.trunk.forward_head(x_rois)
x_rois = self.head(x_rois)
x_rois = F.normalize(x_rois, dim=-1)
x_image = self.trunk.forward_head(x_image)
x_image = self.head(x_image)
x_image = F.normalize(x_image, dim=-1)
return x_rois, x_image