| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| from __future__ import annotations |
| import warnings |
| warnings.filterwarnings("ignore") |
|
|
| import os |
| import io |
| import sys |
| import math |
| import random |
| import collections |
| import collections.abc |
| import re |
| from itertools import repeat |
| from pathlib import Path |
| from typing import Optional, Tuple, Union, List, Dict |
|
|
| import csv |
| import numpy as np |
| import pandas as pd |
| from PIL import Image |
| import seaborn as sns |
| import matplotlib.pyplot as plt |
| from tqdm import trange, tqdm |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import nn |
| from torch.nn.init import _calculate_fan_in_and_fan_out |
| import torch.utils.checkpoint as checkpoint |
|
|
| import torchvision |
| from torchvision.transforms import v2 |
| from torch.utils.tensorboard import SummaryWriter |
| |
|
|
| |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| print(f"Using device: {device}") |
|
|
| import torchaudio |
| import torchaudio.transforms as T |
| from torchlibrosa.stft import Spectrogram, LogmelFilterBank |
| from torchlibrosa.augmentation import SpecAugmentation |
|
|
| from transformers import AutoModel, AutoTokenizer, logging |
| from huggingface_hub.file_download import hf_hub_download |
| from huggingface_hub.file_download import hf_hub_download |
| from peft import get_peft_config, get_peft_model |
| from transformers import CLIPVisionModel, AutoProcessor |
|
|
| from watermark import watermark |
| print(watermark( |
| author='Ashish', |
| |
| current_date=True, |
| datename=True, |
| current_time=True, |
| iso8601=True, |
| timezone=True, |
| updated=True, |
| custom_time=None, |
| python=True, |
| |
| conda=True, |
| hostname=True, |
| machine=True, |
| watermark=False, |
| iversions=True, |
| gpu=True, |
| globals_=globals() |
| )) |
|
|
|
|
| |
| |
| |
| class HTSATConfig: |
| |
| |
| |
| |
|
|
| exp_name = "exp_htsat_pretrain" |
| workspace = "/home/kechen/Research/HTSAT" |
| dataset_path = "/home/Research/audioset" |
| desed_folder = "/home/Research/DESED" |
|
|
| dataset_type = "audioset" |
| index_type = "full_train" |
| balanced_data = True |
|
|
| loss_type = "clip_bce" |
| |
|
|
| |
| resume_checkpoint = None |
| |
| |
| esc_fold = 0 |
|
|
|
|
| debug = False |
|
|
| random_seed = 970131 |
| batch_size = 32 * 4 |
| learning_rate = 1e-3 |
| max_epoch = 100 |
| num_workers = 3 |
|
|
| lr_scheduler_epoch = [10,20,30] |
| lr_rate = [0.02, 0.05, 0.1] |
|
|
| |
| enable_token_label = False |
| class_map_path = "class_hier_map.npy" |
| class_filter = None |
| retrieval_index = [15382, 9202, 130, 17618, 17157, 17516, 16356, 6165, 13992, 9238, 5550, 5733, 1914, 1600, 3450, 13735, 11108, 3762, |
| 9840, 11318, 8131, 4429, 16748, 4992, 16783, 12691, 4945, 8779, 2805, 9418, 2797, 14357, 5603, 212, 3852, 12666, 1338, 10269, 2388, 8260, 4293, 14454, 7677, 11253, 5060, 14938, 8840, 4542, 2627, 16336, 8992, 15496, 11140, 446, 6126, 10691, 8624, 10127, 9068, 16710, 10155, 14358, 7567, 5695, 2354, 8057, 17635, 133, 16183, 14535, 7248, 4560, 14429, 2463, 10773, 113, 2462, 9223, 4929, 14274, 4716, 17307, 4617, 2132, 11083, 1039, 1403, 9621, 13936, 2229, 2875, 17840, 9359, 13311, 9790, 13288, 4750, 17052, 8260, 14900] |
| token_label_range = [0.2,0.6] |
| enable_time_shift = False |
| enable_label_enhance = False |
| enable_repeat_mode = False |
|
|
|
|
|
|
| |
| enable_tscam = True |
|
|
| |
| sample_rate = 32000 |
| clip_samples = sample_rate * 10 |
| window_size = 1024 |
| hop_size = 320 |
| mel_bins = 64 |
| fmin = 50 |
| fmax = 14000 |
| shift_max = int(clip_samples * 0.5) |
|
|
| |
| classes_num = 527 |
| patch_size = (25, 4) |
| crop_size = None |
|
|
| |
| htsat_window_size = 8 |
| htsat_spec_size = 256 |
| htsat_patch_size = 4 |
| htsat_stride = (4, 4) |
| htsat_num_head = [4,8,16,32] |
| htsat_dim = 96 |
| htsat_depth = [2,2,6,2] |
|
|
| swin_pretrain_path = None |
| |
|
|
| |
| htsat_attn_heatmap = False |
| htsat_hier_output = False |
| htsat_use_max = False |
|
|
|
|
| |
|
|
| ensemble_checkpoints = [] |
| ensemble_strides = [] |
|
|
|
|
| |
| wa_folder = "/home/version_0/checkpoints/" |
| |
| wa_model_path = "HTSAT_AudioSet_Saved_x.ckpt" |
|
|
| esm_model_pathes = [ |
| "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_1.ckpt", |
| "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_2.ckpt", |
| "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_3.ckpt", |
| "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_4.ckpt", |
| "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_5.ckpt", |
| "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_6.ckpt" |
| ] |
|
|
| |
| heatmap_dir = "/home/Research/heatmap_output" |
| test_file = "htsat-test-ensemble" |
| fl_local = False |
| fl_dataset = "/home/Research/desed/desedim_embval.npy" |
| fl_class_num = [ |
| "Speech", "Frying", "Dishes", "Running_water", |
| "Blender", "Electric_shaver_toothbrush", "Alarm_bell_ringing", |
| "Cat", "Dog", "Vacuum_cleaner" |
| ] |
|
|
| |
| fl_audioset_mapping = [ |
| [0,1,2,3,4,5,6,7], |
| [366, 367, 368], |
| [364], |
| [288, 289, 290, 291, 292, 293, 294, 295, 296, 297], |
| [369], |
| [382], |
| [310, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402], |
| [81, 82, 83, 84, 85], |
| [74, 75, 76, 77, 78, 79], |
| [377] |
| ] |
|
|
|
|
|
|
| def _ntuple(n): |
| def parse(x): |
| if isinstance(x, collections.abc.Iterable): |
| return x |
| return tuple(repeat(x, n)) |
| return parse |
|
|
| to_1tuple = _ntuple(1) |
| to_2tuple = _ntuple(2) |
| to_3tuple = _ntuple(3) |
| to_4tuple = _ntuple(4) |
| to_ntuple = _ntuple |
|
|
| def do_mixup(x, mixup_lambda): |
| """Mixup x of even indexes (0, 2, 4, ...) with x of odd indexes |
| (1, 3, 5, ...). |
| Args: |
| x: (batch_size * 2, ...) |
| mixup_lambda: (batch_size * 2,) |
| Returns: |
| out: (batch_size, ...) |
| """ |
| out = (x[0 :: 2].transpose(0, -1) * mixup_lambda[0 :: 2] + \ |
| x[1 :: 2].transpose(0, -1) * mixup_lambda[1 :: 2]).transpose(0, -1) |
| return out |
|
|
| def interpolate(x, ratio): |
| """Interpolate data in time domain. This is used to compensate the |
| resolution reduction in downsampling of a CNN. |
| |
| Args: |
| x: (batch_size, time_steps, classes_num) |
| ratio: int, ratio to interpolate |
| Returns: |
| upsampled: (batch_size, time_steps * ratio, classes_num) |
| """ |
| (batch_size, time_steps, classes_num) = x.shape |
| upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1) |
| upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num) |
| return upsampled |
|
|
|
|
| def drop_path(x, drop_prob: float = 0., training: bool = False): |
| """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
| This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, |
| the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... |
| See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for |
| changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use |
| 'survival rate' as the argument. |
| """ |
| if drop_prob == 0. or not training: |
| return x |
| keep_prob = 1 - drop_prob |
| shape = (x.shape[0],) + (1,) * (x.ndim - 1) |
| random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) |
| random_tensor.floor_() |
| output = x.div(keep_prob) * random_tensor |
| return output |
|
|
|
|
| class DropPath(nn.Module): |
| """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
| """ |
| def __init__(self, drop_prob=None): |
| super(DropPath, self).__init__() |
| self.drop_prob = drop_prob |
|
|
| def forward(self, x): |
| return drop_path(x, self.drop_prob, self.training) |
|
|
| class PatchEmbed(nn.Module): |
| """ 2D Image to Patch Embedding |
| """ |
| def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, patch_stride = 16): |
| super().__init__() |
| img_size = to_2tuple(img_size) |
| patch_size = to_2tuple(patch_size) |
| patch_stride = to_2tuple(patch_stride) |
| self.img_size = img_size |
| self.patch_size = patch_size |
| self.patch_stride = patch_stride |
| self.grid_size = (img_size[0] // patch_stride[0], img_size[1] // patch_stride[1]) |
| self.num_patches = self.grid_size[0] * self.grid_size[1] |
| self.flatten = flatten |
| self.in_chans = in_chans |
| self.embed_dim = embed_dim |
| |
| padding = ((patch_size[0] - patch_stride[0]) // 2, (patch_size[1] - patch_stride[1]) // 2) |
|
|
| self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride, padding=padding) |
| self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() |
|
|
| def forward(self, x): |
| B, C, H, W = x.shape |
| assert H == self.img_size[0] and W == self.img_size[1], \ |
| f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." |
| x = self.proj(x) |
| if self.flatten: |
| x = x.flatten(2).transpose(1, 2) |
| x = self.norm(x) |
| return x |
|
|
| class Mlp(nn.Module): |
| """ MLP as used in Vision Transformer, MLP-Mixer and related networks |
| """ |
| def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): |
| super().__init__() |
| out_features = out_features or in_features |
| hidden_features = hidden_features or in_features |
| self.fc1 = nn.Linear(in_features, hidden_features) |
| self.act = act_layer() |
| self.fc2 = nn.Linear(hidden_features, out_features) |
| self.drop = nn.Dropout(drop) |
|
|
| def forward(self, x): |
| x = self.fc1(x) |
| x = self.act(x) |
| x = self.drop(x) |
| x = self.fc2(x) |
| x = self.drop(x) |
| return x |
|
|
| def _no_gradim_audiorunc_normal_(tensor, mean, std, a, b): |
| |
| |
| def norm_cdf(x): |
| |
| return (1. + math.erf(x / math.sqrt(2.))) / 2. |
|
|
| if (mean < a - 2 * std) or (mean > b + 2 * std): |
| warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " |
| "The distribution of values may be incorrect.", |
| stacklevel=2) |
|
|
| with torch.no_grad(): |
| |
| |
| |
| l = norm_cdf((a - mean) / std) |
| u = norm_cdf((b - mean) / std) |
|
|
| |
| |
| tensor.uniform_(2 * l - 1, 2 * u - 1) |
|
|
| |
| |
| tensor.erfinv_() |
|
|
| |
| tensor.mul_(std * math.sqrt(2.)) |
| tensor.add_(mean) |
|
|
| |
| tensor.clamp_(min=a, max=b) |
| return tensor |
|
|
|
|
| def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): |
| |
| r"""Fills the input Tensor with values drawn from a truncated |
| normal distribution. The values are effectively drawn from the |
| normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` |
| with values outside :math:`[a, b]` redrawn until they are within |
| the bounds. The method used for generating the random values works |
| best when :math:`a \leq \text{mean} \leq b`. |
| Args: |
| tensor: an n-dimensional `torch.Tensor` |
| mean: the mean of the normal distribution |
| std: the standard deviation of the normal distribution |
| a: the minimum cutoff value |
| b: the maximum cutoff value |
| Examples: |
| >>> w = torch.empty(3, 5) |
| >>> nn.init.trunc_normal_(w) |
| """ |
| return _no_gradim_audiorunc_normal_(tensor, mean, std, a, b) |
|
|
|
|
| def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): |
| fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) |
| if mode == 'fan_in': |
| denom = fan_in |
| elif mode == 'fan_out': |
| denom = fan_out |
| elif mode == 'fan_avg': |
| denom = (fan_in + fan_out) / 2 |
|
|
| variance = scale / denom |
|
|
| if distribution == "truncated_normal": |
| |
| trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978) |
| elif distribution == "normal": |
| tensor.normal_(std=math.sqrt(variance)) |
| elif distribution == "uniform": |
| bound = math.sqrt(3 * variance) |
| tensor.uniform_(-bound, bound) |
| else: |
| raise ValueError(f"invalid distribution {distribution}") |
|
|
|
|
| def lecun_normal_(tensor): |
| variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal') |
|
|
|
|
| |
| |
|
|
| def window_partition(x, window_size): |
| """ |
| Args: |
| x: (B, H, W, C) |
| window_size (int): window size |
| Returns: |
| windows: (num_windows*B, window_size, window_size, C) |
| """ |
| B, H, W, C = x.shape |
| x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) |
| windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) |
| return windows |
|
|
|
|
| def window_reverse(windows, window_size, H, W): |
| """ |
| Args: |
| windows: (num_windows*B, window_size, window_size, C) |
| window_size (int): Window size |
| H (int): Height of image |
| W (int): Width of image |
| Returns: |
| x: (B, H, W, C) |
| """ |
| B = int(windows.shape[0] / (H * W / window_size / window_size)) |
| x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) |
| x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) |
| return x |
|
|
|
|
| class WindowAttention(nn.Module): |
| r""" Window based multi-head self attention (W-MSA) module with relative position bias. |
| It supports both of shifted and non-shifted window. |
| Args: |
| dim (int): Number of input channels. |
| window_size (tuple[int]): The height and width of the window. |
| num_heads (int): Number of attention heads. |
| qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True |
| qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set |
| attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 |
| proj_drop (float, optional): Dropout ratio of output. Default: 0.0 |
| """ |
|
|
| def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): |
|
|
| super().__init__() |
| self.dim = dim |
| self.window_size = window_size |
| self.num_heads = num_heads |
| head_dim = dim // num_heads |
| self.scale = qk_scale or head_dim ** -0.5 |
|
|
| |
| self.relative_position_bias_table = nn.Parameter( |
| torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) |
|
|
| |
| coords_h = torch.arange(self.window_size[0]) |
| coords_w = torch.arange(self.window_size[1]) |
| coords = torch.stack(torch.meshgrid([coords_h, coords_w])) |
| coords_flatten = torch.flatten(coords, 1) |
| relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] |
| relative_coords = relative_coords.permute(1, 2, 0).contiguous() |
| relative_coords[:, :, 0] += self.window_size[0] - 1 |
| relative_coords[:, :, 1] += self.window_size[1] - 1 |
| relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 |
| relative_position_index = relative_coords.sum(-1) |
| self.register_buffer("relative_position_index", relative_position_index) |
|
|
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
| self.attn_drop = nn.Dropout(attn_drop) |
| self.proj = nn.Linear(dim, dim) |
| self.proj_drop = nn.Dropout(proj_drop) |
|
|
| trunc_normal_(self.relative_position_bias_table, std=.02) |
| self.softmax = nn.Softmax(dim=-1) |
|
|
| def forward(self, x, mask=None): |
| """ |
| Args: |
| x: input features with shape of (num_windows*B, N, C) |
| mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None |
| """ |
| B_, N, C = x.shape |
| qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
| q = q * self.scale |
| attn = (q @ k.transpose(-2, -1)) |
|
|
| relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( |
| self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) |
| relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() |
| attn = attn + relative_position_bias.unsqueeze(0) |
|
|
| if mask is not None: |
| nW = mask.shape[0] |
| attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) |
| attn = attn.view(-1, self.num_heads, N, N) |
| attn = self.softmax(attn) |
| else: |
| attn = self.softmax(attn) |
|
|
| attn = self.attn_drop(attn) |
|
|
| x = (attn @ v).transpose(1, 2).reshape(B_, N, C) |
| x = self.proj(x) |
| x = self.proj_drop(x) |
| return x, attn |
|
|
| def extra_repr(self): |
| return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' |
|
|
|
|
| |
| class SwinTransformerBlock(nn.Module): |
| r""" Swin Transformer Block. |
| Args: |
| dim (int): Number of input channels. |
| input_resolution (tuple[int]): Input resulotion. |
| num_heads (int): Number of attention heads. |
| window_size (int): Window size. |
| shift_size (int): Shift size for SW-MSA. |
| mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. |
| qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True |
| qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. |
| drop (float, optional): Dropout rate. Default: 0.0 |
| attn_drop (float, optional): Attention dropout rate. Default: 0.0 |
| drop_path (float, optional): Stochastic depth rate. Default: 0.0 |
| act_layer (nn.Module, optional): Activation layer. Default: nn.GELU |
| norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm |
| """ |
|
|
| def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, |
| mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., |
| act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_before_mlp='ln'): |
| super().__init__() |
| self.dim = dim |
| self.input_resolution = input_resolution |
| self.num_heads = num_heads |
| self.window_size = window_size |
| self.shift_size = shift_size |
| self.mlp_ratio = mlp_ratio |
| self.norm_before_mlp = norm_before_mlp |
| if min(self.input_resolution) <= self.window_size: |
| |
| self.shift_size = 0 |
| self.window_size = min(self.input_resolution) |
| assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" |
|
|
| self.norm1 = norm_layer(dim) |
| self.attn = WindowAttention( |
| dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, |
| qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) |
|
|
| self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
| if self.norm_before_mlp == 'ln': |
| self.norm2 = nn.LayerNorm(dim) |
| elif self.norm_before_mlp == 'bn': |
| self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose(1, 2) |
| else: |
| raise NotImplementedError |
| mlp_hidden_dim = int(dim * mlp_ratio) |
| self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) |
|
|
| if self.shift_size > 0: |
| |
| H, W = self.input_resolution |
| img_mask = torch.zeros((1, H, W, 1)) |
| h_slices = (slice(0, -self.window_size), |
| slice(-self.window_size, -self.shift_size), |
| slice(-self.shift_size, None)) |
| w_slices = (slice(0, -self.window_size), |
| slice(-self.window_size, -self.shift_size), |
| slice(-self.shift_size, None)) |
| cnt = 0 |
| for h in h_slices: |
| for w in w_slices: |
| img_mask[:, h, w, :] = cnt |
| cnt += 1 |
|
|
| mask_windows = window_partition(img_mask, self.window_size) |
| mask_windows = mask_windows.view(-1, self.window_size * self.window_size) |
| attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) |
| attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) |
| else: |
| attn_mask = None |
|
|
| self.register_buffer("attn_mask", attn_mask) |
|
|
| def forward(self, x): |
| |
| H, W = self.input_resolution |
| |
| |
| |
| B, L, C = x.shape |
| |
|
|
| shortcut = x |
| x = self.norm1(x) |
| x = x.view(B, H, W, C) |
|
|
| |
| if self.shift_size > 0: |
| shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) |
| else: |
| shifted_x = x |
|
|
| |
| x_windows = window_partition(shifted_x, self.window_size) |
| x_windows = x_windows.view(-1, self.window_size * self.window_size, C) |
|
|
| |
| attn_windows, attn = self.attn(x_windows, mask=self.attn_mask) |
|
|
| |
| attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) |
| shifted_x = window_reverse(attn_windows, self.window_size, H, W) |
|
|
| |
| if self.shift_size > 0: |
| x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) |
| else: |
| x = shifted_x |
| x = x.view(B, H * W, C) |
|
|
| |
| x = shortcut + self.drop_path(x) |
| x = x + self.drop_path(self.mlp(self.norm2(x))) |
|
|
| return x, attn |
|
|
| def extra_repr(self): |
| return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ |
| f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" |
|
|
|
|
|
|
| class PatchMerging(nn.Module): |
| r""" Patch Merging Layer. |
| Args: |
| input_resolution (tuple[int]): Resolution of input feature. |
| dim (int): Number of input channels. |
| norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm |
| """ |
|
|
| def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): |
| super().__init__() |
| self.input_resolution = input_resolution |
| self.dim = dim |
| self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) |
| self.norm = norm_layer(4 * dim) |
|
|
| def forward(self, x): |
| """ |
| x: B, H*W, C |
| """ |
| H, W = self.input_resolution |
| B, L, C = x.shape |
| assert L == H * W, "input feature has wrong size" |
| assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." |
|
|
| x = x.view(B, H, W, C) |
|
|
| x0 = x[:, 0::2, 0::2, :] |
| x1 = x[:, 1::2, 0::2, :] |
| x2 = x[:, 0::2, 1::2, :] |
| x3 = x[:, 1::2, 1::2, :] |
| x = torch.cat([x0, x1, x2, x3], -1) |
| x = x.view(B, -1, 4 * C) |
|
|
| x = self.norm(x) |
| x = self.reduction(x) |
|
|
| return x |
|
|
| def extra_repr(self): |
| return f"input_resolution={self.input_resolution}, dim={self.dim}" |
|
|
|
|
| class BasicLayer(nn.Module): |
| """ A basic Swin Transformer layer for one stage. |
| Args: |
| dim (int): Number of input channels. |
| input_resolution (tuple[int]): Input resolution. |
| depth (int): Number of blocks. |
| num_heads (int): Number of attention heads. |
| window_size (int): Local window size. |
| mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. |
| qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True |
| qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. |
| drop (float, optional): Dropout rate. Default: 0.0 |
| attn_drop (float, optional): Attention dropout rate. Default: 0.0 |
| drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 |
| norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm |
| downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None |
| use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. |
| """ |
|
|
| def __init__(self, dim, input_resolution, depth, num_heads, window_size, |
| mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., |
| drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, |
| norm_before_mlp='ln'): |
|
|
| super().__init__() |
| self.dim = dim |
| self.input_resolution = input_resolution |
| self.depth = depth |
| self.use_checkpoint = use_checkpoint |
|
|
| |
| self.blocks = nn.ModuleList([ |
| SwinTransformerBlock(dim=dim, input_resolution=input_resolution, |
| num_heads=num_heads, window_size=window_size, |
| shift_size=0 if (i % 2 == 0) else window_size // 2, |
| mlp_ratio=mlp_ratio, |
| qkv_bias=qkv_bias, qk_scale=qk_scale, |
| drop=drop, attn_drop=attn_drop, |
| drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, |
| norm_layer=norm_layer, norm_before_mlp=norm_before_mlp) |
| for i in range(depth)]) |
|
|
| |
| if downsample is not None: |
| self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) |
| else: |
| self.downsample = None |
|
|
| def forward(self, x): |
| attns = [] |
| for blk in self.blocks: |
| if self.use_checkpoint: |
| x = checkpoint.checkpoint(blk, x) |
| else: |
| x, attn = blk(x) |
| if not self.training: |
| attns.append(attn.unsqueeze(0)) |
| if self.downsample is not None: |
| x = self.downsample(x) |
| if not self.training: |
| attn = torch.cat(attns, dim = 0) |
| attn = torch.mean(attn, dim = 0) |
| return x, attn |
|
|
| def extra_repr(self): |
| return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" |
|
|
|
|
| |
| class HTSAT_Swin_Transformer(nn.Module): |
| r"""HTSAT based on the Swin Transformer |
| Args: |
| spec_size (int | tuple(int)): Input Spectrogram size. Default 256 |
| patch_size (int | tuple(int)): Patch size. Default: 4 |
| path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4 |
| in_chans (int): Number of input image channels. Default: 1 (mono) |
| num_classes (int): Number of classes for classification head. Default: 527 |
| embed_dim (int): Patch embedding dimension. Default: 96 |
| depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer. |
| num_heads (tuple(int)): Number of attention heads in different layers. |
| window_size (int): Window size. Default: 8 |
| mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 |
| qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True |
| qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None |
| drop_rate (float): Dropout rate. Default: 0 |
| attn_drop_rate (float): Attention dropout rate. Default: 0 |
| drop_path_rate (float): Stochastic depth rate. Default: 0.1 |
| norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. |
| ape (bool): If True, add absolute position embedding to the patch embedding. Default: False |
| patch_norm (bool): If True, add normalization after patch embedding. Default: True |
| use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False |
| config (module): The configuration Module from config.py (HTSATConfig Class) |
| """ |
|
|
| def __init__(self, spec_size=256, patch_size=4, patch_stride=(4,4), |
| in_chans=1, num_classes=527, |
| embed_dim=96, depths=[2, 2, 6, 2], num_heads=[4, 8, 16, 32], |
| window_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None, |
| drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, |
| norm_layer=nn.LayerNorm, |
| ape=False, patch_norm=True, |
| use_checkpoint=False, norm_before_mlp='ln', config = None, **kwargs): |
| super(HTSAT_Swin_Transformer, self).__init__() |
|
|
| self.config = config |
| self.spec_size = spec_size |
| self.patch_stride = patch_stride |
| self.patch_size = patch_size |
| self.window_size = window_size |
| self.embed_dim = embed_dim |
| self.depths = depths |
| self.ape = ape |
| self.in_chans = in_chans |
| self.num_classes = num_classes |
| self.num_heads = num_heads |
| self.num_layers = len(self.depths) |
| self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1)) |
| |
| self.drop_rate = drop_rate |
| self.attn_drop_rate = attn_drop_rate |
| self.drop_path_rate = drop_path_rate |
|
|
| self.qkv_bias = qkv_bias |
| self.qk_scale = None |
|
|
| self.patch_norm = patch_norm |
| self.norm_layer = norm_layer if self.patch_norm else None |
| self.norm_before_mlp = norm_before_mlp |
| self.mlp_ratio = mlp_ratio |
|
|
| self.use_checkpoint = use_checkpoint |
|
|
| |
| self.freq_ratio = self.spec_size // self.config.mel_bins |
| window = 'hann' |
| center = True |
| pad_mode = 'reflect' |
| ref = 1.0 |
| amin = 1e-10 |
| top_db = None |
| self.interpolate_ratio = 32 |
| |
| self.spectrogram_extractor = Spectrogram(n_fft=config.window_size, hop_length=config.hop_size, |
| win_length=config.window_size, window=window, center=center, pad_mode=pad_mode, |
| freeze_parameters=True) |
| |
| self.logmel_extractor = LogmelFilterBank(sr=config.sample_rate, n_fft=config.window_size, |
| n_mels=config.mel_bins, fmin=config.fmin, fmax=config.fmax, ref=ref, amin=amin, top_db=top_db, |
| freeze_parameters=True) |
| |
| self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, |
| freq_drop_width=8, freq_stripes_num=2) |
| self.bn0 = nn.BatchNorm2d(self.config.mel_bins) |
|
|
|
|
| |
| self.patch_embed = PatchEmbed( |
| img_size=self.spec_size, patch_size=self.patch_size, in_chans=self.in_chans, |
| embed_dim=self.embed_dim, norm_layer=self.norm_layer, patch_stride = patch_stride) |
|
|
| num_patches = self.patch_embed.num_patches |
| patches_resolution = self.patch_embed.grid_size |
| self.patches_resolution = patches_resolution |
|
|
| |
| if self.ape: |
| self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dim)) |
| trunc_normal_(self.absolute_pos_embed, std=.02) |
|
|
| self.pos_drop = nn.Dropout(p=self.drop_rate) |
|
|
| |
| dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths))] |
|
|
| |
| self.layers = nn.ModuleList() |
| for i_layer in range(self.num_layers): |
| layer = BasicLayer(dim=int(self.embed_dim * 2 ** i_layer), |
| input_resolution=(patches_resolution[0] // (2 ** i_layer), |
| patches_resolution[1] // (2 ** i_layer)), |
| depth=self.depths[i_layer], |
| num_heads=self.num_heads[i_layer], |
| window_size=self.window_size, |
| mlp_ratio=self.mlp_ratio, |
| qkv_bias=self.qkv_bias, qk_scale=self.qk_scale, |
| drop=self.drop_rate, attn_drop=self.attn_drop_rate, |
| drop_path=dpr[sum(self.depths[:i_layer]):sum(self.depths[:i_layer + 1])], |
| norm_layer=self.norm_layer, |
| downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, |
| use_checkpoint=use_checkpoint, |
| norm_before_mlp=self.norm_before_mlp) |
| self.layers.append(layer) |
|
|
| self.norm = self.norm_layer(self.num_features) |
| self.avgpool = nn.AdaptiveAvgPool1d(1) |
| self.maxpool = nn.AdaptiveMaxPool1d(1) |
|
|
| if self.config.enable_tscam: |
| SF = self.spec_size // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] // self.freq_ratio |
| self.tscam_conv = nn.Conv2d( |
| in_channels = self.num_features, |
| out_channels = self.num_classes, |
| kernel_size = (SF,3), |
| padding = (0,1) |
| ) |
| self.head = nn.Linear(num_classes, num_classes) |
| else: |
| self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() |
|
|
| self.apply(self._init_weights) |
|
|
| def _init_weights(self, m): |
| if isinstance(m, nn.Linear): |
| trunc_normal_(m.weight, std=.02) |
| if isinstance(m, nn.Linear) and m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, nn.LayerNorm): |
| nn.init.constant_(m.bias, 0) |
| nn.init.constant_(m.weight, 1.0) |
|
|
| @torch.jit.ignore |
| def no_weight_decay(self): |
| return {'absolute_pos_embed'} |
|
|
| @torch.jit.ignore |
| def no_weight_decay_keywords(self): |
| return {'relative_position_bias_table'} |
|
|
| def forward_features(self, x): |
| frames_num = x.shape[2] |
| x = self.patch_embed(x) |
| if self.ape: |
| x = x + self.absolute_pos_embed |
| x = self.pos_drop(x) |
| for i, layer in enumerate(self.layers): |
| x, attn = layer(x) |
|
|
| if self.config.enable_tscam: |
| |
| x = self.norm(x) |
| B, N, C = x.shape |
| SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] |
| ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1] |
| x = x.permute(0,2,1).contiguous().reshape(B, C, SF, ST) |
| B, C, F, T = x.shape |
| |
| c_freq_bin = F // self.freq_ratio |
| x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T) |
| x = x.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1) |
|
|
| |
| latent_output = self.avgpool(torch.flatten(x,2)) |
| latent_output = torch.flatten(latent_output, 1) |
|
|
| |
| if self.config.htsat_attn_heatmap: |
| |
| attn = torch.mean(attn, dim = 1) |
| attn = torch.mean(attn, dim = 1) |
| attn = attn.reshape(B, SF, ST) |
| c_freq_bin = SF // self.freq_ratio |
| attn = attn.reshape(B, SF // c_freq_bin, c_freq_bin, ST) |
| attn = attn.permute(0,2,1,3).contiguous().reshape(B, c_freq_bin, -1) |
| attn = attn.mean(dim = 1) |
| attn_max = torch.max(attn, dim = 1, keepdim = True)[0] |
| attn_min = torch.min(attn, dim = 1, keepdim = True)[0] |
| attn = ((attn * 0.15) + (attn_max * 0.85 - attn_min)) / (attn_max - attn_min) |
| attn = attn.unsqueeze(dim = 2) |
|
|
| x = self.tscam_conv(x) |
| x = torch.flatten(x, 2) |
|
|
| if self.config.htsat_attn_heatmap: |
| fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous() * attn, 8 * self.patch_stride[1]) |
| else: |
| fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous(), 8 * self.patch_stride[1]) |
| |
| x = self.avgpool(x) |
| x = torch.flatten(x, 1) |
|
|
| if self.config.loss_type == "clip_ce": |
| output_dict = { |
| 'framewise_output': fpx, |
| 'clipwise_output': x, |
| 'latent_output': latent_output |
| } |
| else: |
| output_dict = { |
| 'framewise_output': fpx, |
| 'clipwise_output': torch.sigmoid(x), |
| 'latent_output': latent_output |
| } |
| |
| else: |
| x = self.norm(x) |
| B, N, C = x.shape |
| |
| fpx = x.permute(0,2,1).contiguous().reshape(B, C, frames_num // (2 ** (len(self.depths) + 1)), frames_num // (2 ** (len(self.depths) + 1)) ) |
| B, C, F, T = fpx.shape |
| c_freq_bin = F // self.freq_ratio |
| fpx = fpx.reshape(B, C, F // c_freq_bin, c_freq_bin, T) |
| fpx = fpx.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1) |
| fpx = torch.sum(fpx, dim = 2) |
| fpx = interpolate(fpx.permute(0,2,1).contiguous(), 8 * self.patch_stride[1]) |
| x = self.avgpool(x.transpose(1, 2)) |
| x = torch.flatten(x, 1) |
| if self.num_classes > 0: |
| x = self.head(x) |
| fpx = self.head(fpx) |
| output_dict = {'framewise_output': torch.sigmoid(fpx), |
| 'clipwise_output': torch.sigmoid(x)} |
| return output_dict |
|
|
| def crop_wav(self, x, crop_size, spe_pos = None): |
| time_steps = x.shape[2] |
| tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device) |
| for i in range(len(x)): |
| if spe_pos is None: |
| crop_pos = random.randint(0, time_steps - crop_size - 1) |
| else: |
| crop_pos = spe_pos |
| tx[i][0] = x[i, 0, crop_pos:crop_pos + crop_size,:] |
| return tx |
|
|
| |
| def reshape_wav2img(self, x): |
| B, C, T, F = x.shape |
| target_T = int(self.spec_size * self.freq_ratio) |
| target_F = self.spec_size // self.freq_ratio |
| assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size" |
| |
| if T < target_T: |
| x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True) |
| if F < target_F: |
| x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True) |
| x = x.permute(0,1,3,2).contiguous() |
| x = x.reshape(x.shape[0], x.shape[1], x.shape[2], self.freq_ratio, x.shape[3] // self.freq_ratio) |
| |
| x = x.permute(0,1,3,2,4).contiguous() |
| x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4]) |
| return x |
| |
| |
| def repeat_wat2img(self, x, cur_pos): |
| B, C, T, F = x.shape |
| target_T = int(self.spec_size * self.freq_ratio) |
| target_F = self.spec_size // self.freq_ratio |
| assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size" |
| |
| if T < target_T: |
| x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True) |
| if F < target_F: |
| x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True) |
| x = x.permute(0,1,3,2).contiguous() |
| x = x[:,:,:,cur_pos:cur_pos + self.spec_size] |
| x = x.repeat(repeats = (1,1,4,1)) |
| return x |
|
|
| def forward(self, x: torch.Tensor, mixup_lambda = None, infer_mode = False): |
| x = self.spectrogram_extractor(x) |
| x = self.logmel_extractor(x) |
| |
| |
| x = x.transpose(1, 3) |
| x = self.bn0(x) |
| x = x.transpose(1, 3) |
| if self.training: |
| x = self.spec_augmenter(x) |
| if self.training and mixup_lambda is not None: |
| x = do_mixup(x, mixup_lambda) |
| |
| if infer_mode: |
| |
| frame_num = x.shape[2] |
| target_T = int(self.spec_size * self.freq_ratio) |
| repeat_ratio = math.floor(target_T / frame_num) |
| x = x.repeat(repeats=(1,1,repeat_ratio,1)) |
| x = self.reshape_wav2img(x) |
| output_dict = self.forward_features(x) |
| elif self.config.enable_repeat_mode: |
| if self.training: |
| cur_pos = random.randint(0, (self.freq_ratio - 1) * self.spec_size - 1) |
| x = self.repeat_wat2img(x, cur_pos) |
| output_dict = self.forward_features(x) |
| else: |
| output_dicts = [] |
| for cur_pos in range(0, (self.freq_ratio - 1) * self.spec_size + 1, self.spec_size): |
| tx = x.clone() |
| tx = self.repeat_wat2img(tx, cur_pos) |
| output_dicts.append(self.forward_features(tx)) |
| clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device) |
| framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device) |
| for d in output_dicts: |
| clipwise_output += d["clipwise_output"] |
| framewise_output += d["framewise_output"] |
| clipwise_output = clipwise_output / len(output_dicts) |
| framewise_output = framewise_output / len(output_dicts) |
|
|
| output_dict = { |
| 'framewise_output': framewise_output, |
| 'clipwise_output': clipwise_output |
| } |
| else: |
| if x.shape[2] > self.freq_ratio * self.spec_size: |
| if self.training: |
| x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size) |
| x = self.reshape_wav2img(x) |
| output_dict = self.forward_features(x) |
| else: |
| |
| overlap_size = 344 |
| output_dicts = [] |
| crop_size = 689 |
| for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size): |
| tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos) |
| tx = self.reshape_wav2img(tx) |
| output_dicts.append(self.forward_features(tx)) |
| clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device) |
| framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device) |
| latent_output = torch.zeros_like(output_dicts[0]["latent_output"]).float().to(x.device) |
| for d in output_dicts: |
| clipwise_output += d["clipwise_output"] |
| framewise_output += d["framewise_output"] |
| latent_output += d["latent_output"] |
| clipwise_output = clipwise_output / len(output_dicts) |
| framewise_output = framewise_output / len(output_dicts) |
| latent_output = latent_output / len(output_dicts) |
| output_dict = { |
| 'framewise_output': framewise_output, |
| 'clipwise_output': clipwise_output, |
| 'latent_output': latent_output, |
| } |
| else: |
| x = self.reshape_wav2img(x) |
| output_dict = self.forward_features(x) |
| |
| return output_dict |
|
|
| class HTSATWrapper(nn.Module): |
| def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, |
| fmax, classes_num, out_emb): |
| super().__init__() |
|
|
| |
| |
| |
|
|
| self.htsat = HTSAT_Swin_Transformer(config=HTSATConfig()) |
|
|
| def forward(self, x): |
| out_dict = self.htsat(x) |
| out_dict['embedding'] = out_dict['latent_output'] |
| return out_dict |
|
|
|
|
| def get_audio_encoder(name: str): |
| if name == "HTSAT": |
| return HTSATWrapper |
| else: |
| raise Exception('The audio encoder name {} is incorrect or not supported'.format(name)) |
|
|
| class Projection(nn.Module): |
| def __init__(self, dim_imgn: int, d_out: int, p: float=0.5) -> None: |
| super().__init__() |
| self.linear1 = nn.Linear(dim_imgn, d_out, bias=False) |
| self.linear2 = nn.Linear(d_out, d_out, bias=False) |
| self.layer_norm = nn.LayerNorm(d_out) |
| self.drop = nn.Dropout(p) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| embed1 = self.linear1(x) |
| embed2 = self.drop(self.linear2(F.gelu(embed1))) |
| embeds = self.layer_norm(embed1 + embed2) |
| return embeds |
|
|
| class AudioEncoder(nn.Module): |
| def __init__(self, audioenc_name:str, dim_imgn: int, d_out: int, sample_rate: int, window_size: int, |
| hop_size: int, mel_bins: int, fmin: int, fmax: int, classes_num: int) -> None: |
| super().__init__() |
|
|
| audio_encoder = get_audio_encoder(audioenc_name) |
|
|
| self.base = audio_encoder( |
| sample_rate, window_size, |
| hop_size, mel_bins, fmin, fmax, |
| classes_num, dim_imgn) |
|
|
| self.projection = Projection(dim_imgn, d_out) |
|
|
| def forward(self, x): |
| out_dict = self.base(x) |
| audio_features, audio_classification_output = out_dict['embedding'], out_dict['clipwise_output'] |
| projected_vec = self.projection(audio_features) |
| return projected_vec, audio_classification_output |
|
|
| class TextEncoder(nn.Module): |
| def __init__(self, d_out: int, text_model: str, transformer_embed_dim: int) -> None: |
| super().__init__() |
| self.text_model = text_model |
| self.base = AutoModel.from_pretrained(text_model) |
|
|
| if 'clip' in text_model: |
| self.clip_text_projection = self.base.text_projection |
| self.base = self.base.text_model |
| if 'base' in text_model: |
| transformer_embed_dim = 512 |
| |
| self.projection = Projection(transformer_embed_dim, d_out) |
|
|
| def forward(self, x): |
| if 'clip' in self.text_model: |
| pooled_output = self.base(**x)[1] |
| out = self.clip_text_projection(pooled_output) |
| elif 'gpt' in self.text_model: |
| batch_size = x['input_ids'].shape[0] |
| hidden_states = self.base(**x)[0] |
|
|
| sequence_lengths = torch.ne(x['input_ids'], 0).sum(-1) - 1 |
| out = hidden_states[torch.arange(batch_size, device=hidden_states.device), sequence_lengths] |
| else: |
| out = self.base(**x)[0] |
| out = out[:, 0, :] |
| |
| projected_vec = self.projection(out) |
|
|
| return projected_vec |
|
|
| class CLAP(nn.Module): |
| def __init__(self, |
| |
| audioenc_name: str, |
| sample_rate: int, |
| window_size: int, |
| hop_size: int, |
| mel_bins: int, |
| fmin: int, |
| fmax: int, |
| classes_num: int, |
| out_emb: int, |
| |
| text_model: str, |
| transformer_embed_dim: int, |
| |
| d_proj: int, |
| ): |
| super().__init__() |
|
|
| |
| self.audio_encoder = AudioEncoder( |
| audioenc_name, out_emb, d_proj, |
| sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num) |
|
|
| self.caption_encoder = TextEncoder( |
| d_proj, text_model, transformer_embed_dim |
| ) |
|
|
| self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) |
|
|
| def forward(self, audio, text): |
| audio_embed, _ = self.audio_encoder(audio) |
| caption_embed = self.caption_encoder(text) |
|
|
| return caption_embed, audio_embed, self.logit_scale.exp() |
| |
| |
| |
| |
| |
| |
| def read_audio(audio_path, resample=True): |
| r"""Loads audio file or array and returns a torch tensor""" |
| |
| audio_time_series, sample_rate = torchaudio.load(audio_path) |
|
|
| resample_rate = clapConfig.sample_rate |
| if resample and resample_rate != sample_rate: |
| resampler = T.Resample(sample_rate, resample_rate) |
| audio_time_series = resampler(audio_time_series) |
| return audio_time_series, resample_rate |
|
|
| def load_audio_into_tensor(audio_path, audio_duration, resample=False): |
| r"""Loads audio file and returns raw audio.""" |
| |
| audio_time_series, sample_rate = read_audio(audio_path, resample) |
| audio_time_series = audio_time_series.reshape(-1) |
|
|
| |
| |
| if audio_duration*sample_rate >= audio_time_series.shape[0]: |
| repeat_factor = int(np.ceil((audio_duration*sample_rate) / |
| audio_time_series.shape[0])) |
| |
| audio_time_series = audio_time_series.repeat(repeat_factor) |
| |
| audio_time_series = audio_time_series[0:audio_duration*sample_rate] |
| else: |
| |
| |
| start_index = random.randrange( |
| audio_time_series.shape[0] - audio_duration*sample_rate) |
| audio_time_series = audio_time_series[start_index:start_index + |
| audio_duration*sample_rate] |
| return torch.FloatTensor(audio_time_series) |
|
|
| np_str_obj_array_pattern = re.compile(r'[SaUO]') |
| default_collate_err_msg_format = ( |
| "default_collate: batch must contain tensors, numpy arrays, numbers, " |
| "dicts or lists; found {}") |
|
|
| def default_collate(batch): |
| r"""Puts each data field into a tensor with outer dimension batch size""" |
| elem = batch[0] |
| elem_type = type(elem) |
| if isinstance(elem, torch.Tensor): |
| out = None |
| if torch.utils.data.get_worker_info() is not None: |
| |
| |
| numel = sum([x.numel() for x in batch]) |
| storage = elem.storage()._new_shared(numel) |
| out = elem.new(storage) |
| return torch.stack(batch, 0, out=out) |
| elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ |
| and elem_type.__name__ != 'string_': |
| if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': |
| |
| if np_str_obj_array_pattern.search(elem.dtype.str) is not None: |
| raise TypeError( |
| default_collate_err_msg_format.format(elem.dtype)) |
|
|
| return default_collate([torch.as_tensor(b) for b in batch]) |
| elif elem.shape == (): |
| return torch.as_tensor(batch) |
| elif isinstance(elem, float): |
| return torch.tensor(batch, dtype=torch.float64) |
| elif isinstance(elem, int): |
| return torch.tensor(batch) |
| elif isinstance(elem, str): |
| return batch |
| elif isinstance(elem, collections.abc.Mapping): |
| return {key: default_collate([d[key] for d in batch]) for key in elem} |
| elif isinstance(elem, tuple) and hasattr(elem, '_fields'): |
| return elem_type(*(default_collate(samples) for samples in zip(*batch))) |
| elif isinstance(elem, collections.abc.Sequence): |
| |
| it = iter(batch) |
| elem_size = len(next(it)) |
| if not all(len(elem) == elem_size for elem in it): |
| raise RuntimeError( |
| 'each element in list of batch should be of equal size') |
| transposed = zip(*batch) |
| return [default_collate(samples) for samples in transposed] |
|
|
| raise TypeError(default_collate_err_msg_format.format(elem_type)) |
|
|
| def preprocess_audio(audio_files, resample): |
| r"""Load list of audio files and return raw audio""" |
| audio_tensors = [] |
| for audio_file in audio_files: |
| audio_tensor = load_audio_into_tensor( |
| audio_file, clapConfig.duration, resample) |
| audio_tensor = audio_tensor.reshape(1, -1) |
| audio_tensors.append(audio_tensor) |
| return default_collate(audio_tensors) |
|
|
|
|
|
|
| |
| |
| |
| def CLAPAudioProcessor(audio_files: List[str], resample=True): |
| preprocessed_audio = preprocess_audio(audio_files, resample) |
| preprocessed_audio = preprocessed_audio.reshape( |
| preprocessed_audio.shape[0], preprocessed_audio.shape[2]) |
| preprocessed_audio = preprocessed_audio |
| return preprocessed_audio |
|
|
| def get_audio_embeddings(audio_files: List[str], audio_encoder, resample=True): |
| """Load list of audio files and return audio embeddings""" |
| |
| |
| |
| |
| with torch.no_grad(): |
| preprocessed_audio = CLAPAudioProcessor(audio_files, resample) |
| return audio_encoder(preprocessed_audio)[0] |
|
|
|
|
| |
| |
| |
| class ClapConfig: |
| |
| text_model = 'gpt2' |
| text_len = 77 |
| transformer_embed_dim = 768 |
| freeze_text_encoder_weights = True |
|
|
| |
| audioenc_name = 'HTSAT' |
| out_emb = 768 |
| sample_rate = 44100 |
| duration = 7 |
| fmin = 50 |
| fmax = 8000 |
| n_fft = 1024 |
| hop_size = 320 |
| mel_bins = 64 |
| window_size = 1024 |
|
|
| |
| d_proj = 1024 |
| temperature = 0.003 |
|
|
| |
| num_classes = 527 |
| batch_size = 1024 |
| demo = False |
| |
|
|
| clapConfig = ClapConfig() |
| clap = CLAP( |
| audioenc_name=clapConfig.audioenc_name, |
| sample_rate=clapConfig.sample_rate, |
| window_size=clapConfig.window_size, |
| hop_size=clapConfig.hop_size, |
| mel_bins=clapConfig.mel_bins, |
| fmin=clapConfig.fmin, |
| fmax=clapConfig.fmax, |
| classes_num=clapConfig.num_classes, |
| out_emb=clapConfig.out_emb, |
| text_model=clapConfig.text_model, |
| transformer_embed_dim=clapConfig.transformer_embed_dim, |
| d_proj=clapConfig.d_proj |
| ) |
|
|
| model_repo = "microsoft/msclap" |
| model_name = { |
| '2022': 'CLAP_weights_2022.pth', |
| '2023': 'CLAP_weights_2023.pth', |
| 'clapcap': 'clapcap_weights_2023.pth' |
| } |
|
|
| version = '2023' |
| model_fp = hf_hub_download(model_repo, model_name[version]) |
|
|
| model_state_dict = torch.load(model_fp, map_location=torch.device('cpu'))['model'] |
| clap.load_state_dict(model_state_dict, strict=False) |
| |
|
|
| clap_audio_encoder = clap.audio_encoder.to(device) |
|
|
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
| LoRAconfig = { |
| "peft_type": "LORA", |
| "task_type": "FEATURE_EXTRACTION", |
| "inference_mode": False, |
| "r": 16, |
| "target_modules": ["qkv", "fc1", "fc2", "proj", "linear1", "linear2"], |
| "lora_alpha": 32, |
| "lora_dropout": 0.05, |
| "fan_in_fan_out": False, |
| "bias": "all", |
| } |
| peft_config = get_peft_config(LoRAconfig) |
|
|
| peft_model = get_peft_model(clap_audio_encoder, peft_config) |
|
|
| peft_model.print_trainable_parameters() |
|
|
| peft_clap_audio_encoder = peft_model.base_model |
| |
| |
|
|
|
|
|
|
| |
| |
| |
| |
| |
| |
|
|
|
|
| import os |
| import io |
| import sys |
| import math |
| import random |
| import collections |
| import collections.abc |
| import re |
| from itertools import repeat |
| from pathlib import Path |
| from typing import Optional, Tuple, Union, List, Dict |
|
|
| import csv |
| import numpy as np |
| import pandas as pd |
| from PIL import Image |
| import seaborn as sns |
| import matplotlib.pyplot as plt |
| from tqdm import trange, tqdm |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import nn |
| from torch.nn.init import _calculate_fan_in_and_fan_out |
| import torch.utils.checkpoint as checkpoint |
|
|
| import torchvision |
| from torchvision.transforms import v2 |
|
|
| |
| |
| |
|
|
| import torchaudio |
| import torchaudio.transforms as T |
| from torchlibrosa.stft import Spectrogram, LogmelFilterBank |
| from torchlibrosa.augmentation import SpecAugmentation |
|
|
| from transformers import AutoModel, AutoTokenizer, logging |
| from huggingface_hub.file_download import hf_hub_download |
| from huggingface_hub.file_download import hf_hub_download |
| from peft import get_peft_config, get_peft_model |
|
|
| from typing import Any, Dict, Optional, Tuple, Union |
| import numbers |
| import random |
| import warnings |
| from dataclasses import dataclass, asdict |
| from typing import Any, Dict, List, Optional, Sequence, Tuple, Union |
|
|
| import torch |
| import torchvision.transforms.functional as F |
| from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ |
| CenterCrop, ColorJitter, Grayscale |
|
|
| OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) |
| OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) |
| IMAGENET_MEAN = (0.485, 0.456, 0.406) |
| IMAGENET_STD = (0.229, 0.224, 0.225) |
| INCEPTION_MEAN = (0.5, 0.5, 0.5) |
| INCEPTION_STD = (0.5, 0.5, 0.5) |
|
|
| |
| HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin" |
| HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" |
| HF_CONFIG_NAME = 'open_clip_config.json' |
|
|
|
|
| import collections.abc |
| from itertools import repeat |
| from typing import List, Optional, Tuple, Union |
|
|
| import torch |
| from torch import nn as nn |
| from torch import _assert |
| from torchvision.ops.misc import FrozenBatchNorm2d |
|
|
|
|
| def freeze_batch_norm_2d(module, module_match={}, name=''): |
| """ |
| Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is |
| itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and |
| returned. Otherwise, the module is walked recursively and submodules are converted in place. |
| |
| Args: |
| module (torch.nn.Module): Any PyTorch module. |
| module_match (dict): Dictionary of full module names to freeze (all if empty) |
| name (str): Full module name (prefix) |
| |
| Returns: |
| torch.nn.Module: Resulting module |
| |
| Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 |
| """ |
| res = module |
| is_match = True |
| if module_match: |
| is_match = name in module_match |
| if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): |
| res = FrozenBatchNorm2d(module.num_features) |
| res.num_features = module.num_features |
| res.affine = module.affine |
| if module.affine: |
| res.weight.data = module.weight.data.clone().detach() |
| res.bias.data = module.bias.data.clone().detach() |
| res.running_mean.data = module.running_mean.data |
| res.running_var.data = module.running_var.data |
| res.eps = module.eps |
| else: |
| for child_name, child in module.named_children(): |
| full_child_name = '.'.join([name, child_name]) if name else child_name |
| new_child = freeze_batch_norm_2d(child, module_match, full_child_name) |
| if new_child is not child: |
| res.add_module(child_name, new_child) |
| return res |
|
|
|
|
| |
| def _ntuple(n): |
| def parse(x): |
| if isinstance(x, collections.abc.Iterable): |
| return x |
| return tuple(repeat(x, n)) |
| return parse |
|
|
|
|
| to_1tuple = _ntuple(1) |
| to_2tuple = _ntuple(2) |
| to_3tuple = _ntuple(3) |
| to_4tuple = _ntuple(4) |
| to_ntuple = lambda n, x: _ntuple(n)(x) |
|
|
| |
| |
| def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True): |
| for name, module in model.named_children(): |
| if len(list(module.children())) > 0: |
| replace_linear(module, linear_replacement, include_modules, copy_weights) |
|
|
| if isinstance(module, torch.nn.Linear) and name in include_modules: |
| old_module = model._modules[name] |
| model._modules[name] = linear_replacement( |
| module.in_features, |
| module.out_features, |
| module.bias is not None, |
| ) |
| if copy_weights: |
| model._modules[name].weight.data.copy_(old_module.weight.data) |
| if model._modules[name].bias is not None: |
| model._modules[name].bias.data.copy_(old_module.bias) |
|
|
| return model |
|
|
| def convert_int8_model_to_inference_mode(model): |
| for m in model.modules(): |
| if hasattr(m, 'prepare_for_eval'): |
| int8_original_dtype = m.weight.dtype |
| m.prepare_for_eval() |
| m.int8_original_dtype = int8_original_dtype |
|
|
|
|
| def feature_take_indices( |
| num_features: int, |
| indices: Optional[Union[int, List[int]]] = None, |
| as_set: bool = False, |
| ) -> Tuple[List[int], int]: |
| """ Determine the absolute feature indices to 'take' from. |
| |
| Note: This function can be called in forward() so must be torchscript compatible, |
| which requires some incomplete typing and workaround hacks. |
| |
| Args: |
| num_features: total number of features to select from |
| indices: indices to select, |
| None -> select all |
| int -> select last n |
| list/tuple of int -> return specified (-ve indices specify from end) |
| as_set: return as a set |
| |
| Returns: |
| List (or set) of absolute (from beginning) indices, Maximum index |
| """ |
| if indices is None: |
| indices = num_features |
|
|
| if isinstance(indices, int): |
| |
| _assert(0 < indices <= num_features, f'last-n ({indices}) is out of range (1 to {num_features})') |
| take_indices = [num_features - indices + i for i in range(indices)] |
| else: |
| take_indices: List[int] = [] |
| for i in indices: |
| idx = num_features + i if i < 0 else i |
| _assert(0 <= idx < num_features, f'feature index {idx} is out of range (0 to {num_features - 1})') |
| take_indices.append(idx) |
|
|
| if not torch.jit.is_scripting() and as_set: |
| return set(take_indices), max(take_indices) |
|
|
| return take_indices, max(take_indices) |
|
|
|
|
| def _out_indices_as_tuple(x: Union[int, Tuple[int, ...]]) -> Tuple[int, ...]: |
| if isinstance(x, int): |
| |
| return tuple(range(-x, 0)) |
| return tuple(x) |
|
|
|
|
|
|
| import copy |
| import copy |
| import hashlib |
| import os |
| import urllib |
| import warnings |
| from functools import partial |
| from typing import Dict, Iterable, Optional, Union |
|
|
| from tqdm import tqdm |
|
|
|
|
| try: |
| import safetensors.torch |
| _has_safetensors = True |
| except ImportError: |
| _has_safetensors = False |
|
|
| __version__ = '2.32.0' |
|
|
|
|
| """ CLIP Model |
| |
| Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. |
| """ |
| import copy |
| import logging |
| import math |
| from dataclasses import dataclass |
| from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from torch import nn |
| from torch.utils.checkpoint import checkpoint |
| from functools import partial |
|
|
| |
| |
| from collections import OrderedDict |
| import math |
| from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union |
|
|
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
| from torch.utils.checkpoint import checkpoint |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| import numpy as np |
|
|
| import torch |
|
|
| |
| |
| |
| |
| |
| |
| def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): |
| """ |
| grid_size: int of the grid height and width |
| return: |
| pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) |
| """ |
| grid_h = np.arange(grid_size, dtype=np.float32) |
| grid_w = np.arange(grid_size, dtype=np.float32) |
| grid = np.meshgrid(grid_w, grid_h) |
| grid = np.stack(grid, axis=0) |
|
|
| grid = grid.reshape([2, 1, grid_size, grid_size]) |
| pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) |
| if cls_token: |
| pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) |
| return pos_embed |
|
|
|
|
| def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): |
| assert embed_dim % 2 == 0 |
|
|
| |
| emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) |
| emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) |
|
|
| emb = np.concatenate([emb_h, emb_w], axis=1) |
| return emb |
|
|
|
|
| def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): |
| """ |
| embed_dim: output dimension for each position |
| pos: a list of positions to be encoded: size (M,) |
| out: (M, D) |
| """ |
| assert embed_dim % 2 == 0 |
| omega = np.arange(embed_dim // 2, dtype=float) |
| omega /= embed_dim / 2. |
| omega = 1. / 10000**omega |
|
|
| pos = pos.reshape(-1) |
| out = np.einsum('m,d->md', pos, omega) |
|
|
| emb_sin = np.sin(out) |
| emb_cos = np.cos(out) |
|
|
| emb = np.concatenate([emb_sin, emb_cos], axis=1) |
| return emb |
|
|
|
|
| |
| |
| |
| |
| |
| def interpolate_pos_embed(model, checkpoint_model): |
| if 'pos_embed' in checkpoint_model: |
| pos_embed_checkpoint = checkpoint_model['pos_embed'] |
| embedding_size = pos_embed_checkpoint.shape[-1] |
| num_patches = model.patch_embed.num_patches |
| num_extra_tokens = model.pos_embed.shape[-2] - num_patches |
| |
| orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) |
| |
| new_size = int(num_patches ** 0.5) |
| |
| if orig_size != new_size: |
| print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) |
| extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] |
| |
| pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] |
| pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) |
| pos_tokens = torch.nn.functional.interpolate( |
| pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) |
| pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) |
| new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) |
| checkpoint_model['pos_embed'] = new_pos_embed |
|
|
|
|
|
|
| from collections import OrderedDict |
| from typing import Dict, List, Optional, Union |
|
|
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
|
|
| |
|
|
|
|
| class Bottleneck(nn.Module): |
| expansion = 4 |
|
|
| def __init__(self, inplanes, planes, stride=1): |
| super().__init__() |
|
|
| |
| self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) |
| self.bn1 = nn.BatchNorm2d(planes) |
| self.act1 = nn.ReLU(inplace=True) |
|
|
| self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) |
| self.bn2 = nn.BatchNorm2d(planes) |
| self.act2 = nn.ReLU(inplace=True) |
|
|
| self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() |
|
|
| self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) |
| self.bn3 = nn.BatchNorm2d(planes * self.expansion) |
| self.act3 = nn.ReLU(inplace=True) |
|
|
| self.downsample = None |
| self.stride = stride |
|
|
| if stride > 1 or inplanes != planes * Bottleneck.expansion: |
| |
| self.downsample = nn.Sequential(OrderedDict([ |
| ("-1", nn.AvgPool2d(stride)), |
| ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), |
| ("1", nn.BatchNorm2d(planes * self.expansion)) |
| ])) |
|
|
| def forward(self, x: torch.Tensor): |
| identity = x |
|
|
| out = self.act1(self.bn1(self.conv1(x))) |
| out = self.act2(self.bn2(self.conv2(out))) |
| out = self.avgpool(out) |
| out = self.bn3(self.conv3(out)) |
|
|
| if self.downsample is not None: |
| identity = self.downsample(x) |
|
|
| out += identity |
| out = self.act3(out) |
| return out |
|
|
|
|
| class AttentionPool2d(nn.Module): |
| def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): |
| super().__init__() |
| self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) |
| self.k_proj = nn.Linear(embed_dim, embed_dim) |
| self.q_proj = nn.Linear(embed_dim, embed_dim) |
| self.v_proj = nn.Linear(embed_dim, embed_dim) |
| self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) |
| self.num_heads = num_heads |
|
|
| def forward(self, x): |
| x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) |
| x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) |
| x = x + self.positional_embedding[:, None, :].to(x.dtype) |
| x, _ = F.multi_head_attention_forward( |
| query=x, key=x, value=x, |
| embed_dim_to_check=x.shape[-1], |
| num_heads=self.num_heads, |
| q_proj_weight=self.q_proj.weight, |
| k_proj_weight=self.k_proj.weight, |
| v_proj_weight=self.v_proj.weight, |
| in_proj_weight=None, |
| in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), |
| bias_k=None, |
| bias_v=None, |
| add_zero_attn=False, |
| dropout_p=0., |
| out_proj_weight=self.c_proj.weight, |
| out_proj_bias=self.c_proj.bias, |
| use_separate_proj_weight=True, |
| training=self.training, |
| need_weights=False |
| ) |
|
|
| return x[0] |
|
|
|
|
| class ModifiedResNet(nn.Module): |
| """ |
| A ResNet class that is similar to torchvision's but contains the following changes: |
| - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. |
| - Performs antialiasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 |
| - The final pooling layer is a QKV attention instead of an average pool |
| """ |
|
|
| def __init__( |
| self, |
| layers: List[int], |
| output_dim: int, |
| heads: int, |
| image_size: int = 224, |
| width: int = 64, |
| ): |
| super().__init__() |
| self.output_dim = output_dim |
| self.image_size = image_size |
|
|
| |
| self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) |
| self.bn1 = nn.BatchNorm2d(width // 2) |
| self.act1 = nn.ReLU(inplace=True) |
| self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) |
| self.bn2 = nn.BatchNorm2d(width // 2) |
| self.act2 = nn.ReLU(inplace=True) |
| self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) |
| self.bn3 = nn.BatchNorm2d(width) |
| self.act3 = nn.ReLU(inplace=True) |
| self.avgpool = nn.AvgPool2d(2) |
|
|
| |
| self._inplanes = width |
| self.layer1 = self._make_layer(width, layers[0]) |
| self.layer2 = self._make_layer(width * 2, layers[1], stride=2) |
| self.layer3 = self._make_layer(width * 4, layers[2], stride=2) |
| self.layer4 = self._make_layer(width * 8, layers[3], stride=2) |
|
|
| embed_dim = width * 32 |
| self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) |
|
|
| self.init_parameters() |
|
|
| def _make_layer(self, planes, blocks, stride=1): |
| layers = [Bottleneck(self._inplanes, planes, stride)] |
|
|
| self._inplanes = planes * Bottleneck.expansion |
| for _ in range(1, blocks): |
| layers.append(Bottleneck(self._inplanes, planes)) |
|
|
| return nn.Sequential(*layers) |
|
|
| def init_parameters(self): |
| if self.attnpool is not None: |
| std = self.attnpool.c_proj.in_features ** -0.5 |
| nn.init.normal_(self.attnpool.q_proj.weight, std=std) |
| nn.init.normal_(self.attnpool.k_proj.weight, std=std) |
| nn.init.normal_(self.attnpool.v_proj.weight, std=std) |
| nn.init.normal_(self.attnpool.c_proj.weight, std=std) |
|
|
| for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: |
| for name, param in resnet_block.named_parameters(): |
| if name.endswith("bn3.weight"): |
| nn.init.zeros_(param) |
|
|
| def lock(self, unlocked_groups=0, freeze_bn_stats=False): |
| assert unlocked_groups == 0, 'partial locking not currently supported for this model' |
| for param in self.parameters(): |
| param.requires_grad = False |
| if freeze_bn_stats: |
| freeze_batch_norm_2d(self) |
|
|
| @torch.jit.ignore |
| def set_grad_checkpointing(self, enable=True): |
| |
| pass |
|
|
| def stem(self, x): |
| x = self.act1(self.bn1(self.conv1(x))) |
| x = self.act2(self.bn2(self.conv2(x))) |
| x = self.act3(self.bn3(self.conv3(x))) |
| x = self.avgpool(x) |
| return x |
|
|
| def forward_intermediates( |
| self, |
| x: torch.Tensor, |
| indices: Optional[Union[int, List[int]]] = None, |
| stop_early: bool = False, |
| normalize_intermediates: bool = False, |
| intermediates_only: bool = False, |
| output_fmt: str = 'NCHW', |
| output_extra_tokens: bool = False, |
| ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: |
| """ Forward features that returns intermediates. |
| |
| Args: |
| x: Input image tensor |
| indices: Take last n blocks if int, all if None, select matching indices if sequence |
| stop_early: Stop iterating over blocks when last desired intermediate hit |
| normalize_intermediates: Apply final norm layer to all intermediates |
| intermediates_only: Only return intermediate features |
| output_fmt: Shape of intermediate feature outputs |
| output_extra_tokens: Return both extra class, eot tokens |
| Returns: |
| |
| """ |
| assert output_fmt in ('NCHW',), 'Output format must be == NCHW.' |
| |
| take_indices, max_index = feature_take_indices(5, indices) |
|
|
| output = {} |
| intermediates = [] |
| blocks = [self.stem, self.layer1, self.layer2, self.layer3, self.layer4] |
| if torch.jit.is_scripting() or not stop_early: |
| blocks = blocks[:max_index + 1] |
| for i, blk in enumerate(blocks): |
| x = blk(x) |
| if i in take_indices: |
| intermediates.append(x) |
|
|
| output['image_intermediates'] = intermediates |
|
|
| if intermediates_only: |
| return output |
|
|
| x = self.attnpool(x) |
| output['image_features'] = x |
|
|
| return output |
|
|
| def forward(self, x): |
| x = self.stem(x) |
| x = self.layer1(x) |
| x = self.layer2(x) |
| x = self.layer3(x) |
| x = self.layer4(x) |
| x = self.attnpool(x) |
|
|
| return x |
|
|
|
|
| """ huggingface model adapter |
| |
| Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model. |
| """ |
| import re |
|
|
| import torch |
| import torch.nn as nn |
| from torch import TensorType |
|
|
| try: |
| import transformers |
| from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig |
| from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \ |
| BaseModelOutputWithPoolingAndCrossAttentions |
| except ImportError as e: |
| transformers = None |
|
|
|
|
| class BaseModelOutput: |
| pass |
|
|
|
|
| class PretrainedConfig: |
| pass |
|
|
| |
| |
| arch_dict = { |
| |
| "roberta": { |
| "config_names": { |
| "context_length": "max_position_embeddings", |
| "vocab_size": "vocab_size", |
| "width": "hidden_size", |
| "heads": "num_attention_heads", |
| "layers": "num_hidden_layers", |
| "layer_attr": "layer", |
| "token_embeddings_attr": "embeddings" |
| }, |
| "pooler": "mean_pooler", |
| }, |
| |
| "xlm-roberta": { |
| "config_names": { |
| "context_length": "max_position_embeddings", |
| "vocab_size": "vocab_size", |
| "width": "hidden_size", |
| "heads": "num_attention_heads", |
| "layers": "num_hidden_layers", |
| "layer_attr": "layer", |
| "token_embeddings_attr": "embeddings" |
| }, |
| "pooler": "mean_pooler", |
| }, |
| |
| "mt5": { |
| "config_names": { |
| |
| |
| |
| "context_length": "", |
| "vocab_size": "vocab_size", |
| "width": "d_model", |
| "heads": "num_heads", |
| "layers": "num_layers", |
| "layer_attr": "block", |
| "token_embeddings_attr": "embed_tokens" |
| }, |
| "pooler": "mean_pooler", |
| }, |
| |
| "bert": { |
| "config_names": { |
| "context_length": "max_position_embeddings", |
| "vocab_size": "vocab_size", |
| "width": "hidden_size", |
| "heads": "num_attention_heads", |
| "layers": "num_hidden_layers", |
| }, |
| "pooler": "cls_pooler", |
| }, |
| |
| "m2m_100": { |
| "config_names": { |
| "context_length": "max_position_embeddings", |
| "vocab_size": "vocab_size", |
| "width": "d_model", |
| "heads": "encoder_attention_heads", |
| "layers": "encoder_layers", |
| }, |
| "pooler": "cls_pooler", |
| }, |
| } |
|
|
|
|
|
|
| |
| def _camel2snake(s): |
| return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower() |
|
|
|
|
| |
| _POOLERS = {} |
|
|
|
|
| def register_pooler(cls): |
| """Decorator registering pooler class""" |
| _POOLERS[_camel2snake(cls.__name__)] = cls |
| return cls |
|
|
|
|
| @register_pooler |
| class MeanPooler(nn.Module): |
| """Mean pooling""" |
|
|
| def forward(self, x: BaseModelOutput, attention_mask: TensorType): |
| masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1) |
| return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True) |
|
|
|
|
| @register_pooler |
| class MaxPooler(nn.Module): |
| """Max pooling""" |
|
|
| def forward(self, x: BaseModelOutput, attention_mask: TensorType): |
| masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf) |
| return masked_output.max(1).values |
|
|
|
|
| @register_pooler |
| class ClsPooler(nn.Module): |
| """CLS token pooling""" |
|
|
| def __init__(self, use_pooler_output=True): |
| super().__init__() |
| self.cls_token_position = 0 |
| self.use_pooler_output = use_pooler_output |
|
|
| def forward(self, x: BaseModelOutput, attention_mask: TensorType): |
| if (self.use_pooler_output and |
| isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and |
| (x.pooler_output is not None) |
| ): |
| return x.pooler_output |
|
|
| return x.last_hidden_state[:, self.cls_token_position, :] |
|
|
|
|
| @register_pooler |
| class ClsLastHiddenStatePooler(nn.Module): |
| """CLS token pooling |
| NOTE: this is equivalent to ClsPooler above with use_pooler_output=False |
| """ |
|
|
| def __init__(self): |
| super().__init__() |
| self.cls_token_position = 0 |
|
|
| def forward(self, x: BaseModelOutput, attention_mask: TensorType): |
| return x.last_hidden_state[:, self.cls_token_position, :] |
|
|
|
|
| class HFTextEncoder(nn.Module): |
| """HuggingFace model adapter""" |
| output_tokens: torch.jit.Final[bool] |
|
|
| def __init__( |
| self, |
| model_name_or_path: str, |
| output_dim: int, |
| config: PretrainedConfig = None, |
| pooler_type: str = None, |
| proj_type: str = None, |
| pretrained: bool = True, |
| output_tokens: bool = False, |
| ): |
| super().__init__() |
| self.output_tokens = output_tokens |
| self.output_dim = output_dim |
|
|
| |
| uses_transformer_pooler = (pooler_type == "cls_pooler") |
|
|
| if transformers is None: |
| raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models") |
| if config is None: |
| self.config = AutoConfig.from_pretrained(model_name_or_path) |
| create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else ( |
| AutoModel.from_config, self.config) |
| |
| if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder: |
| self.transformer = create_func(model_args) |
| self.transformer = self.transformer.encoder |
| else: |
| self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler) |
| else: |
| self.config = config |
| self.transformer = AutoModel.from_config(config) |
| if pooler_type is None: |
| pooler_type = (arch_dict[self.config.model_type]["pooler"]) |
|
|
| |
| self.vocab_size = getattr(self.config, 'vocab_size', 0) |
| self.context_length = getattr(self.config, 'max_position_embeddings', 0) |
|
|
| self.pooler = _POOLERS[pooler_type]() |
|
|
| d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"]) |
| if (d_model == output_dim) and (proj_type is None): |
| self.proj = nn.Identity() |
| elif proj_type == 'linear': |
| self.proj = nn.Linear(d_model, output_dim, bias=False) |
| elif proj_type == 'mlp': |
| hidden_size = (d_model + output_dim) // 2 |
| self.proj = nn.Sequential( |
| nn.Linear(d_model, hidden_size, bias=False), |
| nn.GELU(), |
| nn.Linear(hidden_size, output_dim, bias=False), |
| ) |
|
|
| def forward(self, x: TensorType): |
| attn_mask = (x != self.config.pad_token_id).long() |
| out = self.transformer(input_ids=x, attention_mask=attn_mask) |
| pooled_out = self.pooler(out, attn_mask) |
| projected = self.proj(pooled_out) |
|
|
| seq_len = out.last_hidden_state.shape[1] |
| tokens = ( |
| out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :] |
| if type(self.pooler) == ClsPooler |
| else out.last_hidden_state |
| ) |
| |
| if self.output_tokens: |
| return projected, tokens |
| return projected |
|
|
| def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): |
| if not unlocked_layers: |
| for n, p in self.transformer.named_parameters(): |
| p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False |
| return |
|
|
| encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer |
| layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"]) |
| print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model") |
| embeddings = getattr( |
| self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"]) |
| modules = [embeddings, *layer_list][:-unlocked_layers] |
| |
| for module in modules: |
| for n, p in module.named_parameters(): |
| p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False |
|
|
| @torch.jit.ignore |
| def set_grad_checkpointing(self, enable=True): |
| self.transformer.gradient_checkpointing_enable() |
|
|
| def init_parameters(self): |
| pass |
|
|
|
|
| """ timm model adapter |
| |
| Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. |
| """ |
| import logging |
| from collections import OrderedDict |
| from typing import Dict, List, Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn as nn |
|
|
| try: |
| import timm |
| from timm.layers import RotAttentionPool2d |
| from timm.layers import AttentionPool2d as AbsAttentionPool2d |
| from timm.layers import Mlp, to_2tuple |
| except ImportError: |
| timm = None |
|
|
|
|
|
|
| class TimmModel(nn.Module): |
| """ timm model adapter |
| """ |
|
|
| def __init__( |
| self, |
| model_name: str, |
| embed_dim: int, |
| image_size: Union[int, Tuple[int, int]] = 224, |
| pool: str = 'avg', |
| proj: str = 'linear', |
| proj_bias: bool = False, |
| drop: float = 0., |
| drop_path: Optional[float] = None, |
| patch_drop: Optional[float] = None, |
| pretrained: bool = False, |
| ): |
| super().__init__() |
| if timm is None: |
| raise RuntimeError("Please install the latest timm (`pip install timm`) to use timm based models.") |
| self.image_size = to_2tuple(image_size) |
|
|
| |
| timm_kwargs = {} |
| if drop_path is not None: |
| timm_kwargs['drop_path_rate'] = drop_path |
| if patch_drop is not None: |
| timm_kwargs['patch_drop_rate'] = patch_drop |
|
|
| custom_pool = pool in ('abs_attn', 'rot_attn') |
| if proj: |
| assert proj in ("linear", "mlp", "none") |
| extra_proj = proj in ("linear", "mlp") |
| if not extra_proj and not custom_pool: |
| |
| |
| proj_dim = 0 if proj == 'none' else embed_dim |
| self.trunk = timm.create_model( |
| model_name, |
| num_classes=proj_dim, |
| global_pool=pool, |
| pretrained=pretrained, |
| **timm_kwargs, |
| ) |
| prev_chs = embed_dim |
| else: |
| self.trunk = timm.create_model( |
| model_name, |
| pretrained=pretrained, |
| **timm_kwargs, |
| ) |
| feat_size = self.trunk.default_cfg.get('pool_size', None) |
| feature_ndim = 1 if not feat_size else 2 |
| if custom_pool: |
| assert feature_ndim == 2 |
| |
| self.trunk.reset_classifier(0, global_pool='') |
| else: |
| |
| reset_kwargs = dict(global_pool=pool) if pool else {} |
| self.trunk.reset_classifier(0, **reset_kwargs) |
| prev_chs = self.trunk.num_features |
|
|
| head_layers = OrderedDict() |
|
|
| |
| if pool == 'abs_attn': |
| head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) |
| prev_chs = embed_dim |
| elif pool == 'rot_attn': |
| head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) |
| prev_chs = embed_dim |
|
|
| |
| if proj == 'linear': |
| head_layers['drop'] = nn.Dropout(drop) |
| head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) |
| elif proj == 'mlp': |
| head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias)) |
|
|
| self.head = nn.Sequential(head_layers) |
|
|
| def lock(self, unlocked_groups: int = 0, freeze_bn_stats: bool = False): |
| """ lock modules |
| Args: |
| unlocked_groups (int): leave last n layer groups unlocked (default: 0) |
| """ |
| if not unlocked_groups: |
| |
| for param in self.trunk.parameters(): |
| param.requires_grad = False |
| if freeze_bn_stats: |
| freeze_batch_norm_2d(self.trunk) |
| else: |
| |
| try: |
| |
| from timm.models.helpers import group_parameters, group_modules |
| except ImportError: |
| raise RuntimeError( |
| 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') |
| matcher = self.trunk.group_matcher() |
| gparams = group_parameters(self.trunk, matcher) |
| max_layer_id = max(gparams.keys()) |
| max_layer_id = max_layer_id - unlocked_groups |
| for group_idx in range(max_layer_id + 1): |
| group = gparams[group_idx] |
| for param in group: |
| self.trunk.get_parameter(param).requires_grad = False |
| if freeze_bn_stats: |
| gmodules = group_modules(self.trunk, matcher, reverse=True) |
| gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} |
| freeze_batch_norm_2d(self.trunk, gmodules) |
|
|
| @torch.jit.ignore |
| def set_grad_checkpointing(self, enable: bool = True): |
| try: |
| self.trunk.set_grad_checkpointing(enable) |
| except Exception as e: |
| logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') |
|
|
| def forward_intermediates( |
| self, |
| x: torch.Tensor, |
| indices: Optional[Union[int, List[int]]] = None, |
| stop_early: bool = False, |
| normalize_intermediates: bool = False, |
| intermediates_only: bool = False, |
| output_fmt: str = 'NCHW', |
| output_extra_tokens: bool = False, |
| ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: |
| """ Forward features that returns intermediates. |
| |
| Args: |
| x: Input image tensor |
| indices: Take last n blocks if int, all if None, select matching indices if sequence |
| stop_early: Stop iterating over blocks when last desired intermediate hit |
| normalize_intermediates: Apply norm layer to all intermediates |
| intermediates_only: Only return intermediate features |
| output_fmt: Shape of intermediate feature outputs |
| output_extra_tokens: Return both prefix and spatial intermediate tokens |
| Returns: |
| """ |
| extra_args = {} |
| if output_extra_tokens: |
| extra_args['return_prefix_tokens'] = True |
| trunk_output = self.trunk.forward_intermediates( |
| x, |
| indices=indices, |
| intermediates_only=intermediates_only, |
| norm=normalize_intermediates, |
| stop_early=stop_early, |
| output_fmt=output_fmt, |
| **extra_args, |
| ) |
|
|
| return_dict = {} |
| intermediates = trunk_output if intermediates_only else trunk_output[1] |
| if output_extra_tokens and intermediates and isinstance(intermediates[0], tuple): |
| intermediates_prefix = [xi[1] for xi in intermediates] |
| intermediates = [xi[0] for xi in intermediates] |
| return_dict['image_intermediates_prefix'] = intermediates_prefix |
|
|
| return_dict['image_intermediates'] = intermediates |
| if intermediates_only: |
| return return_dict |
|
|
| image_features = self.trunk.forward_head(trunk_output[0]) |
| image_features = self.head(image_features) |
| return_dict['image_features'] = image_features |
| return return_dict |
|
|
| def forward(self, x): |
| x = self.trunk(x) |
| x = self.head(x) |
| return x |
|
|
|
|
| class LayerNormFp32(nn.LayerNorm): |
| """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" |
|
|
| def forward(self, x: torch.Tensor): |
| orig_type = x.dtype |
| x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps) |
| return x.to(orig_type) |
|
|
|
|
| class LayerNorm(nn.LayerNorm): |
| """Subclass torch's LayerNorm (with cast back to input dtype).""" |
|
|
| def forward(self, x: torch.Tensor): |
| orig_type = x.dtype |
| x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) |
| return x.to(orig_type) |
|
|
|
|
| class QuickGELU(nn.Module): |
| |
| def forward(self, x: torch.Tensor): |
| return x * torch.sigmoid(1.702 * x) |
|
|
|
|
| class LayerScale(nn.Module): |
| def __init__(self, dim, init_values=1e-5, inplace=False): |
| super().__init__() |
| self.inplace = inplace |
| self.gamma = nn.Parameter(init_values * torch.ones(dim)) |
|
|
| def forward(self, x): |
| return x.mul_(self.gamma) if self.inplace else x * self.gamma |
|
|
|
|
| class PatchDropout(nn.Module): |
| """ |
| https://arxiv.org/abs/2212.00794 |
| """ |
|
|
| def __init__( |
| self, |
| prob: float = 0.5, |
| exclude_first_token: bool = True |
| ): |
| super().__init__() |
| assert 0 <= prob < 1. |
| self.prob = prob |
| self.exclude_first_token = exclude_first_token |
|
|
| def forward(self, x): |
| if not self.training or self.prob == 0.: |
| return x |
|
|
| if self.exclude_first_token: |
| cls_tokens, x = x[:, :1], x[:, 1:] |
| else: |
| cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) |
|
|
| batch = x.size()[0] |
| num_tokens = x.size()[1] |
|
|
| batch_indices = torch.arange(batch) |
| batch_indices = batch_indices[..., None] |
|
|
| keep_prob = 1 - self.prob |
| num_patches_keep = max(1, int(num_tokens * keep_prob)) |
|
|
| rand = torch.randn(batch, num_tokens) |
| patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices |
|
|
| x = x[batch_indices, patch_indices_keep] |
|
|
| if self.exclude_first_token: |
| x = torch.cat((cls_tokens, x), dim=1) |
|
|
| return x |
|
|
|
|
| class Attention(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| num_heads: int = 8, |
| qkv_bias: bool = True, |
| scaled_cosine: bool = False, |
| scale_heads: bool = False, |
| logit_scale_max: float = math.log(1. / 0.01), |
| batch_first: bool = True, |
| attn_drop: float = 0., |
| proj_drop: float = 0. |
| ): |
| super().__init__() |
| self.scaled_cosine = scaled_cosine |
| self.scale_heads = scale_heads |
| assert dim % num_heads == 0, 'dim should be divisible by num_heads' |
| self.num_heads = num_heads |
| self.head_dim = dim // num_heads |
| self.scale = self.head_dim ** -0.5 |
| self.logit_scale_max = logit_scale_max |
| self.batch_first = batch_first |
| self.use_fsdpa = hasattr(nn.functional, 'scaled_dot_product_attention') |
|
|
| |
| self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) |
| if qkv_bias: |
| self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) |
| else: |
| self.in_proj_bias = None |
|
|
| if self.scaled_cosine: |
| self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) |
| else: |
| self.logit_scale = None |
| self.attn_drop = nn.Dropout(attn_drop) |
| if self.scale_heads: |
| self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) |
| else: |
| self.head_scale = None |
| self.out_proj = nn.Linear(dim, dim) |
| self.out_drop = nn.Dropout(proj_drop) |
|
|
| def forward(self, x, attn_mask: Optional[torch.Tensor] = None): |
| if self.batch_first: |
| x = x.transpose(0, 1) |
|
|
| L, N, C = x.shape |
| q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) |
| q = q.reshape(L, N * self.num_heads, -1).transpose(0, 1) |
| k = k.reshape(L, N * self.num_heads, -1).transpose(0, 1) |
| v = v.reshape(L, N * self.num_heads, -1).transpose(0, 1) |
|
|
| if attn_mask is not None and attn_mask.dtype == torch.bool: |
| new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) |
| new_attn_mask.masked_fill_(attn_mask, float("-inf")) |
| attn_mask = new_attn_mask |
|
|
| if self.logit_scale is not None: |
| attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) |
| logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() |
| attn = attn.view(N, self.num_heads, L, L) * logit_scale |
| attn = attn.view(-1, L, L) |
| if attn_mask is not None: |
| attn = attn + attn_mask |
| attn = attn.softmax(dim=-1) |
| attn = self.attn_drop(attn) |
| x = torch.bmm(attn, v) |
| else: |
| if self.use_fsdpa: |
| x = F.scaled_dot_product_attention( |
| q, k, v, |
| attn_mask=attn_mask, |
| dropout_p=self.attn_drop.p if self.training else 0., |
| ) |
| else: |
| q = q * self.scale |
| attn = torch.bmm(q, k.transpose(-1, -2)) |
| if attn_mask is not None: |
| attn += attn_mask |
| attn = attn.softmax(dim=-1) |
| attn = self.attn_drop(attn) |
| x = torch.bmm(attn, v) |
|
|
| if self.head_scale is not None: |
| x = x.view(N, self.num_heads, L, C) * self.head_scale |
| x = x.view(-1, L, C) |
|
|
| x = x.transpose(0, 1).reshape(L, N, C) |
|
|
| if self.batch_first: |
| x = x.transpose(0, 1) |
|
|
| x = self.out_proj(x) |
| x = self.out_drop(x) |
| return x |
|
|
|
|
| class AttentionalPooler(nn.Module): |
| def __init__( |
| self, |
| d_model: int, |
| context_dim: int, |
| n_head: int = 8, |
| n_queries: int = 256, |
| norm_layer: Callable = LayerNorm, |
| ): |
| super().__init__() |
| self.query = nn.Parameter(torch.randn(n_queries, d_model)) |
| self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim, batch_first=True) |
| self.ln_q = norm_layer(d_model) |
| self.ln_k = norm_layer(context_dim) |
|
|
| def forward(self, x: torch.Tensor): |
| N = x.shape[0] |
| x = self.ln_k(x) |
| q = self.ln_q(self.query) |
| out = self.attn(q.unsqueeze(0).expand(N, -1, -1), x, x, need_weights=False)[0] |
| return out |
|
|
|
|
| class ResidualAttentionBlock(nn.Module): |
| def __init__( |
| self, |
| d_model: int, |
| n_head: int, |
| mlp_ratio: float = 4.0, |
| ls_init_value: float = None, |
| act_layer: Callable = nn.GELU, |
| norm_layer: Callable = LayerNorm, |
| is_cross_attention: bool = False, |
| batch_first: bool = True, |
| ): |
| super().__init__() |
|
|
| self.ln_1 = norm_layer(d_model) |
| self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=batch_first) |
| self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() |
| if is_cross_attention: |
| self.ln_1_kv = norm_layer(d_model) |
|
|
| self.ln_2 = norm_layer(d_model) |
| mlp_width = int(d_model * mlp_ratio) |
| self.mlp = nn.Sequential(OrderedDict([ |
| ("c_fc", nn.Linear(d_model, mlp_width)), |
| ("gelu", act_layer()), |
| ("c_proj", nn.Linear(mlp_width, d_model)) |
| ])) |
| self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() |
|
|
| def attention( |
| self, |
| q_x: torch.Tensor, |
| k_x: Optional[torch.Tensor] = None, |
| v_x: Optional[torch.Tensor] = None, |
| attn_mask: Optional[torch.Tensor] = None, |
| ): |
| k_x = k_x if k_x is not None else q_x |
| v_x = v_x if v_x is not None else q_x |
|
|
| attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None |
| return self.attn( |
| q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask |
| )[0] |
|
|
| def forward( |
| self, |
| q_x: torch.Tensor, |
| k_x: Optional[torch.Tensor] = None, |
| v_x: Optional[torch.Tensor] = None, |
| attn_mask: Optional[torch.Tensor] = None, |
| ): |
| k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None |
| v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None |
| x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)) |
| x = x + self.ls_2(self.mlp(self.ln_2(x))) |
| return x |
|
|
|
|
| class CustomResidualAttentionBlock(nn.Module): |
| def __init__( |
| self, |
| d_model: int, |
| n_head: int, |
| mlp_ratio: float = 4.0, |
| ls_init_value: float = None, |
| act_layer: Callable = nn.GELU, |
| norm_layer: Callable = LayerNorm, |
| scale_cosine_attn: bool = False, |
| scale_heads: bool = False, |
| scale_attn: bool = False, |
| scale_fc: bool = False, |
| batch_first: bool = True, |
| ): |
| super().__init__() |
|
|
| self.ln_1 = norm_layer(d_model) |
| self.attn = Attention( |
| d_model, |
| n_head, |
| scaled_cosine=scale_cosine_attn, |
| scale_heads=scale_heads, |
| batch_first=batch_first, |
| ) |
| self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity() |
| self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() |
|
|
| self.ln_2 = norm_layer(d_model) |
| mlp_width = int(d_model * mlp_ratio) |
| self.mlp = nn.Sequential(OrderedDict([ |
| ("c_fc", nn.Linear(d_model, mlp_width)), |
| ("gelu", act_layer()), |
| ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()), |
| ("c_proj", nn.Linear(mlp_width, d_model)) |
| ])) |
| self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() |
|
|
| def get_reference_weight(self): |
| return self.mlp.c_fc.weight |
|
|
| def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): |
| x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask))) |
| x = x + self.ls_2(self.mlp(self.ln_2(x))) |
| return x |
|
|
|
|
| class CustomTransformer(nn.Module): |
| """ A custom transformer that can use different block types. """ |
| def __init__( |
| self, |
| width: int, |
| layers: int, |
| heads: int, |
| mlp_ratio: float = 4.0, |
| ls_init_value: float = None, |
| act_layer: Callable = nn.GELU, |
| norm_layer: Callable = LayerNorm, |
| batch_first: bool = True, |
| block_types: Union[str, List[str]] = 'CustomResidualAttentionBlock', |
| ): |
| super().__init__() |
| self.width = width |
| self.layers = layers |
| self.batch_first = batch_first |
| self.grad_checkpointing = False |
|
|
| if isinstance(block_types, str): |
| block_types = [block_types] * layers |
| assert len(block_types) == layers |
|
|
| def _create_block(bt: str): |
| if bt == 'CustomResidualAttentionBlock': |
| return CustomResidualAttentionBlock( |
| width, |
| heads, |
| mlp_ratio=mlp_ratio, |
| ls_init_value=ls_init_value, |
| act_layer=act_layer, |
| norm_layer=norm_layer, |
| batch_first=batch_first, |
| ) |
| else: |
| assert False |
|
|
| self.resblocks = nn.ModuleList([ |
| _create_block(bt) |
| for bt in block_types |
| ]) |
|
|
| def get_cast_dtype(self) -> torch.dtype: |
| weight = self.resblocks[0].get_reference_weight() |
| if hasattr(weight, 'int8_original_dtype'): |
| return weight.int8_original_dtype |
| return weight.dtype |
|
|
| def forward_intermediates( |
| self, |
| x: torch.Tensor, |
| attn_mask: Optional[torch.Tensor] = None, |
| indices: Optional[Union[int, List[int]]] = None, |
| stop_early: bool = False, |
| ): |
| take_indices, max_index = feature_take_indices(len(self.resblocks), indices) |
|
|
| if not self.batch_first: |
| x = x.transpose(0, 1).contiguous() |
|
|
| intermediates = [] |
| if torch.jit.is_scripting() or not stop_early: |
| blocks = self.resblocks |
| else: |
| blocks = self.resblocks[:max_index + 1] |
| for i, blk in enumerate(blocks): |
| if self.grad_checkpointing and not torch.jit.is_scripting(): |
| x = checkpoint(blk, x, None, None, attn_mask, use_reentrant=False) |
| else: |
| x = blk(x, attn_mask=attn_mask) |
|
|
| if i in take_indices: |
| intermediates.append(x.transpose(0, 1) if not self.batch_first else x) |
|
|
| if not self.batch_first: |
| x = x.transpose(0, 1) |
|
|
| return x, intermediates |
|
|
| def prune_intermediate_layers(self, indices: Union[int, List[int]] = 1): |
| """ Prune layers not required for specified intermediates. |
| """ |
| take_indices, max_index = feature_take_indices(len(self.resblocks), indices) |
| self.resblocks = self.resblocks[:max_index + 1] |
| return take_indices |
|
|
| def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): |
| if not self.batch_first: |
| x = x.transpose(0, 1) |
|
|
| for r in self.resblocks: |
| if self.grad_checkpointing and not torch.jit.is_scripting(): |
| |
| x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False) |
| else: |
| x = r(x, attn_mask=attn_mask) |
|
|
| if not self.batch_first: |
| x = x.transpose(0, 1) |
| return x |
|
|
|
|
| class Transformer(nn.Module): |
| def __init__( |
| self, |
| width: int, |
| layers: int, |
| heads: int, |
| mlp_ratio: float = 4.0, |
| ls_init_value: float = None, |
| act_layer: Callable = nn.GELU, |
| norm_layer: Callable = LayerNorm, |
| batch_first: bool = True, |
| ): |
| super().__init__() |
| self.width = width |
| self.layers = layers |
| self.batch_first = batch_first |
| self.grad_checkpointing = False |
|
|
| self.resblocks = nn.ModuleList([ |
| ResidualAttentionBlock( |
| width, |
| heads, |
| mlp_ratio, |
| ls_init_value=ls_init_value, |
| act_layer=act_layer, |
| norm_layer=norm_layer, |
| batch_first=batch_first, |
| ) |
| for _ in range(layers) |
| ]) |
|
|
| def get_cast_dtype(self) -> torch.dtype: |
| if hasattr(self.resblocks[0].mlp.c_fc, 'int8_original_dtype'): |
| return self.resblocks[0].mlp.c_fc.int8_original_dtype |
| return self.resblocks[0].mlp.c_fc.weight.dtype |
|
|
| def forward_intermediates( |
| self, |
| x: torch.Tensor, |
| attn_mask: Optional[torch.Tensor] = None, |
| indices: Optional[Union[int, List[int]]] = None, |
| stop_early: bool = False, |
| ): |
| take_indices, max_index = feature_take_indices(len(self.resblocks), indices) |
|
|
| if not self.batch_first: |
| x = x.transpose(0, 1).contiguous() |
|
|
| intermediates = [] |
| if torch.jit.is_scripting() or not stop_early: |
| blocks = self.resblocks |
| else: |
| blocks = self.resblocks[:max_index + 1] |
| for i, blk in enumerate(blocks): |
| if self.grad_checkpointing and not torch.jit.is_scripting(): |
| x = checkpoint(blk, x, None, None, attn_mask, use_reentrant=False) |
| else: |
| x = blk(x, attn_mask=attn_mask) |
|
|
| if i in take_indices: |
| intermediates.append(x.transpose(0, 1) if not self.batch_first else x) |
|
|
| if not self.batch_first: |
| x = x.transpose(0, 1) |
|
|
| return x, intermediates |
|
|
| def prune_intermediate_layers(self, indices: Union[int, List[int]] = 1): |
| """ Prune layers not required for specified intermediates. |
| """ |
| take_indices, max_index = feature_take_indices(len(self.resblocks), indices) |
| self.resblocks = self.resblocks[:max_index + 1] |
| return take_indices |
|
|
| def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): |
| if not self.batch_first: |
| x = x.transpose(0, 1).contiguous() |
|
|
| for r in self.resblocks: |
| if self.grad_checkpointing and not torch.jit.is_scripting(): |
| |
| x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False) |
| else: |
| x = r(x, attn_mask=attn_mask) |
|
|
| if not self.batch_first: |
| x = x.transpose(0, 1) |
| return x |
|
|
|
|
| def _expand_token(token, batch_size: int): |
| return token.view(1, 1, -1).expand(batch_size, -1, -1) |
|
|
|
|
| class VisionTransformer(nn.Module): |
| output_tokens: torch.jit.Final[bool] |
|
|
| def __init__( |
| self, |
| image_size: int, |
| patch_size: int, |
| width: int, |
| layers: int, |
| heads: int, |
| mlp_ratio: float, |
| ls_init_value: float = None, |
| attentional_pool: bool = False, |
| attn_pooler_queries: int = 256, |
| attn_pooler_heads: int = 8, |
| output_dim: int = 512, |
| patch_dropout: float = 0., |
| no_ln_pre: bool = False, |
| pos_embed_type: str = 'learnable', |
| pool_type: str = 'tok', |
| final_ln_after_pool: bool = False, |
| act_layer: Callable = nn.GELU, |
| norm_layer: Callable = LayerNorm, |
| output_tokens: bool = False, |
| ): |
| super().__init__() |
| assert pool_type in ('tok', 'avg', 'none') |
| self.output_tokens = output_tokens |
| image_height, image_width = self.image_size = to_2tuple(image_size) |
| patch_height, patch_width = self.patch_size = to_2tuple(patch_size) |
| self.grid_size = (image_height // patch_height, image_width // patch_width) |
| self.final_ln_after_pool = final_ln_after_pool |
| self.output_dim = output_dim |
|
|
| self.conv1 = nn.Conv2d( |
| in_channels=3, |
| out_channels=width, |
| kernel_size=patch_size, |
| stride=patch_size, |
| bias=False, |
| ) |
|
|
| |
| scale = width ** -0.5 |
| self.class_embedding = nn.Parameter(scale * torch.randn(width)) |
| if pos_embed_type == 'learnable': |
| self.positional_embedding = nn.Parameter( |
| scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) |
| elif pos_embed_type == 'sin_cos_2d': |
| |
| assert self.grid_size[0] == self.grid_size[1],\ |
| 'currently sin cos 2d pos embedding only supports square input' |
| self.positional_embedding = nn.Parameter( |
| torch.zeros(self.grid_size[0] * self.grid_size[1] + 1, width), requires_grad=False) |
| pos_embed_type = get_2d_sincos_pos_embed(width, self.grid_size[0], cls_token=True) |
| self.positional_embedding.data.copy_(torch.from_numpy(pos_embed_type).float()) |
| else: |
| raise ValueError |
|
|
| |
| self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() |
|
|
| self.ln_pre = nn.Identity() if no_ln_pre else norm_layer(width) |
| self.transformer = Transformer( |
| width, |
| layers, |
| heads, |
| mlp_ratio, |
| ls_init_value=ls_init_value, |
| act_layer=act_layer, |
| norm_layer=norm_layer, |
| ) |
|
|
| if attentional_pool: |
| if isinstance(attentional_pool, str): |
| self.attn_pool_type = attentional_pool |
| self.pool_type = 'none' |
| if attentional_pool in ('parallel', 'cascade'): |
| self.attn_pool = AttentionalPooler( |
| output_dim, |
| width, |
| n_head=attn_pooler_heads, |
| n_queries=attn_pooler_queries, |
| ) |
| self.attn_pool_contrastive = AttentionalPooler( |
| output_dim, |
| width, |
| n_head=attn_pooler_heads, |
| n_queries=1, |
| ) |
| else: |
| assert False |
| else: |
| self.attn_pool_type = '' |
| self.pool_type = pool_type |
| self.attn_pool = AttentionalPooler( |
| output_dim, |
| width, |
| n_head=attn_pooler_heads, |
| n_queries=attn_pooler_queries, |
| ) |
| self.attn_pool_contrastive = None |
| pool_dim = output_dim |
| else: |
| self.attn_pool = None |
| pool_dim = width |
| self.pool_type = pool_type |
|
|
| self.ln_post = norm_layer(pool_dim) |
| self.proj = nn.Parameter(scale * torch.randn(pool_dim, output_dim)) |
|
|
| self.init_parameters() |
|
|
| def lock(self, unlocked_groups: int = 0, freeze_bn_stats: bool = False): |
| for param in self.parameters(): |
| param.requires_grad = False |
|
|
| if unlocked_groups != 0: |
| groups = [ |
| [ |
| self.conv1, |
| self.class_embedding, |
| self.positional_embedding, |
| self.ln_pre, |
| ], |
| *self.transformer.resblocks[:-1], |
| [ |
| self.transformer.resblocks[-1], |
| self.ln_post, |
| ], |
| self.proj, |
| ] |
|
|
| def _unlock(x): |
| if isinstance(x, Sequence): |
| for g in x: |
| _unlock(g) |
| else: |
| if isinstance(x, torch.nn.Parameter): |
| x.requires_grad = True |
| else: |
| for p in x.parameters(): |
| p.requires_grad = True |
|
|
| _unlock(groups[-unlocked_groups:]) |
|
|
| def init_parameters(self): |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| pass |
|
|
| @torch.jit.ignore |
| def set_grad_checkpointing(self, enable: bool = True): |
| self.transformer.grad_checkpointing = enable |
|
|
| @torch.jit.ignore |
| def no_weight_decay(self): |
| |
| no_wd = {'positional_embedding', 'class_embedding'} |
| return no_wd |
|
|
| def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| if self.pool_type == 'avg': |
| pooled, tokens = x[:, 1:].mean(dim=1), x[:, 1:] |
| elif self.pool_type == 'tok': |
| pooled, tokens = x[:, 0], x[:, 1:] |
| else: |
| pooled = tokens = x |
|
|
| return pooled, tokens |
|
|
| def _embeds(self, x:torch.Tensor) -> torch.Tensor: |
| x = self.conv1(x) |
| x = x.reshape(x.shape[0], x.shape[1], -1) |
| x = x.permute(0, 2, 1) |
|
|
| |
| x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1) |
| |
| x = x + self.positional_embedding.to(x.dtype) |
|
|
| |
| x = self.patch_dropout(x) |
|
|
| |
| x = self.ln_pre(x) |
| return x |
|
|
| def _pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| if self.attn_pool is not None: |
| if self.attn_pool_contrastive is not None: |
| |
| x = self.ln_post(x) |
| tokens = self.attn_pool(x) |
| if self.attn_pool_type == 'parallel': |
| pooled = self.attn_pool_contrastive(x) |
| else: |
| assert self.attn_pool_type == 'cascade' |
| pooled = self.attn_pool_contrastive(tokens) |
| else: |
| |
| x = self.attn_pool(x) |
| x = self.ln_post(x) |
| pooled, tokens = self._global_pool(x) |
| elif self.final_ln_after_pool: |
| pooled, tokens = self._global_pool(x) |
| pooled = self.ln_post(pooled) |
| else: |
| x = self.ln_post(x) |
| pooled, tokens = self._global_pool(x) |
|
|
| return pooled, tokens |
|
|
| def forward_intermediates( |
| self, |
| x: torch.Tensor, |
| indices: Optional[Union[int, List[int]]] = None, |
| stop_early: bool = False, |
| normalize_intermediates: bool = False, |
| intermediates_only: bool = False, |
| output_fmt: str = 'NCHW', |
| output_extra_tokens: bool = False, |
| ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: |
| """ Forward features that returns intermediates. |
| |
| Args: |
| x: Input image tensor |
| indices: Take last n blocks if int, all if None, select matching indices if sequence |
| stop_early: Stop iterating over blocks when last desired intermediate hit |
| intermediates_only: Only return intermediate features |
| normalize_intermediates: Apply final norm layer to all intermediates |
| output_fmt: Shape of intermediate feature outputs |
| output_extra_tokens: Return both extra prefix class tokens |
| Returns: |
| |
| """ |
| assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.' |
| reshape = output_fmt == 'NCHW' |
|
|
| |
| B, _, height, width = x.shape |
| x = self._embeds(x) |
| x, intermediates = self.transformer.forward_intermediates( |
| x, |
| indices=indices, |
| stop_early=stop_early, |
| ) |
|
|
| |
| if normalize_intermediates: |
| |
| intermediates = [self.ln_post(xi) for xi in intermediates] |
| num_prefix_tokens = 1 |
| if num_prefix_tokens: |
| |
| prefix_tokens = [y[:, 0:num_prefix_tokens] for y in intermediates] |
| intermediates = [y[:, num_prefix_tokens:] for y in intermediates] |
| else: |
| prefix_tokens = None |
| if reshape: |
| |
| H, W = height // self.patch_size[0], width // self.patch_size[1] |
| intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates] |
|
|
| output = {'image_intermediates': intermediates} |
| if prefix_tokens is not None and output_extra_tokens: |
| output['image_intermediates_prefix'] = prefix_tokens |
|
|
| if intermediates_only: |
| return output |
|
|
| pooled, _ = self._pool(x) |
|
|
| if self.proj is not None: |
| pooled = pooled @ self.proj |
|
|
| output['image_features'] = pooled |
|
|
| return output |
|
|
| def prune_intermediate_layers( |
| self, |
| indices: Union[int, List[int]] = 1, |
| prune_norm: bool = False, |
| prune_head: bool = True, |
| ): |
| """ Prune layers not required for specified intermediates. |
| """ |
| take_indices = self.transformer.prune_intermediate_layers(indices) |
| if prune_norm: |
| self.ln_post = nn.Identity() |
| if prune_head: |
| self.proj = None |
| return take_indices |
|
|
| def forward(self, x: torch.Tensor): |
| x = self._embeds(x) |
| x = self.transformer(x) |
| pooled, tokens = self._pool(x) |
|
|
| if self.proj is not None: |
| pooled = pooled @ self.proj |
|
|
| if self.output_tokens: |
| return pooled, tokens |
| |
| return pooled |
|
|
|
|
| def text_global_pool( |
| x: torch.Tensor, |
| text: Optional[torch.Tensor] = None, |
| pool_type: str = 'argmax', |
| ) -> torch.Tensor: |
| if pool_type == 'first': |
| pooled = x[:, 0] |
| elif pool_type == 'last': |
| pooled = x[:, -1] |
| elif pool_type == 'argmax': |
| |
| assert text is not None |
| pooled = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] |
| else: |
| pooled = x |
|
|
| return pooled |
|
|
|
|
| class TextTransformer(nn.Module): |
| output_tokens: torch.jit.Final[bool] |
|
|
| def __init__( |
| self, |
| context_length: int = 77, |
| vocab_size: int = 49408, |
| width: int = 512, |
| heads: int = 8, |
| layers: int = 12, |
| mlp_ratio: float = 4.0, |
| ls_init_value: float = None, |
| output_dim: Optional[int] = 512, |
| embed_cls: bool = False, |
| no_causal_mask: bool = False, |
| pad_id: int = 0, |
| pool_type: str = 'argmax', |
| proj_type: str = 'linear', |
| proj_bias: bool = False, |
| act_layer: Callable = nn.GELU, |
| norm_layer: Callable = LayerNorm, |
| output_tokens: bool = False, |
| ): |
| super().__init__() |
| assert pool_type in ('first', 'last', 'argmax', 'none') |
| self.output_tokens = output_tokens |
| self.num_pos = self.context_length = context_length |
| self.vocab_size = vocab_size |
| self.width = width |
| self.output_dim = output_dim |
| self.heads = heads |
| self.pad_id = pad_id |
| self.pool_type = pool_type |
|
|
| self.token_embedding = nn.Embedding(vocab_size, width) |
| if embed_cls: |
| self.cls_emb = nn.Parameter(torch.empty(width)) |
| self.num_pos += 1 |
| else: |
| self.cls_emb = None |
| self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width)) |
| self.transformer = Transformer( |
| width=width, |
| layers=layers, |
| heads=heads, |
| mlp_ratio=mlp_ratio, |
| ls_init_value=ls_init_value, |
| act_layer=act_layer, |
| norm_layer=norm_layer, |
| ) |
| self.ln_final = norm_layer(width) |
|
|
| if no_causal_mask: |
| self.attn_mask = None |
| else: |
| self.register_buffer('attn_mask', self.build_causal_mask(), persistent=False) |
|
|
| if proj_type == 'none' or not output_dim: |
| self.text_projection = None |
| else: |
| if proj_bias: |
| self.text_projection = nn.Linear(width, output_dim) |
| else: |
| self.text_projection = nn.Parameter(torch.empty(width, output_dim)) |
|
|
| self.init_parameters() |
|
|
| def init_parameters(self): |
| nn.init.normal_(self.token_embedding.weight, std=0.02) |
| nn.init.normal_(self.positional_embedding, std=0.01) |
| if self.cls_emb is not None: |
| nn.init.normal_(self.cls_emb, std=0.01) |
|
|
| proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) |
| attn_std = self.transformer.width ** -0.5 |
| fc_std = (2 * self.transformer.width) ** -0.5 |
| for block in self.transformer.resblocks: |
| nn.init.normal_(block.attn.in_proj_weight, std=attn_std) |
| nn.init.normal_(block.attn.out_proj.weight, std=proj_std) |
| nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) |
| nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) |
|
|
| if self.text_projection is not None: |
| if isinstance(self.text_projection, nn.Linear): |
| nn.init.normal_(self.text_projection.weight, std=self.transformer.width ** -0.5) |
| if self.text_projection.bias is not None: |
| nn.init.zeros_(self.text_projection.bias) |
| else: |
| nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) |
|
|
| @torch.jit.ignore |
| def set_grad_checkpointing(self, enable=True): |
| self.transformer.grad_checkpointing = enable |
|
|
| @torch.jit.ignore |
| def no_weight_decay(self): |
| |
| no_wd = {'positional_embedding'} |
| if self.cls_emb is not None: |
| no_wd.add('cls_emb') |
| return no_wd |
|
|
| def build_causal_mask(self): |
| |
| |
| mask = torch.empty(self.num_pos, self.num_pos) |
| mask.fill_(float("-inf")) |
| mask.triu_(1) |
| return mask |
|
|
| def build_cls_mask(self, text, cast_dtype: torch.dtype): |
| cls_mask = (text != self.pad_id).unsqueeze(1) |
| cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True) |
| additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device) |
| additive_mask.fill_(0) |
| additive_mask.masked_fill_(~cls_mask, float("-inf")) |
| additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0) |
| return additive_mask |
|
|
| def _embeds(self, text) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
| cast_dtype = self.transformer.get_cast_dtype() |
| seq_len = text.shape[1] |
| x = self.token_embedding(text).to(cast_dtype) |
| attn_mask = self.attn_mask |
| if self.cls_emb is not None: |
| seq_len += 1 |
| x = torch.cat([x, _expand_token(self.cls_emb, x.shape[0])], dim=1) |
| cls_mask = self.build_cls_mask(text, cast_dtype) |
| if attn_mask is not None: |
| attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len] |
| x = x + self.positional_embedding[:seq_len].to(cast_dtype) |
| return x, attn_mask |
|
|
| def forward_intermediates( |
| self, |
| text: torch.Tensor, |
| indices: Optional[Union[int, List[int]]] = None, |
| stop_early: bool = False, |
| normalize_intermediates: bool = False, |
| intermediates_only: bool = False, |
| output_fmt: str = 'NCHW', |
| output_extra_tokens: bool = False, |
| ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: |
| """ Forward features that returns intermediates. |
| |
| Args: |
| text: Input text ids |
| indices: Take last n blocks if int, all if None, select matching indices if sequence |
| stop_early: Stop iterating over blocks when last desired intermediate hit |
| normalize_intermediates: Apply norm layer to all intermediates |
| intermediates_only: Only return intermediate features |
| output_fmt: Shape of intermediate feature outputs |
| output_extra_tokens: Return both prefix and intermediate tokens |
| Returns: |
| |
| """ |
| assert output_fmt in ('NLC',), 'Output format must be NLC.' |
| |
| x, attn_mask = self._embeds(text) |
| x, intermediates = self.transformer.forward_intermediates( |
| x, |
| attn_mask=attn_mask, |
| indices=indices, |
| stop_early=stop_early, |
| ) |
|
|
| |
| if normalize_intermediates: |
| |
| intermediates = [self.ln_final(xi) for xi in intermediates] |
|
|
| output = {} |
|
|
| if self.cls_emb is not None: |
| seq_intermediates = [xi[:, :-1] for xi in intermediates] |
| if output_extra_tokens: |
| |
| cls_intermediates = [xi[:, -1:] for xi in intermediates] |
| output['text_intermediates_suffix'] = cls_intermediates |
| intermediates = seq_intermediates |
| output['text_intermediates'] = intermediates |
|
|
| if intermediates_only: |
| return output |
|
|
| if self.cls_emb is not None: |
| |
| pooled = text_global_pool(x, pool_type='last') |
| pooled = self.ln_final(pooled) |
| else: |
| x = self.ln_final(x) |
| pooled = text_global_pool(x, text, pool_type=self.pool_type) |
|
|
| if self.text_projection is not None: |
| if isinstance(self.text_projection, nn.Linear): |
| pooled = self.text_projection(pooled) |
| else: |
| pooled = pooled @ self.text_projection |
|
|
| output['text_features'] = pooled |
|
|
| return output |
|
|
| def prune_intermediate_layers( |
| self, |
| indices: Union[int, List[int]] = 1, |
| prune_norm: bool = False, |
| prune_head: bool = True, |
| ): |
| """ Prune layers not required for specified intermediates. |
| """ |
| take_indices = self.transformer.prune_intermediate_layers(indices) |
| if prune_norm: |
| self.ln_final = nn.Identity() |
| if prune_head: |
| self.text_projection = None |
| return take_indices |
|
|
| def forward(self, text): |
| x, attn_mask = self._embeds(text) |
|
|
| x = self.transformer(x, attn_mask=attn_mask) |
|
|
| |
| if self.cls_emb is not None: |
| |
| pooled = text_global_pool(x, pool_type='last') |
| pooled = self.ln_final(pooled) |
| tokens = x[:, :-1] |
| else: |
| x = self.ln_final(x) |
| pooled = text_global_pool(x, text, pool_type=self.pool_type) |
| tokens = x |
|
|
| if self.text_projection is not None: |
| if isinstance(self.text_projection, nn.Linear): |
| pooled = self.text_projection(pooled) |
| else: |
| pooled = pooled @ self.text_projection |
|
|
| if self.output_tokens: |
| return pooled, tokens |
|
|
| return pooled |
|
|
|
|
| class MultimodalTransformer(Transformer): |
| def __init__( |
| self, |
| width: int, |
| layers: int, |
| heads: int, |
| context_length: int = 77, |
| mlp_ratio: float = 4.0, |
| ls_init_value: float = None, |
| act_layer: Callable = nn.GELU, |
| norm_layer: Callable = LayerNorm, |
| output_dim: int = 512, |
| batch_first: bool = True, |
| ): |
| super().__init__( |
| width=width, |
| layers=layers, |
| heads=heads, |
| mlp_ratio=mlp_ratio, |
| ls_init_value=ls_init_value, |
| act_layer=act_layer, |
| norm_layer=norm_layer, |
| batch_first=batch_first, |
| ) |
| self.context_length = context_length |
| self.cross_attn = nn.ModuleList([ |
| ResidualAttentionBlock( |
| width, |
| heads, |
| mlp_ratio, |
| ls_init_value=ls_init_value, |
| act_layer=act_layer, |
| norm_layer=norm_layer, |
| is_cross_attention=True, |
| batch_first=batch_first, |
| ) |
| for _ in range(layers) |
| ]) |
|
|
| self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) |
|
|
| self.ln_final = norm_layer(width) |
| self.text_projection = nn.Parameter(torch.empty(width, output_dim)) |
|
|
| def init_parameters(self): |
| proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) |
| attn_std = self.transformer.width ** -0.5 |
| fc_std = (2 * self.transformer.width) ** -0.5 |
| for block in self.transformer.resblocks: |
| nn.init.normal_(block.attn.in_proj_weight, std=attn_std) |
| nn.init.normal_(block.attn.out_proj.weight, std=proj_std) |
| nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) |
| nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) |
| for block in self.transformer.cross_attn: |
| nn.init.normal_(block.attn.in_proj_weight, std=attn_std) |
| nn.init.normal_(block.attn.out_proj.weight, std=proj_std) |
| nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) |
| nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) |
|
|
| if self.text_projection is not None: |
| nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) |
|
|
| def build_attention_mask(self): |
| |
| |
| mask = torch.empty(self.context_length, self.context_length) |
| mask.fill_(float("-inf")) |
| mask.triu_(1) |
| return mask |
|
|
| def forward_intermediates( |
| self, |
| x: torch.Tensor, |
| attn_mask: Optional[torch.Tensor] = None, |
| indices: Optional[Union[int, List[int]]] = None, |
| stop_early: bool = False, |
| ): |
| assert False, "Not currently implemented for MultimodalTransformer w/ xattn" |
|
|
| def forward(self, image_embs, text_embs): |
| seq_len = text_embs.shape[1] |
| if not self.batch_first: |
| image_embs = image_embs.permute(1, 0, 2) |
| text_embs = text_embs.permute(1, 0, 2) |
|
|
| for resblock, cross_attn in zip(self.resblocks, self.cross_attn): |
| if self.grad_checkpointing and not torch.jit.is_scripting(): |
| |
| text_embs = checkpoint( |
| resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len], use_reentrant=False) |
| text_embs = checkpoint( |
| cross_attn, text_embs, image_embs, image_embs, None, use_reentrant=False) |
| else: |
| text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len]) |
| text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs) |
|
|
| if not self.batch_first: |
| text_embs = text_embs.permute(1, 0, 2) |
|
|
| out = self.ln_final(text_embs) |
| if self.text_projection is not None: |
| out = out @ self.text_projection |
|
|
| return out |
|
|
| @torch.jit.ignore |
| def set_grad_checkpointing(self, enable=True): |
| self.grad_checkpointing = enable |
|
|
|
|
|
|
| @dataclass |
| class CLIPVisionCfg: |
| layers: Union[Tuple[int, int, int, int], int] = 12 |
| width: int = 768 |
| head_width: int = 64 |
| mlp_ratio: float = 4.0 |
| patch_size: int = 16 |
| image_size: Union[Tuple[int, int], int] = 224 |
|
|
| ls_init_value: Optional[float] = None |
| patch_dropout: float = 0. |
| attentional_pool: bool = False |
| attn_pooler_queries: int = 256 |
| attn_pooler_heads: int = 8 |
| no_ln_pre: bool = False |
| pos_embed_type: str = 'learnable' |
| final_ln_after_pool: bool = False |
| pool_type: str = 'tok' |
| output_tokens: bool = False |
| act_kwargs: Optional[dict] = None |
| norm_kwargs: Optional[dict] = None |
|
|
| timm_model_name: Optional[str] = None |
| timm_model_pretrained: bool = False |
| timm_pool: str = 'avg' |
| timm_proj: str = 'linear' |
| timm_proj_bias: bool = False |
| timm_drop: float = 0. |
| timm_drop_path: Optional[float] = None |
|
|
|
|
| @dataclass |
| class CLIPTextCfg: |
| context_length: int = 77 |
| vocab_size: int = 49408 |
| hf_tokenizer_name: Optional[str] = None |
| tokenizer_kwargs: Optional[dict] = None |
|
|
| width: int = 512 |
| heads: int = 8 |
| layers: int = 12 |
| mlp_ratio: float = 4.0 |
| ls_init_value: Optional[float] = None |
| embed_cls: bool = False |
| pad_id: int = 0 |
| no_causal_mask: bool = False |
| final_ln_after_pool: bool = False |
| pool_type: str = 'argmax' |
| proj_bias: bool = False |
| proj_type: str = 'linear' |
| output_tokens: bool = False |
| act_kwargs: dict = None |
| norm_kwargs: dict = None |
|
|
| |
| hf_model_name: Optional[str] = None |
| hf_model_pretrained: bool = True |
| hf_proj_type: str = 'mlp' |
| hf_pooler_type: str = 'mean_pooler' |
|
|
|
|
| def get_cast_dtype(precision: str): |
| cast_dtype = None |
| if precision == 'bf16': |
| cast_dtype = torch.bfloat16 |
| elif precision == 'fp16': |
| cast_dtype = torch.float16 |
| return cast_dtype |
|
|
|
|
| def get_input_dtype(precision: str): |
| input_dtype = None |
| if precision in ('bf16', 'pure_bf16'): |
| input_dtype = torch.bfloat16 |
| elif precision in ('fp16', 'pure_fp16'): |
| input_dtype = torch.float16 |
| return input_dtype |
|
|
|
|
| def _build_vision_tower( |
| embed_dim: int, |
| vision_cfg: CLIPVisionCfg, |
| quick_gelu: bool = False, |
| cast_dtype: Optional[torch.dtype] = None |
| ): |
| if isinstance(vision_cfg, dict): |
| vision_cfg = CLIPVisionCfg(**vision_cfg) |
|
|
| |
| |
| |
| act_layer = QuickGELU if quick_gelu else nn.GELU |
|
|
| if vision_cfg.timm_model_name: |
| visual = TimmModel( |
| vision_cfg.timm_model_name, |
| pretrained=vision_cfg.timm_model_pretrained, |
| pool=vision_cfg.timm_pool, |
| proj=vision_cfg.timm_proj, |
| proj_bias=vision_cfg.timm_proj_bias, |
| drop=vision_cfg.timm_drop, |
| drop_path=vision_cfg.timm_drop_path, |
| patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None, |
| embed_dim=embed_dim, |
| image_size=vision_cfg.image_size, |
| ) |
| elif isinstance(vision_cfg.layers, (tuple, list)): |
| vision_heads = vision_cfg.width * 32 // vision_cfg.head_width |
| visual = ModifiedResNet( |
| layers=vision_cfg.layers, |
| output_dim=embed_dim, |
| heads=vision_heads, |
| image_size=vision_cfg.image_size, |
| width=vision_cfg.width, |
| ) |
| else: |
| vision_heads = vision_cfg.width // vision_cfg.head_width |
| norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm |
| if vision_cfg.norm_kwargs: |
| norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs) |
| if vision_cfg.act_kwargs is not None: |
| act_layer = partial(act_layer, **vision_cfg.act_kwargs) |
|
|
| visual = VisionTransformer( |
| image_size=vision_cfg.image_size, |
| patch_size=vision_cfg.patch_size, |
| width=vision_cfg.width, |
| layers=vision_cfg.layers, |
| heads=vision_heads, |
| mlp_ratio=vision_cfg.mlp_ratio, |
| ls_init_value=vision_cfg.ls_init_value, |
| patch_dropout=vision_cfg.patch_dropout, |
| attentional_pool=vision_cfg.attentional_pool, |
| attn_pooler_queries=vision_cfg.attn_pooler_queries, |
| attn_pooler_heads=vision_cfg.attn_pooler_heads, |
| pos_embed_type=vision_cfg.pos_embed_type, |
| no_ln_pre=vision_cfg.no_ln_pre, |
| final_ln_after_pool=vision_cfg.final_ln_after_pool, |
| pool_type=vision_cfg.pool_type, |
| output_tokens=vision_cfg.output_tokens, |
| output_dim=embed_dim, |
| act_layer=act_layer, |
| norm_layer=norm_layer, |
| ) |
|
|
| return visual |
|
|
|
|
| def _build_text_tower( |
| embed_dim: int, |
| text_cfg: CLIPTextCfg, |
| quick_gelu: bool = False, |
| cast_dtype: Optional[torch.dtype] = None, |
| ): |
| if isinstance(text_cfg, dict): |
| text_cfg = CLIPTextCfg(**text_cfg) |
|
|
| if text_cfg.hf_model_name: |
| text = HFTextEncoder( |
| text_cfg.hf_model_name, |
| output_dim=embed_dim, |
| proj_type=text_cfg.hf_proj_type, |
| pooler_type=text_cfg.hf_pooler_type, |
| pretrained=text_cfg.hf_model_pretrained, |
| output_tokens=text_cfg.output_tokens, |
| ) |
| else: |
| act_layer = QuickGELU if quick_gelu else nn.GELU |
| norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm |
| if text_cfg.norm_kwargs: |
| norm_layer = partial(norm_layer, **text_cfg.norm_kwargs) |
| if text_cfg.act_kwargs is not None: |
| act_layer = partial(act_layer, **text_cfg.act_kwargs) |
|
|
| text = TextTransformer( |
| context_length=text_cfg.context_length, |
| vocab_size=text_cfg.vocab_size, |
| width=text_cfg.width, |
| heads=text_cfg.heads, |
| layers=text_cfg.layers, |
| mlp_ratio=text_cfg.mlp_ratio, |
| ls_init_value=text_cfg.ls_init_value, |
| output_dim=embed_dim, |
| embed_cls=text_cfg.embed_cls, |
| no_causal_mask=text_cfg.no_causal_mask, |
| pad_id=text_cfg.pad_id, |
| pool_type=text_cfg.pool_type, |
| proj_type=text_cfg.proj_type, |
| proj_bias=text_cfg.proj_bias, |
| output_tokens=text_cfg.output_tokens, |
| act_layer=act_layer, |
| norm_layer=norm_layer, |
| ) |
| return text |
|
|
|
|
| class CLIP(nn.Module): |
| output_dict: torch.jit.Final[bool] |
|
|
| def __init__( |
| self, |
| embed_dim: int, |
| vision_cfg: CLIPVisionCfg, |
| text_cfg: CLIPTextCfg, |
| quick_gelu: bool = False, |
| init_logit_scale: float = np.log(1 / 0.07), |
| init_logit_bias: Optional[float] = None, |
| nonscalar_logit_scale: bool = False, |
| cast_dtype: Optional[torch.dtype] = None, |
| output_dict: bool = False, |
| ): |
| super().__init__() |
| self.output_dict = output_dict |
|
|
| self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) |
|
|
| text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) |
| self.transformer = text.transformer |
| self.context_length = text.context_length |
| self.vocab_size = text.vocab_size |
| self.token_embedding = text.token_embedding |
| self.positional_embedding = text.positional_embedding |
| self.ln_final = text.ln_final |
| self.text_projection = text.text_projection |
| self.text_pool_type = text.pool_type |
| self.register_buffer('attn_mask', text.attn_mask, persistent=False) |
|
|
| lshape = [1] if nonscalar_logit_scale else [] |
| self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale) |
| if init_logit_bias is not None: |
| self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias) |
| else: |
| self.logit_bias = None |
|
|
| def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): |
| |
| self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) |
|
|
| @torch.jit.ignore |
| def set_grad_checkpointing(self, enable=True): |
| self.visual.set_grad_checkpointing(enable) |
| self.transformer.grad_checkpointing = enable |
|
|
| @torch.jit.ignore |
| def no_weight_decay(self): |
| |
| no_wd = {'positional_embedding'} |
| if hasattr(self.visual, 'no_weight_decay'): |
| for n in self.visual.no_weight_decay(): |
| no_wd.add('visual.' + n) |
| return no_wd |
|
|
| def encode_image(self, image, normalize: bool = False): |
| features = self.visual(image) |
| return F.normalize(features, dim=-1) if normalize else features |
|
|
| def encode_text(self, text, normalize: bool = False): |
| cast_dtype = self.transformer.get_cast_dtype() |
|
|
| x = self.token_embedding(text).to(cast_dtype) |
|
|
| x = x + self.positional_embedding.to(cast_dtype) |
| x = self.transformer(x, attn_mask=self.attn_mask) |
| x = self.ln_final(x) |
| x = text_global_pool(x, text, self.text_pool_type) |
| if self.text_projection is not None: |
| if isinstance(self.text_projection, nn.Linear): |
| x = self.text_projection(x) |
| else: |
| x = x @ self.text_projection |
|
|
| return F.normalize(x, dim=-1) if normalize else x |
|
|
| def get_logits(self, image, text): |
| image_features = self.encode_image(image, normalize=True) |
| text_features = self.encode_text(text, normalize=True) |
| image_logits = self.logit_scale.exp() * image_features @ text_features.T |
| if self.logit_bias is not None: |
| image_logits += self.logit_bias |
| text_logits = image_logits.T |
| return image_logits, text_logits |
|
|
| def forward_intermediates( |
| self, |
| image: Optional[torch.Tensor] = None, |
| text: Optional[torch.Tensor] = None, |
| image_indices: Optional[Union[int, List[int]]] = None, |
| text_indices: Optional[Union[int, List[int]]] = None, |
| stop_early: bool = False, |
| normalize: bool = True, |
| normalize_intermediates: bool = False, |
| intermediates_only: bool = False, |
| image_output_fmt: str = 'NCHW', |
| image_output_extra_tokens: bool = False, |
| text_output_fmt: str = 'NLC', |
| text_output_extra_tokens: bool = False, |
| output_logits: bool = False, |
| output_logit_scale_bias: bool = False, |
| ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: |
| """ Forward features that returns intermediates. |
| |
| Args: |
| image: Input image tensor |
| text: Input text tensor |
| image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence |
| text_indices: Take last n blocks if int, all if None, select matching indices if sequence |
| stop_early: Stop iterating over blocks when last desired intermediate hit |
| normalize_intermediates: Apply final norm layer to all intermediates |
| normalize: L2 Normalize final features |
| intermediates_only: Only return intermediate features, do not return final features |
| image_output_fmt: Shape of intermediate image feature outputs |
| image_output_extra_tokens: Return both prefix and spatial intermediate tokens |
| text_output_fmt: Shape of intermediate text feature outputs (ignored for this model) |
| text_output_extra_tokens: Return both prefix and spatial intermediate tokens (ignored for this model) |
| output_logits: Include logits in output |
| output_logit_scale_bias: Include the logit scale bias in the output |
| Returns: |
| |
| """ |
| output = {} |
| if intermediates_only: |
| |
| normalize = False |
| output_logits = False |
| if output_logits: |
| assert image is not None and text is not None, 'Both image and text inputs are required to compute logits' |
|
|
| if image is not None: |
| image_output = self.visual.forward_intermediates( |
| image, |
| indices=image_indices, |
| stop_early=stop_early, |
| normalize_intermediates=normalize_intermediates, |
| intermediates_only=intermediates_only, |
| output_fmt=image_output_fmt, |
| output_extra_tokens=image_output_extra_tokens, |
| ) |
| if normalize and "image_features" in image_output: |
| image_output["image_features"] = F.normalize(image_output["image_features"], dim=-1) |
| output.update(image_output) |
|
|
| if text is not None: |
| cast_dtype = self.transformer.get_cast_dtype() |
| x = self.token_embedding(text).to(cast_dtype) |
| x = x + self.positional_embedding.to(cast_dtype) |
| x, intermediates = self.transformer.forward_intermediates( |
| x, |
| attn_mask=self.attn_mask, |
| indices=text_indices |
| ) |
| if normalize_intermediates: |
| intermediates = [self.ln_final(xi) for xi in intermediates] |
|
|
| |
| output["text_intermediates"] = intermediates |
|
|
| if not intermediates_only: |
| x = self.ln_final(x) |
| x = text_global_pool(x, text, self.text_pool_type) |
| if self.text_projection is not None: |
| if isinstance(self.text_projection, nn.Linear): |
| x = self.text_projection(x) |
| else: |
| x = x @ self.text_projection |
| if normalize: |
| x = F.normalize(x, dim=-1) |
| output["text_features"] = x |
|
|
| logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None |
|
|
| if output_logits: |
| image_logits = logit_scale_exp * output["image_features"] @ output["text_features"].T |
| if self.logit_bias is not None: |
| image_logits += self.logit_bias |
| text_logits = image_logits.T |
| output["image_logits"] = image_logits |
| output["text_logits"] = text_logits |
|
|
| if output_logit_scale_bias: |
| output["logit_scale"] = logit_scale_exp |
| if self.logit_bias is not None: |
| output['logit_bias'] = self.logit_bias |
|
|
| return output |
|
|
| def forward( |
| self, |
| image: Optional[torch.Tensor] = None, |
| text: Optional[torch.Tensor] = None, |
| ): |
| image_features = self.encode_image(image, normalize=True) if image is not None else None |
| text_features = self.encode_text(text, normalize=True) if text is not None else None |
|
|
| if self.output_dict: |
| out_dict = { |
| "image_features": image_features, |
| "text_features": text_features, |
| "logit_scale": self.logit_scale.exp() |
| } |
| if self.logit_bias is not None: |
| out_dict['logit_bias'] = self.logit_bias |
| return out_dict |
|
|
| if self.logit_bias is not None: |
| return image_features, text_features, self.logit_scale.exp(), self.logit_bias |
| return image_features, text_features, self.logit_scale.exp() |
|
|
|
|
| class CustomTextCLIP(nn.Module): |
| output_dict: torch.jit.Final[bool] |
|
|
| def __init__( |
| self, |
| embed_dim: int, |
| vision_cfg: CLIPVisionCfg, |
| text_cfg: CLIPTextCfg, |
| quick_gelu: bool = False, |
| init_logit_scale: float = np.log(1 / 0.07), |
| init_logit_bias: Optional[float] = None, |
| nonscalar_logit_scale: bool = False, |
| cast_dtype: Optional[torch.dtype] = None, |
| output_dict: bool = False, |
| ): |
| super().__init__() |
| self.output_dict = output_dict |
| self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) |
| self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) |
| self.context_length = self.text.context_length |
| self.vocab_size = self.text.vocab_size |
|
|
| lshape = [1] if nonscalar_logit_scale else [] |
| self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale) |
| if init_logit_bias is not None: |
| self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias) |
| else: |
| self.logit_bias = None |
|
|
| def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): |
| |
| self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) |
|
|
| def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): |
| self.text.lock(unlocked_layers, freeze_layer_norm) |
|
|
| @torch.jit.ignore |
| def set_grad_checkpointing(self, enable=True): |
| self.visual.set_grad_checkpointing(enable) |
| self.text.set_grad_checkpointing(enable) |
|
|
| @torch.jit.ignore |
| def no_weight_decay(self): |
| |
| no_wd = set() |
| if hasattr(self.visual, 'no_weight_decay'): |
| for n in self.visual.no_weight_decay(): |
| no_wd.add('visual.' + n) |
| if hasattr(self.text, 'no_weight_decay'): |
| for n in self.visual.no_weight_decay(): |
| no_wd.add('text.' + n) |
| return no_wd |
|
|
| def encode_image(self, image, normalize: bool = False): |
| features = self.visual(image) |
| return F.normalize(features, dim=-1) if normalize else features |
|
|
| def encode_text(self, text, normalize: bool = False): |
| features = self.text(text) |
| return F.normalize(features, dim=-1) if normalize else features |
|
|
| def get_logits(self, image, text): |
| image_features = self.encode_image(image, normalize=True) |
| text_features = self.encode_text(text, normalize=True) |
| image_logits = self.logit_scale.exp() * image_features @ text_features.T |
| if self.logit_bias is not None: |
| image_logits += self.logit_bias |
| text_logits = image_logits.T |
| return image_logits, text_logits |
|
|
| def forward_intermediates( |
| self, |
| image: Optional[torch.Tensor] = None, |
| text: Optional[torch.Tensor] = None, |
| image_indices: Optional[Union[int, List[int]]] = None, |
| text_indices: Optional[Union[int, List[int]]] = None, |
| stop_early: bool = False, |
| normalize: bool = True, |
| normalize_intermediates: bool = False, |
| intermediates_only: bool = False, |
| image_output_fmt: str = 'NCHW', |
| image_output_extra_tokens: bool = False, |
| text_output_fmt: str = 'NLC', |
| text_output_extra_tokens: bool = False, |
| output_logits: bool = False, |
| output_logit_scale_bias: bool = False, |
| ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: |
| """ Forward features that returns intermediates. |
| |
| Args: |
| image: Input image tensor |
| text: Input text tensor |
| image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence |
| text_indices: Take last n blocks if int, all if None, select matching indices if sequence |
| stop_early: Stop iterating over blocks when last desired intermediate hit |
| normalize: L2 Normalize final image and text features (if present) |
| normalize_intermediates: Apply final encoder norm layer to all intermediates (if possible) |
| intermediates_only: Only return intermediate features, do not return final features |
| image_output_fmt: Shape of intermediate image feature outputs |
| image_output_extra_tokens: Return both prefix and spatial intermediate tokens |
| text_output_fmt: Shape of intermediate text feature outputs |
| text_output_extra_tokens: Return both prefix and spatial intermediate tokens |
| output_logits: Include logits in output |
| output_logit_scale_bias: Include the logit scale bias in the output |
| Returns: |
| |
| """ |
| output = {} |
| if intermediates_only: |
| |
| normalize = False |
| output_logits = False |
| if output_logits: |
| assert image is not None and text is not None, 'Both image and text inputs are required to compute logits' |
|
|
| if image is not None: |
| image_output = self.visual.forward_intermediates( |
| image, |
| indices=image_indices, |
| stop_early=stop_early, |
| normalize_intermediates=normalize_intermediates, |
| intermediates_only=intermediates_only, |
| output_fmt=image_output_fmt, |
| output_extra_tokens=image_output_extra_tokens, |
| ) |
| if normalize and "image_features" in image_output: |
| image_output["image_features"] = F.normalize(image_output["image_features"], dim=-1) |
| output.update(image_output) |
|
|
| if text is not None: |
| text_output = self.text.forward_intermediates( |
| text, |
| indices=text_indices, |
| stop_early=stop_early, |
| normalize_intermediates=normalize_intermediates, |
| intermediates_only=intermediates_only, |
| output_fmt=text_output_fmt, |
| output_extra_tokens=text_output_extra_tokens, |
| ) |
| if normalize and "text_features" in text_output: |
| text_output["text_features"] = F.normalize(text_output["text_features"], dim=-1) |
| output.update(text_output) |
|
|
| logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None |
|
|
| if output_logits: |
| image_logits = logit_scale_exp * output["image_features"] @ output["text_features"].T |
| if self.logit_bias is not None: |
| image_logits += self.logit_bias |
| text_logits = image_logits.T |
| output["image_logits"] = image_logits |
| output["text_logits"] = text_logits |
|
|
| if output_logit_scale_bias: |
| output["logit_scale"] = logit_scale_exp |
| if self.logit_bias is not None: |
| output['logit_bias'] = self.logit_bias |
|
|
| return output |
|
|
| def forward( |
| self, |
| image: Optional[torch.Tensor] = None, |
| text: Optional[torch.Tensor] = None, |
| ): |
| image_features = self.encode_image(image, normalize=True) if image is not None else None |
| text_features = self.encode_text(text, normalize=True) if text is not None else None |
|
|
| if self.output_dict: |
| out_dict = { |
| "image_features": image_features, |
| "text_features": text_features, |
| "logit_scale": self.logit_scale.exp() |
| } |
| if self.logit_bias is not None: |
| out_dict['logit_bias'] = self.logit_bias |
| return out_dict |
|
|
| if self.logit_bias is not None: |
| return image_features, text_features, self.logit_scale.exp(), self.logit_bias |
| return image_features, text_features, self.logit_scale.exp() |
|
|
|
|
| def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): |
| """Convert applicable model parameters to low-precision (bf16 or fp16)""" |
|
|
| def _convert_weights(l): |
| if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): |
| l.weight.data = l.weight.data.to(dtype) |
| if l.bias is not None: |
| l.bias.data = l.bias.data.to(dtype) |
|
|
| if isinstance(l, (nn.MultiheadAttention, Attention)): |
| for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: |
| tensor = getattr(l, attr) |
| if tensor is not None: |
| tensor.data = tensor.data.to(dtype) |
|
|
| if isinstance(l, (CLIP, TextTransformer)): |
| |
| attr = getattr(l, "text_projection", None) |
| if attr is not None: |
| attr.data = attr.data.to(dtype) |
|
|
| if isinstance(l, VisionTransformer): |
| |
| attr = getattr(l, "proj", None) |
| if attr is not None: |
| attr.data = attr.data.to(dtype) |
|
|
| model.apply(_convert_weights) |
|
|
|
|
| convert_weights_to_fp16 = convert_weights_to_lp |
|
|
|
|
| |
| def convert_to_custom_text_state_dict(state_dict: dict): |
| if 'text_projection' in state_dict: |
| |
| new_state_dict = {} |
| for k, v in state_dict.items(): |
| if any(k.startswith(p) for p in ( |
| 'text_projection', |
| 'positional_embedding', |
| 'token_embedding', |
| 'transformer', |
| 'ln_final', |
| )): |
| k = 'text.' + k |
| new_state_dict[k] = v |
| return new_state_dict |
| return state_dict |
|
|
|
|
| def build_model_from_openai_state_dict( |
| state_dict: dict, |
| quick_gelu=True, |
| cast_dtype=torch.float16, |
| ): |
| vit = "visual.proj" in state_dict |
|
|
| if vit: |
| vision_width = state_dict["visual.conv1.weight"].shape[0] |
| vision_layers = len( |
| [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) |
| vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] |
| grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) |
| image_size = vision_patch_size * grid_size |
| else: |
| counts: list = [ |
| len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] |
| vision_layers = tuple(counts) |
| vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] |
| output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) |
| vision_patch_size = None |
| assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] |
| image_size = output_width * 32 |
|
|
| embed_dim = state_dict["text_projection"].shape[1] |
| context_length = state_dict["positional_embedding"].shape[0] |
| vocab_size = state_dict["token_embedding.weight"].shape[0] |
| transformer_width = state_dict["ln_final.weight"].shape[0] |
| transformer_heads = transformer_width // 64 |
| transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) |
|
|
| vision_cfg = CLIPVisionCfg( |
| layers=vision_layers, |
| width=vision_width, |
| patch_size=vision_patch_size, |
| image_size=image_size, |
| ) |
| text_cfg = CLIPTextCfg( |
| context_length=context_length, |
| vocab_size=vocab_size, |
| width=transformer_width, |
| heads=transformer_heads, |
| layers=transformer_layers, |
| ) |
| model = CLIP( |
| embed_dim, |
| vision_cfg=vision_cfg, |
| text_cfg=text_cfg, |
| quick_gelu=quick_gelu, |
| cast_dtype=cast_dtype, |
| ) |
|
|
| for key in ["input_resolution", "context_length", "vocab_size"]: |
| state_dict.pop(key, None) |
| convert_weights_to_fp16(model) |
| model.load_state_dict(state_dict) |
| return model.eval() |
|
|
|
|
| def trace_model(model, batch_size=256, device=torch.device('cpu')): |
| model.eval() |
| image_size = model.visual.image_size |
| example_images = torch.ones((batch_size, 3, image_size, image_size), device=device) |
| example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device) |
| model = torch.jit.trace_module( |
| model, |
| inputs=dict( |
| forward=(example_images, example_text), |
| encode_text=(example_text,), |
| encode_image=(example_images,) |
| )) |
| model.visual.image_size = image_size |
| return model |
|
|
|
|
| def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True): |
| |
| old_pos_embed = state_dict.get('visual.positional_embedding', None) |
| if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): |
| return |
| grid_size = to_2tuple(model.visual.grid_size) |
| extra_tokens = 1 |
| new_seq_len = grid_size[0] * grid_size[1] + extra_tokens |
| if new_seq_len == old_pos_embed.shape[0]: |
| return |
|
|
| if extra_tokens: |
| pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] |
| else: |
| pos_emb_tok, pos_emb_img = None, old_pos_embed |
| old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) |
|
|
| logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) |
| pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) |
| pos_emb_img = F.interpolate( |
| pos_emb_img, |
| size=grid_size, |
| mode=interpolation, |
| antialias=antialias, |
| align_corners=False, |
| ) |
| pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] |
| if pos_emb_tok is not None: |
| new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) |
| else: |
| new_pos_embed = pos_emb_img |
| state_dict['visual.positional_embedding'] = new_pos_embed |
|
|
|
|
| def resize_text_pos_embed(state_dict, model, interpolation: str = 'linear', antialias: bool = False): |
| old_pos_embed = state_dict.get('positional_embedding', None) |
| if old_pos_embed is None: |
| return |
| |
| model_pos_embed = getattr(model, 'positional_embedding', None) |
| if model_pos_embed is None: |
| model_pos_embed = getattr(model.text, 'positional_embedding', None) |
|
|
| old_num_pos = old_pos_embed.shape[0] |
| old_width = old_pos_embed.shape[1] |
| num_pos = model_pos_embed.shape[0] |
| width = model_pos_embed.shape[1] |
| assert old_width == width, 'text pos_embed width changed!' |
| if old_num_pos == num_pos: |
| return |
|
|
| logging.info('Resizing text position embedding num_pos from %s to %s', old_num_pos, num_pos) |
| old_pos_embed = old_pos_embed.reshape(1, old_num_pos, old_width).permute(0, 2, 1) |
| old_pos_embed = F.interpolate( |
| old_pos_embed, |
| size=num_pos, |
| mode=interpolation, |
| antialias=antialias, |
| align_corners=False, |
| ) |
| old_pos_embed = old_pos_embed.permute(0, 2, 1)[0] |
| new_pos_embed = old_pos_embed |
|
|
| state_dict['positional_embedding'] = new_pos_embed |
|
|
|
|
| def get_model_preprocess_cfg(model): |
| module = getattr(model, 'visual', model) |
| preprocess_cfg = getattr(module, 'preprocess_cfg', {}) |
| if not preprocess_cfg: |
| |
| size = getattr(module, 'image_size') |
| if size is not None: |
| preprocess_cfg['size'] = size |
| mean = getattr(module, 'image_mean', None) |
| if mean is not None: |
| preprocess_cfg['mean'] = mean |
| std = getattr(module, 'image_std', None) |
| if std is not None: |
| preprocess_cfg['std'] = std |
| return preprocess_cfg |
|
|
|
|
| def set_model_preprocess_cfg(model, preprocess_cfg: Dict[str, Any]): |
| module = getattr(model, 'visual', model) |
| module.image_mean = preprocess_cfg['mean'] |
| module.image_std = preprocess_cfg['std'] |
| module.preprocess_cfg = copy.deepcopy(preprocess_cfg) |
|
|
|
|
| def get_model_tokenize_cfg(model): |
| module = getattr(model, 'text', model) |
| cfg = {} |
| context_length = getattr(module, 'context_length', None) |
| if context_length is not None: |
| cfg['context_length'] = context_length |
| vocab_size = getattr(module, 'vocab_size', None) |
| if vocab_size is not None: |
| cfg['vocab_size'] = vocab_size |
| return cfg |
|
|
|
|
|
|
| try: |
| from huggingface_hub import hf_hub_download |
| hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__) |
| _has_hf_hub = True |
| except ImportError: |
| hf_hub_download = None |
| _has_hf_hub = False |
|
|
|
|
| def _pcfg(url='', hf_hub='', **kwargs): |
| |
| return { |
| 'url': url, |
| 'hf_hub': hf_hub, |
| 'mean': OPENAI_DATASET_MEAN, |
| 'std': OPENAI_DATASET_STD, |
| 'interpolation': 'bicubic', |
| 'resize_mode': 'shortest', |
| **kwargs, |
| } |
|
|
|
|
| def _slpcfg(url='', hf_hub='', **kwargs): |
| |
| return { |
| 'url': url, |
| 'hf_hub': hf_hub, |
| 'mean': INCEPTION_MEAN, |
| 'std': INCEPTION_STD, |
| 'interpolation': 'bicubic', |
| 'resize_mode': 'squash', |
| **kwargs, |
| } |
|
|
|
|
| def _apcfg(url='', hf_hub='', **kwargs): |
| |
| return { |
| 'url': url, |
| 'hf_hub': hf_hub, |
| 'mean': IMAGENET_MEAN, |
| 'std': IMAGENET_STD, |
| 'interpolation': 'bilinear', |
| 'resize_mode': 'squash', |
| **kwargs, |
| } |
|
|
|
|
| def _mccfg(url='', hf_hub='', **kwargs): |
| |
| return { |
| 'url': url, |
| 'hf_hub': hf_hub, |
| 'mean': (0., 0., 0.), |
| 'std': (1., 1., 1.), |
| 'interpolation': 'bilinear', |
| 'resize_mode': 'shortest', |
| **kwargs, |
| } |
|
|
|
|
|
|
| _RN50 = dict( |
| openai=_pcfg( |
| url="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", |
| hf_hub="timm/resnet50_clip.openai/", |
| quick_gelu=True, |
| ), |
| yfcc15m=_pcfg( |
| url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", |
| hf_hub="timm/resnet50_clip.yfcc15m/", |
| quick_gelu=True, |
| ), |
| cc12m=_pcfg( |
| url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt", |
| hf_hub="timm/resnet50_clip.cc12m/", |
| quick_gelu=True, |
| ), |
| ) |
|
|
| _RN101 = dict( |
| openai=_pcfg( |
| url="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", |
| hf_hub="timm/resnet101_clip.openai/", |
| quick_gelu=True, |
| ), |
| yfcc15m=_pcfg( |
| url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt", |
| hf_hub="timm/resnet101_clip.yfcc15m/", |
| quick_gelu=True, |
| ), |
| ) |
|
|
| _RN50x4 = dict( |
| openai=_pcfg( |
| url="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", |
| hf_hub="timm/resnet50x4_clip.openai/", |
| quick_gelu=True, |
| ), |
| ) |
|
|
| _RN50x16 = dict( |
| openai=_pcfg( |
| url="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", |
| hf_hub="timm/resnet50x16_clip.openai/", |
| quick_gelu=True, |
| ), |
| ) |
|
|
| _RN50x64 = dict( |
| openai=_pcfg( |
| url="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", |
| hf_hub="timm/resnet50x64_clip.openai/", |
| quick_gelu=True, |
| ), |
| ) |
|
|
| _VITB32 = dict( |
| openai=_pcfg( |
| url="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", |
| hf_hub="timm/vit_base_patch32_clip_224.openai/", |
| quick_gelu=True, |
| ), |
| |
| laion400m_e31=_pcfg( |
| url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", |
| hf_hub="timm/vit_base_patch32_clip_224.laion400m_e31/", |
| quick_gelu=True, |
| ), |
| laion400m_e32=_pcfg( |
| url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", |
| hf_hub="timm/vit_base_patch32_clip_224.laion400m_e32/", |
| quick_gelu=True, |
| ), |
| |
| laion2b_e16=_pcfg( |
| url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth", |
| hf_hub="timm/vit_base_patch32_clip_224.laion2b_e16/", |
| ), |
| laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/'), |
| |
| datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K/'), |
| |
| datacomp_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.M-s128M-b4K/'), |
| commonpool_m_clip_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.clip-s128M-b4K/'), |
| commonpool_m_laion_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.laion-s128M-b4K/'), |
| commonpool_m_image_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.image-s128M-b4K/'), |
| commonpool_m_text_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.text-s128M-b4K/'), |
| commonpool_m_basic_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.basic-s128M-b4K/'), |
| commonpool_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M-s128M-b4K/'), |
| |
| datacomp_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.S-s13M-b4K/'), |
| commonpool_s_clip_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.clip-s13M-b4K/'), |
| commonpool_s_laion_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.laion-s13M-b4K/'), |
| commonpool_s_image_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.image-s13M-b4K/'), |
| commonpool_s_text_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.text-s13M-b4K/'), |
| commonpool_s_basic_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.basic-s13M-b4K/'), |
| commonpool_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S-s13M-b4K/'), |
| |
| metaclip_400m=_pcfg( |
| url="https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_400m.pt", |
| hf_hub="timm/vit_base_patch32_clip_224.metaclip_400m/", |
| quick_gelu=True, |
| ), |
| metaclip_fullcc=_pcfg( |
| url="https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_fullcc2.5b.pt", |
| hf_hub="timm/vit_base_patch32_clip_224.metaclip_2pt5b/", |
| quick_gelu=True, |
| ), |
| ) |
|
|
| _VITB32_256 = dict( |
| datacomp_s34b_b86k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-256x256-DataComp-s34B-b86K/'), |
| ) |
|
|
| _VITB16 = dict( |
| openai=_pcfg( |
| url="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", |
| hf_hub="timm/vit_base_patch16_clip_224.openai/", |
| quick_gelu=True, |
| ), |
| |
| laion400m_e31=_pcfg( |
| url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt", |
| hf_hub="timm/vit_base_patch16_clip_224.laion400m_e31/", |
| ), |
| laion400m_e32=_pcfg( |
| url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt", |
| hf_hub="timm/vit_base_patch16_clip_224.laion400m_e32/", |
| ), |
| |
| laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'), |
| |
| datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.XL-s13B-b90K/'), |
| |
| datacomp_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.L-s1B-b8K/'), |
| commonpool_l_clip_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.clip-s1B-b8K/'), |
| commonpool_l_laion_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.laion-s1B-b8K/'), |
| commonpool_l_image_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.image-s1B-b8K/'), |
| commonpool_l_text_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.text-s1B-b8K/'), |
| commonpool_l_basic_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.basic-s1B-b8K/'), |
| commonpool_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L-s1B-b8K/'), |
| |
| dfn2b=_pcfg( |
| hf_hub='apple/DFN2B-CLIP-ViT-B-16/', |
| quick_gelu=True, |
| ), |
| |
| metaclip_400m=_pcfg( |
| url="https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_400m.pt", |
| hf_hub="timm/vit_base_patch16_clip_224.metaclip_400m/", |
| quick_gelu=True, |
| ), |
| metaclip_fullcc=_pcfg( |
| url="https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_fullcc2.5b.pt", |
| hf_hub="timm/vit_base_patch16_clip_224.metaclip_2pt5b/", |
| quick_gelu=True, |
| ), |
| ) |
|
|
| _VITB16_PLUS_240 = dict( |
| laion400m_e31=_pcfg( |
| url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt", |
| hf_hub="timm/vit_base_patch16_plus_clip_240.laion400m_e31/", |
| ), |
| laion400m_e32=_pcfg( |
| url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt", |
| hf_hub="timm/vit_base_patch16_plus_clip_240.laion400m_e31/", |
| ), |
| ) |
|
|
| _VITL14 = dict( |
| openai=_pcfg( |
| url="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", |
| hf_hub="timm/vit_large_patch14_clip_224.openai/", |
| quick_gelu=True, |
| ), |
| |
| laion400m_e31=_pcfg( |
| url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt", |
| hf_hub="timm/vit_large_patch14_clip_224.laion400m_e31/", |
| ), |
| laion400m_e32=_pcfg( |
| url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt", |
| hf_hub="timm/vit_large_patch14_clip_224.laion400m_e32/", |
| ), |
| |
| laion2b_s32b_b82k=_pcfg( |
| hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/', |
| mean=INCEPTION_MEAN, std=INCEPTION_STD), |
| |
| datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K/'), |
| commonpool_xl_clip_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.clip-s13B-b90K/'), |
| commonpool_xl_laion_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.laion-s13B-b90K/'), |
| commonpool_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL-s13B-b90K/'), |
| |
| metaclip_400m=_pcfg( |
| url="https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_400m.pt", |
| hf_hub="timm/vit_large_patch14_clip_224.metaclip_400m/", |
| quick_gelu=True, |
| ), |
| metaclip_fullcc=_pcfg( |
| url="https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_fullcc2.5b.pt", |
| hf_hub="timm/vit_large_patch14_clip_224.metaclip_2pt5b/", |
| quick_gelu=True, |
| ), |
| |
| dfn2b=_pcfg( |
| hf_hub='apple/DFN2B-CLIP-ViT-L-14/', |
| quick_gelu=True, |
| ), |
| |
| dfn2b_s39b=_pcfg( |
| hf_hub='apple/DFN2B-CLIP-ViT-L-14-39B/', |
| ), |
| ) |
|
|
| _VITL14_336 = dict( |
| openai=_pcfg( |
| url="https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", |
| hf_hub="timm/vit_large_patch14_clip_336.openai/", |
| quick_gelu=True, |
| ), |
| ) |
|
|
| _VITH14 = dict( |
| |
| laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'), |
| |
| metaclip_fullcc=_pcfg( |
| url="https://dl.fbaipublicfiles.com/MMPT/metaclip/h14_fullcc2.5b.pt", |
| hf_hub="timm/vit_huge_patch14_clip_224.metaclip_2pt5b/", |
| quick_gelu=True, |
| ), |
| metaclip_altogether=_pcfg( |
| url="https://dl.fbaipublicfiles.com/MMPT/metaclip/h14_v1.2_altogether.pt", |
| hf_hub="timm/vit_huge_patch14_clip_224.metaclip_altogether/", |
| |
| ), |
| |
| dfn5b=_pcfg( |
| hf_hub='apple/DFN5B-CLIP-ViT-H-14/', |
| quick_gelu=True, |
| interpolation="bicubic", |
| resize_mode="squash" |
| ), |
| ) |
|
|
| _VITH14_378 = dict( |
| |
| dfn5b=_pcfg( |
| hf_hub='apple/DFN5B-CLIP-ViT-H-14-378/', |
| quick_gelu=True, |
| interpolation="bicubic", |
| resize_mode="squash" |
| ), |
| ) |
|
|
| _VITg14 = dict( |
| laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'), |
| laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'), |
| ) |
|
|
| _VITbigG14 = dict( |
| |
| laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'), |
| |
| metaclip_fullcc=_pcfg( |
| url='https://dl.fbaipublicfiles.com/MMPT/metaclip/G14_fullcc2.5b.pt', |
| hf_hub="timm/vit_gigantic_patch14_clip_224.metaclip_2pt5b/", |
| quick_gelu=True, |
| ), |
| ) |
|
|
| _robertaViTB32 = dict( |
| laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'), |
| ) |
|
|
| _xlmRobertaBaseViTB32 = dict( |
| laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'), |
| ) |
|
|
| _xlmRobertaLargeFrozenViTH14 = dict( |
| frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'), |
| ) |
|
|
| _convnext_base = dict( |
| laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'), |
| ) |
|
|
| _convnext_base_w = dict( |
| laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'), |
| laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'), |
| laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'), |
| ) |
|
|
| _convnext_base_w_320 = dict( |
| laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'), |
| laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'), |
| ) |
|
|
| _convnext_large_d = dict( |
| laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'), |
| ) |
|
|
| _convnext_large_d_320 = dict( |
| laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'), |
| laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'), |
| ) |
|
|
| _convnext_xxlarge = dict( |
| laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'), |
| laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'), |
| laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'), |
| ) |
|
|
| _coca_VITB32 = dict( |
| laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'), |
| mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/') |
| ) |
|
|
| _coca_VITL14 = dict( |
| laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'), |
| mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/') |
| ) |
|
|
|
|
| _PRETRAINED = { |
| "RN50": _RN50, |
| "RN101": _RN101, |
| "RN50x4": _RN50x4, |
| "RN50x16": _RN50x16, |
| "RN50x64": _RN50x64, |
|
|
| "ViT-B-32": _VITB32, |
| "ViT-B-32-256": _VITB32_256, |
| "ViT-B-16": _VITB16, |
| "ViT-B-16-plus-240": _VITB16_PLUS_240, |
| "ViT-L-14": _VITL14, |
| "ViT-L-14-336": _VITL14_336, |
| "ViT-H-14": _VITH14, |
| "ViT-H-14-378": _VITH14_378, |
| "ViT-g-14": _VITg14, |
| "ViT-bigG-14": _VITbigG14, |
|
|
| "roberta-ViT-B-32": _robertaViTB32, |
| "xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32, |
| "xlm-roberta-large-ViT-H-14": _xlmRobertaLargeFrozenViTH14, |
|
|
| "convnext_base": _convnext_base, |
| "convnext_base_w": _convnext_base_w, |
| "convnext_base_w_320": _convnext_base_w_320, |
| "convnext_large_d": _convnext_large_d, |
| "convnext_large_d_320": _convnext_large_d_320, |
| "convnext_xxlarge": _convnext_xxlarge, |
|
|
| "coca_ViT-B-32": _coca_VITB32, |
| "coca_ViT-L-14": _coca_VITL14, |
|
|
| "EVA01-g-14": dict( |
| |
| laion400m_s11b_b41k=_pcfg(hf_hub='timm/eva_giant_patch14_clip_224.laion400m_s11b_b41k/'), |
| ), |
| "EVA01-g-14-plus": dict( |
| |
| merged2b_s11b_b114k=_pcfg(hf_hub='timm/eva_giant_patch14_plus_clip_224.merged2b_s11b_b114k/'), |
| ), |
| "EVA02-B-16": dict( |
| |
| merged2b_s8b_b131k=_pcfg(hf_hub='timm/eva02_base_patch16_clip_224.merged2b_s8b_b131k/'), |
| ), |
| "EVA02-L-14": dict( |
| |
| merged2b_s4b_b131k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_224.merged2b_s4b_b131k/'), |
| ), |
| "EVA02-L-14-336": dict( |
| |
| merged2b_s6b_b61k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_336.merged2b_s6b_b61k/'), |
| ), |
| "EVA02-E-14": dict( |
| |
| laion2b_s4b_b115k=_pcfg(hf_hub='timm/eva02_enormous_patch14_clip_224.laion2b_s4b_b115k/'), |
| ), |
| "EVA02-E-14-plus": dict( |
| |
| laion2b_s9b_b144k=_pcfg(hf_hub='timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k/'), |
| ), |
|
|
| "ViT-B-16-SigLIP": dict( |
| webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP/'), |
| ), |
| "ViT-B-16-SigLIP-256": dict( |
| webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-256/'), |
| ), |
| "ViT-B-16-SigLIP-i18n-256": dict( |
| webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-i18n-256/'), |
| ), |
| "ViT-B-16-SigLIP-384": dict( |
| webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-384/'), |
| ), |
| "ViT-B-16-SigLIP-512": dict( |
| webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-512/'), |
| ), |
| "ViT-L-16-SigLIP-256": dict( |
| webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP-256/'), |
| ), |
| "ViT-L-16-SigLIP-384": dict( |
| webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP-384/'), |
| ), |
| "ViT-SO400M-14-SigLIP": dict( |
| webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP/'), |
| ), |
| "ViT-SO400M-16-SigLIP-i18n-256": dict( |
| webli=_slpcfg(hf_hub='timm/ViT-SO400M-16-SigLIP-i18n-256/'), |
| ), |
| "ViT-SO400M-14-SigLIP-378": dict( |
| webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP-384/'), |
| ), |
| "ViT-SO400M-14-SigLIP-384": dict( |
| webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP-384/'), |
| ), |
|
|
| "ViT-B-32-SigLIP2-256": dict( |
| webli=_slpcfg(hf_hub='timm/ViT-B-32-SigLIP2-256/'), |
| ), |
| "ViT-B-16-SigLIP2": dict( |
| webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP2/'), |
| ), |
| "ViT-B-16-SigLIP2-256": dict( |
| webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP2-256/'), |
| ), |
| "ViT-B-16-SigLIP2-384": dict( |
| webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP2-384/'), |
| ), |
| "ViT-B-16-SigLIP2-512": dict( |
| webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP2-512/'), |
| ), |
| "ViT-L-16-SigLIP2-256": dict( |
| webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP2-256/'), |
| ), |
| "ViT-L-16-SigLIP2-384": dict( |
| webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP2-384/'), |
| ), |
| "ViT-L-16-SigLIP2-512": dict( |
| webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP2-512/'), |
| ), |
| "ViT-SO400M-14-SigLIP2": dict( |
| webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP2/'), |
| ), |
| "ViT-SO400M-14-SigLIP2-378": dict( |
| webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP2-378/'), |
| ), |
| "ViT-SO400M-16-SigLIP2-256": dict( |
| webli=_slpcfg(hf_hub='timm/ViT-SO400M-16-SigLIP2-256/'), |
| ), |
| "ViT-SO400M-16-SigLIP2-384": dict( |
| webli=_slpcfg(hf_hub='timm/ViT-SO400M-16-SigLIP2-384/'), |
| ), |
| "ViT-SO400M-16-SigLIP2-512": dict( |
| webli=_slpcfg(hf_hub='timm/ViT-SO400M-16-SigLIP2-512/'), |
| ), |
| "ViT-gopt-16-SigLIP2-256": dict( |
| webli=_slpcfg(hf_hub='timm/ViT-gopt-16-SigLIP2-256/'), |
| ), |
| "ViT-gopt-16-SigLIP2-384": dict( |
| webli=_slpcfg(hf_hub='timm/ViT-gopt-16-SigLIP2-384/'), |
| ), |
|
|
| "ViT-L-14-CLIPA": dict( |
| datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-L-14-CLIPA-datacomp1B/'), |
| ), |
| "ViT-L-14-CLIPA-336": dict( |
| datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-L-14-CLIPA-336-datacomp1B/'), |
| ), |
| "ViT-H-14-CLIPA": dict( |
| datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-datacomp1B/'), |
| ), |
| "ViT-H-14-CLIPA-336": dict( |
| laion2b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-336-laion2B/'), |
| datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-336-datacomp1B/'), |
| ), |
| "ViT-bigG-14-CLIPA": dict( |
| datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-bigG-14-CLIPA-datacomp1B/'), |
| ), |
| "ViT-bigG-14-CLIPA-336": dict( |
| datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-bigG-14-CLIPA-336-datacomp1B/'), |
| ), |
|
|
| "nllb-clip-base": dict( |
| v1=_pcfg(hf_hub='visheratin/nllb-clip-base-oc/'), |
| ), |
| "nllb-clip-large": dict( |
| v1=_pcfg(hf_hub='visheratin/nllb-clip-large-oc/'), |
| ), |
|
|
| "nllb-clip-base-siglip": dict( |
| v1=_slpcfg(hf_hub='visheratin/nllb-clip-base-siglip/'), |
| mrl=_slpcfg(hf_hub='visheratin/nllb-siglip-mrl-base/'), |
| ), |
| "nllb-clip-large-siglip": dict( |
| v1=_slpcfg(hf_hub='visheratin/nllb-clip-large-siglip/'), |
| mrl=_slpcfg(hf_hub='visheratin/nllb-siglip-mrl-large/'), |
| ), |
|
|
| "MobileCLIP-S1": dict( |
| datacompdr=_mccfg(hf_hub='apple/MobileCLIP-S1-OpenCLIP/')), |
| "MobileCLIP-S2": dict( |
| datacompdr=_mccfg(hf_hub='apple/MobileCLIP-S2-OpenCLIP/')), |
| "MobileCLIP-B": dict( |
| datacompdr=_mccfg(hf_hub='apple/MobileCLIP-B-OpenCLIP/'), |
| datacompdr_lt=_mccfg(hf_hub='apple/MobileCLIP-B-LT-OpenCLIP/'), |
| ), |
|
|
| "ViTamin-S": dict( |
| datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-S/pytorch_model.bin'), |
| ), |
| "ViTamin-S-LTT": dict( |
| datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-S-LTT/pytorch_model.bin'), |
| ), |
| "ViTamin-B": dict( |
| datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-B/pytorch_model.bin'), |
| ), |
| "ViTamin-B-LTT": dict( |
| datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-B-LTT/pytorch_model.bin'), |
| ), |
| "ViTamin-L": dict( |
| datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-224px/pytorch_model.bin'), |
| ), |
| "ViTamin-L-256": dict( |
| datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-256px/pytorch_model.bin'), |
| ), |
| "ViTamin-L-336": dict( |
| datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-336px/pytorch_model.bin'), |
| ), |
| "ViTamin-L-384": dict( |
| datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-384px/pytorch_model.bin'), |
| ), |
| "ViTamin-L2": dict( |
| datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-224px/pytorch_model.bin'), |
| ), |
| "ViTamin-L2-256": dict( |
| datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-256px/pytorch_model.bin'), |
| ), |
| "ViTamin-L2-336": dict( |
| datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-336px/pytorch_model.bin'), |
| ), |
| "ViTamin-L2-384": dict( |
| datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-384px/pytorch_model.bin'), |
| ), |
| "ViTamin-XL-256": dict( |
| datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-256px/pytorch_model.bin'), |
| ), |
| "ViTamin-XL-336": dict( |
| datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-336px/pytorch_model.bin'), |
| ), |
| "ViTamin-XL-384": dict( |
| datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-384px/pytorch_model.bin'), |
| ), |
| } |
|
|
| _PRETRAINED_quickgelu = {} |
| for k, v in _PRETRAINED.items(): |
| quick_gelu_tags = {} |
| for tk, tv in v.items(): |
| if tv.get('quick_gelu', False): |
| quick_gelu_tags[tk] = copy.deepcopy(tv) |
| if quick_gelu_tags: |
| _PRETRAINED_quickgelu[k + '-quickgelu'] = quick_gelu_tags |
| _PRETRAINED.update(_PRETRAINED_quickgelu) |
|
|
| def _clean_tag(tag: str): |
| |
| return tag.lower().replace('-', '_') |
|
|
|
|
| def list_pretrained(as_str: bool = False): |
| """ returns list of pretrained models |
| Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True |
| """ |
| return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] |
|
|
|
|
| def list_pretrained_models_by_tag(tag: str): |
| """ return all models having the specified pretrain tag """ |
| models = [] |
| tag = _clean_tag(tag) |
| for k in _PRETRAINED.keys(): |
| if tag in _PRETRAINED[k]: |
| models.append(k) |
| return models |
|
|
|
|
| def list_pretrained_tags_by_model(model: str): |
| """ return all pretrain tags for the specified model architecture """ |
| tags = [] |
| if model in _PRETRAINED: |
| tags.extend(_PRETRAINED[model].keys()) |
| return tags |
|
|
|
|
| def is_pretrained_cfg(model: str, tag: str): |
| if model not in _PRETRAINED: |
| return False |
| return _clean_tag(tag) in _PRETRAINED[model] |
|
|
|
|
| def get_pretrained_cfg(model: str, tag: str): |
| if model not in _PRETRAINED: |
| return {} |
| model_pretrained = _PRETRAINED[model] |
| return model_pretrained.get(_clean_tag(tag), {}) |
|
|
|
|
| def get_pretrained_url(model: str, tag: str): |
| cfg = get_pretrained_cfg(model, _clean_tag(tag)) |
| return cfg.get('url', '') |
|
|
|
|
| def download_pretrained_from_url( |
| url: str, |
| cache_dir: Optional[str] = None, |
| ): |
| if not cache_dir: |
| cache_dir = os.path.expanduser("~/.cache/clip") |
| os.makedirs(cache_dir, exist_ok=True) |
| filename = os.path.basename(url) |
|
|
| if 'openaipublic' in url: |
| expected_sha256 = url.split("/")[-2] |
| elif 'mlfoundations' in url: |
| expected_sha256 = os.path.splitext(filename)[0].split("-")[-1] |
| else: |
| expected_sha256 = '' |
|
|
| download_target = os.path.join(cache_dir, filename) |
|
|
| if os.path.exists(download_target) and not os.path.isfile(download_target): |
| raise RuntimeError(f"{download_target} exists and is not a regular file") |
|
|
| if os.path.isfile(download_target): |
| if expected_sha256: |
| if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): |
| return download_target |
| else: |
| warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") |
| else: |
| return download_target |
|
|
| with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: |
| with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: |
| while True: |
| buffer = source.read(8192) |
| if not buffer: |
| break |
|
|
| output.write(buffer) |
| loop.update(len(buffer)) |
|
|
| if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): |
| raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") |
|
|
| return download_target |
|
|
|
|
| def has_hf_hub(necessary=False): |
| if not _has_hf_hub and necessary: |
| |
| raise RuntimeError( |
| 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') |
| return _has_hf_hub |
|
|
|
|
| def _get_safe_alternatives(filename: str) -> Iterable[str]: |
| """Returns potential safetensors alternatives for a given filename. |
| |
| Use case: |
| When downloading a model from the Huggingface Hub, we first look if a .safetensors file exists and if yes, we use it. |
| """ |
| if filename == HF_WEIGHTS_NAME: |
| yield HF_SAFE_WEIGHTS_NAME |
|
|
| if filename not in (HF_WEIGHTS_NAME,) and (filename.endswith(".bin") or filename.endswith(".pth")): |
| yield filename[:-4] + ".safetensors" |
|
|
|
|
| def download_pretrained_from_hf( |
| model_id: str, |
| filename: Optional[str] = None, |
| revision: Optional[str] = None, |
| cache_dir: Optional[str] = None, |
| ): |
| has_hf_hub(True) |
|
|
| filename = filename or HF_WEIGHTS_NAME |
|
|
| |
| if _has_safetensors: |
| for safe_filename in _get_safe_alternatives(filename): |
| try: |
| cached_file = hf_hub_download( |
| repo_id=model_id, |
| filename=safe_filename, |
| revision=revision, |
| cache_dir=cache_dir, |
| ) |
| return cached_file |
| except Exception: |
| pass |
|
|
| try: |
| |
| cached_file = hf_hub_download( |
| repo_id=model_id, |
| filename=filename, |
| revision=revision, |
| cache_dir=cache_dir, |
| ) |
| return cached_file |
| except Exception as e: |
| raise FileNotFoundError(f"Failed to download file ({filename}) for {model_id}. Last error: {e}") |
|
|
|
|
| def download_pretrained( |
| cfg: Dict, |
| prefer_hf_hub: bool = True, |
| cache_dir: Optional[str] = None, |
| ): |
| target = '' |
| if not cfg: |
| return target |
|
|
| if 'file' in cfg: |
| return cfg['file'] |
|
|
| has_hub = has_hf_hub() |
| download_url = cfg.get('url', '') |
| download_hf_hub = cfg.get('hf_hub', '') |
| if has_hub and prefer_hf_hub and download_hf_hub: |
| |
| download_url = '' |
|
|
| if download_url: |
| target = download_pretrained_from_url(download_url, cache_dir=cache_dir) |
| elif download_hf_hub: |
| has_hf_hub(True) |
| |
| |
| |
| model_id, filename = os.path.split(download_hf_hub) |
| if filename: |
| target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir) |
| else: |
| target = download_pretrained_from_hf(model_id, cache_dir=cache_dir) |
|
|
| return target |
|
|
| |
| def merge_preprocess_dict( |
| base: Union[PreprocessCfg, Dict], |
| overlay: Dict, |
| ): |
| """ Merge overlay key-value pairs on top of base preprocess cfg or dict. |
| Input dicts are filtered based on PreprocessCfg fields. |
| """ |
| if isinstance(base, PreprocessCfg): |
| base_clean = asdict(base) |
| else: |
| base_clean = {k: v for k, v in base.items() if k in _PREPROCESS_KEYS} |
| if overlay: |
| overlay_clean = {k: v for k, v in overlay.items() if k in _PREPROCESS_KEYS and v is not None} |
| base_clean.update(overlay_clean) |
| return base_clean |
|
|
|
|
| def merge_preprocess_kwargs(base: PreprocessCfg, **kwargs): |
| return merge_preprocess_dict(base, kwargs) |
|
|
|
|
| @dataclass |
| class PreprocessCfg: |
| size: Union[int, Tuple[int, int]] = 224 |
| mode: str = 'RGB' |
| mean: Tuple[float, ...] = OPENAI_DATASET_MEAN |
| std: Tuple[float, ...] = OPENAI_DATASET_STD |
| interpolation: str = 'bicubic' |
| resize_mode: str = 'shortest' |
| fill_color: int = 0 |
|
|
| def __post_init__(self): |
| assert self.mode in ('RGB',) |
|
|
| @property |
| def num_channels(self): |
| return 3 |
|
|
| @property |
| def input_size(self): |
| return (self.num_channels,) + to_2tuple(self.size) |
|
|
|
|
|
|
|
|
| @dataclass |
| class PreprocessCfg: |
| size: Union[int, Tuple[int, int]] = 224 |
| mode: str = 'RGB' |
| mean: Tuple[float, ...] = OPENAI_DATASET_MEAN |
| std: Tuple[float, ...] = OPENAI_DATASET_STD |
| interpolation: str = 'bicubic' |
| resize_mode: str = 'shortest' |
| fill_color: int = 0 |
|
|
| def __post_init__(self): |
| assert self.mode in ('RGB',) |
|
|
| @property |
| def num_channels(self): |
| return 3 |
|
|
| @property |
| def input_size(self): |
| return (self.num_channels,) + to_2tuple(self.size) |
|
|
| _PREPROCESS_KEYS = set(asdict(PreprocessCfg()).keys()) |
|
|
|
|
| def merge_preprocess_dict( |
| base: Union[PreprocessCfg, Dict], |
| overlay: Dict, |
| ): |
| """ Merge overlay key-value pairs on top of base preprocess cfg or dict. |
| Input dicts are filtered based on PreprocessCfg fields. |
| """ |
| if isinstance(base, PreprocessCfg): |
| base_clean = asdict(base) |
| else: |
| base_clean = {k: v for k, v in base.items() if k in _PREPROCESS_KEYS} |
| if overlay: |
| overlay_clean = {k: v for k, v in overlay.items() if k in _PREPROCESS_KEYS and v is not None} |
| base_clean.update(overlay_clean) |
| return base_clean |
|
|
|
|
| def merge_preprocess_kwargs(base: PreprocessCfg, **kwargs): |
| return merge_preprocess_dict(base, kwargs) |
|
|
|
|
| @dataclass |
| class AugmentationCfg: |
| scale: Tuple[float, float] = (0.9, 1.0) |
| ratio: Optional[Tuple[float, float]] = None |
| color_jitter: Optional[Union[float, Tuple[float, float, float], Tuple[float, float, float, float]]] = None |
| re_prob: Optional[float] = None |
| re_count: Optional[int] = None |
| use_timm: bool = False |
|
|
| |
| color_jitter_prob: float = None |
| gray_scale_prob: float = None |
|
|
|
|
| def _setup_size(size, error_msg): |
| if isinstance(size, numbers.Number): |
| return int(size), int(size) |
|
|
| if isinstance(size, Sequence) and len(size) == 1: |
| return size[0], size[0] |
|
|
| if len(size) != 2: |
| raise ValueError(error_msg) |
|
|
| return size |
|
|
|
|
| class ResizeKeepRatio: |
| """ Resize and Keep Ratio |
| |
| Copy & paste from `timm` |
| """ |
|
|
| def __init__( |
| self, |
| size, |
| longest=0., |
| interpolation=InterpolationMode.BICUBIC, |
| random_scale_prob=0., |
| random_scale_range=(0.85, 1.05), |
| random_aspect_prob=0., |
| random_aspect_range=(0.9, 1.11) |
| ): |
| if isinstance(size, (list, tuple)): |
| self.size = tuple(size) |
| else: |
| self.size = (size, size) |
| self.interpolation = interpolation |
| self.longest = float(longest) |
| self.random_scale_prob = random_scale_prob |
| self.random_scale_range = random_scale_range |
| self.random_aspect_prob = random_aspect_prob |
| self.random_aspect_range = random_aspect_range |
|
|
| @staticmethod |
| def get_params( |
| img, |
| target_size, |
| longest, |
| random_scale_prob=0., |
| random_scale_range=(0.85, 1.05), |
| random_aspect_prob=0., |
| random_aspect_range=(0.9, 1.11) |
| ): |
| """Get parameters |
| """ |
| source_size = img.size[::-1] |
| h, w = source_size |
| target_h, target_w = target_size |
| ratio_h = h / target_h |
| ratio_w = w / target_w |
| ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest) |
| if random_scale_prob > 0 and random.random() < random_scale_prob: |
| ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1]) |
| ratio_factor = (ratio_factor, ratio_factor) |
| else: |
| ratio_factor = (1., 1.) |
| if random_aspect_prob > 0 and random.random() < random_aspect_prob: |
| aspect_factor = random.uniform(random_aspect_range[0], random_aspect_range[1]) |
| ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor) |
| size = [round(x * f / ratio) for x, f in zip(source_size, ratio_factor)] |
| return size |
|
|
| def __call__(self, img): |
| """ |
| Args: |
| img (PIL Image): Image to be cropped and resized. |
| |
| Returns: |
| PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size |
| """ |
| size = self.get_params( |
| img, self.size, self.longest, |
| self.random_scale_prob, self.random_scale_range, |
| self.random_aspect_prob, self.random_aspect_range |
| ) |
| img = F.resize(img, size, self.interpolation) |
| return img |
|
|
| def __repr__(self): |
| format_string = self.__class__.__name__ + '(size={0}'.format(self.size) |
| format_string += f', interpolation={self.interpolation})' |
| format_string += f', longest={self.longest:.3f})' |
| return format_string |
|
|
|
|
| def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> torch.Tensor: |
| """Center crops and/or pads the given image. |
| If the image is torch Tensor, it is expected |
| to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. |
| If image size is smaller than output size along any edge, image is padded with 0 and then center cropped. |
| |
| Args: |
| img (PIL Image or Tensor): Image to be cropped. |
| output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int, |
| it is used for both directions. |
| fill (int, Tuple[int]): Padding color |
| |
| Returns: |
| PIL Image or Tensor: Cropped image. |
| """ |
| if isinstance(output_size, numbers.Number): |
| output_size = (int(output_size), int(output_size)) |
| elif isinstance(output_size, (tuple, list)) and len(output_size) == 1: |
| output_size = (output_size[0], output_size[0]) |
|
|
| _, image_height, image_width = F.get_dimensions(img) |
| crop_height, crop_width = output_size |
|
|
| if crop_width > image_width or crop_height > image_height: |
| padding_ltrb = [ |
| (crop_width - image_width) // 2 if crop_width > image_width else 0, |
| (crop_height - image_height) // 2 if crop_height > image_height else 0, |
| (crop_width - image_width + 1) // 2 if crop_width > image_width else 0, |
| (crop_height - image_height + 1) // 2 if crop_height > image_height else 0, |
| ] |
| img = F.pad(img, padding_ltrb, fill=fill) |
| _, image_height, image_width = F.get_dimensions(img) |
| if crop_width == image_width and crop_height == image_height: |
| return img |
|
|
| crop_top = int(round((image_height - crop_height) / 2.0)) |
| crop_left = int(round((image_width - crop_width) / 2.0)) |
| return F.crop(img, crop_top, crop_left, crop_height, crop_width) |
|
|
|
|
| class CenterCropOrPad(torch.nn.Module): |
| """Crops the given image at the center. |
| If the image is torch Tensor, it is expected |
| to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. |
| If image size is smaller than output size along any edge, image is padded with 0 and then center cropped. |
| |
| Args: |
| size (sequence or int): Desired output size of the crop. If size is an |
| int instead of sequence like (h, w), a square crop (size, size) is |
| made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). |
| """ |
|
|
| def __init__(self, size, fill=0): |
| super().__init__() |
| self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") |
| self.fill = fill |
|
|
| def forward(self, img): |
| """ |
| Args: |
| img (PIL Image or Tensor): Image to be cropped. |
| |
| Returns: |
| PIL Image or Tensor: Cropped image. |
| """ |
| return center_crop_or_pad(img, self.size, fill=self.fill) |
|
|
| def __repr__(self) -> str: |
| return f"{self.__class__.__name__}(size={self.size})" |
|
|
|
|
| def _convert_to_rgb(image): |
| return image.convert('RGB') |
|
|
|
|
| class color_jitter(object): |
| """ |
| Apply Color Jitter to the PIL image with a specified probability. |
| """ |
| def __init__(self, brightness=0., contrast=0., saturation=0., hue=0., p=0.8): |
| assert 0. <= p <= 1. |
| self.p = p |
| self.transf = ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue) |
|
|
| def __call__(self, img): |
| if random.random() < self.p: |
| return self.transf(img) |
| else: |
| return img |
|
|
|
|
| class gray_scale(object): |
| """ |
| Apply Gray Scale to the PIL image with a specified probability. |
| """ |
| def __init__(self, p=0.2): |
| assert 0. <= p <= 1. |
| self.p = p |
| self.transf = Grayscale(num_output_channels=3) |
|
|
| def __call__(self, img): |
| if random.random() < self.p: |
| return self.transf(img) |
| else: |
| return img |
|
|
|
|
| def image_transform( |
| image_size: Union[int, Tuple[int, int]], |
| is_train: bool, |
| mean: Optional[Tuple[float, ...]] = None, |
| std: Optional[Tuple[float, ...]] = None, |
| resize_mode: Optional[str] = None, |
| interpolation: Optional[str] = None, |
| fill_color: int = 0, |
| aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, |
| ): |
| mean = mean or OPENAI_DATASET_MEAN |
| if not isinstance(mean, (list, tuple)): |
| mean = (mean,) * 3 |
|
|
| std = std or OPENAI_DATASET_STD |
| if not isinstance(std, (list, tuple)): |
| std = (std,) * 3 |
|
|
| interpolation = interpolation or 'bicubic' |
| assert interpolation in ['bicubic', 'bilinear', 'random'] |
| |
| interpolation_mode = InterpolationMode.BILINEAR if interpolation == 'bilinear' else InterpolationMode.BICUBIC |
|
|
| resize_mode = resize_mode or 'shortest' |
| assert resize_mode in ('shortest', 'longest', 'squash') |
|
|
| if isinstance(aug_cfg, dict): |
| aug_cfg = AugmentationCfg(**aug_cfg) |
| else: |
| aug_cfg = aug_cfg or AugmentationCfg() |
|
|
| normalize = Normalize(mean=mean, std=std) |
|
|
| if is_train: |
| aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None} |
| use_timm = aug_cfg_dict.pop('use_timm', False) |
| if use_timm: |
| from timm.data import create_transform |
| if isinstance(image_size, (tuple, list)): |
| assert len(image_size) >= 2 |
| input_size = (3,) + image_size[-2:] |
| else: |
| input_size = (3, image_size, image_size) |
|
|
| aug_cfg_dict.setdefault('color_jitter', None) |
| |
| aug_cfg_dict.pop('color_jitter_prob', None) |
| aug_cfg_dict.pop('gray_scale_prob', None) |
|
|
| train_transform = create_transform( |
| input_size=input_size, |
| is_training=True, |
| hflip=0., |
| mean=mean, |
| std=std, |
| re_mode='pixel', |
| interpolation=interpolation, |
| **aug_cfg_dict, |
| ) |
| else: |
| train_transform = [ |
| RandomResizedCrop( |
| image_size, |
| scale=aug_cfg_dict.pop('scale'), |
| interpolation=InterpolationMode.BICUBIC, |
| ), |
| _convert_to_rgb, |
| ] |
| if aug_cfg.color_jitter_prob: |
| assert aug_cfg.color_jitter is not None and len(aug_cfg.color_jitter) == 4 |
| train_transform.extend([ |
| color_jitter(*aug_cfg.color_jitter, p=aug_cfg.color_jitter_prob) |
| ]) |
| if aug_cfg.gray_scale_prob: |
| train_transform.extend([ |
| gray_scale(aug_cfg.gray_scale_prob) |
| ]) |
| train_transform.extend([ |
| ToTensor(), |
| normalize, |
| ]) |
| train_transform = Compose(train_transform) |
| if aug_cfg_dict: |
| warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).') |
| return train_transform |
| else: |
| if resize_mode == 'longest': |
| transforms = [ |
| ResizeKeepRatio(image_size, interpolation=interpolation_mode, longest=1), |
| CenterCropOrPad(image_size, fill=fill_color) |
| ] |
| elif resize_mode == 'squash': |
| if isinstance(image_size, int): |
| image_size = (image_size, image_size) |
| transforms = [ |
| Resize(image_size, interpolation=interpolation_mode), |
| ] |
| else: |
| assert resize_mode == 'shortest' |
| if not isinstance(image_size, (tuple, list)): |
| image_size = (image_size, image_size) |
| if image_size[0] == image_size[1]: |
| |
| transforms = [ |
| Resize(image_size[0], interpolation=interpolation_mode) |
| ] |
| else: |
| |
| transforms = [ResizeKeepRatio(image_size)] |
| transforms += [CenterCrop(image_size)] |
|
|
| transforms.extend([ |
| _convert_to_rgb, |
| ToTensor(), |
| normalize, |
| ]) |
| return Compose(transforms) |
| |
| |
| def image_transform_v2( |
| cfg: PreprocessCfg, |
| is_train: bool, |
| aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, |
| ): |
| return image_transform( |
| image_size=cfg.size, |
| is_train=is_train, |
| mean=cfg.mean, |
| std=cfg.std, |
| interpolation=cfg.interpolation, |
| resize_mode=cfg.resize_mode, |
| fill_color=cfg.fill_color, |
| aug_cfg=aug_cfg, |
| ) |
|
|
| @dataclass |
| class AugmentationCfg: |
| scale: Tuple[float, float] = (0.9, 1.0) |
| ratio: Optional[Tuple[float, float]] = None |
| color_jitter: Optional[Union[float, Tuple[float, float, float], Tuple[float, float, float, float]]] = None |
| re_prob: Optional[float] = None |
| re_count: Optional[int] = None |
| use_timm: bool = False |
|
|
| |
| color_jitter_prob: float = None |
| gray_scale_prob: float = None |
|
|
| def set_model_preprocess_cfg(model, preprocess_cfg: Dict[str, Any]): |
| module = getattr(model, 'visual', model) |
| module.image_mean = preprocess_cfg['mean'] |
| module.image_std = preprocess_cfg['std'] |
| module.preprocess_cfg = copy.deepcopy(preprocess_cfg) |
|
|
|
|
| @torch.no_grad() |
| def convert_mobile_clip_state_dict(model: CustomTextCLIP, state_dict, fastvit = True): |
|
|
| def _convert_timm_img(state_dict): |
| if fastvit: |
| from timm.models.fastvit import checkpoint_filter_fn |
| else: |
| from timm.models.vision_transformer_hybrid import checkpoint_filter_fn |
| timm_state_dict = checkpoint_filter_fn(state_dict, model.visual.trunk) |
| timm_state_dict = {'visual.trunk.' + k: v for k, v in timm_state_dict.items()} |
| return timm_state_dict |
|
|
| def _convert_openclip_txt(state_dict, prefix='text_encoder.'): |
| text_dict = {} |
| for k, v in state_dict.items(): |
| if not k.startswith(prefix): |
| continue |
| k = k.replace(prefix, '') |
| k = k.replace('projection_layer', 'text_projection') |
| k = k.replace('embedding_layer', 'token_embedding') |
| if k.startswith('positional_embedding.pos_embed.pos_embed'): |
| k = k.replace('positional_embedding.pos_embed.pos_embed', 'positional_embedding') |
| v = v.squeeze() |
| k = k.replace('final_layer_norm', 'ln_final') |
| k = k.replace('pre_norm_mha.0', 'ln_1') |
| k = k.replace('pre_norm_mha.1', 'attn') |
| k = k.replace('pre_norm_ffn.0', 'ln_2') |
| k = k.replace('pre_norm_ffn.1', 'mlp.c_fc') |
| k = k.replace('pre_norm_ffn.4', 'mlp.c_proj') |
| k = k.replace('qkv_proj.weight', 'in_proj_weight') |
| k = k.replace('qkv_proj.bias', 'in_proj_bias') |
| k = k.replace('transformer.', 'transformer.resblocks.') |
| text_dict['text.' + k] = v |
| return text_dict |
|
|
| image_dict = _convert_timm_img(state_dict) |
| text_dict = _convert_openclip_txt(state_dict) |
| out_dict = {**image_dict, **text_dict} |
| out_dict['logit_scale'] = state_dict['logit_scale'] |
| return out_dict |
|
|
|
|
| def convert_state_dict(model: Union[CustomTextCLIP, CLIP], state_dict): |
| if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict: |
| |
| state_dict = convert_mobile_clip_state_dict(model, state_dict) |
| if 'image_encoder.model.patch_emb.0.block.conv.weight' in state_dict: |
| |
| state_dict = convert_mobile_clip_state_dict(model, state_dict, fastvit=False) |
| return state_dict |
|
|
| def load_state_dict( |
| checkpoint_path: str, |
| device='cpu', |
| weights_only=True, |
| ): |
| |
| if str(checkpoint_path).endswith(".safetensors"): |
| from safetensors.torch import load_file |
| checkpoint = load_file(checkpoint_path, device=device) |
| else: |
| try: |
| checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=weights_only) |
| except TypeError: |
| checkpoint = torch.load(checkpoint_path, map_location=device) |
|
|
| if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: |
| state_dict = checkpoint['state_dict'] |
| elif isinstance(checkpoint, torch.jit.ScriptModule): |
| state_dict = checkpoint.state_dict() |
| for key in ["input_resolution", "context_length", "vocab_size"]: |
| state_dict.pop(key, None) |
| else: |
| state_dict = checkpoint |
| if next(iter(state_dict.items()))[0].startswith('module'): |
| state_dict = {k[7:]: v for k, v in state_dict.items()} |
| return state_dict |
|
|
| def load_checkpoint( |
| model: Union[CLIP, CustomTextCLIP], |
| checkpoint_path: str, |
| strict: bool = True, |
| weights_only: bool = True, |
| device='cpu', |
| ): |
| if Path(checkpoint_path).suffix in ('.npz', '.npy'): |
| |
| from open_clip.convert import load_big_vision_weights |
| load_big_vision_weights(model, checkpoint_path) |
| return {} |
|
|
| state_dict = load_state_dict(checkpoint_path, device=device, weights_only=weights_only) |
|
|
| |
| state_dict = convert_state_dict(model, state_dict) |
|
|
| |
| if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): |
| state_dict = convert_to_custom_text_state_dict(state_dict) |
|
|
| |
| if 'logit_scale' in state_dict and model.logit_scale.ndim != state_dict['logit_scale'].ndim: |
| state_dict['logit_scale'] = state_dict['logit_scale'].reshape(model.logit_scale.shape) |
|
|
| |
| if 'logit_bias' in state_dict and model.logit_bias.ndim != state_dict['logit_bias'].ndim: |
| state_dict['logit_bias'] = state_dict['logit_bias'].reshape(model.logit_bias.shape) |
|
|
| |
| if 'logit_bias' not in state_dict and model.logit_bias is not None: |
| state_dict["logit_bias"] = torch.zeros_like(state_dict["logit_scale"]) |
|
|
| |
| position_id_key = 'text.transformer.embeddings.position_ids' |
| if position_id_key in state_dict and not hasattr(model, position_id_key): |
| del state_dict[position_id_key] |
|
|
| resize_pos_embed(state_dict, model) |
| resize_text_pos_embed(state_dict, model) |
|
|
| |
| incompatible_keys = model.load_state_dict(state_dict, strict=strict) |
| return incompatible_keys |
|
|
| |
| HF_HUB_PREFIX = 'hf-hub:' |
| |
| _MODEL_CONFIG_PATHS = [Path("/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/model_configs")] |
| _MODEL_CONFIGS = {} |
|
|
| import json |
|
|
| def _get_hf_config( |
| model_id: str, |
| cache_dir: Optional[str] = None, |
| ): |
| """ Fetch model config from HuggingFace Hub. |
| """ |
| config_path = download_pretrained_from_hf( |
| model_id, |
| filename='open_clip_config.json', |
| cache_dir=cache_dir, |
| ) |
| with open(config_path, 'r', encoding='utf-8') as f: |
| config = json.load(f) |
| return config |
|
|
| def get_model_config(model_name): |
| """ Fetch model config from builtin (local library) configs. |
| """ |
| if model_name in _MODEL_CONFIGS: |
| return copy.deepcopy(_MODEL_CONFIGS[model_name]) |
| else: |
| return None |
|
|
| def _natural_key(string_): |
| return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] |
|
|
|
|
| def _rescan_model_configs(): |
| global _MODEL_CONFIGS |
|
|
| config_ext = ('.json',) |
| config_files = [] |
| for config_path in _MODEL_CONFIG_PATHS: |
| if config_path.is_file() and config_path.suffix in config_ext: |
| config_files.append(config_path) |
| elif config_path.is_dir(): |
| for ext in config_ext: |
| config_files.extend(config_path.glob(f'*{ext}')) |
|
|
| for cf in config_files: |
| with open(cf, 'r') as f: |
| model_cfg = json.load(f) |
| if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')): |
| _MODEL_CONFIGS[cf.stem] = model_cfg |
|
|
| _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))} |
|
|
|
|
| _rescan_model_configs() |
|
|
| def list_models(): |
| """ enumerate available model architectures based on config files """ |
| return list(_MODEL_CONFIGS.keys()) |
|
|
|
|
| def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs): |
| if past: |
| input_ids = input_ids[:, -1].unsqueeze(-1) |
|
|
| attention_mask = kwargs.get("attention_mask", None) |
| position_ids = kwargs.get("position_ids", None) |
|
|
| if attention_mask is not None and position_ids is None: |
| |
| position_ids = attention_mask.long().cumsum(-1) - 1 |
| position_ids.masked_fill_(attention_mask == 0, 1) |
| else: |
| position_ids = None |
| return { |
| "text": input_ids, |
| "images": image_inputs, |
| "past_key_values": past, |
| "position_ids": position_ids, |
| "attention_mask": attention_mask, |
| } |
|
|
| @dataclass |
| class MultimodalCfg(CLIPTextCfg): |
| mlp_ratio: int = 4 |
| dim_head: int = 64 |
| heads: int = 8 |
| n_queries: int = 256 |
| attn_pooler_heads: int = 8 |
|
|
| try: |
| from transformers import ( |
| BeamSearchScorer, |
| LogitsProcessorList, |
| TopPLogitsWarper, |
| TopKLogitsWarper, |
| RepetitionPenaltyLogitsProcessor, |
| MinLengthLogitsProcessor, |
| MaxLengthCriteria, |
| StopStringCriteria, |
| EosTokenCriteria, |
| StoppingCriteriaList |
| ) |
|
|
| GENERATION_TYPES = { |
| "top_k": TopKLogitsWarper, |
| "top_p": TopPLogitsWarper, |
| "beam_search": "beam_search" |
| } |
| _has_transformers = True |
| except ImportError as e: |
| GENERATION_TYPES = { |
| "top_k": None, |
| "top_p": None, |
| "beam_search": "beam_search" |
| } |
| _has_transformers = False |
|
|
| def _token_to_tensor(token_id, device: str = "cpu") -> torch.Tensor: |
| if not isinstance(token_id, torch.Tensor): |
| if isinstance(token_id, int): |
| token_id = [token_id] |
| token_id = torch.tensor(token_id, device=device) |
| return token_id |
|
|
|
|
| def _build_text_decoder_tower( |
| embed_dim, |
| multimodal_cfg, |
| quick_gelu: bool = False, |
| cast_dtype: Optional[torch.dtype] = None, |
| ): |
| multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg |
| act_layer = QuickGELU if quick_gelu else nn.GELU |
| norm_layer = ( |
| LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm |
| ) |
|
|
| decoder = MultimodalTransformer( |
| context_length=multimodal_cfg.context_length, |
| width=multimodal_cfg.width, |
| heads=multimodal_cfg.heads, |
| layers=multimodal_cfg.layers, |
| ls_init_value=multimodal_cfg.ls_init_value, |
| output_dim=embed_dim, |
| act_layer=act_layer, |
| norm_layer=norm_layer, |
| ) |
|
|
| return decoder |
|
|
| class CoCa(nn.Module): |
| def __init__( |
| self, |
| embed_dim, |
| multimodal_cfg: MultimodalCfg, |
| text_cfg: CLIPTextCfg, |
| vision_cfg: CLIPVisionCfg, |
| quick_gelu: bool = False, |
| init_logit_scale: float = np.log(1 / 0.07), |
| init_logit_bias: Optional[float] = None, |
| nonscalar_logit_scale: bool = False, |
| cast_dtype: Optional[torch.dtype] = None, |
| pad_id: int = 0, |
| ): |
| super().__init__() |
| multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg |
| text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg |
| vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg |
|
|
| self.text = _build_text_tower( |
| embed_dim=embed_dim, |
| text_cfg=text_cfg, |
| quick_gelu=quick_gelu, |
| cast_dtype=cast_dtype, |
| ) |
|
|
| vocab_size = ( |
| text_cfg.vocab_size |
| if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None |
| else text_cfg.vocab_size |
| ) |
|
|
| self.visual = _build_vision_tower( |
| embed_dim=embed_dim, |
| vision_cfg=vision_cfg, |
| quick_gelu=quick_gelu, |
| cast_dtype=cast_dtype, |
| ) |
|
|
| self.text_decoder = _build_text_decoder_tower( |
| vocab_size, |
| multimodal_cfg=multimodal_cfg, |
| quick_gelu=quick_gelu, |
| cast_dtype=cast_dtype, |
| ) |
|
|
| lshape = [1] if nonscalar_logit_scale else [] |
| self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale) |
| if init_logit_bias is not None: |
| self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias) |
| else: |
| self.logit_bias = None |
| self.pad_id = pad_id |
|
|
| self.context_length = multimodal_cfg.context_length |
|
|
| @torch.jit.ignore |
| def set_grad_checkpointing(self, enable: bool = True): |
| self.visual.set_grad_checkpointing(enable) |
| self.text.set_grad_checkpointing(enable) |
| self.text_decoder.set_grad_checkpointing(enable) |
|
|
| def _encode_image(self, images, normalize: bool = True): |
| image_latent, tokens_embs = self.visual(images) |
| image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent |
| return image_latent, tokens_embs |
|
|
| def _encode_text(self, text, normalize: bool = True): |
| text_latent, token_emb = self.text(text) |
| text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent |
| return text_latent, token_emb |
|
|
| def encode_image(self, images, normalize: bool = True): |
| image_latent, _ = self._encode_image(images, normalize=normalize) |
| return image_latent |
|
|
| def encode_text(self, text, normalize: bool = True): |
| text_latent, _ = self._encode_text(text, normalize=normalize) |
| return text_latent |
|
|
| def forward_intermediates( |
| self, |
| image: Optional[torch.Tensor] = None, |
| text: Optional[torch.Tensor] = None, |
| image_indices: Optional[Union[int, List[int]]] = None, |
| text_indices: Optional[Union[int, List[int]]] = None, |
| stop_early: bool = False, |
| normalize: bool = True, |
| normalize_intermediates: bool = False, |
| intermediates_only: bool = False, |
| image_output_fmt: str = 'NCHW', |
| image_output_extra_tokens: bool = False, |
| text_output_fmt: str = 'NLC', |
| text_output_extra_tokens: bool = False, |
| output_logits: bool = False, |
| output_logit_scale_bias: bool = False, |
| ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: |
| """ Forward features that returns intermediates. |
| |
| Args: |
| image: Input image tensor |
| text: Input text tensor |
| image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence |
| text_indices: Take last n blocks if int, all if None, select matching indices if sequence |
| stop_early: Stop iterating over blocks when last desired intermediate hit |
| normalize: L2 Normalize final image and text features (if present) |
| normalize_intermediates: Apply final encoder norm layer to all intermediates (if possible) |
| intermediates_only: Only return intermediate features, do not return final features |
| image_output_fmt: Shape of intermediate image feature outputs |
| image_output_extra_tokens: Return both prefix and spatial intermediate tokens |
| text_output_fmt: Shape of intermediate text feature outputs |
| text_output_extra_tokens: Return both prefix and spatial intermediate tokens |
| output_logits: Include logits in output |
| output_logit_scale_bias: Include the logit scale bias in the output |
| Returns: |
| |
| """ |
| output = {} |
| if intermediates_only: |
| |
| normalize = False |
| output_logits = False |
| if output_logits: |
| assert False, 'FIXME, needs implementing' |
|
|
| if image is not None: |
| image_output = self.visual.forward_intermediates( |
| image, |
| indices=image_indices, |
| stop_early=stop_early, |
| normalize_intermediates=normalize_intermediates, |
| intermediates_only=intermediates_only, |
| output_fmt=image_output_fmt, |
| output_extra_tokens=image_output_extra_tokens, |
| ) |
| if normalize and "image_features" in image_output: |
| image_output["image_features"] = F.normalize(image_output["image_features"], dim=-1) |
| output.update(image_output) |
|
|
| if text is not None: |
| text_output = self.text.forward_intermediates( |
| text, |
| indices=text_indices, |
| stop_early=stop_early, |
| normalize_intermediates=normalize_intermediates, |
| intermediates_only=intermediates_only, |
| output_fmt=text_output_fmt, |
| output_extra_tokens=text_output_extra_tokens, |
| ) |
| if normalize and "text_features" in text_output: |
| text_output["text_features"] = F.normalize(text_output["text_features"], dim=-1) |
| output.update(text_output) |
|
|
| |
| logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None |
| if output_logit_scale_bias: |
| output["logit_scale"] = logit_scale_exp |
| if self.logit_bias is not None: |
| output['logit_bias'] = self.logit_bias |
|
|
| return output |
|
|
| def forward( |
| self, |
| image, |
| text: Optional[torch.Tensor] = None, |
| image_latent: Optional[torch.Tensor] = None, |
| image_embs: Optional[torch.Tensor] = None, |
| output_labels: bool = True, |
| ): |
| if image_latent is None or image_embs is None: |
| image_latent, image_embs = self._encode_image(image) |
|
|
| if text is None: |
| return {"image_features": image_latent, "image_embs": image_embs} |
|
|
| text_latent, token_embs = self._encode_text(text) |
|
|
| |
| labels: Optional[torch.Tensor] = text[:, 1:] if output_labels else None |
| if output_labels: |
| |
| token_embs = token_embs[:, :-1] |
|
|
| logits = self.text_decoder(image_embs, token_embs) |
| out_dict = { |
| "image_features": image_latent, |
| "text_features": text_latent, |
| "logits": logits, |
| "logit_scale": self.logit_scale.exp() |
| } |
| if labels is not None: |
| out_dict["labels"] = labels |
| if self.logit_bias is not None: |
| out_dict["logit_bias"] = self.logit_bias |
| return out_dict |
|
|
| def generate( |
| self, |
| image, |
| text=None, |
| seq_len=30, |
| max_seq_len=77, |
| temperature=1., |
| generation_type="beam_search", |
| top_p=0.1, |
| top_k=1, |
| pad_token_id=None, |
| eos_token_id=None, |
| sot_token_id=None, |
| num_beams=6, |
| num_beam_groups=3, |
| min_seq_len=5, |
| stopping_criteria=None, |
| repetition_penalty=1.0, |
| fixed_output_length=False |
| ): |
| |
| |
| assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`." |
| assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len" |
| device = image.device |
|
|
| with torch.no_grad(): |
| sot_token_id = _token_to_tensor(49406 if sot_token_id is None else sot_token_id, device=device) |
| eos_token_id = _token_to_tensor(49407 if eos_token_id is None else eos_token_id, device=device) |
| pad_token_id = self.pad_id if pad_token_id is None else pad_token_id |
| logit_processor = LogitsProcessorList( |
| [ |
| MinLengthLogitsProcessor(min_seq_len, eos_token_id), |
| RepetitionPenaltyLogitsProcessor(repetition_penalty), |
| ] |
| ) |
|
|
| if stopping_criteria is None: |
| stopping_criteria = [MaxLengthCriteria(max_length=seq_len)] |
| stopping_criteria = StoppingCriteriaList(stopping_criteria) |
|
|
| if generation_type == "beam_search": |
| output = self._generate_beamsearch( |
| image_inputs=image, |
| pad_token_id=pad_token_id, |
| eos_token_id=eos_token_id, |
| sot_token_id=sot_token_id, |
| num_beams=num_beams, |
| num_beam_groups=num_beam_groups, |
| min_seq_len=min_seq_len, |
| stopping_criteria=stopping_criteria, |
| logit_processor=logit_processor, |
| ) |
| if fixed_output_length and output.shape[1] < seq_len: |
| pad_len = seq_len - output.shape[1] |
| return torch.cat(( |
| output, |
| torch.ones(output.shape[0], pad_len, device=device, dtype=output.dtype) * pad_token_id |
| ), |
| dim=1 |
| ) |
| return output |
|
|
| elif generation_type == "top_p": |
| logit_warper = GENERATION_TYPES[generation_type](top_p) |
| elif generation_type == "top_k": |
| logit_warper = GENERATION_TYPES[generation_type](top_k) |
| else: |
| raise ValueError( |
| f"generation_type has to be one of " |
| f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}." |
| ) |
|
|
| image_latent, image_embs = self._encode_image(image) |
|
|
| if text is None: |
| text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id |
|
|
| was_training = self.training |
| num_dims = len(text.shape) |
|
|
| if num_dims == 1: |
| text = text[None, :] |
|
|
| self.eval() |
| out = text |
|
|
| while True: |
| x = out[:, -max_seq_len:] |
| cur_len = x.shape[1] |
| logits = self( |
| image, |
| x, |
| image_latent=image_latent, |
| image_embs=image_embs, |
| output_labels=False, |
| )["logits"][:, -1] |
| mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id) |
| sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id |
|
|
| if mask.all(): |
| if not fixed_output_length: |
| break |
| else: |
| logits = logits[~mask, :] |
| filtered_logits = logit_processor(x[~mask, :], logits) |
| filtered_logits = logit_warper(x[~mask, :], filtered_logits) |
| probs = F.softmax(filtered_logits / temperature, dim=-1) |
|
|
| if (cur_len + 1 == seq_len): |
| sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id |
| else: |
| sample[~mask, :] = torch.multinomial(probs, 1) |
|
|
| out = torch.cat((out, sample), dim=-1) |
|
|
| cur_len += 1 |
|
|
| if all(stopping_criteria(out, None)): |
| break |
|
|
| if num_dims == 1: |
| out = out.squeeze(0) |
|
|
| self.train(was_training) |
| return out |
|
|
| def _generate_beamsearch( |
| self, |
| image_inputs, |
| pad_token_id=None, |
| eos_token_id=None, |
| sot_token_id=None, |
| num_beams=6, |
| num_beam_groups=3, |
| min_seq_len=5, |
| stopping_criteria=None, |
| logit_processor=None, |
| logit_warper=None, |
| ): |
| device = image_inputs.device |
| batch_size = image_inputs.shape[0] |
| image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0) |
| image_latent, image_embs = self._encode_image(image_inputs) |
|
|
| input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long) |
| input_ids = input_ids * sot_token_id |
| beam_scorer = BeamSearchScorer( |
| batch_size=batch_size, |
| num_beams=num_beams, |
| device=device, |
| num_beam_groups=num_beam_groups, |
| ) |
| |
| logits_processor = ( |
| LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)]) |
| if logit_processor is None |
| else logit_processor |
| ) |
|
|
| num_beams = beam_scorer.num_beams |
| num_beam_groups = beam_scorer.num_beam_groups |
| num_sub_beams = num_beams // num_beam_groups |
| batch_size = len(beam_scorer._beam_hyps) // num_beam_groups |
| batch_beam_size, cur_len = input_ids.shape |
| beam_indices = None |
|
|
| if num_beams * batch_size != batch_beam_size: |
| raise ValueError( |
| f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." |
| ) |
|
|
| beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) |
| |
| |
| beam_scores[:, ::num_sub_beams] = 0 |
| beam_scores = beam_scores.view((batch_size * num_beams,)) |
|
|
| while True: |
|
|
| |
| current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) |
|
|
| |
| reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) |
|
|
| |
| model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs) |
| outputs = self( |
| model_inputs['images'], |
| model_inputs['text'], |
| image_latent=image_latent, |
| image_embs=image_embs, |
| output_labels=False, |
| ) |
|
|
| for beam_group_idx in range(num_beam_groups): |
| group_start_idx = beam_group_idx * num_sub_beams |
| group_end_idx = min(group_start_idx + num_sub_beams, num_beams) |
| group_size = group_end_idx - group_start_idx |
|
|
| |
| batch_group_indices = [] |
|
|
| for batch_idx in range(batch_size): |
| batch_group_indices.extend( |
| [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] |
| ) |
| group_input_ids = input_ids[batch_group_indices] |
|
|
| |
| next_token_logits = outputs['logits'][batch_group_indices, -1, :] |
| vocab_size = next_token_logits.shape[-1] |
|
|
| next_token_scores_processed = logits_processor( |
| group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx |
| ) |
| next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) |
| next_token_scores = next_token_scores.expand_as(next_token_scores_processed) |
|
|
| |
| next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) |
|
|
| next_token_scores, next_tokens = torch.topk( |
| next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True |
| ) |
|
|
| next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") |
| next_tokens = next_tokens % vocab_size |
|
|
| |
| process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None |
| beam_outputs = beam_scorer.process( |
| group_input_ids, |
| next_token_scores, |
| next_tokens, |
| next_indices, |
| pad_token_id=pad_token_id, |
| eos_token_id=eos_token_id, |
| beam_indices=process_beam_indices, |
| group_index=beam_group_idx, |
| ) |
| beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] |
| beam_next_tokens = beam_outputs["next_beam_tokens"] |
| beam_idx = beam_outputs["next_beam_indices"] |
|
|
| input_ids[batch_group_indices] = group_input_ids[beam_idx] |
| group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) |
| current_tokens[batch_group_indices] = group_input_ids[:, -1] |
|
|
| |
| |
| reordering_indices[batch_group_indices] = ( |
| num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size) |
| ) |
|
|
| input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) |
|
|
| |
| cur_len = cur_len + 1 |
| if beam_scorer.is_done or all(stopping_criteria(input_ids, None)): |
| break |
|
|
| final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None |
| sequence_outputs = beam_scorer.finalize( |
| input_ids, |
| beam_scores, |
| next_tokens, |
| next_indices, |
| pad_token_id=pad_token_id, |
| eos_token_id=eos_token_id, |
| max_length=stopping_criteria.max_length, |
| beam_indices=final_beam_indices, |
| ) |
| return sequence_outputs['sequences'] |
|
|
|
|
| def create_model( |
| model_name: str, |
| pretrained: Optional[str] = None, |
| precision: str = 'fp32', |
| device: Union[str, torch.device] = 'cpu', |
| jit: bool = False, |
| force_quick_gelu: bool = False, |
| force_custom_text: bool = False, |
| force_patch_dropout: Optional[float] = None, |
| force_image_size: Optional[Union[int, Tuple[int, int]]] = None, |
| force_preprocess_cfg: Optional[Dict[str, Any]] = None, |
| pretrained_image: bool = False, |
| pretrained_hf: bool = True, |
| cache_dir: Optional[str] = None, |
| output_dict: Optional[bool] = None, |
| require_pretrained: bool = False, |
| load_weights_only: bool = True, |
| **model_kwargs, |
| ): |
| """Creates and configures a contrastive vision-language model. |
| |
| Args: |
| model_name: Name of the model architecture to create. Can be a local model name |
| or a Hugging Face model ID prefixed with 'hf-hub:'. |
| pretrained: Tag/path for pretrained model weights. Can be: |
| - A pretrained tag name (e.g., 'openai') |
| - A path to local weights |
| - None to initialize with random weights |
| precision: Model precision/AMP configuration. Options: |
| - 'fp32': 32-bit floating point |
| - 'fp16'/'bf16': Mixed precision with FP32 for certain layers |
| - 'pure_fp16'/'pure_bf16': Pure 16-bit precision |
| device: Device to load the model on ('cpu', 'cuda', or torch.device object) |
| jit: If True, JIT compile the model |
| force_quick_gelu: Force use of QuickGELU activation |
| force_custom_text: Force use of custom text encoder |
| force_patch_dropout: Override default patch dropout value |
| force_image_size: Override default image size for vision encoder |
| force_preprocess_cfg: Override default preprocessing configuration |
| pretrained_image: Load pretrained weights for timm vision models |
| pretrained_hf: Load pretrained weights for HF text models when not loading CLIP weights |
| cache_dir: Override default cache directory for downloaded model files |
| output_dict: If True and model supports it, return dictionary of features |
| require_pretrained: Raise error if pretrained weights cannot be loaded |
| load_weights_only: Only deserialize model weights and unpickling torch checkpoints (for safety) |
| **model_kwargs: Additional keyword arguments passed to model constructor |
| |
| Returns: |
| Created and configured model instance |
| |
| Raises: |
| RuntimeError: If model config is not found or required pretrained weights |
| cannot be loaded |
| |
| Examples: |
| # Create basic CLIP model |
| model = create_model('ViT-B/32') |
| |
| # Create CLIP model with mixed precision on GPU |
| model = create_model('ViT-B/32', precision='fp16', device='cuda') |
| |
| # Load pretrained OpenAI weights |
| model = create_model('ViT-B/32', pretrained='openai') |
| |
| # Load Hugging Face model |
| model = create_model('hf-hub:organization/model-name') |
| """ |
|
|
| force_preprocess_cfg = force_preprocess_cfg or {} |
| preprocess_cfg = asdict(PreprocessCfg()) |
| has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX) |
| if has_hf_hub_prefix: |
| model_id = model_name[len(HF_HUB_PREFIX):] |
| checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir) |
| config = _get_hf_config(model_id, cache_dir=cache_dir) |
| preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg']) |
| model_cfg = config['model_cfg'] |
| pretrained_hf = False |
| else: |
| model_name = model_name.replace('/', '-') |
| checkpoint_path = None |
| model_cfg = None |
|
|
| if isinstance(device, str): |
| device = torch.device(device) |
|
|
| model_cfg = model_cfg or get_model_config(model_name) |
| if model_cfg is not None: |
| logging.info(f'Loaded {model_name} model config.') |
| else: |
| logging.error(f'Model config for {model_name} not found; available models {list_models()}.') |
| raise RuntimeError(f'Model config for {model_name} not found.') |
|
|
| if force_quick_gelu: |
| |
| model_cfg["quick_gelu"] = True |
|
|
| if force_patch_dropout is not None: |
| |
| model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout |
|
|
| if force_image_size is not None: |
| |
| model_cfg["vision_cfg"]["image_size"] = force_image_size |
|
|
| is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {}) |
| if pretrained_image: |
| if is_timm_model: |
| |
| model_cfg['vision_cfg']['timm_model_pretrained'] = True |
| else: |
| assert False, 'pretrained image towers currently only supported for timm models' |
|
|
| |
| cast_dtype = get_cast_dtype(precision) |
| is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {}) |
| if is_hf_model: |
| |
| model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained |
| custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model |
|
|
| model_cfg = dict(model_cfg, **model_kwargs) |
| if custom_text: |
| if "multimodal_cfg" in model_cfg: |
| model = CoCa(**model_cfg, cast_dtype=cast_dtype) |
| else: |
| model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) |
| else: |
| model = CLIP(**model_cfg, cast_dtype=cast_dtype) |
|
|
| if precision in ("fp16", "bf16"): |
| dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 |
| |
| if is_timm_model: |
| |
| |
| |
| model.to(device=device, dtype=dtype) |
| |
|
|
| def _convert_ln(m): |
| if isinstance(m, LayerNormFp32): |
| m.weight.data = m.weight.data.to(torch.float32) |
| m.bias.data = m.bias.data.to(torch.float32) |
| model.apply(_convert_ln) |
| else: |
| model.to(device=device) |
| convert_weights_to_lp(model, dtype=dtype) |
| elif precision in ("pure_fp16", "pure_bf16"): |
| dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 |
| model.to(device=device, dtype=dtype) |
| else: |
| model.to(device=device) |
|
|
| pretrained_loaded = False |
| if pretrained: |
| checkpoint_path = '' |
| pretrained_cfg = get_pretrained_cfg(model_name, pretrained) |
| if pretrained_cfg: |
| checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir) |
| preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg) |
| pretrained_quick_gelu = pretrained_cfg.get('quick_gelu', False) |
| model_quick_gelu = model_cfg.get('quick_gelu', False) |
| if pretrained_quick_gelu and not model_quick_gelu: |
| warnings.warn( |
| f'These pretrained weights were trained with QuickGELU activation but the model config does ' |
| f'not have that enabled. Consider using a model config with a "-quickgelu" suffix or enable with a flag.') |
| elif not pretrained_quick_gelu and model_quick_gelu: |
| warnings.warn( |
| f'The pretrained weights were not trained with QuickGELU but this activation is enabled in the ' |
| f'model config, consider using a model config without QuickGELU or disable override flags.') |
| elif os.path.exists(pretrained): |
| checkpoint_path = pretrained |
|
|
| if checkpoint_path: |
| logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') |
| load_checkpoint(model, checkpoint_path, weights_only=load_weights_only) |
| else: |
| error_str = ( |
| f'Pretrained weights ({pretrained}) not found for model {model_name}.' |
| f' Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.') |
| logging.warning(error_str) |
| raise RuntimeError(error_str) |
| pretrained_loaded = True |
| elif has_hf_hub_prefix: |
| logging.info(f'Loading pretrained {model_name} weights ({checkpoint_path}).') |
| load_checkpoint(model, checkpoint_path, weights_only=load_weights_only) |
| pretrained_loaded = True |
|
|
| if require_pretrained and not pretrained_loaded: |
| |
| raise RuntimeError( |
| f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.') |
|
|
| if output_dict and hasattr(model, "output_dict"): |
| model.output_dict = True |
|
|
| if jit: |
| model = torch.jit.script(model) |
|
|
| |
| if getattr(model.visual, 'image_size', None) is not None: |
| |
| force_preprocess_cfg['size'] = model.visual.image_size |
| set_model_preprocess_cfg(model, merge_preprocess_dict(preprocess_cfg, force_preprocess_cfg)) |
|
|
| return model |
|
|
| def create_model_and_transforms( |
| model_name: str, |
| pretrained: Optional[str] = None, |
| precision: str = 'fp32', |
| device: Union[str, torch.device] = 'cpu', |
| jit: bool = False, |
| force_quick_gelu: bool = False, |
| force_custom_text: bool = False, |
| force_patch_dropout: Optional[float] = None, |
| force_image_size: Optional[Union[int, Tuple[int, int]]] = None, |
| image_mean: Optional[Tuple[float, ...]] = None, |
| image_std: Optional[Tuple[float, ...]] = None, |
| image_interpolation: Optional[str] = None, |
| image_resize_mode: Optional[str] = None, |
| aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, |
| pretrained_image: bool = False, |
| pretrained_hf: bool = True, |
| cache_dir: Optional[str] = None, |
| output_dict: Optional[bool] = None, |
| load_weights_only: bool = True, |
| **model_kwargs, |
| ): |
| force_preprocess_cfg = merge_preprocess_kwargs( |
| {}, |
| mean=image_mean, |
| std=image_std, |
| interpolation=image_interpolation, |
| resize_mode=image_resize_mode, |
| ) |
|
|
| model = create_model( |
| model_name, |
| pretrained, |
| precision=precision, |
| device=device, |
| jit=jit, |
| force_quick_gelu=force_quick_gelu, |
| force_custom_text=force_custom_text, |
| force_patch_dropout=force_patch_dropout, |
| force_image_size=force_image_size, |
| force_preprocess_cfg=force_preprocess_cfg, |
| pretrained_image=pretrained_image, |
| pretrained_hf=pretrained_hf, |
| cache_dir=cache_dir, |
| output_dict=output_dict, |
| load_weights_only=load_weights_only, |
| **model_kwargs, |
| ) |
|
|
| pp_cfg = PreprocessCfg(**model.visual.preprocess_cfg) |
|
|
| preprocess_train = image_transform_v2( |
| pp_cfg, |
| is_train=True, |
| aug_cfg=aug_cfg, |
| ) |
| preprocess_val = image_transform_v2( |
| pp_cfg, |
| is_train=False, |
| ) |
|
|
| return model, preprocess_train, preprocess_val |
|
|
|
|
|
|
| open_clip_model, open_clip_imgaug, open_clip_preprocess = create_model_and_transforms( |
| model_name='ViT-H-14', pretrained='laion2b_s32b_b79k', device=device |
| ) |
| |
| |
|
|
| |
| |
| |
| class CSIP(nn.Module): |
| def __init__(self, image_encoder, audio_encoder, |
| dim_img=None, dim_audio=1024, dim_emb=1024): |
| super(CSIP, self).__init__() |
| |
| self.image_encoder = image_encoder |
| self.audio_encoder = audio_encoder |
|
|
| for param in self.image_encoder.parameters(): |
| param.requires_grad = False |
| |
| |
| self.audio_proj = nn.Linear(dim_audio, dim_emb) |
|
|
| |
| self.log_temp = nn.Parameter(torch.tensor(0.07).log()) |
|
|
| def forward(self, images, audios): |
| |
| image_features = self.image_encoder(images) |
| audio_features = self.audio_encoder(audios)[0] |
| |
| |
| image_embeds = F.normalize(image_features, dim=1) |
| audio_embeds = F.normalize(self.audio_proj(audio_features), dim=1) |
|
|
| |
| logits = torch.matmul(image_embeds, audio_embeds.T) * self.log_temp.exp() |
| probs = logits.softmax(dim=1) |
|
|
| |
| labels = torch.arange(len(images), device=images.device) |
| loss_i = F.cross_entropy(logits, labels) |
| loss_a = F.cross_entropy(logits.T, labels) |
| loss = (loss_i + loss_a) / 2 |
| |
| |
| similarity_scores = (image_embeds * audio_embeds).sum(dim=1) |
| avg_similarity = similarity_scores.mean() |
| |
| |
| return loss, loss_i, loss_a, logits, probs, avg_similarity |
|
|
|
|
| |
| |
| |
| class VaaniImageAudioDataset(torch.utils.data.Dataset): |
| def __init__(self, df): |
| self.image_paths = df.image_path.tolist() |
| self.audio_paths = df.audio_path.tolist() |
|
|
| def __len__(self): |
| return len(self.audio_paths) |
|
|
| def __getitem__(self, idx): |
| return { |
| 'image_path': self.image_paths[idx], |
| 'audio_path': self.audio_paths[idx] |
| } |
|
|
|
|
| def collate_fn(batch): |
| image_tensor = [open_clip_imgaug(Image.open(item['image_path'])) for item in batch] |
| image_paths = [item['image_path'] for item in batch] |
| |
| audio_paths = [item['audio_path'] for item in batch] |
| audio_tensor = CLAPAudioProcessor(audio_paths, resample=True) |
| |
| return { |
| 'image_tensor': torch.stack(image_tensor), |
| 'image_paths': image_paths, |
| 'audio_tensor': audio_tensor, |
| 'audio_paths': audio_paths |
| } |
|
|
|
|
|
|
| train_df = pd.read_csv("/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/available_img_audios_TRAIN2.csv") |
| test_df = pd.read_csv("/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/available_img_audios_TEST2.csv") |
| train_dataset = VaaniImageAudioDataset(train_df) |
| test_dataset = VaaniImageAudioDataset(test_df) |
|
|
| print('Train Dataset:', len(train_dataset)) |
| print('Test Dataset:', len(test_dataset)) |
|
|
| BATCH_SIZE = int(64) |
| train_dataloader = torch.utils.data.DataLoader( |
| train_dataset, |
| batch_size=BATCH_SIZE, |
| shuffle=True, |
| num_workers=48, |
| collate_fn=collate_fn, |
| pin_memory=True, |
| drop_last=False, |
| persistent_workers=True |
| ) |
|
|
| test_dataloader = torch.utils.data.DataLoader( |
| test_dataset, |
| batch_size=BATCH_SIZE, |
| shuffle=False, |
| num_workers=48, |
| collate_fn=collate_fn, |
| pin_memory=True, |
| drop_last=False, |
| persistent_workers=True |
| ) |
|
|
| batch = next(iter(train_dataloader)) |
| image_tensor_batch = batch['image_tensor'].to(device=device) |
| audio_tensor_batch = batch['audio_tensor'].to(device=device) |
| image_paths_batch = batch['image_paths'] |
| audio_paths_batch = batch['audio_paths'] |
| print("Image batch shape:", image_tensor_batch.shape) |
| print("Audio batch shape:", audio_tensor_batch.shape) |
|
|
|
|
| csip_model = CSIP(open_clip_model.visual, peft_clap_audio_encoder).to(device) |
| |
|
|
| from torchinfo import summary |
| import subprocess |
| summary(model=csip_model, |
| input_data=((image_tensor_batch.to(device)), (audio_tensor_batch.to(device))), |
| |
| dtypes=[torch.long], |
| col_names = ["input_size", "output_size", "num_params", "trainable", "params_percent"], |
| col_width=20, |
| row_settings=["var_names"], |
| depth = 2, |
| |
| |
| ) |
|
|
| |
| |
|
|
|
|
| def train_batch(model, images, audio, optimizer): |
| model.train() |
| optimizer.zero_grad() |
| loss, loss_i, loss_a, logits, probs, avg_similarity = model(images, audio) |
| loss.backward() |
| optimizer.step() |
| return loss.item(), loss_i.item(), loss_a.item(), logits, probs, avg_similarity.item() |
|
|
| @torch.no_grad() |
| def evaluate_batch(model, images, audio): |
| model.eval() |
| loss, loss_i, loss_a, logits, probs, avg_similarity = model(images, audio) |
| return loss.item(), loss_i.item(), loss_a.item(), logits, probs, avg_similarity.item() |
|
|
| def save_checkpoint(state, checkpoint_dir, epoch, max_checkpoints=2): |
| filename = f"csip_best_epoch_{epoch+1}.pt" |
| path = os.path.join(checkpoint_dir, filename) |
| torch.save(state, path) |
| checkpoints = sorted( |
| [f for f in os.listdir(checkpoint_dir) if f.startswith("csip_best_epoch_")], |
| key=lambda x: int(x.split("_")[-1].split(".")[0]) |
| ) |
| while len(checkpoints) > max_checkpoints: |
| to_delete = checkpoints.pop(0) |
| os.remove(os.path.join(checkpoint_dir, to_delete)) |
|
|
|
|
| def load_checkpoint(checkpoint_dir, model, optimizer, scheduler): |
| checkpoints = sorted( |
| [f for f in os.listdir(checkpoint_dir) if f.startswith("csip_best_epoch_")], |
| key=lambda x: int(x.split("_")[-1].split(".")[0]) |
| ) |
| if not checkpoints: |
| print("No checkpoint found to resume from.") |
| return 0, float("inf") |
|
|
| best_ckpt = checkpoints[-1] |
| path = os.path.join(checkpoint_dir, best_ckpt) |
| checkpoint = torch.load(path) |
| model.load_state_dict(checkpoint['model_state']) |
| |
| |
| start_epoch = checkpoint['epoch'] |
| best_loss = checkpoint['best_loss'] |
| print(f"Resumed training from epoch {start_epoch+1} with best loss {best_loss:.4f}") |
| return start_epoch, best_loss |
|
|
|
|
| def fig_to_tensor(fig): |
| """Convert a Matplotlib figure to a tensor suitable for TensorBoard.""" |
| buf = io.BytesIO() |
| fig.savefig(buf, format='png') |
| buf.seek(0) |
| image = Image.open(buf).convert("RGB") |
| tensor = torchvision.transforms.functional.to_tensor(image) |
| buf.close() |
| plt.close(fig) |
| return tensor |
|
|
| def save_similarity_heatmaps(logits, epoch, loss, save_dir, writer): |
| os.makedirs(os.path.join(save_dir, 'logits'), exist_ok=True) |
| os.makedirs(os.path.join(save_dir, 'probs'), exist_ok=True) |
|
|
| |
| logits_np = logits.detach().cpu().numpy() |
| fig_logits = plt.figure(figsize=(15, 13)) |
| sns.heatmap(logits_np, square=True, cmap="Blues", cbar=True, annot=False) |
| plt.title(f"Raw Logits Heatmap — Epoch {epoch+1}, Loss {loss:.4f}") |
| plt.xlabel("Audio Index") |
| plt.ylabel("Image Index") |
| raw_path = os.path.join(save_dir, 'logits', f"raw_logits_epoch_{epoch+1}_loss_{loss:.4f}.png") |
| fig_logits.savefig(raw_path) |
| writer.add_image("Heatmap/RawLogits", fig_to_tensor(fig_logits), global_step=epoch+1) |
|
|
| |
| probs_np = logits.softmax(dim=1).cpu().numpy() |
| fig_probs = plt.figure(figsize=(15, 13)) |
| sns.heatmap(probs_np, square=True, cmap="Blues", cbar=True, annot=False) |
| plt.title(f"Softmax Probabilities Heatmap — Epoch {epoch+1}, Loss {loss:.4f}") |
| plt.xlabel("Audio Index") |
| plt.ylabel("Image Index") |
| prob_path = os.path.join(save_dir, "probs", f"probs_epoch_{epoch+1}_loss_{loss:.4f}.png") |
| fig_probs.savefig(prob_path) |
| writer.add_image("Heatmap/SoftmaxProbs", fig_to_tensor(fig_probs), global_step=epoch+1) |
|
|
|
|
|
|
| def train_model(model, train_loader, test_loader, |
| optimizer, scheduler, device, log_dir, |
| checkpoint_dir, resume=False, epochs=10): |
|
|
| os.makedirs(log_dir, exist_ok=True) |
| os.makedirs(checkpoint_dir, exist_ok=True) |
| csv_path = os.path.join(log_dir, "training_log.csv") |
|
|
| writer = SummaryWriter(log_dir=log_dir) |
|
|
| start_epoch = 0 |
| best_loss = float("inf") |
| best_epoch = -1 |
|
|
| if resume: |
| start_epoch, best_loss = load_checkpoint(checkpoint_dir, model, optimizer, scheduler) |
|
|
| |
| if not (resume and os.path.exists(csv_path)): |
| with open(csv_path, mode='w', newline='') as f: |
| writer_csv = csv.writer(f) |
| writer_csv.writerow(["Epoch", "Best Epoch", "Train Loss", "Test Loss", |
| "Best Loss", "Train Sim", "Test Sim", "Learning Rate", |
| "Train I-Loss", "Test I-Loss", "Train A-Loss", "Test A-Loss"]) |
|
|
| for epoch in trange(start_epoch, epochs, colour='yellow', dynamic_ncols=True): |
| train_losses = [] |
| test_losses = [] |
| train_i_losses = [] |
| test_i_losses = [] |
| train_a_losses = [] |
| test_a_losses = [] |
| train_sim = [] |
| test_sim = [] |
|
|
| train_loop = tqdm(train_loader, desc=f"[TrainEp {epoch+1}]", colour='blue', dynamic_ncols=True) |
| for batch in train_loop: |
| images = batch['image_tensor'].to(device) |
| audios = batch['audio_tensor'].to(device) |
| loss, loss_i, loss_a, logits, probs, avg_similarity = train_batch(model, images, audios, optimizer) |
| train_losses.append(loss) |
| train_i_losses.append(loss_i) |
| train_a_losses.append(loss_a) |
| train_sim.append(avg_similarity) |
| train_loop.set_postfix(trainLoss=loss) |
|
|
| test_loop = tqdm(test_loader, desc=f"[TestEp {epoch+1}]", colour='red', dynamic_ncols=True) |
| for batch in test_loop: |
| images = batch['image_tensor'].to(device) |
| audios = batch['audio_tensor'].to(device) |
| loss, loss_i, loss_a, logits, probs, avg_similarity = evaluate_batch(model, images, audios) |
| test_losses.append(loss) |
| test_i_losses.append(loss_i) |
| test_a_losses.append(loss_a) |
| test_sim.append(avg_similarity) |
| test_loop.set_postfix(testLoss=loss) |
|
|
| avg_train_loss = sum(train_losses) / len(train_losses) |
| avg_test_loss = sum(test_losses) / len(test_losses) |
| avg_train_i_loss = sum(train_i_losses) / len(train_i_losses) |
| avg_test_i_loss = sum(test_i_losses) / len(test_i_losses) |
| avg_train_a_loss = sum(train_a_losses) / len(train_a_losses) |
| avg_test_a_loss = sum(test_a_losses) / len(test_a_losses) |
| avg_train_sim = sum(train_sim) / len(train_sim) |
| avg_test_sim = sum(test_sim) / len(test_sim) |
| current_lr = optimizer.param_groups[0]['lr'] |
|
|
| writer.add_scalar("Loss/Train", avg_train_loss, epoch + 1) |
| writer.add_scalar("Loss/Test", avg_test_loss, epoch + 1) |
| writer.add_scalar("Loss/Train/Image", avg_train_i_loss, epoch + 1) |
| writer.add_scalar("Loss/Test/Image", avg_test_i_loss, epoch + 1) |
| writer.add_scalar("Loss/Train/Audio", avg_train_a_loss, epoch + 1) |
| writer.add_scalar("Loss/Test/Audio", avg_test_a_loss, epoch + 1) |
| writer.add_scalar("Similarity/Train", avg_train_sim, epoch + 1) |
| writer.add_scalar("Similarity/Test", avg_test_sim, epoch + 1) |
| writer.add_scalar("Learning Rate", current_lr, epoch + 1) |
|
|
| print(f"\n\n |" |
| f"Epoch {epoch+1} | Loss: ({avg_train_loss:.4f}, {avg_test_loss:.4f}, {best_loss:.4f}) |" |
| f"LR: {current_lr:.2e} |" |
| f"Similarity: ({avg_train_sim:.4f}, {avg_test_sim:.4f})" |
| f"| I-Loss: ({avg_train_i_loss:.4f}, {avg_test_i_loss:.4f})" |
| f"| A-Loss: ({avg_train_a_loss:.4f}, {avg_test_a_loss:.4f})" |
| ) |
|
|
| if avg_test_loss < best_loss: |
| save_similarity_heatmaps(logits, epoch, avg_test_loss, checkpoint_dir, writer) |
| best_loss = avg_test_loss |
| best_epoch = epoch + 1 |
| save_checkpoint({ |
| 'epoch': epoch, |
| 'model_state': model.state_dict(), |
| 'optimizer_state': optimizer.state_dict(), |
| 'best_loss': best_loss, |
| 'train_loss': avg_train_loss, |
| 'similarity': avg_test_sim, |
| 'train_similarity': avg_train_sim, |
| 'learning_rate': current_lr, |
| 'scheduler_state': scheduler.state_dict() if scheduler else None |
| }, checkpoint_dir, epoch) |
| print(f">>> Saved new best model at epoch {epoch+1}") |
|
|
| scheduler.step() |
| with open(csv_path, mode='a', newline='') as f: |
| writer_csv = csv.writer(f) |
| writer_csv.writerow([epoch + 1, best_epoch, avg_train_loss, avg_test_loss, best_loss, |
| avg_train_sim, avg_test_sim, current_lr, |
| avg_train_i_loss, avg_test_i_loss, avg_train_a_loss, avg_test_a_loss]) |
|
|
| writer.close() |
| print(f"Training completed. Best epoch: {best_epoch}, Best loss: {best_loss:.4f}, Best similarity: {avg_test_sim:.4f}") |
| print(f"Training log saved to {csv_path}") |
|
|
|
|
| model_name = "csip_model_openClip_CLAP" |
| epochs = 500 |
|
|
| learning_rate = 1e-5 |
| optimizer = torch.optim.AdamW(csip_model.parameters(), lr=learning_rate) |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-10) |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| train_model( |
| model=csip_model, |
| train_loader=train_dataloader, |
| test_loader=test_dataloader, |
| optimizer=optimizer, |
| scheduler=scheduler, |
| device=device, |
| log_dir=f"{model_name}/runs/csip", |
| checkpoint_dir=f"{model_name}/checkpoints/csip", |
| resume=True, |
| epochs=epochs |
| ) |
|
|
|
|
| |
| |
|
|
| |
|
|