Spaces:
Runtime error
Runtime error
| import fvcore.nn.weight_init as weight_init | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| import numpy as np | |
| import logging | |
| from functools import partial | |
| from scipy import interpolate | |
| from math import pi | |
| from einops import rearrange, repeat | |
| import warnings | |
| from PIL import Image | |
| import torch.utils.checkpoint as cp | |
| from transformers import CLIPImageProcessor | |
| # from ..utils.attention import FlashAttention, FlashMHA | |
| # try: | |
| # import xformers.ops as xops | |
| # except: | |
| # pass | |
| logger = logging.getLogger(__name__) | |
| BatchNorm2d = torch.nn.BatchNorm2d | |
| class Conv2d(torch.nn.Conv2d): | |
| """ | |
| A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features. | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| """ | |
| Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`: | |
| Args: | |
| norm (nn.Module, optional): a normalization layer | |
| activation (callable(Tensor) -> Tensor): a callable activation function | |
| It assumes that norm layer is used before activation. | |
| """ | |
| norm = kwargs.pop("norm", None) | |
| activation = kwargs.pop("activation", None) | |
| super().__init__(*args, **kwargs) | |
| self.norm = norm | |
| self.activation = activation | |
| def forward(self, x): | |
| # torchscript does not support SyncBatchNorm yet | |
| # https://github.com/pytorch/pytorch/issues/40507 | |
| # and we skip these codes in torchscript since: | |
| # 1. currently we only support torchscript in evaluation mode | |
| # 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or | |
| # later version, `Conv2d` in these PyTorch versions has already supported empty inputs. | |
| if not torch.jit.is_scripting(): | |
| with warnings.catch_warnings(record=True): | |
| if x.numel() == 0 and self.training: | |
| # https://github.com/pytorch/pytorch/issues/12013 | |
| assert not isinstance( | |
| self.norm, torch.nn.SyncBatchNorm | |
| ), "SyncBatchNorm does not support empty inputs!" | |
| x = F.conv2d( | |
| x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups | |
| ) | |
| if self.norm is not None: | |
| x = self.norm(x) | |
| if self.activation is not None: | |
| x = self.activation(x) | |
| return x | |
| def window_partition(x, window_size): | |
| """ | |
| Partition into non-overlapping windows with padding if needed. | |
| Args: | |
| x (tensor): input tokens with [B, H, W, C]. | |
| window_size (int): window size. | |
| Returns: | |
| windows: windows after partition with [B * num_windows, window_size, window_size, C]. | |
| (Hp, Wp): padded height and width before partition | |
| """ | |
| B, H, W, C = x.shape | |
| pad_h = (window_size - H % window_size) % window_size | |
| pad_w = (window_size - W % window_size) % window_size | |
| if pad_h > 0 or pad_w > 0: | |
| x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) | |
| Hp, Wp = H + pad_h, W + pad_w | |
| x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) | |
| windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) | |
| return windows, (Hp, Wp) | |
| def window_unpartition(windows, window_size, pad_hw, hw): | |
| """ | |
| Window unpartition into original sequences and removing padding. | |
| Args: | |
| x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. | |
| window_size (int): window size. | |
| pad_hw (Tuple): padded height and width (Hp, Wp). | |
| hw (Tuple): original height and width (H, W) before padding. | |
| Returns: | |
| x: unpartitioned sequences with [B, H, W, C]. | |
| """ | |
| Hp, Wp = pad_hw | |
| H, W = hw | |
| B = windows.shape[0] // (Hp * Wp // window_size // window_size) | |
| x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) | |
| x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) | |
| if Hp > H or Wp > W: | |
| x = x[:, :H, :W, :].contiguous() | |
| return x | |
| def get_rel_pos(q_size, k_size, rel_pos): | |
| """ | |
| Get relative positional embeddings according to the relative positions of | |
| query and key sizes. | |
| Args: | |
| q_size (int): size of query q. | |
| k_size (int): size of key k. | |
| rel_pos (Tensor): relative position embeddings (L, C). | |
| Returns: | |
| Extracted positional embeddings according to relative positions. | |
| """ | |
| max_rel_dist = int(2 * max(q_size, k_size) - 1) | |
| use_log_interpolation = True | |
| # Interpolate rel pos if needed. | |
| if rel_pos.shape[0] != max_rel_dist: | |
| if not use_log_interpolation: | |
| # Interpolate rel pos. | |
| rel_pos_resized = F.interpolate( | |
| rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), | |
| size=max_rel_dist, | |
| mode="linear", | |
| ) | |
| rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) | |
| else: | |
| src_size = rel_pos.shape[0] | |
| dst_size = max_rel_dist | |
| # q = 1.13492 | |
| q = 1.0903078 | |
| dis = [] | |
| cur = 1 | |
| for i in range(src_size // 2): | |
| dis.append(cur) | |
| cur += q ** (i + 1) | |
| r_ids = [-_ for _ in reversed(dis)] | |
| x = r_ids + [0] + dis | |
| t = dst_size // 2.0 | |
| dx = np.arange(-t, t + 0.1, 1.0) | |
| all_rel_pos_bias = [] | |
| for i in range(rel_pos.shape[1]): | |
| z = rel_pos[:, i].view(src_size).cpu().float().numpy() | |
| f = interpolate.interp1d(x, z, kind='cubic', fill_value="extrapolate") | |
| all_rel_pos_bias.append( | |
| torch.Tensor(f(dx)).contiguous().view(-1, 1).to(rel_pos.device)) | |
| rel_pos_resized = torch.cat(all_rel_pos_bias, dim=-1) | |
| else: | |
| rel_pos_resized = rel_pos | |
| # Scale the coords with short length if shapes for q and k are different. | |
| q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) | |
| k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) | |
| relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) | |
| return rel_pos_resized[relative_coords.long()] | |
| def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size): | |
| """ | |
| Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. | |
| https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 | |
| Args: | |
| attn (Tensor): attention map. | |
| q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). | |
| rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. | |
| rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. | |
| q_size (Tuple): spatial sequence size of query q with (q_h, q_w). | |
| k_size (Tuple): spatial sequence size of key k with (k_h, k_w). | |
| Returns: | |
| attn (Tensor): attention map with added relative positional embeddings. | |
| """ | |
| q_h, q_w = q_size | |
| k_h, k_w = k_size | |
| Rh = get_rel_pos(q_h, k_h, rel_pos_h) | |
| Rw = get_rel_pos(q_w, k_w, rel_pos_w) | |
| B, _, dim = q.shape | |
| r_q = q.reshape(B, q_h, q_w, dim) | |
| rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) | |
| rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) | |
| attn = ( | |
| attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] | |
| ).view(B, q_h * q_w, k_h * k_w) | |
| return attn | |
| def get_abs_pos(abs_pos, has_cls_token, hw): | |
| """ | |
| Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token | |
| dimension for the original embeddings. | |
| Args: | |
| abs_pos (Tensor): absolute positional embeddings with (1, num_position, C). | |
| has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token. | |
| hw (Tuple): size of input image tokens. | |
| Returns: | |
| Absolute positional embeddings after processing with shape (1, H, W, C) | |
| """ | |
| h, w = hw | |
| if has_cls_token: | |
| abs_pos = abs_pos[:, 1:] | |
| xy_num = abs_pos.shape[1] | |
| size = int(math.sqrt(xy_num)) | |
| assert size * size == xy_num | |
| if size != h or size != w: | |
| original_datatype = abs_pos.dtype | |
| new_abs_pos = F.interpolate( | |
| abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2).float(), # bf16 is not implemented | |
| size=(h, w), | |
| mode="bicubic", | |
| align_corners=False, | |
| ).to(original_datatype) | |
| return new_abs_pos.permute(0, 2, 3, 1) | |
| else: | |
| return abs_pos.reshape(1, h, w, -1) | |
| class PatchEmbed(nn.Module): | |
| """ | |
| Image to Patch Embedding. | |
| """ | |
| def __init__( | |
| self, kernel_size=(16, 16), stride=(16, 16), padding=(0, 0), in_chans=3, embed_dim=768 | |
| ): | |
| """ | |
| Args: | |
| kernel_size (Tuple): kernel size of the projection layer. | |
| stride (Tuple): stride of the projection layer. | |
| padding (Tuple): padding size of the projection layer. | |
| in_chans (int): Number of input image channels. | |
| embed_dim (int): embed_dim (int): Patch embedding dimension. | |
| """ | |
| super().__init__() | |
| self.proj = nn.Conv2d( | |
| in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding | |
| ) | |
| def forward(self, x): | |
| x = self.proj(x) | |
| # B C H W -> B H W C | |
| x = x.permute(0, 2, 3, 1) | |
| return x | |
| def broadcat(tensors, dim = -1): | |
| num_tensors = len(tensors) | |
| shape_lens = set(list(map(lambda t: len(t.shape), tensors))) | |
| assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions' | |
| shape_len = list(shape_lens)[0] | |
| dim = (dim + shape_len) if dim < 0 else dim | |
| dims = list(zip(*map(lambda t: list(t.shape), tensors))) | |
| expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] | |
| assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation' | |
| max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) | |
| expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) | |
| expanded_dims.insert(dim, (dim, dims[dim])) | |
| expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) | |
| tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) | |
| return torch.cat(tensors, dim = dim) | |
| def rotate_half(x): | |
| x = rearrange(x, '... (d r) -> ... d r', r = 2) | |
| x1, x2 = x.unbind(dim = -1) | |
| x = torch.stack((-x2, x1), dim = -1) | |
| return rearrange(x, '... d r -> ... (d r)') | |
| class VisionRotaryEmbedding(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| pt_seq_len, | |
| ft_seq_len=None, | |
| custom_freqs = None, | |
| freqs_for = 'lang', | |
| theta = 10000, | |
| max_freq = 10, | |
| num_freqs = 1, | |
| ): | |
| super().__init__() | |
| if custom_freqs: | |
| freqs = custom_freqs | |
| elif freqs_for == 'lang': | |
| freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) | |
| elif freqs_for == 'pixel': | |
| freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi | |
| elif freqs_for == 'constant': | |
| freqs = torch.ones(num_freqs).float() | |
| else: | |
| raise ValueError(f'unknown modality {freqs_for}') | |
| if ft_seq_len is None: ft_seq_len = pt_seq_len | |
| t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len | |
| freqs_h = torch.einsum('..., f -> ... f', t, freqs) | |
| freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2) | |
| freqs_w = torch.einsum('..., f -> ... f', t, freqs) | |
| freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2) | |
| freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1) | |
| self.register_buffer("freqs_cos", freqs.cos()) | |
| self.register_buffer("freqs_sin", freqs.sin()) | |
| # print('======== shape of rope freq', self.freqs_cos.shape, '========') | |
| def forward(self, t, start_index = 0): | |
| rot_dim = self.freqs_cos.shape[-1] | |
| end_index = start_index + rot_dim | |
| assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' | |
| t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] | |
| t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin) | |
| return torch.cat((t_left, t, t_right), dim = -1) | |
| class VisionRotaryEmbeddingFast(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| pt_seq_len=16, | |
| ft_seq_len=None, | |
| custom_freqs = None, | |
| freqs_for = 'lang', | |
| theta = 10000, | |
| max_freq = 10, | |
| num_freqs = 1, | |
| ): | |
| super().__init__() | |
| if custom_freqs: | |
| freqs = custom_freqs | |
| elif freqs_for == 'lang': | |
| freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) | |
| elif freqs_for == 'pixel': | |
| freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi | |
| elif freqs_for == 'constant': | |
| freqs = torch.ones(num_freqs).float() | |
| else: | |
| raise ValueError(f'unknown modality {freqs_for}') | |
| if ft_seq_len is None: ft_seq_len = pt_seq_len | |
| t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len | |
| freqs = torch.einsum('..., f -> ... f', t, freqs) | |
| freqs = repeat(freqs, '... n -> ... (n r)', r = 2) | |
| freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1) | |
| freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) | |
| freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) | |
| self.register_buffer("freqs_cos", freqs_cos) | |
| self.register_buffer("freqs_sin", freqs_sin) | |
| # print('======== shape of rope freq', self.freqs_cos.shape, '========') | |
| def forward(self, t): return t * self.freqs_cos + rotate_half(t) * self.freqs_sin | |
| class FrozenBatchNorm2d(nn.Module): | |
| """ | |
| BatchNorm2d where the batch statistics and the affine parameters are fixed. | |
| It contains non-trainable buffers called | |
| "weight" and "bias", "running_mean", "running_var", | |
| initialized to perform identity transformation. | |
| The pre-trained backbone models from Caffe2 only contain "weight" and "bias", | |
| which are computed from the original four parameters of BN. | |
| The affine transform `x * weight + bias` will perform the equivalent | |
| computation of `(x - running_mean) / sqrt(running_var) * weight + bias`. | |
| When loading a backbone model from Caffe2, "running_mean" and "running_var" | |
| will be left unchanged as identity transformation. | |
| Other pre-trained backbone models may contain all 4 parameters. | |
| The forward is implemented by `F.batch_norm(..., training=False)`. | |
| """ | |
| _version = 3 | |
| def __init__(self, num_features, eps=1e-5): | |
| super().__init__() | |
| self.num_features = num_features | |
| self.eps = eps | |
| self.register_buffer("weight", torch.ones(num_features)) | |
| self.register_buffer("bias", torch.zeros(num_features)) | |
| self.register_buffer("running_mean", torch.zeros(num_features)) | |
| self.register_buffer("running_var", torch.ones(num_features) - eps) | |
| def forward(self, x): | |
| if x.requires_grad: | |
| # When gradients are needed, F.batch_norm will use extra memory | |
| # because its backward op computes gradients for weight/bias as well. | |
| scale = self.weight * (self.running_var + self.eps).rsqrt() | |
| bias = self.bias - self.running_mean * scale | |
| scale = scale.reshape(1, -1, 1, 1) | |
| bias = bias.reshape(1, -1, 1, 1) | |
| out_dtype = x.dtype # may be half | |
| return x * scale.to(out_dtype) + bias.to(out_dtype) | |
| else: | |
| # When gradients are not needed, F.batch_norm is a single fused op | |
| # and provide more optimization opportunities. | |
| return F.batch_norm( | |
| x, | |
| self.running_mean, | |
| self.running_var, | |
| self.weight, | |
| self.bias, | |
| training=False, | |
| eps=self.eps, | |
| ) | |
| def _load_from_state_dict( | |
| self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs | |
| ): | |
| version = local_metadata.get("version", None) | |
| if version is None or version < 2: | |
| # No running_mean/var in early versions | |
| # This will silent the warnings | |
| if prefix + "running_mean" not in state_dict: | |
| state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean) | |
| if prefix + "running_var" not in state_dict: | |
| state_dict[prefix + "running_var"] = torch.ones_like(self.running_var) | |
| super()._load_from_state_dict( | |
| state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs | |
| ) | |
| def __repr__(self): | |
| return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps) | |
| def convert_frozen_batchnorm(cls, module): | |
| """ | |
| Convert all BatchNorm/SyncBatchNorm in module into FrozenBatchNorm. | |
| Args: | |
| module (torch.nn.Module): | |
| Returns: | |
| If module is BatchNorm/SyncBatchNorm, returns a new module. | |
| Otherwise, in-place convert module and return it. | |
| Similar to convert_sync_batchnorm in | |
| https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py | |
| """ | |
| bn_module = nn.modules.batchnorm | |
| bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm) | |
| res = module | |
| if isinstance(module, bn_module): | |
| res = cls(module.num_features) | |
| if module.affine: | |
| res.weight.data = module.weight.data.clone().detach() | |
| res.bias.data = module.bias.data.clone().detach() | |
| res.running_mean.data = module.running_mean.data | |
| res.running_var.data = module.running_var.data | |
| res.eps = module.eps | |
| else: | |
| for name, child in module.named_children(): | |
| new_child = cls.convert_frozen_batchnorm(child) | |
| if new_child is not child: | |
| res.add_module(name, new_child) | |
| return res | |
| class LayerNorm(nn.Module): | |
| """ | |
| A LayerNorm variant, popularized by Transformers, that performs point-wise mean and | |
| variance normalization over the channel dimension for inputs that have shape | |
| (batch_size, channels, height, width). | |
| https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950 | |
| """ | |
| def __init__(self, normalized_shape, eps=1e-6): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(normalized_shape)) | |
| self.bias = nn.Parameter(torch.zeros(normalized_shape)) | |
| self.eps = eps | |
| self.normalized_shape = (normalized_shape,) | |
| def forward(self, x): | |
| u = x.mean(1, keepdim=True) | |
| s = (x - u).pow(2).mean(1, keepdim=True) | |
| x = (x - u) / torch.sqrt(s + self.eps) | |
| x = self.weight[:, None, None] * x + self.bias[:, None, None] | |
| return x | |
| class CNNBlockBase(nn.Module): | |
| """ | |
| A CNN block is assumed to have input channels, output channels and a stride. | |
| The input and output of `forward()` method must be NCHW tensors. | |
| The method can perform arbitrary computation but must match the given | |
| channels and stride specification. | |
| Attribute: | |
| in_channels (int): | |
| out_channels (int): | |
| stride (int): | |
| """ | |
| def __init__(self, in_channels, out_channels, stride): | |
| """ | |
| The `__init__` method of any subclass should also contain these arguments. | |
| Args: | |
| in_channels (int): | |
| out_channels (int): | |
| stride (int): | |
| """ | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.stride = stride | |
| def freeze(self): | |
| """ | |
| Make this block not trainable. | |
| This method sets all parameters to `requires_grad=False`, | |
| and convert all BatchNorm layers to FrozenBatchNorm | |
| Returns: | |
| the block itself | |
| """ | |
| for p in self.parameters(): | |
| p.requires_grad = False | |
| FrozenBatchNorm2d.convert_frozen_batchnorm(self) | |
| return self | |
| def get_norm(norm, out_channels): | |
| """ | |
| Args: | |
| norm (str or callable): either one of BN, SyncBN, FrozenBN, GN; | |
| or a callable that takes a channel number and returns | |
| the normalization layer as a nn.Module. | |
| Returns: | |
| nn.Module or None: the normalization layer | |
| """ | |
| if norm is None: | |
| return None | |
| if isinstance(norm, str): | |
| if len(norm) == 0: | |
| return None | |
| norm = { | |
| "BN": BatchNorm2d, | |
| # Fixed in https://github.com/pytorch/pytorch/pull/36382 | |
| "SyncBN": nn.SyncBatchNorm, | |
| "FrozenBN": FrozenBatchNorm2d, | |
| "GN": lambda channels: nn.GroupNorm(32, channels), | |
| # for debugging: | |
| "nnSyncBN": nn.SyncBatchNorm, | |
| "LN": lambda channels: LayerNorm(channels) | |
| }[norm] | |
| return norm(out_channels) | |
| class DropPath(nn.Module): | |
| """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | |
| """ | |
| def __init__(self, drop_prob=None): | |
| super(DropPath, self).__init__() | |
| self.drop_prob = drop_prob | |
| def forward(self, x): | |
| if self.drop_prob == 0. or not self.training: | |
| return x | |
| keep_prob = 1 - self.drop_prob | |
| # work with diff dim tensors, not just 2D ConvNets | |
| shape = (x.shape[0],) + (1,) * (x.ndim - 1) | |
| random_tensor = keep_prob + \ | |
| torch.rand(shape, dtype=x.dtype, device=x.device) | |
| random_tensor.floor_() # binarize | |
| output = x.div(keep_prob) * random_tensor | |
| return output | |
| class SwiGLU(nn.Module): | |
| def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0., | |
| norm_layer=nn.LayerNorm, subln=False | |
| ): | |
| super().__init__() | |
| out_features = out_features or in_features | |
| hidden_features = hidden_features or in_features | |
| self.w1 = nn.Linear(in_features, hidden_features) | |
| self.w2 = nn.Linear(in_features, hidden_features) | |
| self.act = act_layer() | |
| self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity() | |
| self.w3 = nn.Linear(hidden_features, out_features) | |
| self.drop = nn.Dropout(drop) | |
| def forward(self, x): | |
| x1 = self.w1(x) | |
| x2 = self.w2(x) | |
| hidden = self.act(x1) * x2 | |
| x = self.ffn_ln(hidden) | |
| x = self.w3(x) | |
| x = self.drop(x) | |
| return x | |
| class Attention(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| num_heads=8, | |
| qkv_bias=True, | |
| qk_scale=None, | |
| attn_head_dim=None, | |
| norm_layer=nn.LayerNorm, | |
| rope=None, | |
| xattn=True, | |
| subln=False | |
| ): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| head_dim = dim // num_heads | |
| if attn_head_dim is not None: | |
| head_dim = attn_head_dim | |
| all_head_dim = head_dim * self.num_heads | |
| self.scale = qk_scale or head_dim ** -0.5 | |
| self.subln = subln | |
| self.q_proj = nn.Linear(dim, all_head_dim, bias=False) | |
| self.k_proj = nn.Linear(dim, all_head_dim, bias=False) | |
| self.v_proj = nn.Linear(dim, all_head_dim, bias=False) | |
| if qkv_bias: | |
| self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) | |
| self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) | |
| else: | |
| self.q_bias = None | |
| self.v_bias = None | |
| self.rope = rope | |
| self.xattn = xattn | |
| self.proj = nn.Linear(all_head_dim, dim) | |
| self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity() | |
| if self.xattn: | |
| factory_kwargs = {'device': 'cuda', 'dtype': torch.float16} | |
| self.inner_attn = FlashAttention(attention_dropout=0.0, **factory_kwargs) | |
| def forward(self, x): | |
| B, H, W, C = x.shape | |
| x = x.view(B, -1, C) | |
| N = H * W | |
| q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias) | |
| k = F.linear(input=x, weight=self.k_proj.weight, bias=None) | |
| v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias) | |
| q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C | |
| k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) | |
| v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) | |
| ## rope | |
| q = self.rope(q).type_as(v) | |
| k = self.rope(k).type_as(v) | |
| if self.xattn: | |
| q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C | |
| k = k.permute(0, 2, 1, 3) | |
| v = v.permute(0, 2, 1, 3) | |
| kv = torch.stack([k, v], dim=2) | |
| x, attn_weights = self.inner_attn(q, kv, key_padding_mask=None, causal=False) | |
| # x = xops.memory_efficient_attention(q, k, v) | |
| x = x.reshape(B, N, -1) | |
| x = self.inner_attn_ln(x) | |
| else: | |
| q = q * self.scale | |
| attn = (q @ k.transpose(-2, -1)) | |
| attn = attn.softmax(dim=-1).type_as(x) | |
| x = (attn @ v).transpose(1, 2).reshape(B, N, -1) | |
| x = self.inner_attn_ln(x) | |
| x = self.proj(x) | |
| x = x.view(B, H, W, C) | |
| return x | |
| class ResBottleneckBlock(CNNBlockBase): | |
| """ | |
| The standard bottleneck residual block without the last activation layer. | |
| It contains 3 conv layers with kernels 1x1, 3x3, 1x1. | |
| """ | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| bottleneck_channels, | |
| norm="LN", | |
| act_layer=nn.GELU, | |
| ): | |
| """ | |
| Args: | |
| in_channels (int): Number of input channels. | |
| out_channels (int): Number of output channels. | |
| bottleneck_channels (int): number of output channels for the 3x3 | |
| "bottleneck" conv layers. | |
| norm (str or callable): normalization for all conv layers. | |
| See :func:`layers.get_norm` for supported format. | |
| act_layer (callable): activation for all conv layers. | |
| """ | |
| super().__init__(in_channels, out_channels, 1) | |
| self.conv1 = Conv2d(in_channels, bottleneck_channels, 1, bias=False) | |
| self.norm1 = get_norm(norm, bottleneck_channels) | |
| self.act1 = act_layer() | |
| self.conv2 = Conv2d( | |
| bottleneck_channels, | |
| bottleneck_channels, | |
| 3, | |
| padding=1, | |
| bias=False, | |
| ) | |
| self.norm2 = get_norm(norm, bottleneck_channels) | |
| self.act2 = act_layer() | |
| self.conv3 = Conv2d(bottleneck_channels, out_channels, 1, bias=False) | |
| self.norm3 = get_norm(norm, out_channels) | |
| for layer in [self.conv1, self.conv2, self.conv3]: | |
| weight_init.c2_msra_fill(layer) | |
| for layer in [self.norm1, self.norm2]: | |
| layer.weight.data.fill_(1.0) | |
| layer.bias.data.zero_() | |
| # zero init last norm layer. | |
| self.norm3.weight.data.zero_() | |
| self.norm3.bias.data.zero_() | |
| def forward(self, x): | |
| out = x | |
| for layer in self.children(): | |
| out = layer(out) | |
| out = x + out | |
| return out | |
| class Block(nn.Module): | |
| """Transformer blocks with support of window attention and residual propagation blocks""" | |
| def __init__( | |
| self, | |
| dim, | |
| num_heads, | |
| mlp_ratio=4*2/3, | |
| qkv_bias=True, | |
| drop_path=0.0, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
| window_size=0, | |
| use_residual_block=False, | |
| rope=None, | |
| xattn=True, | |
| subln=False, | |
| # with_cp=True, | |
| ): | |
| """ | |
| Args: | |
| dim (int): Number of input channels. | |
| num_heads (int): Number of attention heads in each ViT block. | |
| mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. | |
| qkv_bias (bool): If True, add a learnable bias to query, key, value. | |
| drop_path (float): Stochastic depth rate. | |
| norm_layer (nn.Module): Normalization layer. | |
| act_layer (nn.Module): Activation layer. | |
| use_rel_pos (bool): If True, add relative positional embeddings to the attention map. | |
| rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. | |
| window_size (int): Window size for window attention blocks. If it equals 0, then not | |
| use window attention. | |
| use_residual_block (bool): If True, use a residual block after the MLP block. | |
| input_size (int or None): Input resolution for calculating the relative positional | |
| parameter size. | |
| """ | |
| super().__init__() | |
| self.norm1 = norm_layer(dim) | |
| self.attn = Attention( | |
| dim, | |
| num_heads=num_heads, | |
| qkv_bias=qkv_bias, | |
| rope=rope, | |
| xattn=xattn, | |
| subln=subln | |
| ) | |
| # self.with_cp = with_cp | |
| self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() | |
| self.norm2 = norm_layer(dim) | |
| self.mlp = SwiGLU( | |
| in_features=dim, | |
| hidden_features=int(dim * mlp_ratio), | |
| subln=True, | |
| norm_layer=norm_layer, | |
| ) | |
| self.window_size = window_size | |
| self.use_residual_block = use_residual_block | |
| if use_residual_block: | |
| # Use a residual block with bottleneck channel as dim // 2 | |
| self.residual = ResBottleneckBlock( | |
| in_channels=dim, | |
| out_channels=dim, | |
| bottleneck_channels=dim // 2, | |
| norm="LN", | |
| ) | |
| def _forward(self, x): | |
| shortcut = x | |
| x = self.norm1(x) | |
| # Window partition | |
| if self.window_size > 0: | |
| H, W = x.shape[1], x.shape[2] | |
| x, pad_hw = window_partition(x, self.window_size) | |
| x = self.attn(x) | |
| # Reverse window partition | |
| if self.window_size > 0: | |
| x = window_unpartition(x, self.window_size, pad_hw, (H, W)) | |
| x = shortcut + self.drop_path(x) | |
| x = x + self.drop_path(self.mlp(self.norm2(x))) | |
| if self.use_residual_block: | |
| x = self.residual(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) | |
| return x | |
| def forward(self, x, with_cp=False): | |
| # if self.with_cp and self.training: | |
| if with_cp: | |
| x = cp.checkpoint(self._forward, x) | |
| else: | |
| x = self._forward(x) | |
| return x | |
| #@BACKBONES.register_module() | |
| class EVAViT(nn.Module): | |
| """ | |
| This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`. | |
| "Exploring Plain Vision Transformer Backbones for Object Detection", | |
| https://arxiv.org/abs/2203.16527 | |
| """ | |
| def __init__( | |
| self, | |
| img_size=1024, | |
| patch_size=16, | |
| in_chans=3, | |
| embed_dim=768, | |
| depth=12, | |
| num_heads=12, | |
| mlp_ratio=4*2/3, | |
| qkv_bias=True, | |
| drop_path_rate=0.0, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
| act_layer=nn.GELU, | |
| use_abs_pos=True, | |
| use_rel_pos=False, | |
| # sim_fpn=None, | |
| rope=True, | |
| pt_hw_seq_len=16, | |
| intp_freq=True, | |
| window_size=0, | |
| global_window_size=0, | |
| window_block_indexes=(), | |
| residual_block_indexes=(), | |
| pretrain_img_size=224, | |
| pretrain_use_cls_token=True, | |
| out_feature="last_feat", | |
| subln=False, | |
| xattn=True, | |
| # with_cp=True, | |
| frozen=False, | |
| ): | |
| """ | |
| Args: | |
| img_size (int): Input image size. | |
| patch_size (int): Patch size. | |
| in_chans (int): Number of input image channels. | |
| embed_dim (int): Patch embedding dimension. | |
| depth (int): Depth of ViT. | |
| num_heads (int): Number of attention heads in each ViT block. | |
| mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. | |
| qkv_bias (bool): If True, add a learnable bias to query, key, value. | |
| drop_path_rate (float): Stochastic depth rate. | |
| norm_layer (nn.Module): Normalization layer. | |
| act_layer (nn.Module): Activation layer. | |
| use_abs_pos (bool): If True, use absolute positional embeddings. | |
| use_rel_pos (bool): If True, add relative positional embeddings to the attention map. | |
| rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. | |
| window_size (int): Window size for window attention blocks. | |
| window_block_indexes (list): Indexes for blocks using window attention. | |
| residual_block_indexes (list): Indexes for blocks using conv propagation. | |
| use_act_checkpoint (bool): If True, use activation checkpointing. | |
| pretrain_img_size (int): input image size for pretraining models. | |
| pretrain_use_cls_token (bool): If True, pretrainig models use class token. | |
| out_feature (str): name of the feature from the last block. | |
| """ | |
| super().__init__() | |
| self.pretrain_use_cls_token = pretrain_use_cls_token | |
| self.patch_embed = PatchEmbed( | |
| kernel_size=(patch_size, patch_size), | |
| stride=(patch_size, patch_size), | |
| in_chans=in_chans, | |
| embed_dim=embed_dim, | |
| ) | |
| self.frozen = frozen | |
| self.gradient_checkpointing = False | |
| if use_abs_pos: | |
| # Initialize absolute positional embedding with pretrain image size. | |
| num_patches = (pretrain_img_size // patch_size) * (pretrain_img_size // patch_size) | |
| num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches | |
| self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim)) | |
| else: | |
| self.pos_embed = None | |
| half_head_dim = embed_dim // num_heads // 2 | |
| hw_seq_len = img_size // patch_size | |
| self.rope_win = VisionRotaryEmbeddingFast( | |
| dim=half_head_dim, | |
| pt_seq_len=pt_hw_seq_len, | |
| ft_seq_len=window_size if intp_freq else None, | |
| ) | |
| self.rope_glb = VisionRotaryEmbeddingFast( | |
| dim=half_head_dim, | |
| pt_seq_len=pt_hw_seq_len, | |
| ft_seq_len=hw_seq_len if intp_freq else None, | |
| ) | |
| # stochastic depth decay rule | |
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] | |
| self.blocks = nn.ModuleList() | |
| for i in range(depth): | |
| block = Block( | |
| dim=embed_dim, | |
| num_heads=num_heads, | |
| mlp_ratio=mlp_ratio, | |
| qkv_bias=qkv_bias, | |
| drop_path=dpr[i], | |
| norm_layer=norm_layer, | |
| window_size=window_size if i in window_block_indexes else global_window_size, | |
| use_residual_block=i in residual_block_indexes, | |
| rope=self.rope_win if i in window_block_indexes else self.rope_glb, | |
| xattn=xattn, | |
| subln=subln, | |
| # with_cp=with_cp, | |
| ) | |
| self.blocks.append(block) | |
| self._out_feature_channels = {out_feature: embed_dim} | |
| self._out_feature_strides = {out_feature: patch_size} | |
| self._out_features = [out_feature] | |
| # if self.pos_embed is not None: | |
| # nn.init.trunc_normal_(self.pos_embed, std=0.02) | |
| if self.pos_embed is not None: | |
| nn.init.normal_(self.pos_embed, std=0.02) | |
| # MIN SHI: I disable the weight initialization since they will be automatically loaded | |
| # **However, they will cause problems (deepspeed + bf16)** | |
| # self.apply(self._init_weights) | |
| self._freeze_stages() | |
| # def _init_weights(self, m): | |
| # if isinstance(m, nn.Linear): | |
| # nn.init.trunc_normal_(m.weight, std=0.02) | |
| # if isinstance(m, nn.Linear) and m.bias is not None: | |
| # nn.init.constant_(m.bias, 0) | |
| # elif isinstance(m, nn.LayerNorm): | |
| # nn.init.constant_(m.bias, 0) | |
| # nn.init.constant_(m.weight, 1.0) | |
| def _freeze_stages(self): | |
| if self.frozen: | |
| self.eval() | |
| for m in self.parameters(): | |
| m.requires_grad = False | |
| def forward(self, x): | |
| x = self.patch_embed(x) | |
| if self.pos_embed is not None: | |
| x = x + get_abs_pos( | |
| self.pos_embed, self.pretrain_use_cls_token, (x.shape[1], x.shape[2]) | |
| ) | |
| for blk in self.blocks: | |
| x = blk(x, with_cp=self.gradient_checkpointing) # b, h, w, c | |
| x = x.permute(0, 3, 1, 2) # b, c, h, w | |
| # if self.adapter is not None: | |
| # outputs = self.adapter(x) | |
| # else: | |
| # outputs = [x, ] | |
| # return outputs | |
| return x | |
| ''' | |
| EVA VIT vision encoder for LLaVA | |
| ''' | |
| class EVAVITVisionTower(nn.Module): | |
| def __init__(self, vision_tower, args, delay_load=False): | |
| super().__init__() | |
| self.is_loaded = False | |
| self.vision_tower_name = vision_tower | |
| self.select_layer = args.mm_vision_select_layer # NOTE: not implemented yet, this parameter has no effect | |
| self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') | |
| self.args = args | |
| self.vision_tower, vision_tower_config = build_eva_vit(args=args, | |
| model_name=vision_tower, | |
| image_size=args.input_image_size | |
| ) | |
| self.input_image_size=args.input_image_size | |
| self.vision_tower.config = vision_tower_config | |
| self.freeze_vision = args.freeze_vision | |
| if not self.is_loaded: | |
| self.load_model() | |
| # if not delay_load: | |
| # self.load_model() | |
| # else: | |
| # self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) | |
| def load_model(self): | |
| if self.is_loaded: | |
| return | |
| # self.args.vision_tower_input_size = 224 # hardcode | |
| self.image_processor = CLIPImageProcessor(crop_size={"height": self.args.input_image_size, "width": self.args.input_image_size}, | |
| size={'shortest_edge': self.args.input_image_size}, | |
| image_mean=[0.48145466, 0.4578275, 0.40821073], | |
| image_std=[0.26862954, 0.26130258, 0.27577711]) | |
| # load weights | |
| if self.args.vision_tower_pretrained_from is None: | |
| self.args.vision_tower_pretrained_from = "/lustre/fsw/portfolios/llmservice/users/fuxiaol/eva02_L_coco_det_sys_o365.pth" | |
| # pretrained_params = torch.load(self.args.vision_tower_pretrained_from) | |
| # if 'ema_state' in pretrained_params: | |
| # pretrained_params = pretrained_params['ema_state'] | |
| # elif 'module' in pretrained_params: | |
| # pretrained_params = pretrained_params['module'] | |
| # from collections import OrderedDict | |
| # new_params = OrderedDict() | |
| # kw = "" | |
| # if "det" in self.args.vision_tower_pretrained_from.lower(): | |
| # kw = "backbone.net." | |
| # elif "clip" in self.args.vision_tower_pretrained_from.lower(): | |
| # kw = "visual." | |
| # for k, v in pretrained_params.items(): | |
| # if len(kw) > 0: | |
| # if kw in k and ("rope" not in k): | |
| # new_params[k.replace(kw, "")] = v | |
| # else: | |
| # if "rope" not in k: | |
| # new_params[k] = v | |
| # incompatiblekeys = self.vision_tower.load_state_dict(new_params, strict=False) | |
| # for k in incompatiblekeys[0]: | |
| # if "rope" not in k: | |
| # warnings.warn(f"Find incompatible keys {k} in state dict.") | |
| # print(f"EVA-02 ckpt loaded from {self.args.vision_tower_pretrained_from}") | |
| if self.freeze_vision: | |
| self.vision_tower.requires_grad_(False) | |
| self.is_loaded = True | |
| # @torch.no_grad() | |
| def forward(self, images): | |
| if type(images) is list: | |
| image_features = [] | |
| for image in images: | |
| image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0)) | |
| image_feature = image_forward_out.flatten(2,3).transpose(1,2) # b, n, c | |
| image_features.append(image_feature) | |
| else: | |
| image_forward_out = self.vision_tower(images.to(device=self.device, dtype=self.dtype)) | |
| return image_forward_out | |
| def dummy_feature(self): | |
| return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) | |
| def dtype(self): | |
| return next(self.vision_tower.parameters()).dtype | |
| def device(self): | |
| return next(self.vision_tower.parameters()).device | |
| def config(self): | |
| # if self.is_loaded: | |
| # return self.vision_tower.config | |
| # else: | |
| # return self.cfg_only | |
| # TODO | |
| return self.vision_tower.config | |
| def hidden_size(self): | |
| #return self.config.hidden_size | |
| return self.config['hidden_dim'] | |
| def num_patches(self): | |
| # return (self.config.image_size // self.config.patch_size) ** 2 | |
| return self.config['num_patches'] | |
| def build_eva_vit(args, | |
| model_name=None, | |
| image_size=224, | |
| window_attn=True | |
| ): | |
| if "336" in args.vision_tower_pretrained_from: | |
| pretrained_image_size = 336 | |
| else: | |
| pretrained_image_size = 224 | |
| if "clip" in args.vision_tower_pretrained_from.lower(): | |
| subln = True | |
| else: | |
| subln = False | |
| if model_name == 'eva02-l-16': | |
| # shilong said that use this: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/det/eva02_L_coco_det_sys_o365.pth | |
| if window_attn: | |
| window_block_indexes = (list(range(0, 2)) + list(range(3, 5)) + list(range(6, 8)) + list(range(9, 11)) + list(range(12, 14)) + list(range(15, 17)) + list(range(18, 20)) + list(range(21, 23))) | |
| else: | |
| window_block_indexes = () | |
| model = EVAViT( | |
| img_size=image_size, | |
| patch_size=16, | |
| window_size=16, | |
| in_chans=3, | |
| embed_dim=1024, | |
| depth=24, | |
| num_heads=16, | |
| mlp_ratio=4*2/3, | |
| window_block_indexes = window_block_indexes, | |
| qkv_bias=True, | |
| drop_path_rate=0.0, | |
| xattn=False, | |
| # with_cp=False, | |
| # frozen=True, | |
| ) | |
| # image_size = 224 # HARDCODE | |
| eva_config = dict(image_size=image_size, | |
| patch_size=16, | |
| window_size=16, | |
| hidden_dim=1024, | |
| depth=24, | |
| num_heads=16, | |
| window_block_indexes=window_block_indexes, | |
| num_patches=image_size ** 2 // 16 ** 2, | |
| pretrained_from=args.vision_tower_pretrained_from | |
| ) | |
| elif model_name == 'eva02-l-14': | |
| # shilong said that use this: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/det/eva02_L_coco_det_sys_o365.pth | |
| if window_attn: | |
| window_block_indexes = (list(range(0, 2)) + list(range(3, 5)) + list(range(6, 8)) + list(range(9, 11)) + list(range(12, 14)) + list(range(15, 17)) + list(range(18, 20)) + list(range(21, 23))) | |
| else: | |
| window_block_indexes = () | |
| model = EVAViT( | |
| img_size=image_size, | |
| pretrain_img_size=pretrained_image_size, | |
| patch_size=14, | |
| window_size=16, | |
| in_chans=3, | |
| embed_dim=1024, | |
| depth=24, | |
| num_heads=16, | |
| mlp_ratio=4*2/3, | |
| window_block_indexes = window_block_indexes, | |
| qkv_bias=True, | |
| drop_path_rate=0.0, | |
| xattn=False, | |
| # with_cp=False, | |
| subln=subln, | |
| # frozen=True, | |
| ) | |
| # image_size = 224 # HARDCODE | |
| eva_config = dict(image_size=image_size, | |
| patch_size=14, | |
| window_size=16, | |
| hidden_dim=1024, | |
| depth=24, | |
| num_heads=16, | |
| window_block_indexes=window_block_indexes, | |
| num_patches=image_size ** 2 // 14 ** 2, | |
| pretrained_from=args.vision_tower_pretrained_from | |
| ) | |
| else: | |
| raise NotImplementedError | |
| return model, eva_config |