# ================================================================== # C S I P # ================================================================== # Author:: ASHISH KUMAR UCHADIYA # Date:: 2024-May-27 # #<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< # ================================================================== # I M P O R T S # ================================================================== from __future__ import annotations import warnings warnings.filterwarnings("ignore") import os import io import sys import math import random import collections import collections.abc import re from itertools import repeat from pathlib import Path from typing import Optional, Tuple, Union, List, Dict import csv import copy import numpy as np import pandas as pd from PIL import Image import seaborn as sns import matplotlib.pyplot as plt from tqdm import trange, tqdm import torch import torch.nn.functional as F from torch import nn from torch.nn.init import _calculate_fan_in_and_fan_out import torch.utils.checkpoint as checkpoint import torchvision from torchvision.transforms import v2 from torch.utils.tensorboard import SummaryWriter # from tensorboardX import SummaryWriter # os.environ["CUDA_VISIBLE_DEVICES"] = "1" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # print(f"Using device: {device}") import torchaudio import torchaudio.transforms as T from torchlibrosa.stft import Spectrogram, LogmelFilterBank from torchlibrosa.augmentation import SpecAugmentation from transformers import AutoModel, AutoTokenizer, logging from huggingface_hub.file_download import hf_hub_download from huggingface_hub.file_download import hf_hub_download from peft import get_peft_config, get_peft_model from transformers import CLIPVisionModel, AutoProcessor # from watermark import watermark # print(watermark( # author='Ashish', # # email='ashish@example.com', # current_date=True, # datename=True, # current_time=True, # iso8601=True, # timezone=True, # updated=True, # custom_time=None, # python=True, # # packages="torch,torchvision,numpy", # conda=True, # hostname=True, # machine=True, # watermark=False, # iversions=True, # gpu=True, # globals_=globals() # )) # ================================================================== # H T S - A T # ================================================================== class HTSATConfig: # Ke Chen # knutchen@ucsd.edu # HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION # The configuration for training the model exp_name = "exp_htsat_pretrain" # the saved ckpt prefix name of the model workspace = "/home/kechen/Research/HTSAT" # the folder of your code dataset_path = "/home/Research/audioset" # the dataset path desed_folder = "/home/Research/DESED" # the desed file dataset_type = "audioset" # "audioset" "esc-50" "scv2" index_type = "full_train" # only works for audioset balanced_data = True # only works for audioset loss_type = "clip_bce" # # AudioSet & SCV2: "clip_bce" | ESC-50: "clip_ce" # trained from a checkpoint, or evaluate a single model resume_checkpoint = None # "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_1.ckpt" esc_fold = 0 # just for esc dataset, select the fold you need for evaluation and (+1) validation debug = False random_seed = 970131 # 19970318 970131 12412 127777 1009 34047 batch_size = 32 * 4 # batch size per GPU x GPU number , default is 32 x 4 = 128 learning_rate = 1e-3 # 1e-4 also workable max_epoch = 100 num_workers = 3 lr_scheduler_epoch = [10,20,30] lr_rate = [0.02, 0.05, 0.1] # these data preparation optimizations do not bring many improvements, so deprecated enable_token_label = False # token label class_map_path = "class_hier_map.npy" class_filter = None retrieval_index = [15382, 9202, 130, 17618, 17157, 17516, 16356, 6165, 13992, 9238, 5550, 5733, 1914, 1600, 3450, 13735, 11108, 3762, 9840, 11318, 8131, 4429, 16748, 4992, 16783, 12691, 4945, 8779, 2805, 9418, 2797, 14357, 5603, 212, 3852, 12666, 1338, 10269, 2388, 8260, 4293, 14454, 7677, 11253, 5060, 14938, 8840, 4542, 2627, 16336, 8992, 15496, 11140, 446, 6126, 10691, 8624, 10127, 9068, 16710, 10155, 14358, 7567, 5695, 2354, 8057, 17635, 133, 16183, 14535, 7248, 4560, 14429, 2463, 10773, 113, 2462, 9223, 4929, 14274, 4716, 17307, 4617, 2132, 11083, 1039, 1403, 9621, 13936, 2229, 2875, 17840, 9359, 13311, 9790, 13288, 4750, 17052, 8260, 14900] token_label_range = [0.2,0.6] enable_time_shift = False # shift time enable_label_enhance = False # enhance hierarchical label enable_repeat_mode = False # repeat the spectrogram / reshape the spectrogram # for model's design enable_tscam = True # enbale the token-semantic layer # for signal processing sample_rate = 32000 # 16000 for scv2, 32000 for audioset and esc-50 clip_samples = sample_rate * 10 # audio_set 10-sec clip window_size = 1024 hop_size = 320 # 160 for scv2, 320 for audioset and esc-50 mel_bins = 64 fmin = 50 fmax = 14000 shift_max = int(clip_samples * 0.5) # for data collection classes_num = 527 # esc: 50 | audioset: 527 | scv2: 35 patch_size = (25, 4) # deprecated crop_size = None # int(clip_samples * 0.5) deprecated # for htsat hyperparamater htsat_window_size = 8 htsat_spec_size = 256 htsat_patch_size = 4 htsat_stride = (4, 4) htsat_num_head = [4,8,16,32] htsat_dim = 96 htsat_depth = [2,2,6,2] swin_pretrain_path = None # "/home/Research/model_backup/pretrain/swin_tiny_c24_patch4_window8_256.pth" # Some Deprecated Optimization in the model design, check the model code for details htsat_attn_heatmap = False htsat_hier_output = False htsat_use_max = False # for ensemble test ensemble_checkpoints = [] ensemble_strides = [] # weight average folder wa_folder = "/home/version_0/checkpoints/" # weight average output filename wa_model_path = "HTSAT_AudioSet_Saved_x.ckpt" esm_model_pathes = [ "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_1.ckpt", "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_2.ckpt", "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_3.ckpt", "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_4.ckpt", "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_5.ckpt", "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_6.ckpt" ] # for framewise localization heatmap_dir = "/home/Research/heatmap_output" test_file = "htsat-test-ensemble" fl_local = False # indicate if we need to use this dataset for the framewise detection fl_dataset = "/home/Research/desed/desedim_embval.npy" fl_class_num = [ "Speech", "Frying", "Dishes", "Running_water", "Blender", "Electric_shaver_toothbrush", "Alarm_bell_ringing", "Cat", "Dog", "Vacuum_cleaner" ] # map 527 classes into 10 classes fl_audioset_mapping = [ [0,1,2,3,4,5,6,7], [366, 367, 368], [364], [288, 289, 290, 291, 292, 293, 294, 295, 296, 297], [369], [382], [310, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402], [81, 82, 83, 84, 85], [74, 75, 76, 77, 78, 79], [377] ] def _ntuple(n): def parse(x): if isinstance(x, collections.abc.Iterable): return x return tuple(repeat(x, n)) return parse to_1tuple = _ntuple(1) to_2tuple = _ntuple(2) to_3tuple = _ntuple(3) to_4tuple = _ntuple(4) to_ntuple = _ntuple def do_mixup(x, mixup_lambda): """Mixup x of even indexes (0, 2, 4, ...) with x of odd indexes (1, 3, 5, ...). Args: x: (batch_size * 2, ...) mixup_lambda: (batch_size * 2,) Returns: out: (batch_size, ...) """ out = (x[0 :: 2].transpose(0, -1) * mixup_lambda[0 :: 2] + \ x[1 :: 2].transpose(0, -1) * mixup_lambda[1 :: 2]).transpose(0, -1) return out def interpolate(x, ratio): """Interpolate data in time domain. This is used to compensate the resolution reduction in downsampling of a CNN. Args: x: (batch_size, time_steps, classes_num) ratio: int, ratio to interpolate Returns: upsampled: (batch_size, time_steps * ratio, classes_num) """ (batch_size, time_steps, classes_num) = x.shape upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1) upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num) return upsampled def drop_path(x, drop_prob: float = 0., training: bool = False): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. """ if drop_prob == 0. or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor.floor_() # binarize output = x.div(keep_prob) * random_tensor return output class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training) class PatchEmbed(nn.Module): """ 2D Image to Patch Embedding """ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, patch_stride = 16): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) patch_stride = to_2tuple(patch_stride) self.img_size = img_size self.patch_size = patch_size self.patch_stride = patch_stride self.grid_size = (img_size[0] // patch_stride[0], img_size[1] // patch_stride[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] self.flatten = flatten self.in_chans = in_chans self.embed_dim = embed_dim padding = ((patch_size[0] - patch_stride[0]) // 2, (patch_size[1] - patch_stride[1]) // 2) self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride, padding=padding) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x): B, C, H, W = x.shape assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x) if self.flatten: x = x.flatten(2).transpose(1, 2) # BCHW -> BNC x = self.norm(x) return x class Mlp(nn.Module): """ MLP as used in Vision Transformer, MLP-Mixer and related networks """ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x def _no_gradim_audiorunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function return (1. + math.erf(x / math.sqrt(2.))) / 2. if (mean < a - 2 * std) or (mean > b + 2 * std): warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " "The distribution of values may be incorrect.", stacklevel=2) with torch.no_grad(): # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values l = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. tensor.uniform_(2 * l - 1, 2 * u - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal tensor.erfinv_() # Transform to proper mean, std tensor.mul_(std * math.sqrt(2.)) tensor.add_(mean) # Clamp to ensure it's in the proper range tensor.clamp_(min=a, max=b) return tensor def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): # type: (Tensor, float, float, float, float) -> Tensor r"""Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for generating the random values works best when :math:`a \leq \text{mean} \leq b`. Args: tensor: an n-dimensional `torch.Tensor` mean: the mean of the normal distribution std: the standard deviation of the normal distribution a: the minimum cutoff value b: the maximum cutoff value Examples: >>> w = torch.empty(3, 5) >>> nn.init.trunc_normal_(w) """ return _no_gradim_audiorunc_normal_(tensor, mean, std, a, b) def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) if mode == 'fan_in': denom = fan_in elif mode == 'fan_out': denom = fan_out elif mode == 'fan_avg': denom = (fan_in + fan_out) / 2 variance = scale / denom if distribution == "truncated_normal": # constant is stddev of standard normal truncated to (-2, 2) trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978) elif distribution == "normal": tensor.normal_(std=math.sqrt(variance)) elif distribution == "uniform": bound = math.sqrt(3 * variance) tensor.uniform_(-bound, bound) else: raise ValueError(f"invalid distribution {distribution}") def lecun_normal_(tensor): variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal') # below codes are based and referred from https://github.com/microsoft/Swin-Transformer # Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf def window_partition(x, window_size): """ Args: x: (B, H, W, C) window_size (int): window size Returns: windows: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows def window_reverse(windows, window_size, H, W): """ Args: windows: (num_windows*B, window_size, window_size, C) window_size (int): Window size H (int): Height of image W (int): Width of image Returns: x: (B, H, W, C) """ B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x class WindowAttention(nn.Module): r""" Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. window_size (tuple[int]): The height and width of the window. num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) trunc_normal_(self.relative_position_bias_table, std=.02) self.softmax = nn.Softmax(dim=-1) def forward(self, x, mask=None): """ Args: x: input features with shape of (num_windows*B, N, C) mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, N, C = x.shape qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = (q @ k.transpose(-2, -1)) relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: nW = mask.shape[0] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x, attn def extra_repr(self): return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' # We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model class SwinTransformerBlock(nn.Module): r""" Swin Transformer Block. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resulotion. num_heads (int): Number of attention heads. window_size (int): Window size. shift_size (int): Shift size for SW-MSA. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_before_mlp='ln'): super().__init__() self.dim = dim self.input_resolution = input_resolution self.num_heads = num_heads self.window_size = window_size self.shift_size = shift_size self.mlp_ratio = mlp_ratio self.norm_before_mlp = norm_before_mlp if min(self.input_resolution) <= self.window_size: # if window size is larger than input resolution, we don't partition windows self.shift_size = 0 self.window_size = min(self.input_resolution) assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" self.norm1 = norm_layer(dim) self.attn = WindowAttention( dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() if self.norm_before_mlp == 'ln': self.norm2 = nn.LayerNorm(dim) elif self.norm_before_mlp == 'bn': self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose(1, 2) else: raise NotImplementedError mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) if self.shift_size > 0: # calculate attention mask for SW-MSA H, W = self.input_resolution img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) else: attn_mask = None self.register_buffer("attn_mask", attn_mask) def forward(self, x): # pdb.set_trace() H, W = self.input_resolution # print("H: ", H) # print("W: ", W) # pdb.set_trace() B, L, C = x.shape # assert L == H * W, "input feature has wrong size" shortcut = x x = self.norm1(x) x = x.view(B, H, W, C) # cyclic shift if self.shift_size > 0: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) else: shifted_x = x # partition windows x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C # W-MSA/SW-MSA attn_windows, attn = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C # reverse cyclic shift if self.shift_size > 0: x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = shifted_x x = x.view(B, H * W, C) # FFN x = shortcut + self.drop_path(x) x = x + self.drop_path(self.mlp(self.norm2(x))) return x, attn def extra_repr(self): return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" class PatchMerging(nn.Module): r""" Patch Merging Layer. Args: input_resolution (tuple[int]): Resolution of input feature. dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): super().__init__() self.input_resolution = input_resolution self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer(4 * dim) def forward(self, x): """ x: B, H*W, C """ H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." x = x.view(B, H, W, C) x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C x = self.norm(x) x = self.reduction(x) return x def extra_repr(self): return f"input_resolution={self.input_resolution}, dim={self.dim}" class BasicLayer(nn.Module): """ A basic Swin Transformer layer for one stage. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resolution. depth (int): Number of blocks. num_heads (int): Number of attention heads. window_size (int): Local window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__(self, dim, input_resolution, depth, num_heads, window_size, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, norm_before_mlp='ln'): super().__init__() self.dim = dim self.input_resolution = input_resolution self.depth = depth self.use_checkpoint = use_checkpoint # build blocks self.blocks = nn.ModuleList([ SwinTransformerBlock(dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, norm_before_mlp=norm_before_mlp) for i in range(depth)]) # patch merging layer if downsample is not None: self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) else: self.downsample = None def forward(self, x): attns = [] for blk in self.blocks: if self.use_checkpoint: x = checkpoint.checkpoint(blk, x) else: x, attn = blk(x) if not self.training: attns.append(attn.unsqueeze(0)) if self.downsample is not None: x = self.downsample(x) if not self.training: attn = torch.cat(attns, dim = 0) attn = torch.mean(attn, dim = 0) return x, attn def extra_repr(self): return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" # The Core of HTSAT class HTSAT_Swin_Transformer(nn.Module): r"""HTSAT based on the Swin Transformer Args: spec_size (int | tuple(int)): Input Spectrogram size. Default 256 patch_size (int | tuple(int)): Patch size. Default: 4 path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4 in_chans (int): Number of input image channels. Default: 1 (mono) num_classes (int): Number of classes for classification head. Default: 527 embed_dim (int): Patch embedding dimension. Default: 96 depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer. num_heads (tuple(int)): Number of attention heads in different layers. window_size (int): Window size. Default: 8 mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None drop_rate (float): Dropout rate. Default: 0 attn_drop_rate (float): Attention dropout rate. Default: 0 drop_path_rate (float): Stochastic depth rate. Default: 0.1 norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. ape (bool): If True, add absolute position embedding to the patch embedding. Default: False patch_norm (bool): If True, add normalization after patch embedding. Default: True use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False config (module): The configuration Module from config.py (HTSATConfig Class) """ def __init__(self, spec_size=256, patch_size=4, patch_stride=(4,4), in_chans=1, num_classes=527, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[4, 8, 16, 32], window_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, use_checkpoint=False, norm_before_mlp='ln', config = None, **kwargs): super(HTSAT_Swin_Transformer, self).__init__() self.config = config self.spec_size = spec_size self.patch_stride = patch_stride self.patch_size = patch_size self.window_size = window_size self.embed_dim = embed_dim self.depths = depths self.ape = ape self.in_chans = in_chans self.num_classes = num_classes self.num_heads = num_heads self.num_layers = len(self.depths) self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1)) self.drop_rate = drop_rate self.attn_drop_rate = attn_drop_rate self.drop_path_rate = drop_path_rate self.qkv_bias = qkv_bias self.qk_scale = None self.patch_norm = patch_norm self.norm_layer = norm_layer if self.patch_norm else None self.norm_before_mlp = norm_before_mlp self.mlp_ratio = mlp_ratio self.use_checkpoint = use_checkpoint # process mel-spec ; used only once self.freq_ratio = self.spec_size // self.config.mel_bins window = 'hann' center = True pad_mode = 'reflect' ref = 1.0 amin = 1e-10 top_db = None self.interpolate_ratio = 32 # Downsampled ratio # Spectrogram extractor self.spectrogram_extractor = Spectrogram(n_fft=config.window_size, hop_length=config.hop_size, win_length=config.window_size, window=window, center=center, pad_mode=pad_mode, freeze_parameters=True) # Logmel feature extractor self.logmel_extractor = LogmelFilterBank(sr=config.sample_rate, n_fft=config.window_size, n_mels=config.mel_bins, fmin=config.fmin, fmax=config.fmax, ref=ref, amin=amin, top_db=top_db, freeze_parameters=True) # Spec augmenter self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, freq_drop_width=8, freq_stripes_num=2) # 2 2 self.bn0 = nn.BatchNorm2d(self.config.mel_bins) # split spctrogram into non-overlapping patches self.patch_embed = PatchEmbed( img_size=self.spec_size, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=self.embed_dim, norm_layer=self.norm_layer, patch_stride = patch_stride) num_patches = self.patch_embed.num_patches patches_resolution = self.patch_embed.grid_size self.patches_resolution = patches_resolution # absolute position embedding if self.ape: self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dim)) trunc_normal_(self.absolute_pos_embed, std=.02) self.pos_drop = nn.Dropout(p=self.drop_rate) # stochastic depth dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths))] # stochastic depth decay rule # build layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): layer = BasicLayer(dim=int(self.embed_dim * 2 ** i_layer), input_resolution=(patches_resolution[0] // (2 ** i_layer), patches_resolution[1] // (2 ** i_layer)), depth=self.depths[i_layer], num_heads=self.num_heads[i_layer], window_size=self.window_size, mlp_ratio=self.mlp_ratio, qkv_bias=self.qkv_bias, qk_scale=self.qk_scale, drop=self.drop_rate, attn_drop=self.attn_drop_rate, drop_path=dpr[sum(self.depths[:i_layer]):sum(self.depths[:i_layer + 1])], norm_layer=self.norm_layer, downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint, norm_before_mlp=self.norm_before_mlp) self.layers.append(layer) self.norm = self.norm_layer(self.num_features) self.avgpool = nn.AdaptiveAvgPool1d(1) self.maxpool = nn.AdaptiveMaxPool1d(1) if self.config.enable_tscam: SF = self.spec_size // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] // self.freq_ratio self.tscam_conv = nn.Conv2d( in_channels = self.num_features, out_channels = self.num_classes, kernel_size = (SF,3), padding = (0,1) ) self.head = nn.Linear(num_classes, num_classes) else: self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) @torch.jit.ignore def no_weight_decay(self): return {'absolute_pos_embed'} @torch.jit.ignore def no_weight_decay_keywords(self): return {'relative_position_bias_table'} def forward_features(self, x): frames_num = x.shape[2] x = self.patch_embed(x) if self.ape: x = x + self.absolute_pos_embed x = self.pos_drop(x) for i, layer in enumerate(self.layers): x, attn = layer(x) if self.config.enable_tscam: # for x x = self.norm(x) B, N, C = x.shape SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1] x = x.permute(0,2,1).contiguous().reshape(B, C, SF, ST) B, C, F, T = x.shape # group 2D CNN c_freq_bin = F // self.freq_ratio x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T) x = x.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1) # get latent_output latent_output = self.avgpool(torch.flatten(x,2)) latent_output = torch.flatten(latent_output, 1) # display the attention map, if needed if self.config.htsat_attn_heatmap: # for attn attn = torch.mean(attn, dim = 1) attn = torch.mean(attn, dim = 1) attn = attn.reshape(B, SF, ST) c_freq_bin = SF // self.freq_ratio attn = attn.reshape(B, SF // c_freq_bin, c_freq_bin, ST) attn = attn.permute(0,2,1,3).contiguous().reshape(B, c_freq_bin, -1) attn = attn.mean(dim = 1) attn_max = torch.max(attn, dim = 1, keepdim = True)[0] attn_min = torch.min(attn, dim = 1, keepdim = True)[0] attn = ((attn * 0.15) + (attn_max * 0.85 - attn_min)) / (attn_max - attn_min) attn = attn.unsqueeze(dim = 2) x = self.tscam_conv(x) x = torch.flatten(x, 2) # B, C, T if self.config.htsat_attn_heatmap: fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous() * attn, 8 * self.patch_stride[1]) else: fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous(), 8 * self.patch_stride[1]) x = self.avgpool(x) x = torch.flatten(x, 1) if self.config.loss_type == "clip_ce": output_dict = { 'framewise_output': fpx, # already sigmoided 'clipwise_output': x, 'latent_output': latent_output } else: output_dict = { 'framewise_output': fpx, # already sigmoided 'clipwise_output': torch.sigmoid(x), 'latent_output': latent_output } else: x = self.norm(x) # B N C B, N, C = x.shape fpx = x.permute(0,2,1).contiguous().reshape(B, C, frames_num // (2 ** (len(self.depths) + 1)), frames_num // (2 ** (len(self.depths) + 1)) ) B, C, F, T = fpx.shape c_freq_bin = F // self.freq_ratio fpx = fpx.reshape(B, C, F // c_freq_bin, c_freq_bin, T) fpx = fpx.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1) fpx = torch.sum(fpx, dim = 2) fpx = interpolate(fpx.permute(0,2,1).contiguous(), 8 * self.patch_stride[1]) x = self.avgpool(x.transpose(1, 2)) # B C 1 x = torch.flatten(x, 1) if self.num_classes > 0: x = self.head(x) fpx = self.head(fpx) output_dict = {'framewise_output': torch.sigmoid(fpx), 'clipwise_output': torch.sigmoid(x)} return output_dict def crop_wav(self, x, crop_size, spe_pos = None): time_steps = x.shape[2] tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device) for i in range(len(x)): if spe_pos is None: crop_pos = random.randint(0, time_steps - crop_size - 1) else: crop_pos = spe_pos tx[i][0] = x[i, 0, crop_pos:crop_pos + crop_size,:] return tx # Reshape the wavform to a img size, if you want to use the pretrained swin transformer model def reshape_wav2img(self, x): B, C, T, F = x.shape target_T = int(self.spec_size * self.freq_ratio) target_F = self.spec_size // self.freq_ratio assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size" # to avoid bicubic zero error if T < target_T: x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True) if F < target_F: x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True) x = x.permute(0,1,3,2).contiguous() x = x.reshape(x.shape[0], x.shape[1], x.shape[2], self.freq_ratio, x.shape[3] // self.freq_ratio) # print(x.shape) x = x.permute(0,1,3,2,4).contiguous() x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4]) return x # Repeat the wavform to a img size, if you want to use the pretrained swin transformer model def repeat_wat2img(self, x, cur_pos): B, C, T, F = x.shape target_T = int(self.spec_size * self.freq_ratio) target_F = self.spec_size // self.freq_ratio assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size" # to avoid bicubic zero error if T < target_T: x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True) if F < target_F: x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True) x = x.permute(0,1,3,2).contiguous() # B C F T x = x[:,:,:,cur_pos:cur_pos + self.spec_size] x = x.repeat(repeats = (1,1,4,1)) return x def forward(self, x: torch.Tensor, mixup_lambda = None, infer_mode = False):# out_feat_keys: List[str] = None): x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins) x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) x = x.transpose(1, 3) x = self.bn0(x) x = x.transpose(1, 3) if self.training: x = self.spec_augmenter(x) if self.training and mixup_lambda is not None: x = do_mixup(x, mixup_lambda) if infer_mode: # in infer mode. we need to handle different length audio input frame_num = x.shape[2] target_T = int(self.spec_size * self.freq_ratio) repeat_ratio = math.floor(target_T / frame_num) x = x.repeat(repeats=(1,1,repeat_ratio,1)) x = self.reshape_wav2img(x) output_dict = self.forward_features(x) elif self.config.enable_repeat_mode: if self.training: cur_pos = random.randint(0, (self.freq_ratio - 1) * self.spec_size - 1) x = self.repeat_wat2img(x, cur_pos) output_dict = self.forward_features(x) else: output_dicts = [] for cur_pos in range(0, (self.freq_ratio - 1) * self.spec_size + 1, self.spec_size): tx = x.clone() tx = self.repeat_wat2img(tx, cur_pos) output_dicts.append(self.forward_features(tx)) clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device) framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device) for d in output_dicts: clipwise_output += d["clipwise_output"] framewise_output += d["framewise_output"] clipwise_output = clipwise_output / len(output_dicts) framewise_output = framewise_output / len(output_dicts) output_dict = { 'framewise_output': framewise_output, 'clipwise_output': clipwise_output } else: if x.shape[2] > self.freq_ratio * self.spec_size: if self.training: x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size) x = self.reshape_wav2img(x) output_dict = self.forward_features(x) else: # Change: Hard code here overlap_size = 344 #(x.shape[2] - 1) // 4 output_dicts = [] crop_size = 689 #(x.shape[2] - 1) // 2 for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size): tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos) tx = self.reshape_wav2img(tx) output_dicts.append(self.forward_features(tx)) clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device) framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device) latent_output = torch.zeros_like(output_dicts[0]["latent_output"]).float().to(x.device) for d in output_dicts: clipwise_output += d["clipwise_output"] framewise_output += d["framewise_output"] latent_output += d["latent_output"] clipwise_output = clipwise_output / len(output_dicts) framewise_output = framewise_output / len(output_dicts) latent_output = latent_output / len(output_dicts) output_dict = { 'framewise_output': framewise_output, 'clipwise_output': clipwise_output, 'latent_output': latent_output, } else: # this part is typically used, and most easy one x = self.reshape_wav2img(x) output_dict = self.forward_features(x) # x = self.head(x) return output_dict class HTSATWrapper(nn.Module): def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num, out_emb): super().__init__() # print("parameters are being overidden when using HTSAT") # print("HTSAT only support loading a pretrained model on AudioSet") # @TODO later look at what parameters are same and can be merged self.htsat = HTSAT_Swin_Transformer(config=HTSATConfig()) def forward(self, x): out_dict = self.htsat(x) out_dict['embedding'] = out_dict['latent_output'] return out_dict def get_audio_encoder(name: str): if name == "HTSAT": return HTSATWrapper else: raise Exception('The audio encoder name {} is incorrect or not supported'.format(name)) class Projection(nn.Module): def __init__(self, dim_imgn: int, d_out: int, p: float=0.5) -> None: super().__init__() self.linear1 = nn.Linear(dim_imgn, d_out, bias=False) self.linear2 = nn.Linear(d_out, d_out, bias=False) self.layer_norm = nn.LayerNorm(d_out) self.drop = nn.Dropout(p) def forward(self, x: torch.Tensor) -> torch.Tensor: embed1 = self.linear1(x) embed2 = self.drop(self.linear2(F.gelu(embed1))) embeds = self.layer_norm(embed1 + embed2) return embeds class AudioEncoder(nn.Module): def __init__(self, audioenc_name:str, dim_imgn: int, d_out: int, sample_rate: int, window_size: int, hop_size: int, mel_bins: int, fmin: int, fmax: int, classes_num: int) -> None: super().__init__() audio_encoder = get_audio_encoder(audioenc_name) self.base = audio_encoder( sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num, dim_imgn) self.projection = Projection(dim_imgn, d_out) def forward(self, x): out_dict = self.base(x) audio_features, audio_classification_output = out_dict['embedding'], out_dict['clipwise_output'] projected_vec = self.projection(audio_features) return projected_vec, audio_classification_output class TextEncoder(nn.Module): def __init__(self, d_out: int, text_model: str, transformer_embed_dim: int) -> None: super().__init__() self.text_model = text_model self.base = AutoModel.from_pretrained(text_model) if 'clip' in text_model: self.clip_text_projection = self.base.text_projection self.base = self.base.text_model if 'base' in text_model: transformer_embed_dim = 512 self.projection = Projection(transformer_embed_dim, d_out) def forward(self, x): if 'clip' in self.text_model: pooled_output = self.base(**x)[1] # get pooled output out = self.clip_text_projection(pooled_output) # get CLS token output elif 'gpt' in self.text_model: batch_size = x['input_ids'].shape[0] hidden_states = self.base(**x)[0] # (batch_size=4, seq_len, 768) sequence_lengths = torch.ne(x['input_ids'], 0).sum(-1) - 1 # tensor([13, 14, 18, 17]) out = hidden_states[torch.arange(batch_size, device=hidden_states.device), sequence_lengths] # [batch_size, 768] = [4, 768] else: out = self.base(**x)[0] out = out[:, 0, :] # get CLS token output projected_vec = self.projection(out) return projected_vec class CLAP(nn.Module): def __init__(self, # audio audioenc_name: str, sample_rate: int, window_size: int, hop_size: int, mel_bins: int, fmin: int, fmax: int, classes_num: int, out_emb: int, # text text_model: str, transformer_embed_dim: int, # common d_proj: int, ): super().__init__() self.audio_encoder = AudioEncoder( audioenc_name, out_emb, d_proj, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num) self.caption_encoder = TextEncoder( d_proj, text_model, transformer_embed_dim ) self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) def forward(self, audio, text): audio_embed, _ = self.audio_encoder(audio) caption_embed = self.caption_encoder(text) return caption_embed, audio_embed, self.logit_scale.exp() # ================================================================== # A U D I O - P R E - P R O C E S S I N G # ================================================================== def read_audio(audio_path, resample=True): r"""Loads audio file or array and returns a torch tensor""" # Randomly sample a segment of audio_duration from the clip or pad to match duration audio_time_series, sample_rate = torchaudio.load(audio_path) resample_rate = clapConfig.sample_rate if resample and resample_rate != sample_rate: resampler = T.Resample(sample_rate, resample_rate) audio_time_series = resampler(audio_time_series) return audio_time_series, resample_rate def load_audio_into_tensor(audio_path, audio_duration, resample=False): r"""Loads audio file and returns raw audio.""" # Randomly sample a segment of audio_duration from the clip or pad to match duration audio_time_series, sample_rate = read_audio(audio_path, resample) audio_time_series = audio_time_series.reshape(-1) # audio_time_series is shorter than predefined audio duration, # so audio_time_series is extended if audio_duration*sample_rate >= audio_time_series.shape[0]: repeat_factor = int(np.ceil((audio_duration*sample_rate) / audio_time_series.shape[0])) # Repeat audio_time_series by repeat_factor to match audio_duration audio_time_series = audio_time_series.repeat(repeat_factor) # remove excess part of audio_time_series audio_time_series = audio_time_series[0:audio_duration*sample_rate] else: # audio_time_series is longer than predefined audio duration, # so audio_time_series is trimmed start_index = random.randrange( audio_time_series.shape[0] - audio_duration*sample_rate) audio_time_series = audio_time_series[start_index:start_index + audio_duration*sample_rate] return torch.FloatTensor(audio_time_series) np_str_obj_array_pattern = re.compile(r'[SaUO]') default_collate_err_msg_format = ( "default_collate: batch must contain tensors, numpy arrays, numbers, " "dicts or lists; found {}") def default_collate(batch): r"""Puts each data field into a tensor with outer dimension batch size""" elem = batch[0] elem_type = type(elem) if isinstance(elem, torch.Tensor): out = None if torch.utils.data.get_worker_info() is not None: # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = sum([x.numel() for x in batch]) storage = elem.storage()._new_shared(numel) out = elem.new(storage) return torch.stack(batch, 0, out=out) elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ and elem_type.__name__ != 'string_': if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': # array of string classes and object if np_str_obj_array_pattern.search(elem.dtype.str) is not None: raise TypeError( default_collate_err_msg_format.format(elem.dtype)) return default_collate([torch.as_tensor(b) for b in batch]) elif elem.shape == (): # scalars return torch.as_tensor(batch) elif isinstance(elem, float): return torch.tensor(batch, dtype=torch.float64) elif isinstance(elem, int): return torch.tensor(batch) elif isinstance(elem, str): return batch elif isinstance(elem, collections.abc.Mapping): return {key: default_collate([d[key] for d in batch]) for key in elem} elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple return elem_type(*(default_collate(samples) for samples in zip(*batch))) elif isinstance(elem, collections.abc.Sequence): # check to make sure that the elements in batch have consistent size it = iter(batch) elem_size = len(next(it)) if not all(len(elem) == elem_size for elem in it): raise RuntimeError( 'each element in list of batch should be of equal size') transposed = zip(*batch) return [default_collate(samples) for samples in transposed] raise TypeError(default_collate_err_msg_format.format(elem_type)) def preprocess_audio(audio_files, resample): r"""Load list of audio files and return raw audio""" audio_tensors = [] for audio_file in audio_files: audio_tensor = load_audio_into_tensor( audio_file, clapConfig.duration, resample) audio_tensor = audio_tensor.reshape(1, -1) audio_tensors.append(audio_tensor) return default_collate(audio_tensors) # ================================================================== # A U D I O - E M B E D D I N G S - H E L P E R # ================================================================== def CLAPAudioProcessor(audio_files: List[str], resample=True): preprocessed_audio = preprocess_audio(audio_files, resample) preprocessed_audio = preprocessed_audio.reshape( preprocessed_audio.shape[0], preprocessed_audio.shape[2]) preprocessed_audio = preprocessed_audio return preprocessed_audio def get_audio_embeddings(audio_files: List[str], audio_encoder, resample=True): """Load list of audio files and return audio embeddings""" # preprocessed_audio = preprocess_audio(audio_files, resample) # with torch.no_grad(): # preprocessed_audio = preprocessed_audio.reshape( # preprocessed_audio.shape[0], preprocessed_audio.shape[2]) with torch.no_grad(): preprocessed_audio = CLAPAudioProcessor(audio_files, resample) return audio_encoder(preprocessed_audio)[0] # ================================================================== # C L A P # ================================================================== class ClapConfig: # TEXT ENCODER CONFIG text_model = 'gpt2' text_len = 77 transformer_embed_dim = 768 freeze_text_encoder_weights = True # AUDIO ENCODER CONFIG audioenc_name = 'HTSAT' out_emb = 768 sample_rate = 44100 duration = 7 fmin = 50 fmax = 8000 # 14000 n_fft = 1024 # 1028 hop_size = 320 mel_bins = 64 window_size = 1024 # PROJECTION SPACE CONFIG d_proj = 1024 temperature = 0.003 # TRAINING AND EVALUATION CONFIG num_classes = 527 batch_size = 1024 demo = False clapConfig = ClapConfig() clap = CLAP( audioenc_name=clapConfig.audioenc_name, sample_rate=clapConfig.sample_rate, window_size=clapConfig.window_size, hop_size=clapConfig.hop_size, mel_bins=clapConfig.mel_bins, fmin=clapConfig.fmin, fmax=clapConfig.fmax, classes_num=clapConfig.num_classes, out_emb=clapConfig.out_emb, text_model=clapConfig.text_model, transformer_embed_dim=clapConfig.transformer_embed_dim, d_proj=clapConfig.d_proj ) model_repo = "microsoft/msclap" model_name = { '2022': 'CLAP_weights_2022.pth', '2023': 'CLAP_weights_2023.pth', 'clapcap': 'clapcap_weights_2023.pth' } version = '2023' model_fp = hf_hub_download(model_repo, model_name[version]) model_state_dict = torch.load(model_fp, map_location=torch.device('cpu'))['model'] clap.load_state_dict(model_state_dict, strict=False) # clap.eval() clap_audio_encoder = clap.audio_encoder.to(device) # ENGLISH_AUDIO_DIR = r"/home/IITB/ai-at-ieor/23m1521/datasets/Vaani/Audios/English" # audio_files = [os.path.join(ENGLISH_AUDIO_DIR, i) for i in os.listdir(ENGLISH_AUDIO_DIR) if i.endswith(".wav")] # audio_embedding = get_audio_embeddings(audio_files, clap_audio_encoder) # print("CLAP Audio Encoder Embeddings:", audio_embedding.shape) # [5, 1024] # ================================================================== # C L A P - L o R A - M O D E L # ================================================================== LoRAconfig = { "peft_type": "LORA", "task_type": "FEATURE_EXTRACTION", "inference_mode": False, "r": 16, "target_modules": ["qkv", "fc1", "fc2", "proj", "linear1", "linear2"], "lora_alpha": 32, "lora_dropout": 0.05, "fan_in_fan_out": False, "bias": "all", } peft_config = get_peft_config(LoRAconfig) peft_model = get_peft_model(clap_audio_encoder, peft_config) # peft_model.print_trainable_parameters() peft_clap_audio_encoder = peft_model.base_model # audio_embedding = get_audio_embeddings(audio_files, peft_clap_audio_encoder) # print("CLAP LoRA Audio Encoder Embeddings:", audio_embedding.shape) # [5, 1024] # ================================================================== # O P E N - C L I P - M O D E L # ================================================================== # ================================================================== # I M P O R T S # ================================================================== import os import io import sys import math import random import collections import collections.abc import re from itertools import repeat from pathlib import Path from typing import Optional, Tuple, Union, List, Dict import csv import numpy as np import pandas as pd from PIL import Image import seaborn as sns import matplotlib.pyplot as plt from tqdm import trange, tqdm import torch import torch.nn.functional as F from torch import nn from torch.nn.init import _calculate_fan_in_and_fan_out import torch.utils.checkpoint as checkpoint import torchvision from torchvision.transforms import v2 # os.environ["CUDA_VISIBLE_DEVICES"] = "1" # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # print(f"Using device: {device}") import torchaudio import torchaudio.transforms as T from torchlibrosa.stft import Spectrogram, LogmelFilterBank from torchlibrosa.augmentation import SpecAugmentation from transformers import AutoModel, AutoTokenizer, logging from huggingface_hub.file_download import hf_hub_download from huggingface_hub.file_download import hf_hub_download from peft import get_peft_config, get_peft_model from typing import Any, Dict, Optional, Tuple, Union import numbers import random import warnings from dataclasses import dataclass, asdict from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import torch import torchvision.transforms.functional as F from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ CenterCrop, ColorJitter, Grayscale OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) IMAGENET_MEAN = (0.485, 0.456, 0.406) IMAGENET_STD = (0.229, 0.224, 0.225) INCEPTION_MEAN = (0.5, 0.5, 0.5) INCEPTION_STD = (0.5, 0.5, 0.5) # Default name for a weights file hosted on the Huggingface Hub. HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version HF_CONFIG_NAME = 'open_clip_config.json' import collections.abc from itertools import repeat from typing import List, Optional, Tuple, Union import torch from torch import nn as nn from torch import _assert from torchvision.ops.misc import FrozenBatchNorm2d def freeze_batch_norm_2d(module, module_match={}, name=''): """ Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and returned. Otherwise, the module is walked recursively and submodules are converted in place. Args: module (torch.nn.Module): Any PyTorch module. module_match (dict): Dictionary of full module names to freeze (all if empty) name (str): Full module name (prefix) Returns: torch.nn.Module: Resulting module Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 """ res = module is_match = True if module_match: is_match = name in module_match if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): res = FrozenBatchNorm2d(module.num_features) res.num_features = module.num_features res.affine = module.affine if module.affine: res.weight.data = module.weight.data.clone().detach() res.bias.data = module.bias.data.clone().detach() res.running_mean.data = module.running_mean.data res.running_var.data = module.running_var.data res.eps = module.eps else: for child_name, child in module.named_children(): full_child_name = '.'.join([name, child_name]) if name else child_name new_child = freeze_batch_norm_2d(child, module_match, full_child_name) if new_child is not child: res.add_module(child_name, new_child) return res # From PyTorch internals def _ntuple(n): def parse(x): if isinstance(x, collections.abc.Iterable): return x return tuple(repeat(x, n)) return parse to_1tuple = _ntuple(1) to_2tuple = _ntuple(2) to_3tuple = _ntuple(3) to_4tuple = _ntuple(4) to_ntuple = lambda n, x: _ntuple(n)(x) # Replaces all linear layers with linear_replacement # TODO: add int8 support for other linear layers including attn and convnets def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True): for name, module in model.named_children(): if len(list(module.children())) > 0: replace_linear(module, linear_replacement, include_modules, copy_weights) if isinstance(module, torch.nn.Linear) and name in include_modules: old_module = model._modules[name] model._modules[name] = linear_replacement( module.in_features, module.out_features, module.bias is not None, ) if copy_weights: model._modules[name].weight.data.copy_(old_module.weight.data) if model._modules[name].bias is not None: model._modules[name].bias.data.copy_(old_module.bias) return model def convert_int8_model_to_inference_mode(model): for m in model.modules(): if hasattr(m, 'prepare_for_eval'): int8_original_dtype = m.weight.dtype m.prepare_for_eval() m.int8_original_dtype = int8_original_dtype def feature_take_indices( num_features: int, indices: Optional[Union[int, List[int]]] = None, as_set: bool = False, ) -> Tuple[List[int], int]: """ Determine the absolute feature indices to 'take' from. Note: This function can be called in forward() so must be torchscript compatible, which requires some incomplete typing and workaround hacks. Args: num_features: total number of features to select from indices: indices to select, None -> select all int -> select last n list/tuple of int -> return specified (-ve indices specify from end) as_set: return as a set Returns: List (or set) of absolute (from beginning) indices, Maximum index """ if indices is None: indices = num_features # all features if None if isinstance(indices, int): # convert int -> last n indices _assert(0 < indices <= num_features, f'last-n ({indices}) is out of range (1 to {num_features})') take_indices = [num_features - indices + i for i in range(indices)] else: take_indices: List[int] = [] for i in indices: idx = num_features + i if i < 0 else i _assert(0 <= idx < num_features, f'feature index {idx} is out of range (0 to {num_features - 1})') take_indices.append(idx) if not torch.jit.is_scripting() and as_set: return set(take_indices), max(take_indices) return take_indices, max(take_indices) def _out_indices_as_tuple(x: Union[int, Tuple[int, ...]]) -> Tuple[int, ...]: if isinstance(x, int): # if indices is an int, take last N features return tuple(range(-x, 0)) return tuple(x) import copy import copy import hashlib import os import urllib import warnings from functools import partial from typing import Dict, Iterable, Optional, Union from tqdm import tqdm try: import safetensors.torch _has_safetensors = True except ImportError: _has_safetensors = False __version__ = '2.32.0' """ CLIP Model Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. """ import copy import logging import math from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch import torch.nn.functional as F from torch import nn from torch.utils.checkpoint import checkpoint from functools import partial # from .hf_model import HFTextEncoder # from .modified_resnet import ModifiedResNet from collections import OrderedDict import math from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union import torch from torch import nn from torch.nn import functional as F from torch.utils.checkpoint import checkpoint # from .utils import to_2tuple, feature_take_indices # from .pos_embed import get_2d_sincos_pos_embed # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # -------------------------------------------------------- # Position embedding utils # -------------------------------------------------------- import numpy as np import torch # -------------------------------------------------------- # 2D sine-cosine position embedding # References: # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py # MoCo v3: https://github.com/facebookresearch/moco-v3 # -------------------------------------------------------- def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): """ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ grid_h = np.arange(grid_size, dtype=np.float32) grid_w = np.arange(grid_size, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size, grid_size]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token: pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) return pos_embed def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=float) omega /= embed_dim / 2. omega = 1. / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb # -------------------------------------------------------- # Interpolate position embeddings for high-resolution # References: # DeiT: https://github.com/facebookresearch/deit # -------------------------------------------------------- def interpolate_pos_embed(model, checkpoint_model): if 'pos_embed' in checkpoint_model: pos_embed_checkpoint = checkpoint_model['pos_embed'] embedding_size = pos_embed_checkpoint.shape[-1] num_patches = model.patch_embed.num_patches num_extra_tokens = model.pos_embed.shape[-2] - num_patches # height (== width) for the checkpoint position embedding orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) # height (== width) for the new position embedding new_size = int(num_patches ** 0.5) # class_token and dist_token are kept unchanged if orig_size != new_size: print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] # only the position tokens are interpolated pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) pos_tokens = torch.nn.functional.interpolate( pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) checkpoint_model['pos_embed'] = new_pos_embed from collections import OrderedDict from typing import Dict, List, Optional, Union import torch from torch import nn from torch.nn import functional as F # from .utils import freeze_batch_norm_2d, feature_take_indices class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1): super().__init__() # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.act1 = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.act2 = nn.ReLU(inplace=True) self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) self.bn3 = nn.BatchNorm2d(planes * self.expansion) self.act3 = nn.ReLU(inplace=True) self.downsample = None self.stride = stride if stride > 1 or inplanes != planes * Bottleneck.expansion: # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 self.downsample = nn.Sequential(OrderedDict([ ("-1", nn.AvgPool2d(stride)), ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), ("1", nn.BatchNorm2d(planes * self.expansion)) ])) def forward(self, x: torch.Tensor): identity = x out = self.act1(self.bn1(self.conv1(x))) out = self.act2(self.bn2(self.conv2(out))) out = self.avgpool(out) out = self.bn3(self.conv3(out)) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.act3(out) return out class AttentionPool2d(nn.Module): def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): super().__init__() self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) self.k_proj = nn.Linear(embed_dim, embed_dim) self.q_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) self.num_heads = num_heads def forward(self, x): x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC x, _ = F.multi_head_attention_forward( query=x, key=x, value=x, embed_dim_to_check=x.shape[-1], num_heads=self.num_heads, q_proj_weight=self.q_proj.weight, k_proj_weight=self.k_proj.weight, v_proj_weight=self.v_proj.weight, in_proj_weight=None, in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), bias_k=None, bias_v=None, add_zero_attn=False, dropout_p=0., out_proj_weight=self.c_proj.weight, out_proj_bias=self.c_proj.bias, use_separate_proj_weight=True, training=self.training, need_weights=False ) return x[0] class ModifiedResNet(nn.Module): """ A ResNet class that is similar to torchvision's but contains the following changes: - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. - Performs antialiasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 - The final pooling layer is a QKV attention instead of an average pool """ def __init__( self, layers: List[int], output_dim: int, heads: int, image_size: int = 224, width: int = 64, ): super().__init__() self.output_dim = output_dim self.image_size = image_size # the 3-layer stem self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(width // 2) self.act1 = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(width // 2) self.act2 = nn.ReLU(inplace=True) self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) self.bn3 = nn.BatchNorm2d(width) self.act3 = nn.ReLU(inplace=True) self.avgpool = nn.AvgPool2d(2) # residual layers self._inplanes = width # this is a *mutable* variable used during construction self.layer1 = self._make_layer(width, layers[0]) self.layer2 = self._make_layer(width * 2, layers[1], stride=2) self.layer3 = self._make_layer(width * 4, layers[2], stride=2) self.layer4 = self._make_layer(width * 8, layers[3], stride=2) embed_dim = width * 32 # the ResNet feature dimension self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) self.init_parameters() def _make_layer(self, planes, blocks, stride=1): layers = [Bottleneck(self._inplanes, planes, stride)] self._inplanes = planes * Bottleneck.expansion for _ in range(1, blocks): layers.append(Bottleneck(self._inplanes, planes)) return nn.Sequential(*layers) def init_parameters(self): if self.attnpool is not None: std = self.attnpool.c_proj.in_features ** -0.5 nn.init.normal_(self.attnpool.q_proj.weight, std=std) nn.init.normal_(self.attnpool.k_proj.weight, std=std) nn.init.normal_(self.attnpool.v_proj.weight, std=std) nn.init.normal_(self.attnpool.c_proj.weight, std=std) for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: for name, param in resnet_block.named_parameters(): if name.endswith("bn3.weight"): nn.init.zeros_(param) def lock(self, unlocked_groups=0, freeze_bn_stats=False): assert unlocked_groups == 0, 'partial locking not currently supported for this model' for param in self.parameters(): param.requires_grad = False if freeze_bn_stats: freeze_batch_norm_2d(self) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): # FIXME support for non-transformer pass def stem(self, x): x = self.act1(self.bn1(self.conv1(x))) x = self.act2(self.bn2(self.conv2(x))) x = self.act3(self.bn3(self.conv3(x))) x = self.avgpool(x) return x def forward_intermediates( self, x: torch.Tensor, indices: Optional[Union[int, List[int]]] = None, stop_early: bool = False, normalize_intermediates: bool = False, intermediates_only: bool = False, output_fmt: str = 'NCHW', output_extra_tokens: bool = False, ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: """ Forward features that returns intermediates. Args: x: Input image tensor indices: Take last n blocks if int, all if None, select matching indices if sequence stop_early: Stop iterating over blocks when last desired intermediate hit normalize_intermediates: Apply final norm layer to all intermediates intermediates_only: Only return intermediate features output_fmt: Shape of intermediate feature outputs output_extra_tokens: Return both extra class, eot tokens Returns: """ assert output_fmt in ('NCHW',), 'Output format must be == NCHW.' # NOTE normalize_intermediates and return_extra_tokens don't apply take_indices, max_index = feature_take_indices(5, indices) output = {} intermediates = [] blocks = [self.stem, self.layer1, self.layer2, self.layer3, self.layer4] if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript blocks = blocks[:max_index + 1] for i, blk in enumerate(blocks): x = blk(x) if i in take_indices: intermediates.append(x) output['image_intermediates'] = intermediates if intermediates_only: return output x = self.attnpool(x) output['image_features'] = x return output def forward(self, x): x = self.stem(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.attnpool(x) return x """ huggingface model adapter Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model. """ import re import torch import torch.nn as nn from torch import TensorType try: import transformers from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \ BaseModelOutputWithPoolingAndCrossAttentions except ImportError as e: transformers = None class BaseModelOutput: pass class PretrainedConfig: pass # from .hf_configs import arch_dict # HF architecture dict: arch_dict = { # https://huggingface.co/docs/transformers/model_doc/roberta#roberta "roberta": { "config_names": { "context_length": "max_position_embeddings", "vocab_size": "vocab_size", "width": "hidden_size", "heads": "num_attention_heads", "layers": "num_hidden_layers", "layer_attr": "layer", "token_embeddings_attr": "embeddings" }, "pooler": "mean_pooler", }, # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig "xlm-roberta": { "config_names": { "context_length": "max_position_embeddings", "vocab_size": "vocab_size", "width": "hidden_size", "heads": "num_attention_heads", "layers": "num_hidden_layers", "layer_attr": "layer", "token_embeddings_attr": "embeddings" }, "pooler": "mean_pooler", }, # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 "mt5": { "config_names": { # unlimited seqlen # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 "context_length": "", "vocab_size": "vocab_size", "width": "d_model", "heads": "num_heads", "layers": "num_layers", "layer_attr": "block", "token_embeddings_attr": "embed_tokens" }, "pooler": "mean_pooler", }, # https://huggingface.co/docs/transformers/model_doc/bert "bert": { "config_names": { "context_length": "max_position_embeddings", "vocab_size": "vocab_size", "width": "hidden_size", "heads": "num_attention_heads", "layers": "num_hidden_layers", }, "pooler": "cls_pooler", }, # https://huggingface.co/docs/transformers/model_doc/m2m_100 "m2m_100": { "config_names": { "context_length": "max_position_embeddings", "vocab_size": "vocab_size", "width": "d_model", "heads": "encoder_attention_heads", "layers": "encoder_layers", }, "pooler": "cls_pooler", }, } # utils def _camel2snake(s): return re.sub(r'(? Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: """ Forward features that returns intermediates. Args: x: Input image tensor indices: Take last n blocks if int, all if None, select matching indices if sequence stop_early: Stop iterating over blocks when last desired intermediate hit normalize_intermediates: Apply norm layer to all intermediates intermediates_only: Only return intermediate features output_fmt: Shape of intermediate feature outputs output_extra_tokens: Return both prefix and spatial intermediate tokens Returns: """ extra_args = {} if output_extra_tokens: extra_args['return_prefix_tokens'] = True trunk_output = self.trunk.forward_intermediates( x, indices=indices, intermediates_only=intermediates_only, norm=normalize_intermediates, stop_early=stop_early, output_fmt=output_fmt, **extra_args, ) return_dict = {} intermediates = trunk_output if intermediates_only else trunk_output[1] if output_extra_tokens and intermediates and isinstance(intermediates[0], tuple): intermediates_prefix = [xi[1] for xi in intermediates] intermediates = [xi[0] for xi in intermediates] return_dict['image_intermediates_prefix'] = intermediates_prefix return_dict['image_intermediates'] = intermediates if intermediates_only: return return_dict image_features = self.trunk.forward_head(trunk_output[0]) # run through timm pooling / projection image_features = self.head(image_features) # run through adapter pooling / projection return_dict['image_features'] = image_features return return_dict def forward(self, x): x = self.trunk(x) x = self.head(x) return x class LayerNormFp32(nn.LayerNorm): """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" def forward(self, x: torch.Tensor): orig_type = x.dtype x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps) return x.to(orig_type) class LayerNorm(nn.LayerNorm): """Subclass torch's LayerNorm (with cast back to input dtype).""" def forward(self, x: torch.Tensor): orig_type = x.dtype x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) return x.to(orig_type) class QuickGELU(nn.Module): # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory def forward(self, x: torch.Tensor): return x * torch.sigmoid(1.702 * x) class LayerScale(nn.Module): def __init__(self, dim, init_values=1e-5, inplace=False): super().__init__() self.inplace = inplace self.gamma = nn.Parameter(init_values * torch.ones(dim)) def forward(self, x): return x.mul_(self.gamma) if self.inplace else x * self.gamma class PatchDropout(nn.Module): """ https://arxiv.org/abs/2212.00794 """ def __init__( self, prob: float = 0.5, exclude_first_token: bool = True ): super().__init__() assert 0 <= prob < 1. self.prob = prob self.exclude_first_token = exclude_first_token # exclude CLS token def forward(self, x): if not self.training or self.prob == 0.: return x if self.exclude_first_token: cls_tokens, x = x[:, :1], x[:, 1:] else: cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) batch = x.size()[0] num_tokens = x.size()[1] batch_indices = torch.arange(batch) batch_indices = batch_indices[..., None] keep_prob = 1 - self.prob num_patches_keep = max(1, int(num_tokens * keep_prob)) rand = torch.randn(batch, num_tokens) patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices x = x[batch_indices, patch_indices_keep] if self.exclude_first_token: x = torch.cat((cls_tokens, x), dim=1) return x class Attention(nn.Module): def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = True, scaled_cosine: bool = False, scale_heads: bool = False, logit_scale_max: float = math.log(1. / 0.01), batch_first: bool = True, attn_drop: float = 0., proj_drop: float = 0. ): super().__init__() self.scaled_cosine = scaled_cosine self.scale_heads = scale_heads assert dim % num_heads == 0, 'dim should be divisible by num_heads' self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 self.logit_scale_max = logit_scale_max self.batch_first = batch_first self.use_fsdpa = hasattr(nn.functional, 'scaled_dot_product_attention') # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) if qkv_bias: self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) else: self.in_proj_bias = None if self.scaled_cosine: self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) else: self.logit_scale = None self.attn_drop = nn.Dropout(attn_drop) if self.scale_heads: self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) else: self.head_scale = None self.out_proj = nn.Linear(dim, dim) self.out_drop = nn.Dropout(proj_drop) def forward(self, x, attn_mask: Optional[torch.Tensor] = None): if self.batch_first: x = x.transpose(0, 1) L, N, C = x.shape q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) q = q.reshape(L, N * self.num_heads, -1).transpose(0, 1) k = k.reshape(L, N * self.num_heads, -1).transpose(0, 1) v = v.reshape(L, N * self.num_heads, -1).transpose(0, 1) if attn_mask is not None and attn_mask.dtype == torch.bool: new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) new_attn_mask.masked_fill_(attn_mask, float("-inf")) attn_mask = new_attn_mask if self.logit_scale is not None: attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() attn = attn.view(N, self.num_heads, L, L) * logit_scale attn = attn.view(-1, L, L) if attn_mask is not None: attn = attn + attn_mask attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = torch.bmm(attn, v) else: if self.use_fsdpa: x = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, dropout_p=self.attn_drop.p if self.training else 0., ) else: q = q * self.scale attn = torch.bmm(q, k.transpose(-1, -2)) if attn_mask is not None: attn += attn_mask attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = torch.bmm(attn, v) if self.head_scale is not None: x = x.view(N, self.num_heads, L, C) * self.head_scale x = x.view(-1, L, C) x = x.transpose(0, 1).reshape(L, N, C) if self.batch_first: x = x.transpose(0, 1) x = self.out_proj(x) x = self.out_drop(x) return x class AttentionalPooler(nn.Module): def __init__( self, d_model: int, context_dim: int, n_head: int = 8, n_queries: int = 256, norm_layer: Callable = LayerNorm, ): super().__init__() self.query = nn.Parameter(torch.randn(n_queries, d_model)) self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim, batch_first=True) self.ln_q = norm_layer(d_model) self.ln_k = norm_layer(context_dim) def forward(self, x: torch.Tensor): N = x.shape[0] x = self.ln_k(x) q = self.ln_q(self.query) out = self.attn(q.unsqueeze(0).expand(N, -1, -1), x, x, need_weights=False)[0] return out class ResidualAttentionBlock(nn.Module): def __init__( self, d_model: int, n_head: int, mlp_ratio: float = 4.0, ls_init_value: float = None, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, is_cross_attention: bool = False, batch_first: bool = True, ): super().__init__() self.ln_1 = norm_layer(d_model) self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=batch_first) self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() if is_cross_attention: self.ln_1_kv = norm_layer(d_model) self.ln_2 = norm_layer(d_model) mlp_width = int(d_model * mlp_ratio) self.mlp = nn.Sequential(OrderedDict([ ("c_fc", nn.Linear(d_model, mlp_width)), ("gelu", act_layer()), ("c_proj", nn.Linear(mlp_width, d_model)) ])) self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() def attention( self, q_x: torch.Tensor, k_x: Optional[torch.Tensor] = None, v_x: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, ): k_x = k_x if k_x is not None else q_x v_x = v_x if v_x is not None else q_x attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None return self.attn( q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask )[0] def forward( self, q_x: torch.Tensor, k_x: Optional[torch.Tensor] = None, v_x: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, ): k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)) x = x + self.ls_2(self.mlp(self.ln_2(x))) return x class CustomResidualAttentionBlock(nn.Module): def __init__( self, d_model: int, n_head: int, mlp_ratio: float = 4.0, ls_init_value: float = None, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, scale_cosine_attn: bool = False, scale_heads: bool = False, scale_attn: bool = False, scale_fc: bool = False, batch_first: bool = True, ): super().__init__() self.ln_1 = norm_layer(d_model) self.attn = Attention( d_model, n_head, scaled_cosine=scale_cosine_attn, scale_heads=scale_heads, batch_first=batch_first, ) self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity() self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() self.ln_2 = norm_layer(d_model) mlp_width = int(d_model * mlp_ratio) self.mlp = nn.Sequential(OrderedDict([ ("c_fc", nn.Linear(d_model, mlp_width)), ("gelu", act_layer()), ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()), ("c_proj", nn.Linear(mlp_width, d_model)) ])) self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() def get_reference_weight(self): return self.mlp.c_fc.weight def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask))) x = x + self.ls_2(self.mlp(self.ln_2(x))) return x class CustomTransformer(nn.Module): """ A custom transformer that can use different block types. """ def __init__( self, width: int, layers: int, heads: int, mlp_ratio: float = 4.0, ls_init_value: float = None, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, batch_first: bool = True, block_types: Union[str, List[str]] = 'CustomResidualAttentionBlock', ): super().__init__() self.width = width self.layers = layers self.batch_first = batch_first # run transformer stack in batch first (N, L, D) self.grad_checkpointing = False if isinstance(block_types, str): block_types = [block_types] * layers assert len(block_types) == layers def _create_block(bt: str): if bt == 'CustomResidualAttentionBlock': return CustomResidualAttentionBlock( width, heads, mlp_ratio=mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, batch_first=batch_first, ) else: assert False self.resblocks = nn.ModuleList([ _create_block(bt) for bt in block_types ]) def get_cast_dtype(self) -> torch.dtype: weight = self.resblocks[0].get_reference_weight() if hasattr(weight, 'int8_original_dtype'): return weight.int8_original_dtype return weight.dtype def forward_intermediates( self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, indices: Optional[Union[int, List[int]]] = None, stop_early: bool = False, ): take_indices, max_index = feature_take_indices(len(self.resblocks), indices) if not self.batch_first: x = x.transpose(0, 1).contiguous() # NLD -> LND intermediates = [] if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript blocks = self.resblocks else: blocks = self.resblocks[:max_index + 1] for i, blk in enumerate(blocks): if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint(blk, x, None, None, attn_mask, use_reentrant=False) else: x = blk(x, attn_mask=attn_mask) if i in take_indices: intermediates.append(x.transpose(0, 1) if not self.batch_first else x) if not self.batch_first: x = x.transpose(0, 1) # LND -> NLD return x, intermediates def prune_intermediate_layers(self, indices: Union[int, List[int]] = 1): """ Prune layers not required for specified intermediates. """ take_indices, max_index = feature_take_indices(len(self.resblocks), indices) self.resblocks = self.resblocks[:max_index + 1] # truncate blocks return take_indices def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): if not self.batch_first: x = x.transpose(0, 1) # NLD -> LND for r in self.resblocks: if self.grad_checkpointing and not torch.jit.is_scripting(): # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False) else: x = r(x, attn_mask=attn_mask) if not self.batch_first: x = x.transpose(0, 1) # NLD -> LND return x class Transformer(nn.Module): def __init__( self, width: int, layers: int, heads: int, mlp_ratio: float = 4.0, ls_init_value: float = None, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, batch_first: bool = True, ): super().__init__() self.width = width self.layers = layers self.batch_first = batch_first self.grad_checkpointing = False self.resblocks = nn.ModuleList([ ResidualAttentionBlock( width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, batch_first=batch_first, ) for _ in range(layers) ]) def get_cast_dtype(self) -> torch.dtype: if hasattr(self.resblocks[0].mlp.c_fc, 'int8_original_dtype'): return self.resblocks[0].mlp.c_fc.int8_original_dtype return self.resblocks[0].mlp.c_fc.weight.dtype def forward_intermediates( self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, indices: Optional[Union[int, List[int]]] = None, stop_early: bool = False, ): take_indices, max_index = feature_take_indices(len(self.resblocks), indices) if not self.batch_first: x = x.transpose(0, 1).contiguous() # NLD -> LND intermediates = [] if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript blocks = self.resblocks else: blocks = self.resblocks[:max_index + 1] for i, blk in enumerate(blocks): if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint(blk, x, None, None, attn_mask, use_reentrant=False) else: x = blk(x, attn_mask=attn_mask) if i in take_indices: intermediates.append(x.transpose(0, 1) if not self.batch_first else x) if not self.batch_first: x = x.transpose(0, 1) # LND -> NLD return x, intermediates def prune_intermediate_layers(self, indices: Union[int, List[int]] = 1): """ Prune layers not required for specified intermediates. """ take_indices, max_index = feature_take_indices(len(self.resblocks), indices) self.resblocks = self.resblocks[:max_index + 1] # truncate blocks return take_indices def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): if not self.batch_first: x = x.transpose(0, 1).contiguous() # NLD -> LND for r in self.resblocks: if self.grad_checkpointing and not torch.jit.is_scripting(): # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False) else: x = r(x, attn_mask=attn_mask) if not self.batch_first: x = x.transpose(0, 1) # LND -> NLD return x def _expand_token(token, batch_size: int): return token.view(1, 1, -1).expand(batch_size, -1, -1) class VisionTransformer(nn.Module): output_tokens: torch.jit.Final[bool] def __init__( self, image_size: int, patch_size: int, width: int, layers: int, heads: int, mlp_ratio: float, ls_init_value: float = None, attentional_pool: bool = False, attn_pooler_queries: int = 256, attn_pooler_heads: int = 8, output_dim: int = 512, patch_dropout: float = 0., no_ln_pre: bool = False, pos_embed_type: str = 'learnable', pool_type: str = 'tok', final_ln_after_pool: bool = False, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, output_tokens: bool = False, ): super().__init__() assert pool_type in ('tok', 'avg', 'none') self.output_tokens = output_tokens image_height, image_width = self.image_size = to_2tuple(image_size) patch_height, patch_width = self.patch_size = to_2tuple(patch_size) self.grid_size = (image_height // patch_height, image_width // patch_width) self.final_ln_after_pool = final_ln_after_pool # currently ignored w/ attn pool enabled self.output_dim = output_dim self.conv1 = nn.Conv2d( in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False, ) # class embeddings and positional embeddings scale = width ** -0.5 self.class_embedding = nn.Parameter(scale * torch.randn(width)) if pos_embed_type == 'learnable': self.positional_embedding = nn.Parameter( scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) elif pos_embed_type == 'sin_cos_2d': # fixed sin-cos embedding assert self.grid_size[0] == self.grid_size[1],\ 'currently sin cos 2d pos embedding only supports square input' self.positional_embedding = nn.Parameter( torch.zeros(self.grid_size[0] * self.grid_size[1] + 1, width), requires_grad=False) pos_embed_type = get_2d_sincos_pos_embed(width, self.grid_size[0], cls_token=True) self.positional_embedding.data.copy_(torch.from_numpy(pos_embed_type).float()) else: raise ValueError # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() self.ln_pre = nn.Identity() if no_ln_pre else norm_layer(width) self.transformer = Transformer( width, layers, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, ) if attentional_pool: if isinstance(attentional_pool, str): self.attn_pool_type = attentional_pool self.pool_type = 'none' if attentional_pool in ('parallel', 'cascade'): self.attn_pool = AttentionalPooler( output_dim, width, n_head=attn_pooler_heads, n_queries=attn_pooler_queries, ) self.attn_pool_contrastive = AttentionalPooler( output_dim, width, n_head=attn_pooler_heads, n_queries=1, ) else: assert False else: self.attn_pool_type = '' self.pool_type = pool_type self.attn_pool = AttentionalPooler( output_dim, width, n_head=attn_pooler_heads, n_queries=attn_pooler_queries, ) self.attn_pool_contrastive = None pool_dim = output_dim else: self.attn_pool = None pool_dim = width self.pool_type = pool_type self.ln_post = norm_layer(pool_dim) self.proj = nn.Parameter(scale * torch.randn(pool_dim, output_dim)) self.init_parameters() def lock(self, unlocked_groups: int = 0, freeze_bn_stats: bool = False): for param in self.parameters(): param.requires_grad = False if unlocked_groups != 0: groups = [ [ self.conv1, self.class_embedding, self.positional_embedding, self.ln_pre, ], *self.transformer.resblocks[:-1], [ self.transformer.resblocks[-1], self.ln_post, ], self.proj, ] def _unlock(x): if isinstance(x, Sequence): for g in x: _unlock(g) else: if isinstance(x, torch.nn.Parameter): x.requires_grad = True else: for p in x.parameters(): p.requires_grad = True _unlock(groups[-unlocked_groups:]) def init_parameters(self): # FIXME OpenAI CLIP did not define an init for the VisualTransformer # TODO experiment if default PyTorch init, below, or alternate init is best. # nn.init.normal_(self.class_embedding, std=self.scale) # nn.init.normal_(self.positional_embedding, std=self.scale) # # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) # attn_std = self.transformer.width ** -0.5 # fc_std = (2 * self.transformer.width) ** -0.5 # for block in self.transformer.resblocks: # nn.init.normal_(block.attn.in_proj_weight, std=attn_std) # nn.init.normal_(block.attn.out_proj.weight, std=proj_std) # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) # # if self.text_projection is not None: # nn.init.normal_(self.text_projection, std=self.scale) pass @torch.jit.ignore def set_grad_checkpointing(self, enable: bool = True): self.transformer.grad_checkpointing = enable @torch.jit.ignore def no_weight_decay(self): # for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default no_wd = {'positional_embedding', 'class_embedding'} return no_wd def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if self.pool_type == 'avg': pooled, tokens = x[:, 1:].mean(dim=1), x[:, 1:] elif self.pool_type == 'tok': pooled, tokens = x[:, 0], x[:, 1:] else: pooled = tokens = x return pooled, tokens def _embeds(self, x:torch.Tensor) -> torch.Tensor: x = self.conv1(x) # shape = [*, dim, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] # class embeddings and positional embeddings x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1) # shape = [*, grid ** 2 + 1, width] x = x + self.positional_embedding.to(x.dtype) # patch dropout (if active) x = self.patch_dropout(x) # apply norm before transformer x = self.ln_pre(x) return x def _pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if self.attn_pool is not None: if self.attn_pool_contrastive is not None: # This is untested, WIP pooling that should match paper x = self.ln_post(x) # TBD LN first or separate one after each pool? tokens = self.attn_pool(x) if self.attn_pool_type == 'parallel': pooled = self.attn_pool_contrastive(x) else: assert self.attn_pool_type == 'cascade' pooled = self.attn_pool_contrastive(tokens) else: # this is the original OpenCLIP CoCa setup, does not match paper x = self.attn_pool(x) x = self.ln_post(x) pooled, tokens = self._global_pool(x) elif self.final_ln_after_pool: pooled, tokens = self._global_pool(x) pooled = self.ln_post(pooled) else: x = self.ln_post(x) pooled, tokens = self._global_pool(x) return pooled, tokens def forward_intermediates( self, x: torch.Tensor, indices: Optional[Union[int, List[int]]] = None, stop_early: bool = False, normalize_intermediates: bool = False, intermediates_only: bool = False, output_fmt: str = 'NCHW', output_extra_tokens: bool = False, ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: """ Forward features that returns intermediates. Args: x: Input image tensor indices: Take last n blocks if int, all if None, select matching indices if sequence stop_early: Stop iterating over blocks when last desired intermediate hit intermediates_only: Only return intermediate features normalize_intermediates: Apply final norm layer to all intermediates output_fmt: Shape of intermediate feature outputs output_extra_tokens: Return both extra prefix class tokens Returns: """ assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.' reshape = output_fmt == 'NCHW' # forward pass B, _, height, width = x.shape x = self._embeds(x) x, intermediates = self.transformer.forward_intermediates( x, indices=indices, stop_early=stop_early, ) # process intermediates if normalize_intermediates: # apply final norm to all intermediates intermediates = [self.ln_post(xi) for xi in intermediates] num_prefix_tokens = 1 # one class token that's always there (as of now) if num_prefix_tokens: # split prefix (e.g. class, distill) and spatial feature tokens prefix_tokens = [y[:, 0:num_prefix_tokens] for y in intermediates] intermediates = [y[:, num_prefix_tokens:] for y in intermediates] else: prefix_tokens = None if reshape: # reshape to BCHW output format H, W = height // self.patch_size[0], width // self.patch_size[1] intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates] output = {'image_intermediates': intermediates} if prefix_tokens is not None and output_extra_tokens: output['image_intermediates_prefix'] = prefix_tokens if intermediates_only: return output pooled, _ = self._pool(x) if self.proj is not None: pooled = pooled @ self.proj output['image_features'] = pooled return output def prune_intermediate_layers( self, indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, ): """ Prune layers not required for specified intermediates. """ take_indices = self.transformer.prune_intermediate_layers(indices) if prune_norm: self.ln_post = nn.Identity() if prune_head: self.proj = None return take_indices def forward(self, x: torch.Tensor): x = self._embeds(x) x = self.transformer(x) pooled, tokens = self._pool(x) if self.proj is not None: pooled = pooled @ self.proj if self.output_tokens: return pooled, tokens return pooled def text_global_pool( x: torch.Tensor, text: Optional[torch.Tensor] = None, pool_type: str = 'argmax', ) -> torch.Tensor: if pool_type == 'first': pooled = x[:, 0] elif pool_type == 'last': pooled = x[:, -1] elif pool_type == 'argmax': # take features from the eot embedding (eot_token is the highest number in each sequence) assert text is not None pooled = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] else: pooled = x return pooled class TextTransformer(nn.Module): output_tokens: torch.jit.Final[bool] def __init__( self, context_length: int = 77, vocab_size: int = 49408, width: int = 512, heads: int = 8, layers: int = 12, mlp_ratio: float = 4.0, ls_init_value: float = None, output_dim: Optional[int] = 512, embed_cls: bool = False, no_causal_mask: bool = False, pad_id: int = 0, pool_type: str = 'argmax', proj_type: str = 'linear', proj_bias: bool = False, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, output_tokens: bool = False, ): super().__init__() assert pool_type in ('first', 'last', 'argmax', 'none') self.output_tokens = output_tokens self.num_pos = self.context_length = context_length self.vocab_size = vocab_size self.width = width self.output_dim = output_dim self.heads = heads self.pad_id = pad_id self.pool_type = pool_type self.token_embedding = nn.Embedding(vocab_size, width) if embed_cls: self.cls_emb = nn.Parameter(torch.empty(width)) self.num_pos += 1 else: self.cls_emb = None self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width)) self.transformer = Transformer( width=width, layers=layers, heads=heads, mlp_ratio=mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, ) self.ln_final = norm_layer(width) if no_causal_mask: self.attn_mask = None else: self.register_buffer('attn_mask', self.build_causal_mask(), persistent=False) if proj_type == 'none' or not output_dim: self.text_projection = None else: if proj_bias: self.text_projection = nn.Linear(width, output_dim) else: self.text_projection = nn.Parameter(torch.empty(width, output_dim)) self.init_parameters() def init_parameters(self): nn.init.normal_(self.token_embedding.weight, std=0.02) nn.init.normal_(self.positional_embedding, std=0.01) if self.cls_emb is not None: nn.init.normal_(self.cls_emb, std=0.01) proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) attn_std = self.transformer.width ** -0.5 fc_std = (2 * self.transformer.width) ** -0.5 for block in self.transformer.resblocks: nn.init.normal_(block.attn.in_proj_weight, std=attn_std) nn.init.normal_(block.attn.out_proj.weight, std=proj_std) nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) if self.text_projection is not None: if isinstance(self.text_projection, nn.Linear): nn.init.normal_(self.text_projection.weight, std=self.transformer.width ** -0.5) if self.text_projection.bias is not None: nn.init.zeros_(self.text_projection.bias) else: nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.transformer.grad_checkpointing = enable @torch.jit.ignore def no_weight_decay(self): # for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default no_wd = {'positional_embedding'} if self.cls_emb is not None: no_wd.add('cls_emb') return no_wd def build_causal_mask(self): # lazily create causal attention mask, with full attention between the tokens # pytorch uses additive attention mask; fill with -inf mask = torch.empty(self.num_pos, self.num_pos) mask.fill_(float("-inf")) mask.triu_(1) # zero out the lower diagonal return mask def build_cls_mask(self, text, cast_dtype: torch.dtype): cls_mask = (text != self.pad_id).unsqueeze(1) cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True) additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device) additive_mask.fill_(0) additive_mask.masked_fill_(~cls_mask, float("-inf")) additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0) return additive_mask def _embeds(self, text) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: cast_dtype = self.transformer.get_cast_dtype() seq_len = text.shape[1] x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] attn_mask = self.attn_mask if self.cls_emb is not None: seq_len += 1 x = torch.cat([x, _expand_token(self.cls_emb, x.shape[0])], dim=1) cls_mask = self.build_cls_mask(text, cast_dtype) if attn_mask is not None: attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len] x = x + self.positional_embedding[:seq_len].to(cast_dtype) return x, attn_mask def forward_intermediates( self, text: torch.Tensor, indices: Optional[Union[int, List[int]]] = None, stop_early: bool = False, normalize_intermediates: bool = False, intermediates_only: bool = False, output_fmt: str = 'NCHW', output_extra_tokens: bool = False, ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: """ Forward features that returns intermediates. Args: text: Input text ids indices: Take last n blocks if int, all if None, select matching indices if sequence stop_early: Stop iterating over blocks when last desired intermediate hit normalize_intermediates: Apply norm layer to all intermediates intermediates_only: Only return intermediate features output_fmt: Shape of intermediate feature outputs output_extra_tokens: Return both prefix and intermediate tokens Returns: """ assert output_fmt in ('NLC',), 'Output format must be NLC.' # forward pass x, attn_mask = self._embeds(text) x, intermediates = self.transformer.forward_intermediates( x, attn_mask=attn_mask, indices=indices, stop_early=stop_early, ) # process intermediates if normalize_intermediates: # apply final norm to all intermediates intermediates = [self.ln_final(xi) for xi in intermediates] output = {} if self.cls_emb is not None: seq_intermediates = [xi[:, :-1] for xi in intermediates] # separate concat'd class token from sequence if output_extra_tokens: # return suffix class tokens separately cls_intermediates = [xi[:, -1:] for xi in intermediates] output['text_intermediates_suffix'] = cls_intermediates intermediates = seq_intermediates output['text_intermediates'] = intermediates if intermediates_only: return output if self.cls_emb is not None: # presence of appended cls embed (CoCa) overrides pool_type, always take last token pooled = text_global_pool(x, pool_type='last') pooled = self.ln_final(pooled) # final LN applied after pooling in this case else: x = self.ln_final(x) pooled = text_global_pool(x, text, pool_type=self.pool_type) if self.text_projection is not None: if isinstance(self.text_projection, nn.Linear): pooled = self.text_projection(pooled) else: pooled = pooled @ self.text_projection output['text_features'] = pooled return output def prune_intermediate_layers( self, indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, ): """ Prune layers not required for specified intermediates. """ take_indices = self.transformer.prune_intermediate_layers(indices) if prune_norm: self.ln_final = nn.Identity() if prune_head: self.text_projection = None return take_indices def forward(self, text): x, attn_mask = self._embeds(text) x = self.transformer(x, attn_mask=attn_mask) # x.shape = [batch_size, n_ctx, transformer.width] if self.cls_emb is not None: # presence of appended cls embed (CoCa) overrides pool_type, always take last token pooled = text_global_pool(x, pool_type='last') pooled = self.ln_final(pooled) # final LN applied after pooling in this case tokens = x[:, :-1] else: x = self.ln_final(x) pooled = text_global_pool(x, text, pool_type=self.pool_type) tokens = x if self.text_projection is not None: if isinstance(self.text_projection, nn.Linear): pooled = self.text_projection(pooled) else: pooled = pooled @ self.text_projection if self.output_tokens: return pooled, tokens return pooled class MultimodalTransformer(Transformer): def __init__( self, width: int, layers: int, heads: int, context_length: int = 77, mlp_ratio: float = 4.0, ls_init_value: float = None, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, output_dim: int = 512, batch_first: bool = True, ): super().__init__( width=width, layers=layers, heads=heads, mlp_ratio=mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, batch_first=batch_first, ) self.context_length = context_length self.cross_attn = nn.ModuleList([ ResidualAttentionBlock( width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, is_cross_attention=True, batch_first=batch_first, ) for _ in range(layers) ]) self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) self.ln_final = norm_layer(width) self.text_projection = nn.Parameter(torch.empty(width, output_dim)) def init_parameters(self): proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) attn_std = self.transformer.width ** -0.5 fc_std = (2 * self.transformer.width) ** -0.5 for block in self.transformer.resblocks: nn.init.normal_(block.attn.in_proj_weight, std=attn_std) nn.init.normal_(block.attn.out_proj.weight, std=proj_std) nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) for block in self.transformer.cross_attn: nn.init.normal_(block.attn.in_proj_weight, std=attn_std) nn.init.normal_(block.attn.out_proj.weight, std=proj_std) nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) if self.text_projection is not None: nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) def build_attention_mask(self): # lazily create causal attention mask, with full attention between the tokens # pytorch uses additive attention mask; fill with -inf mask = torch.empty(self.context_length, self.context_length) mask.fill_(float("-inf")) mask.triu_(1) # zero out the lower diagonal return mask def forward_intermediates( self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, indices: Optional[Union[int, List[int]]] = None, stop_early: bool = False, ): assert False, "Not currently implemented for MultimodalTransformer w/ xattn" def forward(self, image_embs, text_embs): seq_len = text_embs.shape[1] if not self.batch_first: image_embs = image_embs.permute(1, 0, 2) # NLD -> LND text_embs = text_embs.permute(1, 0, 2) # NLD -> LND for resblock, cross_attn in zip(self.resblocks, self.cross_attn): if self.grad_checkpointing and not torch.jit.is_scripting(): # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 text_embs = checkpoint( resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len], use_reentrant=False) text_embs = checkpoint( cross_attn, text_embs, image_embs, image_embs, None, use_reentrant=False) else: text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len]) text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs) if not self.batch_first: text_embs = text_embs.permute(1, 0, 2) # LND -> NLD out = self.ln_final(text_embs) if self.text_projection is not None: out = out @ self.text_projection return out @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable @dataclass class CLIPVisionCfg: layers: Union[Tuple[int, int, int, int], int] = 12 width: int = 768 head_width: int = 64 mlp_ratio: float = 4.0 patch_size: int = 16 image_size: Union[Tuple[int, int], int] = 224 ls_init_value: Optional[float] = None # layer scale initial value patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer (overrides pool_type) attn_pooler_queries: int = 256 # n_queries for attentional pooler attn_pooler_heads: int = 8 # n heads for attentional_pooling no_ln_pre: bool = False # disable pre transformer LayerNorm pos_embed_type: str = 'learnable' final_ln_after_pool: bool = False # apply final LayerNorm after pooling pool_type: str = 'tok' output_tokens: bool = False act_kwargs: Optional[dict] = None norm_kwargs: Optional[dict] = None timm_model_name: Optional[str] = None # a valid model name overrides layers, width, patch_size timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') timm_proj_bias: bool = False # enable bias final projection timm_drop: float = 0. # head dropout timm_drop_path: Optional[float] = None # backbone stochastic depth @dataclass class CLIPTextCfg: context_length: int = 77 vocab_size: int = 49408 hf_tokenizer_name: Optional[str] = None tokenizer_kwargs: Optional[dict] = None width: int = 512 heads: int = 8 layers: int = 12 mlp_ratio: float = 4.0 ls_init_value: Optional[float] = None # layer scale initial value embed_cls: bool = False pad_id: int = 0 no_causal_mask: bool = False # disable causal masking final_ln_after_pool: bool = False # apply final LayerNorm after pooling pool_type: str = 'argmax' proj_bias: bool = False proj_type: str = 'linear' # control final text projection, 'none' forces no projection output_tokens: bool = False act_kwargs: dict = None norm_kwargs: dict = None # HuggingFace specific text tower config hf_model_name: Optional[str] = None hf_model_pretrained: bool = True hf_proj_type: str = 'mlp' hf_pooler_type: str = 'mean_pooler' # attentional pooling for HF models def get_cast_dtype(precision: str): cast_dtype = None if precision == 'bf16': cast_dtype = torch.bfloat16 elif precision == 'fp16': cast_dtype = torch.float16 return cast_dtype def get_input_dtype(precision: str): input_dtype = None if precision in ('bf16', 'pure_bf16'): input_dtype = torch.bfloat16 elif precision in ('fp16', 'pure_fp16'): input_dtype = torch.float16 return input_dtype def _build_vision_tower( embed_dim: int, vision_cfg: CLIPVisionCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None ): if isinstance(vision_cfg, dict): vision_cfg = CLIPVisionCfg(**vision_cfg) # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more # memory efficient in recent PyTorch releases (>= 1.10). # NOTE: timm models always use native GELU regardless of quick_gelu flag. act_layer = QuickGELU if quick_gelu else nn.GELU if vision_cfg.timm_model_name: visual = TimmModel( vision_cfg.timm_model_name, pretrained=vision_cfg.timm_model_pretrained, pool=vision_cfg.timm_pool, proj=vision_cfg.timm_proj, proj_bias=vision_cfg.timm_proj_bias, drop=vision_cfg.timm_drop, drop_path=vision_cfg.timm_drop_path, patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None, embed_dim=embed_dim, image_size=vision_cfg.image_size, ) elif isinstance(vision_cfg.layers, (tuple, list)): vision_heads = vision_cfg.width * 32 // vision_cfg.head_width visual = ModifiedResNet( layers=vision_cfg.layers, output_dim=embed_dim, heads=vision_heads, image_size=vision_cfg.image_size, width=vision_cfg.width, ) else: vision_heads = vision_cfg.width // vision_cfg.head_width norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm if vision_cfg.norm_kwargs: norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs) if vision_cfg.act_kwargs is not None: act_layer = partial(act_layer, **vision_cfg.act_kwargs) visual = VisionTransformer( image_size=vision_cfg.image_size, patch_size=vision_cfg.patch_size, width=vision_cfg.width, layers=vision_cfg.layers, heads=vision_heads, mlp_ratio=vision_cfg.mlp_ratio, ls_init_value=vision_cfg.ls_init_value, patch_dropout=vision_cfg.patch_dropout, attentional_pool=vision_cfg.attentional_pool, attn_pooler_queries=vision_cfg.attn_pooler_queries, attn_pooler_heads=vision_cfg.attn_pooler_heads, pos_embed_type=vision_cfg.pos_embed_type, no_ln_pre=vision_cfg.no_ln_pre, final_ln_after_pool=vision_cfg.final_ln_after_pool, pool_type=vision_cfg.pool_type, output_tokens=vision_cfg.output_tokens, output_dim=embed_dim, act_layer=act_layer, norm_layer=norm_layer, ) return visual def _build_text_tower( embed_dim: int, text_cfg: CLIPTextCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, ): if isinstance(text_cfg, dict): text_cfg = CLIPTextCfg(**text_cfg) if text_cfg.hf_model_name: text = HFTextEncoder( text_cfg.hf_model_name, output_dim=embed_dim, proj_type=text_cfg.hf_proj_type, pooler_type=text_cfg.hf_pooler_type, pretrained=text_cfg.hf_model_pretrained, output_tokens=text_cfg.output_tokens, ) else: act_layer = QuickGELU if quick_gelu else nn.GELU norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm if text_cfg.norm_kwargs: norm_layer = partial(norm_layer, **text_cfg.norm_kwargs) if text_cfg.act_kwargs is not None: act_layer = partial(act_layer, **text_cfg.act_kwargs) text = TextTransformer( context_length=text_cfg.context_length, vocab_size=text_cfg.vocab_size, width=text_cfg.width, heads=text_cfg.heads, layers=text_cfg.layers, mlp_ratio=text_cfg.mlp_ratio, ls_init_value=text_cfg.ls_init_value, output_dim=embed_dim, embed_cls=text_cfg.embed_cls, no_causal_mask=text_cfg.no_causal_mask, pad_id=text_cfg.pad_id, pool_type=text_cfg.pool_type, proj_type=text_cfg.proj_type, proj_bias=text_cfg.proj_bias, output_tokens=text_cfg.output_tokens, act_layer=act_layer, norm_layer=norm_layer, ) return text class CLIP(nn.Module): output_dict: torch.jit.Final[bool] def __init__( self, embed_dim: int, vision_cfg: CLIPVisionCfg, text_cfg: CLIPTextCfg, quick_gelu: bool = False, init_logit_scale: float = np.log(1 / 0.07), init_logit_bias: Optional[float] = None, nonscalar_logit_scale: bool = False, cast_dtype: Optional[torch.dtype] = None, output_dict: bool = False, ): super().__init__() self.output_dict = output_dict self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) self.transformer = text.transformer self.context_length = text.context_length self.vocab_size = text.vocab_size self.token_embedding = text.token_embedding self.positional_embedding = text.positional_embedding self.ln_final = text.ln_final self.text_projection = text.text_projection self.text_pool_type = text.pool_type self.register_buffer('attn_mask', text.attn_mask, persistent=False) lshape = [1] if nonscalar_logit_scale else [] self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale) if init_logit_bias is not None: self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias) else: self.logit_bias = None def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.visual.set_grad_checkpointing(enable) self.transformer.grad_checkpointing = enable @torch.jit.ignore def no_weight_decay(self): # for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default no_wd = {'positional_embedding'} if hasattr(self.visual, 'no_weight_decay'): for n in self.visual.no_weight_decay(): no_wd.add('visual.' + n) return no_wd def encode_image(self, image, normalize: bool = False): features = self.visual(image) return F.normalize(features, dim=-1) if normalize else features def encode_text(self, text, normalize: bool = False): cast_dtype = self.transformer.get_cast_dtype() x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] x = x + self.positional_embedding.to(cast_dtype) x = self.transformer(x, attn_mask=self.attn_mask) x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] x = text_global_pool(x, text, self.text_pool_type) if self.text_projection is not None: if isinstance(self.text_projection, nn.Linear): x = self.text_projection(x) else: x = x @ self.text_projection return F.normalize(x, dim=-1) if normalize else x def get_logits(self, image, text): image_features = self.encode_image(image, normalize=True) text_features = self.encode_text(text, normalize=True) image_logits = self.logit_scale.exp() * image_features @ text_features.T if self.logit_bias is not None: image_logits += self.logit_bias text_logits = image_logits.T return image_logits, text_logits def forward_intermediates( self, image: Optional[torch.Tensor] = None, text: Optional[torch.Tensor] = None, image_indices: Optional[Union[int, List[int]]] = None, text_indices: Optional[Union[int, List[int]]] = None, stop_early: bool = False, normalize: bool = True, normalize_intermediates: bool = False, intermediates_only: bool = False, image_output_fmt: str = 'NCHW', image_output_extra_tokens: bool = False, text_output_fmt: str = 'NLC', text_output_extra_tokens: bool = False, output_logits: bool = False, output_logit_scale_bias: bool = False, ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: """ Forward features that returns intermediates. Args: image: Input image tensor text: Input text tensor image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence text_indices: Take last n blocks if int, all if None, select matching indices if sequence stop_early: Stop iterating over blocks when last desired intermediate hit normalize_intermediates: Apply final norm layer to all intermediates normalize: L2 Normalize final features intermediates_only: Only return intermediate features, do not return final features image_output_fmt: Shape of intermediate image feature outputs image_output_extra_tokens: Return both prefix and spatial intermediate tokens text_output_fmt: Shape of intermediate text feature outputs (ignored for this model) text_output_extra_tokens: Return both prefix and spatial intermediate tokens (ignored for this model) output_logits: Include logits in output output_logit_scale_bias: Include the logit scale bias in the output Returns: """ output = {} if intermediates_only: # intermediates only disables final feature normalization, and include logits normalize = False output_logits = False if output_logits: assert image is not None and text is not None, 'Both image and text inputs are required to compute logits' if image is not None: image_output = self.visual.forward_intermediates( image, indices=image_indices, stop_early=stop_early, normalize_intermediates=normalize_intermediates, intermediates_only=intermediates_only, output_fmt=image_output_fmt, output_extra_tokens=image_output_extra_tokens, ) if normalize and "image_features" in image_output: image_output["image_features"] = F.normalize(image_output["image_features"], dim=-1) output.update(image_output) if text is not None: cast_dtype = self.transformer.get_cast_dtype() x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] x = x + self.positional_embedding.to(cast_dtype) x, intermediates = self.transformer.forward_intermediates( x, attn_mask=self.attn_mask, indices=text_indices ) if normalize_intermediates: intermediates = [self.ln_final(xi) for xi in intermediates] # NOTE this model doesn't support cls embed in text transformer, no need for extra intermediate tokens output["text_intermediates"] = intermediates if not intermediates_only: x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] x = text_global_pool(x, text, self.text_pool_type) if self.text_projection is not None: if isinstance(self.text_projection, nn.Linear): x = self.text_projection(x) else: x = x @ self.text_projection if normalize: x = F.normalize(x, dim=-1) output["text_features"] = x logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None if output_logits: image_logits = logit_scale_exp * output["image_features"] @ output["text_features"].T if self.logit_bias is not None: image_logits += self.logit_bias text_logits = image_logits.T output["image_logits"] = image_logits output["text_logits"] = text_logits if output_logit_scale_bias: output["logit_scale"] = logit_scale_exp if self.logit_bias is not None: output['logit_bias'] = self.logit_bias return output def forward( self, image: Optional[torch.Tensor] = None, text: Optional[torch.Tensor] = None, ): image_features = self.encode_image(image, normalize=True) if image is not None else None text_features = self.encode_text(text, normalize=True) if text is not None else None if self.output_dict: out_dict = { "image_features": image_features, "text_features": text_features, "logit_scale": self.logit_scale.exp() } if self.logit_bias is not None: out_dict['logit_bias'] = self.logit_bias return out_dict if self.logit_bias is not None: return image_features, text_features, self.logit_scale.exp(), self.logit_bias return image_features, text_features, self.logit_scale.exp() class CustomTextCLIP(nn.Module): output_dict: torch.jit.Final[bool] def __init__( self, embed_dim: int, vision_cfg: CLIPVisionCfg, text_cfg: CLIPTextCfg, quick_gelu: bool = False, init_logit_scale: float = np.log(1 / 0.07), init_logit_bias: Optional[float] = None, nonscalar_logit_scale: bool = False, cast_dtype: Optional[torch.dtype] = None, output_dict: bool = False, ): super().__init__() self.output_dict = output_dict self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) self.context_length = self.text.context_length self.vocab_size = self.text.vocab_size lshape = [1] if nonscalar_logit_scale else [] self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale) if init_logit_bias is not None: self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias) else: self.logit_bias = None def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): self.text.lock(unlocked_layers, freeze_layer_norm) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.visual.set_grad_checkpointing(enable) self.text.set_grad_checkpointing(enable) @torch.jit.ignore def no_weight_decay(self): # for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default no_wd = set() if hasattr(self.visual, 'no_weight_decay'): for n in self.visual.no_weight_decay(): no_wd.add('visual.' + n) if hasattr(self.text, 'no_weight_decay'): for n in self.visual.no_weight_decay(): no_wd.add('text.' + n) return no_wd def encode_image(self, image, normalize: bool = False): features = self.visual(image) return F.normalize(features, dim=-1) if normalize else features def encode_text(self, text, normalize: bool = False): features = self.text(text) return F.normalize(features, dim=-1) if normalize else features def get_logits(self, image, text): image_features = self.encode_image(image, normalize=True) text_features = self.encode_text(text, normalize=True) image_logits = self.logit_scale.exp() * image_features @ text_features.T if self.logit_bias is not None: image_logits += self.logit_bias text_logits = image_logits.T return image_logits, text_logits def forward_intermediates( self, image: Optional[torch.Tensor] = None, text: Optional[torch.Tensor] = None, image_indices: Optional[Union[int, List[int]]] = None, text_indices: Optional[Union[int, List[int]]] = None, stop_early: bool = False, normalize: bool = True, normalize_intermediates: bool = False, intermediates_only: bool = False, image_output_fmt: str = 'NCHW', image_output_extra_tokens: bool = False, text_output_fmt: str = 'NLC', text_output_extra_tokens: bool = False, output_logits: bool = False, output_logit_scale_bias: bool = False, ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: """ Forward features that returns intermediates. Args: image: Input image tensor text: Input text tensor image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence text_indices: Take last n blocks if int, all if None, select matching indices if sequence stop_early: Stop iterating over blocks when last desired intermediate hit normalize: L2 Normalize final image and text features (if present) normalize_intermediates: Apply final encoder norm layer to all intermediates (if possible) intermediates_only: Only return intermediate features, do not return final features image_output_fmt: Shape of intermediate image feature outputs image_output_extra_tokens: Return both prefix and spatial intermediate tokens text_output_fmt: Shape of intermediate text feature outputs text_output_extra_tokens: Return both prefix and spatial intermediate tokens output_logits: Include logits in output output_logit_scale_bias: Include the logit scale bias in the output Returns: """ output = {} if intermediates_only: # intermediates only disables final feature normalization, and include logits normalize = False output_logits = False if output_logits: assert image is not None and text is not None, 'Both image and text inputs are required to compute logits' if image is not None: image_output = self.visual.forward_intermediates( image, indices=image_indices, stop_early=stop_early, normalize_intermediates=normalize_intermediates, intermediates_only=intermediates_only, output_fmt=image_output_fmt, output_extra_tokens=image_output_extra_tokens, ) if normalize and "image_features" in image_output: image_output["image_features"] = F.normalize(image_output["image_features"], dim=-1) output.update(image_output) if text is not None: text_output = self.text.forward_intermediates( text, indices=text_indices, stop_early=stop_early, normalize_intermediates=normalize_intermediates, intermediates_only=intermediates_only, output_fmt=text_output_fmt, output_extra_tokens=text_output_extra_tokens, ) if normalize and "text_features" in text_output: text_output["text_features"] = F.normalize(text_output["text_features"], dim=-1) output.update(text_output) logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None if output_logits: image_logits = logit_scale_exp * output["image_features"] @ output["text_features"].T if self.logit_bias is not None: image_logits += self.logit_bias text_logits = image_logits.T output["image_logits"] = image_logits output["text_logits"] = text_logits if output_logit_scale_bias: output["logit_scale"] = logit_scale_exp if self.logit_bias is not None: output['logit_bias'] = self.logit_bias return output def forward( self, image: Optional[torch.Tensor] = None, text: Optional[torch.Tensor] = None, ): image_features = self.encode_image(image, normalize=True) if image is not None else None text_features = self.encode_text(text, normalize=True) if text is not None else None if self.output_dict: out_dict = { "image_features": image_features, "text_features": text_features, "logit_scale": self.logit_scale.exp() } if self.logit_bias is not None: out_dict['logit_bias'] = self.logit_bias return out_dict if self.logit_bias is not None: return image_features, text_features, self.logit_scale.exp(), self.logit_bias return image_features, text_features, self.logit_scale.exp() def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): """Convert applicable model parameters to low-precision (bf16 or fp16)""" def _convert_weights(l): if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): l.weight.data = l.weight.data.to(dtype) if l.bias is not None: l.bias.data = l.bias.data.to(dtype) if isinstance(l, (nn.MultiheadAttention, Attention)): for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: tensor = getattr(l, attr) if tensor is not None: tensor.data = tensor.data.to(dtype) if isinstance(l, (CLIP, TextTransformer)): # convert text nn.Parameter projections attr = getattr(l, "text_projection", None) if attr is not None: attr.data = attr.data.to(dtype) if isinstance(l, VisionTransformer): # convert vision nn.Parameter projections attr = getattr(l, "proj", None) if attr is not None: attr.data = attr.data.to(dtype) model.apply(_convert_weights) convert_weights_to_fp16 = convert_weights_to_lp # backwards compat # used to maintain checkpoint compatibility def convert_to_custom_text_state_dict(state_dict: dict): if 'text_projection' in state_dict: # old format state_dict, move text tower -> .text new_state_dict = {} for k, v in state_dict.items(): if any(k.startswith(p) for p in ( 'text_projection', 'positional_embedding', 'token_embedding', 'transformer', 'ln_final', )): k = 'text.' + k new_state_dict[k] = v return new_state_dict return state_dict def build_model_from_openai_state_dict( state_dict: dict, quick_gelu=True, cast_dtype=torch.float16, ): vit = "visual.proj" in state_dict if vit: vision_width = state_dict["visual.conv1.weight"].shape[0] vision_layers = len( [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) image_size = vision_patch_size * grid_size else: counts: list = [ len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] vision_layers = tuple(counts) vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) vision_patch_size = None assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] image_size = output_width * 32 embed_dim = state_dict["text_projection"].shape[1] context_length = state_dict["positional_embedding"].shape[0] vocab_size = state_dict["token_embedding.weight"].shape[0] transformer_width = state_dict["ln_final.weight"].shape[0] transformer_heads = transformer_width // 64 transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) vision_cfg = CLIPVisionCfg( layers=vision_layers, width=vision_width, patch_size=vision_patch_size, image_size=image_size, ) text_cfg = CLIPTextCfg( context_length=context_length, vocab_size=vocab_size, width=transformer_width, heads=transformer_heads, layers=transformer_layers, ) model = CLIP( embed_dim, vision_cfg=vision_cfg, text_cfg=text_cfg, quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU cast_dtype=cast_dtype, ) for key in ["input_resolution", "context_length", "vocab_size"]: state_dict.pop(key, None) convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16 model.load_state_dict(state_dict) return model.eval() def trace_model(model, batch_size=256, device=torch.device('cpu')): model.eval() image_size = model.visual.image_size example_images = torch.ones((batch_size, 3, image_size, image_size), device=device) example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device) model = torch.jit.trace_module( model, inputs=dict( forward=(example_images, example_text), encode_text=(example_text,), encode_image=(example_images,) )) model.visual.image_size = image_size return model def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True): # Rescale the grid of position embeddings when loading from state_dict old_pos_embed = state_dict.get('visual.positional_embedding', None) if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): return grid_size = to_2tuple(model.visual.grid_size) extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) new_seq_len = grid_size[0] * grid_size[1] + extra_tokens if new_seq_len == old_pos_embed.shape[0]: return if extra_tokens: pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] else: pos_emb_tok, pos_emb_img = None, old_pos_embed old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) pos_emb_img = F.interpolate( pos_emb_img, size=grid_size, mode=interpolation, antialias=antialias, align_corners=False, ) pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] if pos_emb_tok is not None: new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) else: new_pos_embed = pos_emb_img state_dict['visual.positional_embedding'] = new_pos_embed def resize_text_pos_embed(state_dict, model, interpolation: str = 'linear', antialias: bool = False): old_pos_embed = state_dict.get('positional_embedding', None) if old_pos_embed is None: return # FIXME add support for text cls_token model_pos_embed = getattr(model, 'positional_embedding', None) if model_pos_embed is None: model_pos_embed = getattr(model.text, 'positional_embedding', None) old_num_pos = old_pos_embed.shape[0] old_width = old_pos_embed.shape[1] num_pos = model_pos_embed.shape[0] width = model_pos_embed.shape[1] assert old_width == width, 'text pos_embed width changed!' if old_num_pos == num_pos: return logging.info('Resizing text position embedding num_pos from %s to %s', old_num_pos, num_pos) old_pos_embed = old_pos_embed.reshape(1, old_num_pos, old_width).permute(0, 2, 1) old_pos_embed = F.interpolate( old_pos_embed, size=num_pos, mode=interpolation, antialias=antialias, align_corners=False, ) old_pos_embed = old_pos_embed.permute(0, 2, 1)[0] new_pos_embed = old_pos_embed state_dict['positional_embedding'] = new_pos_embed def get_model_preprocess_cfg(model): module = getattr(model, 'visual', model) preprocess_cfg = getattr(module, 'preprocess_cfg', {}) if not preprocess_cfg: # use separate legacy attributes if preprocess_cfg dict not found size = getattr(module, 'image_size') if size is not None: preprocess_cfg['size'] = size mean = getattr(module, 'image_mean', None) if mean is not None: preprocess_cfg['mean'] = mean std = getattr(module, 'image_std', None) if std is not None: preprocess_cfg['std'] = std return preprocess_cfg def set_model_preprocess_cfg(model, preprocess_cfg: Dict[str, Any]): module = getattr(model, 'visual', model) module.image_mean = preprocess_cfg['mean'] # legacy attribute, keeping for bwd compat module.image_std = preprocess_cfg['std'] # legacy attribute, keeping for bwd compat module.preprocess_cfg = copy.deepcopy(preprocess_cfg) # new attr, package all pp cfg as dict def get_model_tokenize_cfg(model): module = getattr(model, 'text', model) cfg = {} context_length = getattr(module, 'context_length', None) if context_length is not None: cfg['context_length'] = context_length vocab_size = getattr(module, 'vocab_size', None) if vocab_size is not None: cfg['vocab_size'] = vocab_size return cfg try: from huggingface_hub import hf_hub_download hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__) _has_hf_hub = True except ImportError: hf_hub_download = None _has_hf_hub = False def _pcfg(url='', hf_hub='', **kwargs): # OpenAI / OpenCLIP defaults return { 'url': url, 'hf_hub': hf_hub, 'mean': OPENAI_DATASET_MEAN, 'std': OPENAI_DATASET_STD, 'interpolation': 'bicubic', 'resize_mode': 'shortest', **kwargs, } def _slpcfg(url='', hf_hub='', **kwargs): # SiGLIP defaults return { 'url': url, 'hf_hub': hf_hub, 'mean': INCEPTION_MEAN, 'std': INCEPTION_STD, 'interpolation': 'bicubic', 'resize_mode': 'squash', **kwargs, } def _apcfg(url='', hf_hub='', **kwargs): # CLIPA defaults return { 'url': url, 'hf_hub': hf_hub, 'mean': IMAGENET_MEAN, 'std': IMAGENET_STD, 'interpolation': 'bilinear', 'resize_mode': 'squash', **kwargs, } def _mccfg(url='', hf_hub='', **kwargs): # MobileCLIP return { 'url': url, 'hf_hub': hf_hub, 'mean': (0., 0., 0.), 'std': (1., 1., 1.), 'interpolation': 'bilinear', 'resize_mode': 'shortest', **kwargs, } _RN50 = dict( openai=_pcfg( url="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", hf_hub="timm/resnet50_clip.openai/", quick_gelu=True, ), yfcc15m=_pcfg( url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", hf_hub="timm/resnet50_clip.yfcc15m/", quick_gelu=True, ), cc12m=_pcfg( url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt", hf_hub="timm/resnet50_clip.cc12m/", quick_gelu=True, ), ) _RN101 = dict( openai=_pcfg( url="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", hf_hub="timm/resnet101_clip.openai/", quick_gelu=True, ), yfcc15m=_pcfg( url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt", hf_hub="timm/resnet101_clip.yfcc15m/", quick_gelu=True, ), ) _RN50x4 = dict( openai=_pcfg( url="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", hf_hub="timm/resnet50x4_clip.openai/", quick_gelu=True, ), ) _RN50x16 = dict( openai=_pcfg( url="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", hf_hub="timm/resnet50x16_clip.openai/", quick_gelu=True, ), ) _RN50x64 = dict( openai=_pcfg( url="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", hf_hub="timm/resnet50x64_clip.openai/", quick_gelu=True, ), ) _VITB32 = dict( openai=_pcfg( url="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", hf_hub="timm/vit_base_patch32_clip_224.openai/", quick_gelu=True, ), # LAION 400M (quick gelu) laion400m_e31=_pcfg( url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", hf_hub="timm/vit_base_patch32_clip_224.laion400m_e31/", quick_gelu=True, ), laion400m_e32=_pcfg( url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", hf_hub="timm/vit_base_patch32_clip_224.laion400m_e32/", quick_gelu=True, ), # LAION 2B-en laion2b_e16=_pcfg( url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth", hf_hub="timm/vit_base_patch32_clip_224.laion2b_e16/", ), laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/'), # DataComp-XL models datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K/'), # DataComp-M models datacomp_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.M-s128M-b4K/'), commonpool_m_clip_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.clip-s128M-b4K/'), commonpool_m_laion_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.laion-s128M-b4K/'), commonpool_m_image_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.image-s128M-b4K/'), commonpool_m_text_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.text-s128M-b4K/'), commonpool_m_basic_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.basic-s128M-b4K/'), commonpool_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M-s128M-b4K/'), # DataComp-S models datacomp_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.S-s13M-b4K/'), commonpool_s_clip_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.clip-s13M-b4K/'), commonpool_s_laion_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.laion-s13M-b4K/'), commonpool_s_image_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.image-s13M-b4K/'), commonpool_s_text_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.text-s13M-b4K/'), commonpool_s_basic_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.basic-s13M-b4K/'), commonpool_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S-s13M-b4K/'), # MetaClip models (NOTE quick-gelu activation used) metaclip_400m=_pcfg( url="https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_400m.pt", hf_hub="timm/vit_base_patch32_clip_224.metaclip_400m/", quick_gelu=True, ), metaclip_fullcc=_pcfg( url="https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_fullcc2.5b.pt", hf_hub="timm/vit_base_patch32_clip_224.metaclip_2pt5b/", quick_gelu=True, ), ) _VITB32_256 = dict( datacomp_s34b_b86k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-256x256-DataComp-s34B-b86K/'), ) _VITB16 = dict( openai=_pcfg( url="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", hf_hub="timm/vit_base_patch16_clip_224.openai/", quick_gelu=True, ), # LAION-400M laion400m_e31=_pcfg( url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt", hf_hub="timm/vit_base_patch16_clip_224.laion400m_e31/", ), laion400m_e32=_pcfg( url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt", hf_hub="timm/vit_base_patch16_clip_224.laion400m_e32/", ), # LAION-2B laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'), # DataComp-XL models datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.XL-s13B-b90K/'), # DataComp-L models datacomp_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.L-s1B-b8K/'), commonpool_l_clip_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.clip-s1B-b8K/'), commonpool_l_laion_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.laion-s1B-b8K/'), commonpool_l_image_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.image-s1B-b8K/'), commonpool_l_text_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.text-s1B-b8K/'), commonpool_l_basic_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.basic-s1B-b8K/'), commonpool_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L-s1B-b8K/'), # DFN dfn2b=_pcfg( hf_hub='apple/DFN2B-CLIP-ViT-B-16/', quick_gelu=True, ), # MetaCLIP (these are quick-gelu) metaclip_400m=_pcfg( url="https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_400m.pt", hf_hub="timm/vit_base_patch16_clip_224.metaclip_400m/", quick_gelu=True, ), metaclip_fullcc=_pcfg( url="https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_fullcc2.5b.pt", hf_hub="timm/vit_base_patch16_clip_224.metaclip_2pt5b/", quick_gelu=True, ), ) _VITB16_PLUS_240 = dict( laion400m_e31=_pcfg( url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt", hf_hub="timm/vit_base_patch16_plus_clip_240.laion400m_e31/", ), laion400m_e32=_pcfg( url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt", hf_hub="timm/vit_base_patch16_plus_clip_240.laion400m_e31/", ), ) _VITL14 = dict( openai=_pcfg( url="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", hf_hub="timm/vit_large_patch14_clip_224.openai/", quick_gelu=True, ), # LAION-400M laion400m_e31=_pcfg( url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt", hf_hub="timm/vit_large_patch14_clip_224.laion400m_e31/", ), laion400m_e32=_pcfg( url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt", hf_hub="timm/vit_large_patch14_clip_224.laion400m_e32/", ), # LAION-2B-en laion2b_s32b_b82k=_pcfg( hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/', mean=INCEPTION_MEAN, std=INCEPTION_STD), # DataComp-XL models datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K/'), commonpool_xl_clip_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.clip-s13B-b90K/'), commonpool_xl_laion_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.laion-s13B-b90K/'), commonpool_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL-s13B-b90K/'), # MetaCLIP metaclip_400m=_pcfg( url="https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_400m.pt", hf_hub="timm/vit_large_patch14_clip_224.metaclip_400m/", quick_gelu=True, ), metaclip_fullcc=_pcfg( url="https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_fullcc2.5b.pt", hf_hub="timm/vit_large_patch14_clip_224.metaclip_2pt5b/", quick_gelu=True, ), # DFN-2B (quick-gelu) dfn2b=_pcfg( hf_hub='apple/DFN2B-CLIP-ViT-L-14/', quick_gelu=True, ), # DFN-2B 39B SS dfn2b_s39b=_pcfg( hf_hub='apple/DFN2B-CLIP-ViT-L-14-39B/', ), ) _VITL14_336 = dict( openai=_pcfg( url="https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", hf_hub="timm/vit_large_patch14_clip_336.openai/", quick_gelu=True, ), ) _VITH14 = dict( # LAION-2B-en laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'), # MetaCLIP (quick-gelu) metaclip_fullcc=_pcfg( url="https://dl.fbaipublicfiles.com/MMPT/metaclip/h14_fullcc2.5b.pt", hf_hub="timm/vit_huge_patch14_clip_224.metaclip_2pt5b/", quick_gelu=True, ), metaclip_altogether=_pcfg( url="https://dl.fbaipublicfiles.com/MMPT/metaclip/h14_v1.2_altogether.pt", hf_hub="timm/vit_huge_patch14_clip_224.metaclip_altogether/", # NOTE unlike other MetaCLIP models, this is not using QuickGELU, yay! ), # DFN-5B (quick-gelu) dfn5b=_pcfg( hf_hub='apple/DFN5B-CLIP-ViT-H-14/', quick_gelu=True, interpolation="bicubic", resize_mode="squash" ), ) _VITH14_378 = dict( # DFN-5B (quick-gelu) dfn5b=_pcfg( hf_hub='apple/DFN5B-CLIP-ViT-H-14-378/', quick_gelu=True, interpolation="bicubic", resize_mode="squash" ), ) _VITg14 = dict( laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'), laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'), ) _VITbigG14 = dict( # LAION-2B-en laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'), # MetaCLIP (quick-gelu) metaclip_fullcc=_pcfg( url='https://dl.fbaipublicfiles.com/MMPT/metaclip/G14_fullcc2.5b.pt', hf_hub="timm/vit_gigantic_patch14_clip_224.metaclip_2pt5b/", quick_gelu=True, ), ) _robertaViTB32 = dict( laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'), ) _xlmRobertaBaseViTB32 = dict( laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'), ) _xlmRobertaLargeFrozenViTH14 = dict( frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'), ) _convnext_base = dict( laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'), ) _convnext_base_w = dict( laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'), laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'), laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'), ) _convnext_base_w_320 = dict( laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'), laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'), ) _convnext_large_d = dict( laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'), ) _convnext_large_d_320 = dict( laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'), laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'), ) _convnext_xxlarge = dict( laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'), laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'), laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'), ) _coca_VITB32 = dict( laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'), mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/') ) _coca_VITL14 = dict( laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'), mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/') ) _PRETRAINED = { "RN50": _RN50, "RN101": _RN101, "RN50x4": _RN50x4, "RN50x16": _RN50x16, "RN50x64": _RN50x64, "ViT-B-32": _VITB32, "ViT-B-32-256": _VITB32_256, "ViT-B-16": _VITB16, "ViT-B-16-plus-240": _VITB16_PLUS_240, "ViT-L-14": _VITL14, "ViT-L-14-336": _VITL14_336, "ViT-H-14": _VITH14, "ViT-H-14-378": _VITH14_378, "ViT-g-14": _VITg14, "ViT-bigG-14": _VITbigG14, "roberta-ViT-B-32": _robertaViTB32, "xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32, "xlm-roberta-large-ViT-H-14": _xlmRobertaLargeFrozenViTH14, "convnext_base": _convnext_base, "convnext_base_w": _convnext_base_w, "convnext_base_w_320": _convnext_base_w_320, "convnext_large_d": _convnext_large_d, "convnext_large_d_320": _convnext_large_d_320, "convnext_xxlarge": _convnext_xxlarge, "coca_ViT-B-32": _coca_VITB32, "coca_ViT-L-14": _coca_VITL14, "EVA01-g-14": dict( # from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt laion400m_s11b_b41k=_pcfg(hf_hub='timm/eva_giant_patch14_clip_224.laion400m_s11b_b41k/'), ), "EVA01-g-14-plus": dict( # from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt merged2b_s11b_b114k=_pcfg(hf_hub='timm/eva_giant_patch14_plus_clip_224.merged2b_s11b_b114k/'), ), "EVA02-B-16": dict( # from QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt merged2b_s8b_b131k=_pcfg(hf_hub='timm/eva02_base_patch16_clip_224.merged2b_s8b_b131k/'), ), "EVA02-L-14": dict( # from QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt merged2b_s4b_b131k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_224.merged2b_s4b_b131k/'), ), "EVA02-L-14-336": dict( # from QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt merged2b_s6b_b61k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_336.merged2b_s6b_b61k/'), ), "EVA02-E-14": dict( # from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt laion2b_s4b_b115k=_pcfg(hf_hub='timm/eva02_enormous_patch14_clip_224.laion2b_s4b_b115k/'), ), "EVA02-E-14-plus": dict( # from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt laion2b_s9b_b144k=_pcfg(hf_hub='timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k/'), ), "ViT-B-16-SigLIP": dict( webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP/'), ), "ViT-B-16-SigLIP-256": dict( webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-256/'), ), "ViT-B-16-SigLIP-i18n-256": dict( webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-i18n-256/'), ), "ViT-B-16-SigLIP-384": dict( webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-384/'), ), "ViT-B-16-SigLIP-512": dict( webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-512/'), ), "ViT-L-16-SigLIP-256": dict( webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP-256/'), ), "ViT-L-16-SigLIP-384": dict( webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP-384/'), ), "ViT-SO400M-14-SigLIP": dict( webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP/'), ), "ViT-SO400M-16-SigLIP-i18n-256": dict( webli=_slpcfg(hf_hub='timm/ViT-SO400M-16-SigLIP-i18n-256/'), ), "ViT-SO400M-14-SigLIP-378": dict( webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP-384/'), # NOTE using 384 weights, but diff img_size used ), "ViT-SO400M-14-SigLIP-384": dict( webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP-384/'), ), "ViT-B-32-SigLIP2-256": dict( webli=_slpcfg(hf_hub='timm/ViT-B-32-SigLIP2-256/'), ), "ViT-B-16-SigLIP2": dict( webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP2/'), ), "ViT-B-16-SigLIP2-256": dict( webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP2-256/'), ), "ViT-B-16-SigLIP2-384": dict( webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP2-384/'), ), "ViT-B-16-SigLIP2-512": dict( webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP2-512/'), ), "ViT-L-16-SigLIP2-256": dict( webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP2-256/'), ), "ViT-L-16-SigLIP2-384": dict( webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP2-384/'), ), "ViT-L-16-SigLIP2-512": dict( webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP2-512/'), ), "ViT-SO400M-14-SigLIP2": dict( webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP2/'), ), "ViT-SO400M-14-SigLIP2-378": dict( webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP2-378/'), ), "ViT-SO400M-16-SigLIP2-256": dict( webli=_slpcfg(hf_hub='timm/ViT-SO400M-16-SigLIP2-256/'), ), "ViT-SO400M-16-SigLIP2-384": dict( webli=_slpcfg(hf_hub='timm/ViT-SO400M-16-SigLIP2-384/'), ), "ViT-SO400M-16-SigLIP2-512": dict( webli=_slpcfg(hf_hub='timm/ViT-SO400M-16-SigLIP2-512/'), ), "ViT-gopt-16-SigLIP2-256": dict( webli=_slpcfg(hf_hub='timm/ViT-gopt-16-SigLIP2-256/'), ), "ViT-gopt-16-SigLIP2-384": dict( webli=_slpcfg(hf_hub='timm/ViT-gopt-16-SigLIP2-384/'), ), "ViT-L-14-CLIPA": dict( datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-L-14-CLIPA-datacomp1B/'), ), "ViT-L-14-CLIPA-336": dict( datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-L-14-CLIPA-336-datacomp1B/'), ), "ViT-H-14-CLIPA": dict( datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-datacomp1B/'), ), "ViT-H-14-CLIPA-336": dict( laion2b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-336-laion2B/'), datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-336-datacomp1B/'), ), "ViT-bigG-14-CLIPA": dict( datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-bigG-14-CLIPA-datacomp1B/'), ), "ViT-bigG-14-CLIPA-336": dict( datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-bigG-14-CLIPA-336-datacomp1B/'), ), "nllb-clip-base": dict( v1=_pcfg(hf_hub='visheratin/nllb-clip-base-oc/'), ), "nllb-clip-large": dict( v1=_pcfg(hf_hub='visheratin/nllb-clip-large-oc/'), ), "nllb-clip-base-siglip": dict( v1=_slpcfg(hf_hub='visheratin/nllb-clip-base-siglip/'), mrl=_slpcfg(hf_hub='visheratin/nllb-siglip-mrl-base/'), ), "nllb-clip-large-siglip": dict( v1=_slpcfg(hf_hub='visheratin/nllb-clip-large-siglip/'), mrl=_slpcfg(hf_hub='visheratin/nllb-siglip-mrl-large/'), ), "MobileCLIP-S1": dict( datacompdr=_mccfg(hf_hub='apple/MobileCLIP-S1-OpenCLIP/')), "MobileCLIP-S2": dict( datacompdr=_mccfg(hf_hub='apple/MobileCLIP-S2-OpenCLIP/')), "MobileCLIP-B": dict( datacompdr=_mccfg(hf_hub='apple/MobileCLIP-B-OpenCLIP/'), datacompdr_lt=_mccfg(hf_hub='apple/MobileCLIP-B-LT-OpenCLIP/'), ), "ViTamin-S": dict( datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-S/pytorch_model.bin'), ), "ViTamin-S-LTT": dict( datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-S-LTT/pytorch_model.bin'), ), "ViTamin-B": dict( datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-B/pytorch_model.bin'), ), "ViTamin-B-LTT": dict( datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-B-LTT/pytorch_model.bin'), ), "ViTamin-L": dict( datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-224px/pytorch_model.bin'), ), "ViTamin-L-256": dict( datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-256px/pytorch_model.bin'), ), "ViTamin-L-336": dict( datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-336px/pytorch_model.bin'), ), "ViTamin-L-384": dict( datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-384px/pytorch_model.bin'), ), "ViTamin-L2": dict( datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-224px/pytorch_model.bin'), ), "ViTamin-L2-256": dict( datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-256px/pytorch_model.bin'), ), "ViTamin-L2-336": dict( datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-336px/pytorch_model.bin'), ), "ViTamin-L2-384": dict( datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-384px/pytorch_model.bin'), ), "ViTamin-XL-256": dict( datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-256px/pytorch_model.bin'), ), "ViTamin-XL-336": dict( datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-336px/pytorch_model.bin'), ), "ViTamin-XL-384": dict( datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-384px/pytorch_model.bin'), ), } _PRETRAINED_quickgelu = {} for k, v in _PRETRAINED.items(): quick_gelu_tags = {} for tk, tv in v.items(): if tv.get('quick_gelu', False): quick_gelu_tags[tk] = copy.deepcopy(tv) if quick_gelu_tags: _PRETRAINED_quickgelu[k + '-quickgelu'] = quick_gelu_tags _PRETRAINED.update(_PRETRAINED_quickgelu) def _clean_tag(tag: str): # normalize pretrained tags return tag.lower().replace('-', '_') def list_pretrained(as_str: bool = False): """ returns list of pretrained models Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True """ return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] def list_pretrained_models_by_tag(tag: str): """ return all models having the specified pretrain tag """ models = [] tag = _clean_tag(tag) for k in _PRETRAINED.keys(): if tag in _PRETRAINED[k]: models.append(k) return models def list_pretrained_tags_by_model(model: str): """ return all pretrain tags for the specified model architecture """ tags = [] if model in _PRETRAINED: tags.extend(_PRETRAINED[model].keys()) return tags def is_pretrained_cfg(model: str, tag: str): if model not in _PRETRAINED: return False return _clean_tag(tag) in _PRETRAINED[model] def get_pretrained_cfg(model: str, tag: str): if model not in _PRETRAINED: return {} model_pretrained = _PRETRAINED[model] return model_pretrained.get(_clean_tag(tag), {}) def get_pretrained_url(model: str, tag: str): cfg = get_pretrained_cfg(model, _clean_tag(tag)) return cfg.get('url', '') def download_pretrained_from_url( url: str, cache_dir: Optional[str] = None, ): if not cache_dir: cache_dir = os.path.expanduser("~/.cache/clip") os.makedirs(cache_dir, exist_ok=True) filename = os.path.basename(url) if 'openaipublic' in url: expected_sha256 = url.split("/")[-2] elif 'mlfoundations' in url: expected_sha256 = os.path.splitext(filename)[0].split("-")[-1] else: expected_sha256 = '' download_target = os.path.join(cache_dir, filename) if os.path.exists(download_target) and not os.path.isfile(download_target): raise RuntimeError(f"{download_target} exists and is not a regular file") if os.path.isfile(download_target): if expected_sha256: if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): return download_target else: warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") else: return download_target with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: while True: buffer = source.read(8192) if not buffer: break output.write(buffer) loop.update(len(buffer)) if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") return download_target def has_hf_hub(necessary=False): if not _has_hf_hub and necessary: # if no HF Hub module installed, and it is necessary to continue, raise error raise RuntimeError( 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') return _has_hf_hub def _get_safe_alternatives(filename: str) -> Iterable[str]: """Returns potential safetensors alternatives for a given filename. Use case: When downloading a model from the Huggingface Hub, we first look if a .safetensors file exists and if yes, we use it. """ if filename == HF_WEIGHTS_NAME: yield HF_SAFE_WEIGHTS_NAME if filename not in (HF_WEIGHTS_NAME,) and (filename.endswith(".bin") or filename.endswith(".pth")): yield filename[:-4] + ".safetensors" def download_pretrained_from_hf( model_id: str, filename: Optional[str] = None, revision: Optional[str] = None, cache_dir: Optional[str] = None, ): has_hf_hub(True) filename = filename or HF_WEIGHTS_NAME # Look for .safetensors alternatives and load from it if it exists if _has_safetensors: for safe_filename in _get_safe_alternatives(filename): try: cached_file = hf_hub_download( repo_id=model_id, filename=safe_filename, revision=revision, cache_dir=cache_dir, ) return cached_file except Exception: pass try: # Attempt to download the file cached_file = hf_hub_download( repo_id=model_id, filename=filename, revision=revision, cache_dir=cache_dir, ) return cached_file # Return the path to the downloaded file if successful except Exception as e: raise FileNotFoundError(f"Failed to download file ({filename}) for {model_id}. Last error: {e}") def download_pretrained( cfg: Dict, prefer_hf_hub: bool = True, cache_dir: Optional[str] = None, ): target = '' if not cfg: return target if 'file' in cfg: return cfg['file'] has_hub = has_hf_hub() download_url = cfg.get('url', '') download_hf_hub = cfg.get('hf_hub', '') if has_hub and prefer_hf_hub and download_hf_hub: # prefer to use HF hub, remove url info download_url = '' if download_url: target = download_pretrained_from_url(download_url, cache_dir=cache_dir) elif download_hf_hub: has_hf_hub(True) # we assume the hf_hub entries in pretrained config combine model_id + filename in # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'. model_id, filename = os.path.split(download_hf_hub) if filename: target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir) else: target = download_pretrained_from_hf(model_id, cache_dir=cache_dir) return target # ================================================================== def merge_preprocess_dict( base: Union[PreprocessCfg, Dict], overlay: Dict, ): """ Merge overlay key-value pairs on top of base preprocess cfg or dict. Input dicts are filtered based on PreprocessCfg fields. """ if isinstance(base, PreprocessCfg): base_clean = asdict(base) else: base_clean = {k: v for k, v in base.items() if k in _PREPROCESS_KEYS} if overlay: overlay_clean = {k: v for k, v in overlay.items() if k in _PREPROCESS_KEYS and v is not None} base_clean.update(overlay_clean) return base_clean def merge_preprocess_kwargs(base: PreprocessCfg, **kwargs): return merge_preprocess_dict(base, kwargs) @dataclass class PreprocessCfg: size: Union[int, Tuple[int, int]] = 224 mode: str = 'RGB' mean: Tuple[float, ...] = OPENAI_DATASET_MEAN std: Tuple[float, ...] = OPENAI_DATASET_STD interpolation: str = 'bicubic' resize_mode: str = 'shortest' fill_color: int = 0 def __post_init__(self): assert self.mode in ('RGB',) @property def num_channels(self): return 3 @property def input_size(self): return (self.num_channels,) + to_2tuple(self.size) @dataclass class PreprocessCfg: size: Union[int, Tuple[int, int]] = 224 mode: str = 'RGB' mean: Tuple[float, ...] = OPENAI_DATASET_MEAN std: Tuple[float, ...] = OPENAI_DATASET_STD interpolation: str = 'bicubic' resize_mode: str = 'shortest' fill_color: int = 0 def __post_init__(self): assert self.mode in ('RGB',) @property def num_channels(self): return 3 @property def input_size(self): return (self.num_channels,) + to_2tuple(self.size) _PREPROCESS_KEYS = set(asdict(PreprocessCfg()).keys()) def merge_preprocess_dict( base: Union[PreprocessCfg, Dict], overlay: Dict, ): """ Merge overlay key-value pairs on top of base preprocess cfg or dict. Input dicts are filtered based on PreprocessCfg fields. """ if isinstance(base, PreprocessCfg): base_clean = asdict(base) else: base_clean = {k: v for k, v in base.items() if k in _PREPROCESS_KEYS} if overlay: overlay_clean = {k: v for k, v in overlay.items() if k in _PREPROCESS_KEYS and v is not None} base_clean.update(overlay_clean) return base_clean def merge_preprocess_kwargs(base: PreprocessCfg, **kwargs): return merge_preprocess_dict(base, kwargs) @dataclass class AugmentationCfg: scale: Tuple[float, float] = (0.9, 1.0) ratio: Optional[Tuple[float, float]] = None color_jitter: Optional[Union[float, Tuple[float, float, float], Tuple[float, float, float, float]]] = None re_prob: Optional[float] = None re_count: Optional[int] = None use_timm: bool = False # params for simclr_jitter_gray color_jitter_prob: float = None gray_scale_prob: float = None def _setup_size(size, error_msg): if isinstance(size, numbers.Number): return int(size), int(size) if isinstance(size, Sequence) and len(size) == 1: return size[0], size[0] if len(size) != 2: raise ValueError(error_msg) return size class ResizeKeepRatio: """ Resize and Keep Ratio Copy & paste from `timm` """ def __init__( self, size, longest=0., interpolation=InterpolationMode.BICUBIC, random_scale_prob=0., random_scale_range=(0.85, 1.05), random_aspect_prob=0., random_aspect_range=(0.9, 1.11) ): if isinstance(size, (list, tuple)): self.size = tuple(size) else: self.size = (size, size) self.interpolation = interpolation self.longest = float(longest) # [0, 1] where 0 == shortest edge, 1 == longest self.random_scale_prob = random_scale_prob self.random_scale_range = random_scale_range self.random_aspect_prob = random_aspect_prob self.random_aspect_range = random_aspect_range @staticmethod def get_params( img, target_size, longest, random_scale_prob=0., random_scale_range=(0.85, 1.05), random_aspect_prob=0., random_aspect_range=(0.9, 1.11) ): """Get parameters """ source_size = img.size[::-1] # h, w h, w = source_size target_h, target_w = target_size ratio_h = h / target_h ratio_w = w / target_w ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest) if random_scale_prob > 0 and random.random() < random_scale_prob: ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1]) ratio_factor = (ratio_factor, ratio_factor) else: ratio_factor = (1., 1.) if random_aspect_prob > 0 and random.random() < random_aspect_prob: aspect_factor = random.uniform(random_aspect_range[0], random_aspect_range[1]) ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor) size = [round(x * f / ratio) for x, f in zip(source_size, ratio_factor)] return size def __call__(self, img): """ Args: img (PIL Image): Image to be cropped and resized. Returns: PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size """ size = self.get_params( img, self.size, self.longest, self.random_scale_prob, self.random_scale_range, self.random_aspect_prob, self.random_aspect_range ) img = F.resize(img, size, self.interpolation) return img def __repr__(self): format_string = self.__class__.__name__ + '(size={0}'.format(self.size) format_string += f', interpolation={self.interpolation})' format_string += f', longest={self.longest:.3f})' return format_string def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> torch.Tensor: """Center crops and/or pads the given image. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. If image size is smaller than output size along any edge, image is padded with 0 and then center cropped. Args: img (PIL Image or Tensor): Image to be cropped. output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int, it is used for both directions. fill (int, Tuple[int]): Padding color Returns: PIL Image or Tensor: Cropped image. """ if isinstance(output_size, numbers.Number): output_size = (int(output_size), int(output_size)) elif isinstance(output_size, (tuple, list)) and len(output_size) == 1: output_size = (output_size[0], output_size[0]) _, image_height, image_width = F.get_dimensions(img) crop_height, crop_width = output_size if crop_width > image_width or crop_height > image_height: padding_ltrb = [ (crop_width - image_width) // 2 if crop_width > image_width else 0, (crop_height - image_height) // 2 if crop_height > image_height else 0, (crop_width - image_width + 1) // 2 if crop_width > image_width else 0, (crop_height - image_height + 1) // 2 if crop_height > image_height else 0, ] img = F.pad(img, padding_ltrb, fill=fill) _, image_height, image_width = F.get_dimensions(img) if crop_width == image_width and crop_height == image_height: return img crop_top = int(round((image_height - crop_height) / 2.0)) crop_left = int(round((image_width - crop_width) / 2.0)) return F.crop(img, crop_top, crop_left, crop_height, crop_width) class CenterCropOrPad(torch.nn.Module): """Crops the given image at the center. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. If image size is smaller than output size along any edge, image is padded with 0 and then center cropped. Args: size (sequence or int): Desired output size of the crop. If size is an int instead of sequence like (h, w), a square crop (size, size) is made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). """ def __init__(self, size, fill=0): super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.fill = fill def forward(self, img): """ Args: img (PIL Image or Tensor): Image to be cropped. Returns: PIL Image or Tensor: Cropped image. """ return center_crop_or_pad(img, self.size, fill=self.fill) def __repr__(self) -> str: return f"{self.__class__.__name__}(size={self.size})" def _convert_to_rgb(image): return image.convert('RGB') class color_jitter(object): """ Apply Color Jitter to the PIL image with a specified probability. """ def __init__(self, brightness=0., contrast=0., saturation=0., hue=0., p=0.8): assert 0. <= p <= 1. self.p = p self.transf = ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue) def __call__(self, img): if random.random() < self.p: return self.transf(img) else: return img class gray_scale(object): """ Apply Gray Scale to the PIL image with a specified probability. """ def __init__(self, p=0.2): assert 0. <= p <= 1. self.p = p self.transf = Grayscale(num_output_channels=3) def __call__(self, img): if random.random() < self.p: return self.transf(img) else: return img def image_transform( image_size: Union[int, Tuple[int, int]], is_train: bool, mean: Optional[Tuple[float, ...]] = None, std: Optional[Tuple[float, ...]] = None, resize_mode: Optional[str] = None, interpolation: Optional[str] = None, fill_color: int = 0, aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, ): mean = mean or OPENAI_DATASET_MEAN if not isinstance(mean, (list, tuple)): mean = (mean,) * 3 std = std or OPENAI_DATASET_STD if not isinstance(std, (list, tuple)): std = (std,) * 3 interpolation = interpolation or 'bicubic' assert interpolation in ['bicubic', 'bilinear', 'random'] # NOTE random is ignored for interpolation_mode, so defaults to BICUBIC for inference if set interpolation_mode = InterpolationMode.BILINEAR if interpolation == 'bilinear' else InterpolationMode.BICUBIC resize_mode = resize_mode or 'shortest' assert resize_mode in ('shortest', 'longest', 'squash') if isinstance(aug_cfg, dict): aug_cfg = AugmentationCfg(**aug_cfg) else: aug_cfg = aug_cfg or AugmentationCfg() normalize = Normalize(mean=mean, std=std) if is_train: aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None} use_timm = aug_cfg_dict.pop('use_timm', False) if use_timm: from timm.data import create_transform # timm can still be optional if isinstance(image_size, (tuple, list)): assert len(image_size) >= 2 input_size = (3,) + image_size[-2:] else: input_size = (3, image_size, image_size) aug_cfg_dict.setdefault('color_jitter', None) # disable by default # drop extra non-timm items aug_cfg_dict.pop('color_jitter_prob', None) aug_cfg_dict.pop('gray_scale_prob', None) train_transform = create_transform( input_size=input_size, is_training=True, hflip=0., mean=mean, std=std, re_mode='pixel', interpolation=interpolation, **aug_cfg_dict, ) else: train_transform = [ RandomResizedCrop( image_size, scale=aug_cfg_dict.pop('scale'), interpolation=InterpolationMode.BICUBIC, ), _convert_to_rgb, ] if aug_cfg.color_jitter_prob: assert aug_cfg.color_jitter is not None and len(aug_cfg.color_jitter) == 4 train_transform.extend([ color_jitter(*aug_cfg.color_jitter, p=aug_cfg.color_jitter_prob) ]) if aug_cfg.gray_scale_prob: train_transform.extend([ gray_scale(aug_cfg.gray_scale_prob) ]) train_transform.extend([ ToTensor(), normalize, ]) train_transform = Compose(train_transform) if aug_cfg_dict: warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).') return train_transform else: if resize_mode == 'longest': transforms = [ ResizeKeepRatio(image_size, interpolation=interpolation_mode, longest=1), CenterCropOrPad(image_size, fill=fill_color) ] elif resize_mode == 'squash': if isinstance(image_size, int): image_size = (image_size, image_size) transforms = [ Resize(image_size, interpolation=interpolation_mode), ] else: assert resize_mode == 'shortest' if not isinstance(image_size, (tuple, list)): image_size = (image_size, image_size) if image_size[0] == image_size[1]: # simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg) transforms = [ Resize(image_size[0], interpolation=interpolation_mode) ] else: # resize shortest edge to matching target dim for non-square target transforms = [ResizeKeepRatio(image_size)] transforms += [CenterCrop(image_size)] transforms.extend([ _convert_to_rgb, ToTensor(), normalize, ]) return Compose(transforms) def image_transform_v2( cfg: PreprocessCfg, is_train: bool, aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, ): return image_transform( image_size=cfg.size, is_train=is_train, mean=cfg.mean, std=cfg.std, interpolation=cfg.interpolation, resize_mode=cfg.resize_mode, fill_color=cfg.fill_color, aug_cfg=aug_cfg, ) @dataclass class AugmentationCfg: scale: Tuple[float, float] = (0.9, 1.0) ratio: Optional[Tuple[float, float]] = None color_jitter: Optional[Union[float, Tuple[float, float, float], Tuple[float, float, float, float]]] = None re_prob: Optional[float] = None re_count: Optional[int] = None use_timm: bool = False # params for simclr_jitter_gray color_jitter_prob: float = None gray_scale_prob: float = None def set_model_preprocess_cfg(model, preprocess_cfg: Dict[str, Any]): module = getattr(model, 'visual', model) module.image_mean = preprocess_cfg['mean'] # legacy attribute, keeping for bwd compat module.image_std = preprocess_cfg['std'] # legacy attribute, keeping for bwd compat module.preprocess_cfg = copy.deepcopy(preprocess_cfg) # new attr, package all pp cfg as dict @torch.no_grad() def convert_mobile_clip_state_dict(model: CustomTextCLIP, state_dict, fastvit = True): def _convert_timm_img(state_dict): if fastvit: from timm.models.fastvit import checkpoint_filter_fn else: from timm.models.vision_transformer_hybrid import checkpoint_filter_fn timm_state_dict = checkpoint_filter_fn(state_dict, model.visual.trunk) timm_state_dict = {'visual.trunk.' + k: v for k, v in timm_state_dict.items()} return timm_state_dict def _convert_openclip_txt(state_dict, prefix='text_encoder.'): text_dict = {} for k, v in state_dict.items(): if not k.startswith(prefix): continue k = k.replace(prefix, '') k = k.replace('projection_layer', 'text_projection') k = k.replace('embedding_layer', 'token_embedding') if k.startswith('positional_embedding.pos_embed.pos_embed'): k = k.replace('positional_embedding.pos_embed.pos_embed', 'positional_embedding') v = v.squeeze() k = k.replace('final_layer_norm', 'ln_final') k = k.replace('pre_norm_mha.0', 'ln_1') k = k.replace('pre_norm_mha.1', 'attn') k = k.replace('pre_norm_ffn.0', 'ln_2') k = k.replace('pre_norm_ffn.1', 'mlp.c_fc') k = k.replace('pre_norm_ffn.4', 'mlp.c_proj') k = k.replace('qkv_proj.weight', 'in_proj_weight') k = k.replace('qkv_proj.bias', 'in_proj_bias') k = k.replace('transformer.', 'transformer.resblocks.') text_dict['text.' + k] = v return text_dict image_dict = _convert_timm_img(state_dict) text_dict = _convert_openclip_txt(state_dict) out_dict = {**image_dict, **text_dict} out_dict['logit_scale'] = state_dict['logit_scale'] return out_dict def convert_state_dict(model: Union[CustomTextCLIP, CLIP], state_dict): if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict: # Apple MobileCLIP s1 & s2 state_dicts (s0 and b not currently supported) state_dict = convert_mobile_clip_state_dict(model, state_dict) if 'image_encoder.model.patch_emb.0.block.conv.weight' in state_dict: # convert b model state_dict = convert_mobile_clip_state_dict(model, state_dict, fastvit=False) return state_dict def load_state_dict( checkpoint_path: str, device='cpu', weights_only=True, ): # Check if safetensors or not and load weights accordingly if str(checkpoint_path).endswith(".safetensors"): from safetensors.torch import load_file checkpoint = load_file(checkpoint_path, device=device) else: try: checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=weights_only) except TypeError: checkpoint = torch.load(checkpoint_path, map_location=device) if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] elif isinstance(checkpoint, torch.jit.ScriptModule): state_dict = checkpoint.state_dict() for key in ["input_resolution", "context_length", "vocab_size"]: state_dict.pop(key, None) else: state_dict = checkpoint if next(iter(state_dict.items()))[0].startswith('module'): state_dict = {k[7:]: v for k, v in state_dict.items()} return state_dict def load_checkpoint( model: Union[CLIP, CustomTextCLIP], checkpoint_path: str, strict: bool = True, weights_only: bool = True, device='cpu', ): if Path(checkpoint_path).suffix in ('.npz', '.npy'): # Separate path loading numpy big_vision (SigLIP) weights from open_clip.convert import load_big_vision_weights load_big_vision_weights(model, checkpoint_path) return {} state_dict = load_state_dict(checkpoint_path, device=device, weights_only=weights_only) # Detect & convert 3rd party state_dicts -> open_clip state_dict = convert_state_dict(model, state_dict) # Detect old format and make compatible with new format if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): state_dict = convert_to_custom_text_state_dict(state_dict) # correct if logit_scale differs in being scaler vs 1d param if 'logit_scale' in state_dict and model.logit_scale.ndim != state_dict['logit_scale'].ndim: state_dict['logit_scale'] = state_dict['logit_scale'].reshape(model.logit_scale.shape) # correct if logit_bias differs in being scaler vs 1d param if 'logit_bias' in state_dict and model.logit_bias.ndim != state_dict['logit_bias'].ndim: state_dict['logit_bias'] = state_dict['logit_bias'].reshape(model.logit_bias.shape) # If loading a non-SigLIP model for SigLIP training. See https://github.com/mlfoundations/open_clip/issues/712 if 'logit_bias' not in state_dict and model.logit_bias is not None: state_dict["logit_bias"] = torch.zeros_like(state_dict["logit_scale"]) # Certain text transformers no longer expect position_ids after transformers==4.31 position_id_key = 'text.transformer.embeddings.position_ids' if position_id_key in state_dict and not hasattr(model, position_id_key): del state_dict[position_id_key] resize_pos_embed(state_dict, model) resize_text_pos_embed(state_dict, model) # Finally, load the massaged state_dict into model incompatible_keys = model.load_state_dict(state_dict, strict=strict) return incompatible_keys # /home/IITB/ai-at-ieor/23m1521/.conda/envs/openclip2/lib/python3.11/site-packages/open_clip/factory.py HF_HUB_PREFIX = 'hf-hub:' # _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] _MODEL_CONFIG_PATHS = [Path("/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/model_configs")] _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs import json def _get_hf_config( model_id: str, cache_dir: Optional[str] = None, ): """ Fetch model config from HuggingFace Hub. """ config_path = download_pretrained_from_hf( model_id, filename='open_clip_config.json', cache_dir=cache_dir, ) with open(config_path, 'r', encoding='utf-8') as f: config = json.load(f) return config def get_model_config(model_name): """ Fetch model config from builtin (local library) configs. """ if model_name in _MODEL_CONFIGS: return copy.deepcopy(_MODEL_CONFIGS[model_name]) else: return None def _natural_key(string_): return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] def _rescan_model_configs(): global _MODEL_CONFIGS config_ext = ('.json',) config_files = [] for config_path in _MODEL_CONFIG_PATHS: if config_path.is_file() and config_path.suffix in config_ext: config_files.append(config_path) elif config_path.is_dir(): for ext in config_ext: config_files.extend(config_path.glob(f'*{ext}')) for cf in config_files: with open(cf, 'r') as f: model_cfg = json.load(f) if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')): _MODEL_CONFIGS[cf.stem] = model_cfg _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))} _rescan_model_configs() # initial populate of model config registry def list_models(): """ enumerate available model architectures based on config files """ return list(_MODEL_CONFIGS.keys()) def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs): if past: input_ids = input_ids[:, -1].unsqueeze(-1) attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) else: position_ids = None return { "text": input_ids, "images": image_inputs, "past_key_values": past, "position_ids": position_ids, "attention_mask": attention_mask, } @dataclass class MultimodalCfg(CLIPTextCfg): mlp_ratio: int = 4 dim_head: int = 64 heads: int = 8 n_queries: int = 256 attn_pooler_heads: int = 8 try: from transformers import ( BeamSearchScorer, LogitsProcessorList, TopPLogitsWarper, TopKLogitsWarper, RepetitionPenaltyLogitsProcessor, MinLengthLogitsProcessor, MaxLengthCriteria, StopStringCriteria, EosTokenCriteria, StoppingCriteriaList ) GENERATION_TYPES = { "top_k": TopKLogitsWarper, "top_p": TopPLogitsWarper, "beam_search": "beam_search" } _has_transformers = True except ImportError as e: GENERATION_TYPES = { "top_k": None, "top_p": None, "beam_search": "beam_search" } _has_transformers = False def _token_to_tensor(token_id, device: str = "cpu") -> torch.Tensor: if not isinstance(token_id, torch.Tensor): if isinstance(token_id, int): token_id = [token_id] token_id = torch.tensor(token_id, device=device) return token_id def _build_text_decoder_tower( embed_dim, multimodal_cfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, ): multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg act_layer = QuickGELU if quick_gelu else nn.GELU norm_layer = ( LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm ) decoder = MultimodalTransformer( context_length=multimodal_cfg.context_length, width=multimodal_cfg.width, heads=multimodal_cfg.heads, layers=multimodal_cfg.layers, ls_init_value=multimodal_cfg.ls_init_value, output_dim=embed_dim, act_layer=act_layer, norm_layer=norm_layer, ) return decoder class CoCa(nn.Module): def __init__( self, embed_dim, multimodal_cfg: MultimodalCfg, text_cfg: CLIPTextCfg, vision_cfg: CLIPVisionCfg, quick_gelu: bool = False, init_logit_scale: float = np.log(1 / 0.07), init_logit_bias: Optional[float] = None, nonscalar_logit_scale: bool = False, cast_dtype: Optional[torch.dtype] = None, pad_id: int = 0, ): super().__init__() multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg self.text = _build_text_tower( embed_dim=embed_dim, text_cfg=text_cfg, quick_gelu=quick_gelu, cast_dtype=cast_dtype, ) vocab_size = ( text_cfg.vocab_size # for hf models if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None else text_cfg.vocab_size ) self.visual = _build_vision_tower( embed_dim=embed_dim, vision_cfg=vision_cfg, quick_gelu=quick_gelu, cast_dtype=cast_dtype, ) self.text_decoder = _build_text_decoder_tower( vocab_size, multimodal_cfg=multimodal_cfg, quick_gelu=quick_gelu, cast_dtype=cast_dtype, ) lshape = [1] if nonscalar_logit_scale else [] self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale) if init_logit_bias is not None: self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias) else: self.logit_bias = None self.pad_id = pad_id self.context_length = multimodal_cfg.context_length @torch.jit.ignore def set_grad_checkpointing(self, enable: bool = True): self.visual.set_grad_checkpointing(enable) self.text.set_grad_checkpointing(enable) self.text_decoder.set_grad_checkpointing(enable) def _encode_image(self, images, normalize: bool = True): image_latent, tokens_embs = self.visual(images) image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent return image_latent, tokens_embs def _encode_text(self, text, normalize: bool = True): text_latent, token_emb = self.text(text) text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent return text_latent, token_emb def encode_image(self, images, normalize: bool = True): image_latent, _ = self._encode_image(images, normalize=normalize) return image_latent def encode_text(self, text, normalize: bool = True): text_latent, _ = self._encode_text(text, normalize=normalize) return text_latent def forward_intermediates( self, image: Optional[torch.Tensor] = None, text: Optional[torch.Tensor] = None, image_indices: Optional[Union[int, List[int]]] = None, text_indices: Optional[Union[int, List[int]]] = None, stop_early: bool = False, normalize: bool = True, normalize_intermediates: bool = False, intermediates_only: bool = False, image_output_fmt: str = 'NCHW', image_output_extra_tokens: bool = False, text_output_fmt: str = 'NLC', text_output_extra_tokens: bool = False, output_logits: bool = False, output_logit_scale_bias: bool = False, ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: """ Forward features that returns intermediates. Args: image: Input image tensor text: Input text tensor image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence text_indices: Take last n blocks if int, all if None, select matching indices if sequence stop_early: Stop iterating over blocks when last desired intermediate hit normalize: L2 Normalize final image and text features (if present) normalize_intermediates: Apply final encoder norm layer to all intermediates (if possible) intermediates_only: Only return intermediate features, do not return final features image_output_fmt: Shape of intermediate image feature outputs image_output_extra_tokens: Return both prefix and spatial intermediate tokens text_output_fmt: Shape of intermediate text feature outputs text_output_extra_tokens: Return both prefix and spatial intermediate tokens output_logits: Include logits in output output_logit_scale_bias: Include the logit scale bias in the output Returns: """ output = {} if intermediates_only: # intermediates only disables final feature normalization, and include logits normalize = False output_logits = False if output_logits: assert False, 'FIXME, needs implementing' if image is not None: image_output = self.visual.forward_intermediates( image, indices=image_indices, stop_early=stop_early, normalize_intermediates=normalize_intermediates, intermediates_only=intermediates_only, output_fmt=image_output_fmt, output_extra_tokens=image_output_extra_tokens, ) if normalize and "image_features" in image_output: image_output["image_features"] = F.normalize(image_output["image_features"], dim=-1) output.update(image_output) if text is not None: text_output = self.text.forward_intermediates( text, indices=text_indices, stop_early=stop_early, normalize_intermediates=normalize_intermediates, intermediates_only=intermediates_only, output_fmt=text_output_fmt, output_extra_tokens=text_output_extra_tokens, ) if normalize and "text_features" in text_output: text_output["text_features"] = F.normalize(text_output["text_features"], dim=-1) output.update(text_output) # FIXME text decoder logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None if output_logit_scale_bias: output["logit_scale"] = logit_scale_exp if self.logit_bias is not None: output['logit_bias'] = self.logit_bias return output def forward( self, image, text: Optional[torch.Tensor] = None, image_latent: Optional[torch.Tensor] = None, image_embs: Optional[torch.Tensor] = None, output_labels: bool = True, ): if image_latent is None or image_embs is None: image_latent, image_embs = self._encode_image(image) if text is None: return {"image_features": image_latent, "image_embs": image_embs} text_latent, token_embs = self._encode_text(text) # FIXME this isn't an ideal solution, would like to improve -RW labels: Optional[torch.Tensor] = text[:, 1:] if output_labels else None if output_labels: # align text_embs and thus logits with labels for teacher-forcing caption loss token_embs = token_embs[:, :-1] logits = self.text_decoder(image_embs, token_embs) out_dict = { "image_features": image_latent, "text_features": text_latent, "logits": logits, "logit_scale": self.logit_scale.exp() } if labels is not None: out_dict["labels"] = labels if self.logit_bias is not None: out_dict["logit_bias"] = self.logit_bias return out_dict def generate( self, image, text=None, seq_len=30, max_seq_len=77, temperature=1., generation_type="beam_search", top_p=0.1, # keep tokens in the 1 - top_p quantile top_k=1, # keeps the top_k most probable tokens pad_token_id=None, eos_token_id=None, sot_token_id=None, num_beams=6, num_beam_groups=3, min_seq_len=5, stopping_criteria=None, repetition_penalty=1.0, fixed_output_length=False # if True output.shape == (batch_size, seq_len) ): # taking many ideas and components from HuggingFace GenerationMixin # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`." assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len" device = image.device with torch.no_grad(): sot_token_id = _token_to_tensor(49406 if sot_token_id is None else sot_token_id, device=device) eos_token_id = _token_to_tensor(49407 if eos_token_id is None else eos_token_id, device=device) pad_token_id = self.pad_id if pad_token_id is None else pad_token_id logit_processor = LogitsProcessorList( [ MinLengthLogitsProcessor(min_seq_len, eos_token_id), RepetitionPenaltyLogitsProcessor(repetition_penalty), ] ) if stopping_criteria is None: stopping_criteria = [MaxLengthCriteria(max_length=seq_len)] stopping_criteria = StoppingCriteriaList(stopping_criteria) if generation_type == "beam_search": output = self._generate_beamsearch( image_inputs=image, pad_token_id=pad_token_id, eos_token_id=eos_token_id, sot_token_id=sot_token_id, num_beams=num_beams, num_beam_groups=num_beam_groups, min_seq_len=min_seq_len, stopping_criteria=stopping_criteria, logit_processor=logit_processor, ) if fixed_output_length and output.shape[1] < seq_len: pad_len = seq_len - output.shape[1] return torch.cat(( output, torch.ones(output.shape[0], pad_len, device=device, dtype=output.dtype) * pad_token_id ), dim=1 ) return output elif generation_type == "top_p": logit_warper = GENERATION_TYPES[generation_type](top_p) elif generation_type == "top_k": logit_warper = GENERATION_TYPES[generation_type](top_k) else: raise ValueError( f"generation_type has to be one of " f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}." ) image_latent, image_embs = self._encode_image(image) if text is None: text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id was_training = self.training num_dims = len(text.shape) if num_dims == 1: text = text[None, :] self.eval() out = text while True: x = out[:, -max_seq_len:] cur_len = x.shape[1] logits = self( image, x, image_latent=image_latent, image_embs=image_embs, output_labels=False, )["logits"][:, -1] mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id) sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id if mask.all(): if not fixed_output_length: break else: logits = logits[~mask, :] filtered_logits = logit_processor(x[~mask, :], logits) filtered_logits = logit_warper(x[~mask, :], filtered_logits) probs = F.softmax(filtered_logits / temperature, dim=-1) if (cur_len + 1 == seq_len): sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id else: sample[~mask, :] = torch.multinomial(probs, 1) out = torch.cat((out, sample), dim=-1) cur_len += 1 if all(stopping_criteria(out, None)): break if num_dims == 1: out = out.squeeze(0) self.train(was_training) return out def _generate_beamsearch( self, image_inputs, pad_token_id=None, eos_token_id=None, sot_token_id=None, num_beams=6, num_beam_groups=3, min_seq_len=5, stopping_criteria=None, logit_processor=None, logit_warper=None, ): device = image_inputs.device batch_size = image_inputs.shape[0] image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0) image_latent, image_embs = self._encode_image(image_inputs) input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long) input_ids = input_ids * sot_token_id beam_scorer = BeamSearchScorer( batch_size=batch_size, num_beams=num_beams, device=device, num_beam_groups=num_beam_groups, ) # instantiate logits processors logits_processor = ( LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)]) if logit_processor is None else logit_processor ) num_beams = beam_scorer.num_beams num_beam_groups = beam_scorer.num_beam_groups num_sub_beams = num_beams // num_beam_groups batch_size = len(beam_scorer._beam_hyps) // num_beam_groups batch_beam_size, cur_len = input_ids.shape beam_indices = None if num_beams * batch_size != batch_beam_size: raise ValueError( f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." ) beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in # the same group don't produce same tokens everytime. beam_scores[:, ::num_sub_beams] = 0 beam_scores = beam_scores.view((batch_size * num_beams,)) while True: # predicted tokens in cur_len step current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) # indices which will form the beams in the next time step reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) # do one decoder step on all beams of all sentences in batch model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs) outputs = self( model_inputs['images'], model_inputs['text'], image_latent=image_latent, image_embs=image_embs, output_labels=False, ) for beam_group_idx in range(num_beam_groups): group_start_idx = beam_group_idx * num_sub_beams group_end_idx = min(group_start_idx + num_sub_beams, num_beams) group_size = group_end_idx - group_start_idx # indices of beams of current group among all sentences in batch batch_group_indices = [] for batch_idx in range(batch_size): batch_group_indices.extend( [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] ) group_input_ids = input_ids[batch_group_indices] # select outputs of beams of currentg group only next_token_logits = outputs['logits'][batch_group_indices, -1, :] vocab_size = next_token_logits.shape[-1] next_token_scores_processed = logits_processor( group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx ) next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) next_token_scores = next_token_scores.expand_as(next_token_scores_processed) # reshape for beam search next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) next_token_scores, next_tokens = torch.topk( next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True ) next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") next_tokens = next_tokens % vocab_size # stateless process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None beam_outputs = beam_scorer.process( group_input_ids, next_token_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, beam_indices=process_beam_indices, group_index=beam_group_idx, ) beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] beam_next_tokens = beam_outputs["next_beam_tokens"] beam_idx = beam_outputs["next_beam_indices"] input_ids[batch_group_indices] = group_input_ids[beam_idx] group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) current_tokens[batch_group_indices] = group_input_ids[:, -1] # (beam_idx // group_size) -> batch_idx # (beam_idx % group_size) -> offset of idx inside the group reordering_indices[batch_group_indices] = ( num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size) ) input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) # increase cur_len cur_len = cur_len + 1 if beam_scorer.is_done or all(stopping_criteria(input_ids, None)): break final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None sequence_outputs = beam_scorer.finalize( input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, max_length=stopping_criteria.max_length, beam_indices=final_beam_indices, ) return sequence_outputs['sequences'] def create_model( model_name: str, pretrained: Optional[str] = None, precision: str = 'fp32', device: Union[str, torch.device] = 'cpu', jit: bool = False, force_quick_gelu: bool = False, force_custom_text: bool = False, force_patch_dropout: Optional[float] = None, force_image_size: Optional[Union[int, Tuple[int, int]]] = None, force_preprocess_cfg: Optional[Dict[str, Any]] = None, pretrained_image: bool = False, pretrained_hf: bool = True, cache_dir: Optional[str] = None, output_dict: Optional[bool] = None, require_pretrained: bool = False, load_weights_only: bool = True, **model_kwargs, ): """Creates and configures a contrastive vision-language model. Args: model_name: Name of the model architecture to create. Can be a local model name or a Hugging Face model ID prefixed with 'hf-hub:'. pretrained: Tag/path for pretrained model weights. Can be: - A pretrained tag name (e.g., 'openai') - A path to local weights - None to initialize with random weights precision: Model precision/AMP configuration. Options: - 'fp32': 32-bit floating point - 'fp16'/'bf16': Mixed precision with FP32 for certain layers - 'pure_fp16'/'pure_bf16': Pure 16-bit precision device: Device to load the model on ('cpu', 'cuda', or torch.device object) jit: If True, JIT compile the model force_quick_gelu: Force use of QuickGELU activation force_custom_text: Force use of custom text encoder force_patch_dropout: Override default patch dropout value force_image_size: Override default image size for vision encoder force_preprocess_cfg: Override default preprocessing configuration pretrained_image: Load pretrained weights for timm vision models pretrained_hf: Load pretrained weights for HF text models when not loading CLIP weights cache_dir: Override default cache directory for downloaded model files output_dict: If True and model supports it, return dictionary of features require_pretrained: Raise error if pretrained weights cannot be loaded load_weights_only: Only deserialize model weights and unpickling torch checkpoints (for safety) **model_kwargs: Additional keyword arguments passed to model constructor Returns: Created and configured model instance Raises: RuntimeError: If model config is not found or required pretrained weights cannot be loaded Examples: # Create basic CLIP model model = create_model('ViT-B/32') # Create CLIP model with mixed precision on GPU model = create_model('ViT-B/32', precision='fp16', device='cuda') # Load pretrained OpenAI weights model = create_model('ViT-B/32', pretrained='openai') # Load Hugging Face model model = create_model('hf-hub:organization/model-name') """ force_preprocess_cfg = force_preprocess_cfg or {} preprocess_cfg = asdict(PreprocessCfg()) has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX) if has_hf_hub_prefix: model_id = model_name[len(HF_HUB_PREFIX):] checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir) config = _get_hf_config(model_id, cache_dir=cache_dir) preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg']) model_cfg = config['model_cfg'] pretrained_hf = False # override, no need to load original HF text weights else: model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names checkpoint_path = None model_cfg = None if isinstance(device, str): device = torch.device(device) model_cfg = model_cfg or get_model_config(model_name) if model_cfg is not None: logging.info(f'Loaded {model_name} model config.') else: logging.error(f'Model config for {model_name} not found; available models {list_models()}.') raise RuntimeError(f'Model config for {model_name} not found.') if force_quick_gelu: # override for use of QuickGELU on non-OpenAI transformer models model_cfg["quick_gelu"] = True if force_patch_dropout is not None: # override the default patch dropout value model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout if force_image_size is not None: # override model config's image size model_cfg["vision_cfg"]["image_size"] = force_image_size is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {}) if pretrained_image: if is_timm_model: # pretrained weight loading for timm models set via vision_cfg model_cfg['vision_cfg']['timm_model_pretrained'] = True else: assert False, 'pretrained image towers currently only supported for timm models' # cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes cast_dtype = get_cast_dtype(precision) is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {}) if is_hf_model: # load pretrained weights for HF text model IFF no CLIP weights being loaded model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model model_cfg = dict(model_cfg, **model_kwargs) # merge cfg dict w/ kwargs (kwargs overrides cfg) if custom_text: if "multimodal_cfg" in model_cfg: model = CoCa(**model_cfg, cast_dtype=cast_dtype) else: model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) else: model = CLIP(**model_cfg, cast_dtype=cast_dtype) if precision in ("fp16", "bf16"): dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 # manual mixed precision that matches original OpenAI behaviour if is_timm_model: # FIXME this is a bit janky, create timm based model in low-precision and # then cast only LayerNormFp32 instances back to float32 so they don't break. # Why? The convert_weights_to_lp fn only works with native models. model.to(device=device, dtype=dtype) # from .transformer import LayerNormFp32 def _convert_ln(m): if isinstance(m, LayerNormFp32): m.weight.data = m.weight.data.to(torch.float32) m.bias.data = m.bias.data.to(torch.float32) model.apply(_convert_ln) else: model.to(device=device) convert_weights_to_lp(model, dtype=dtype) elif precision in ("pure_fp16", "pure_bf16"): dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 model.to(device=device, dtype=dtype) else: model.to(device=device) pretrained_loaded = False if pretrained: checkpoint_path = '' pretrained_cfg = get_pretrained_cfg(model_name, pretrained) if pretrained_cfg: checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir) preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg) pretrained_quick_gelu = pretrained_cfg.get('quick_gelu', False) model_quick_gelu = model_cfg.get('quick_gelu', False) if pretrained_quick_gelu and not model_quick_gelu: warnings.warn( f'These pretrained weights were trained with QuickGELU activation but the model config does ' f'not have that enabled. Consider using a model config with a "-quickgelu" suffix or enable with a flag.') elif not pretrained_quick_gelu and model_quick_gelu: warnings.warn( f'The pretrained weights were not trained with QuickGELU but this activation is enabled in the ' f'model config, consider using a model config without QuickGELU or disable override flags.') elif os.path.exists(pretrained): checkpoint_path = pretrained if checkpoint_path: logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') load_checkpoint(model, checkpoint_path, weights_only=load_weights_only) else: error_str = ( f'Pretrained weights ({pretrained}) not found for model {model_name}.' f' Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.') logging.warning(error_str) raise RuntimeError(error_str) pretrained_loaded = True elif has_hf_hub_prefix: logging.info(f'Loading pretrained {model_name} weights ({checkpoint_path}).') load_checkpoint(model, checkpoint_path, weights_only=load_weights_only) pretrained_loaded = True if require_pretrained and not pretrained_loaded: # callers of create_model_from_pretrained always expect pretrained weights raise RuntimeError( f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.') if output_dict and hasattr(model, "output_dict"): model.output_dict = True if jit: model = torch.jit.script(model) # set image preprocessing configuration in model attributes for convenience if getattr(model.visual, 'image_size', None) is not None: # use image_size set on model creation (via config or force_image_size arg) force_preprocess_cfg['size'] = model.visual.image_size set_model_preprocess_cfg(model, merge_preprocess_dict(preprocess_cfg, force_preprocess_cfg)) return model def create_model_and_transforms( model_name: str, pretrained: Optional[str] = None, precision: str = 'fp32', device: Union[str, torch.device] = 'cpu', jit: bool = False, force_quick_gelu: bool = False, force_custom_text: bool = False, force_patch_dropout: Optional[float] = None, force_image_size: Optional[Union[int, Tuple[int, int]]] = None, image_mean: Optional[Tuple[float, ...]] = None, image_std: Optional[Tuple[float, ...]] = None, image_interpolation: Optional[str] = None, image_resize_mode: Optional[str] = None, # only effective for inference aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, pretrained_image: bool = False, pretrained_hf: bool = True, cache_dir: Optional[str] = None, output_dict: Optional[bool] = None, load_weights_only: bool = True, **model_kwargs, ): force_preprocess_cfg = merge_preprocess_kwargs( {}, mean=image_mean, std=image_std, interpolation=image_interpolation, resize_mode=image_resize_mode, ) model = create_model( model_name, pretrained, precision=precision, device=device, jit=jit, force_quick_gelu=force_quick_gelu, force_custom_text=force_custom_text, force_patch_dropout=force_patch_dropout, force_image_size=force_image_size, force_preprocess_cfg=force_preprocess_cfg, pretrained_image=pretrained_image, pretrained_hf=pretrained_hf, cache_dir=cache_dir, output_dict=output_dict, load_weights_only=load_weights_only, **model_kwargs, ) pp_cfg = PreprocessCfg(**model.visual.preprocess_cfg) preprocess_train = image_transform_v2( pp_cfg, is_train=True, aug_cfg=aug_cfg, ) preprocess_val = image_transform_v2( pp_cfg, is_train=False, ) return model, preprocess_train, preprocess_val open_clip_model, open_clip_imgaug, open_clip_preprocess = create_model_and_transforms( model_name='ViT-H-14', pretrained='laion2b_s32b_b79k', device=device ) # print("ashish 1") # exit() # ================================================================== # C S I P - M O D U L E # ================================================================== class CSIP(nn.Module): def __init__(self, image_encoder, audio_encoder, dim_img=None, dim_audio=1024, dim_emb=1024): super(CSIP, self).__init__() self.image_encoder = image_encoder # CLIPVisionModel self.audio_encoder = audio_encoder # CLAP_audio_encoder for param in self.image_encoder.parameters(): param.requires_grad = False # self.image_proj = nn.Linear(dim_img, dim_emb) self.audio_proj = nn.Linear(dim_audio, dim_emb) # Learnable temperature parameter # self.log_temp = nn.Parameter(torch.tensor(1/0.07).log()) self.log_temp = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) def forward(self, images, audios): # image_features = self.image_encoder(images) # shape: [n, dim_img] image_features = images # shape: [n, dim_img] audio_features = self.audio_encoder(audios)[0] # shape: [n, dim_audio] # Step 2: Project and normalize image_embeds = F.normalize(image_features, dim=1) # [n, dim_emb] audio_embeds = F.normalize(self.audio_proj(audio_features), dim=1) # [n, dim_emb] # Step 3: Cosine similarity with temperature logits = torch.matmul(image_embeds, audio_embeds.T) * self.log_temp.exp() # [n, n] probs = logits.softmax(dim=1) # Step 4: Symmetric cross-entropy loss labels = torch.arange(len(images), device=images.device) loss_i = F.cross_entropy(logits, labels) loss_a = F.cross_entropy(logits.T, labels) loss = (loss_i + loss_a) / 2 # Step 5: Similarity metric (average cosine similarity on matched pairs) similarity_scores = (image_embeds * audio_embeds).sum(dim=1) # Cosine similarity of matching pairs avg_similarity = similarity_scores.mean() return loss, loss_i, loss_a, logits, probs, avg_similarity if __name__ == "__main__": # ================================================================== # I M A G E - A U D I O - D A T A S E T # ================================================================== class VaaniImageAudioDataset(torch.utils.data.Dataset): def __init__(self, df, image_features_savedir, audio_tensors_savedir): self.image_paths = df.image_path.tolist() self.audio_paths = df.audio_path.tolist() self.image_features_savedir = image_features_savedir self.audio_tensors_savedir = audio_tensors_savedir def __len__(self): return len(self.audio_paths) def __getitem__(self, idx): return { 'image_path': self.image_paths[idx], 'image_feature': torch.load(os.path.join( self.image_features_savedir, f"{os.path.basename(self.image_paths[idx])}.pt"))['image_features'], 'audio_path': self.audio_paths[idx], 'audio_tensor': torch.load(os.path.join( audio_tensors_savedir, f"{os.path.basename(self.audio_paths[idx])}.pt"))['audio_tensor'] } train_df = pd.read_csv("/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/available_img_audios_TRAIN3.csv") test_df = pd.read_csv("/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/available_img_audios_TEST2.csv") image_features_savedir = '/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Hindi_Image_features/' audio_tensors_savedir = '/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Hindi_Audio_tensors/' train_dataset = VaaniImageAudioDataset(train_df, image_features_savedir, audio_tensors_savedir) test_dataset = VaaniImageAudioDataset(test_df, image_features_savedir, audio_tensors_savedir) print('Train Dataset:', len(train_dataset)) print('Test Dataset:', len(test_dataset)) BATCH_SIZE = int(128) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=48, pin_memory=True, drop_last=False, persistent_workers=True ) test_dataloader = torch.utils.data.DataLoader( test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=48, pin_memory=True, drop_last=False, persistent_workers=True ) batch = next(iter(train_dataloader)) image_features_batch = batch['image_feature'].to(device=device) audio_tensor_batch = batch['audio_tensor'].to(device=device) image_paths_batch = batch['image_path'] audio_paths_batch = batch['audio_path'] print("Image batch shape:", image_features_batch.shape) # [BATCH_SIZE, 3, 224, 224] print("Audio batch shape:", audio_tensor_batch.shape) # [BATCH_SIZE, 1, 44100] csip_model = CSIP(open_clip_model.visual, peft_clap_audio_encoder).to(device) # csip_model = nn.DataParallel(CSIP2(open_clip_model.visual, peft_clap_audio_encoder), device_ids=[0, 1]).to(device) from torchinfo import summary import subprocess for param in csip_model.audio_encoder.model.projection.parameters(): param.requires_grad = True summary(model=csip_model, input_data=((image_features_batch.to(device)), (audio_tensor_batch.to(device))), # input_size = (1, 3, config.IMAGE_SIZE, config.IMAGE_SIZE), dtypes=[torch.long], col_names = ["trainable", "params_percent", "input_size", "output_size", "num_params"], col_width=20, row_settings=["var_names"], depth = 4, # verbose=2, # device=device ) # loss, logits, probs = csip_model(batch['image_tensor'].to(device), batch['audio_tensor'].to(device)) # loss, logits, probs, logits.shape, probs.shape