Vaani-Audio2Img-LDM / Vaani /Img_Audio_Alignment /_2.1.1_Train_OpenCLIP.py
alpha31476's picture
Image Audio Alingment Train OpenClip With Features
632cf1e verified
# ==================================================================
# L A T E N T D I F F U S I O N M O D E L
# ==================================================================
# Author : Ashish Kumar Uchadiya
# Created : May 11, 2025
# Description: This script implements the training of a VQ-VAE model for
# image reconstruction, integrated with Latent Diffusion Models (LDMs) and
# audio conditioning. The VQ-VAE maps images to a discrete latent space,
# which is then modeled by the LDM for learning a diffusion process over the
# compressed representation. Audio features are used as conditioning inputs
# to guide the generation process. The training minimizes a combination of
# LPIPS (Learned Perceptual Image Patch Similarity) loss for perceptual
# fidelity and PatchGAN loss to enforce local realism. This setup enables
# efficient and semantically-aware generation of high-quality images driven
# by audio cues.
# ==================================================================
# I M P O R T S
# ==================================================================
from __future__ import annotations
import warnings
warnings.filterwarnings("ignore")
import os
import io
import sys
import math
import random
import collections
import collections.abc
import re
from itertools import repeat
from pathlib import Path
from typing import Optional, Tuple, Union, List, Dict
import csv
import numpy as np
import pandas as pd
from PIL import Image
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import trange, tqdm
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.init import _calculate_fan_in_and_fan_out
import torch.utils.checkpoint as checkpoint
import torchvision
from torchvision.transforms import v2
from torch.utils.tensorboard import SummaryWriter
# from tensorboardX import SummaryWriter
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
import torchaudio
import torchaudio.transforms as T
from torchlibrosa.stft import Spectrogram, LogmelFilterBank
from torchlibrosa.augmentation import SpecAugmentation
from transformers import AutoModel, AutoTokenizer, logging
from huggingface_hub.file_download import hf_hub_download
from huggingface_hub.file_download import hf_hub_download
from peft import get_peft_config, get_peft_model
from transformers import CLIPVisionModel, AutoProcessor
from watermark import watermark
print(watermark(
author='Ashish',
# email='ashish@example.com',
current_date=True,
datename=True,
current_time=True,
iso8601=True,
timezone=True,
updated=True,
custom_time=None,
python=True,
# packages="torch,torchvision,numpy",
conda=True,
hostname=True,
machine=True,
watermark=False,
iversions=True,
gpu=True,
globals_=globals()
))
# ==================================================================
# H T S - A T
# ==================================================================
class HTSATConfig:
# Ke Chen
# knutchen@ucsd.edu
# HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
# The configuration for training the model
exp_name = "exp_htsat_pretrain" # the saved ckpt prefix name of the model
workspace = "/home/kechen/Research/HTSAT" # the folder of your code
dataset_path = "/home/Research/audioset" # the dataset path
desed_folder = "/home/Research/DESED" # the desed file
dataset_type = "audioset" # "audioset" "esc-50" "scv2"
index_type = "full_train" # only works for audioset
balanced_data = True # only works for audioset
loss_type = "clip_bce" #
# AudioSet & SCV2: "clip_bce" | ESC-50: "clip_ce"
# trained from a checkpoint, or evaluate a single model
resume_checkpoint = None
# "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_1.ckpt"
esc_fold = 0 # just for esc dataset, select the fold you need for evaluation and (+1) validation
debug = False
random_seed = 970131 # 19970318 970131 12412 127777 1009 34047
batch_size = 32 * 4 # batch size per GPU x GPU number , default is 32 x 4 = 128
learning_rate = 1e-3 # 1e-4 also workable
max_epoch = 100
num_workers = 3
lr_scheduler_epoch = [10,20,30]
lr_rate = [0.02, 0.05, 0.1]
# these data preparation optimizations do not bring many improvements, so deprecated
enable_token_label = False # token label
class_map_path = "class_hier_map.npy"
class_filter = None
retrieval_index = [15382, 9202, 130, 17618, 17157, 17516, 16356, 6165, 13992, 9238, 5550, 5733, 1914, 1600, 3450, 13735, 11108, 3762,
9840, 11318, 8131, 4429, 16748, 4992, 16783, 12691, 4945, 8779, 2805, 9418, 2797, 14357, 5603, 212, 3852, 12666, 1338, 10269, 2388, 8260, 4293, 14454, 7677, 11253, 5060, 14938, 8840, 4542, 2627, 16336, 8992, 15496, 11140, 446, 6126, 10691, 8624, 10127, 9068, 16710, 10155, 14358, 7567, 5695, 2354, 8057, 17635, 133, 16183, 14535, 7248, 4560, 14429, 2463, 10773, 113, 2462, 9223, 4929, 14274, 4716, 17307, 4617, 2132, 11083, 1039, 1403, 9621, 13936, 2229, 2875, 17840, 9359, 13311, 9790, 13288, 4750, 17052, 8260, 14900]
token_label_range = [0.2,0.6]
enable_time_shift = False # shift time
enable_label_enhance = False # enhance hierarchical label
enable_repeat_mode = False # repeat the spectrogram / reshape the spectrogram
# for model's design
enable_tscam = True # enbale the token-semantic layer
# for signal processing
sample_rate = 32000 # 16000 for scv2, 32000 for audioset and esc-50
clip_samples = sample_rate * 10 # audio_set 10-sec clip
window_size = 1024
hop_size = 320 # 160 for scv2, 320 for audioset and esc-50
mel_bins = 64
fmin = 50
fmax = 14000
shift_max = int(clip_samples * 0.5)
# for data collection
classes_num = 527 # esc: 50 | audioset: 527 | scv2: 35
patch_size = (25, 4) # deprecated
crop_size = None # int(clip_samples * 0.5) deprecated
# for htsat hyperparamater
htsat_window_size = 8
htsat_spec_size = 256
htsat_patch_size = 4
htsat_stride = (4, 4)
htsat_num_head = [4,8,16,32]
htsat_dim = 96
htsat_depth = [2,2,6,2]
swin_pretrain_path = None
# "/home/Research/model_backup/pretrain/swin_tiny_c24_patch4_window8_256.pth"
# Some Deprecated Optimization in the model design, check the model code for details
htsat_attn_heatmap = False
htsat_hier_output = False
htsat_use_max = False
# for ensemble test
ensemble_checkpoints = []
ensemble_strides = []
# weight average folder
wa_folder = "/home/version_0/checkpoints/"
# weight average output filename
wa_model_path = "HTSAT_AudioSet_Saved_x.ckpt"
esm_model_pathes = [
"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_1.ckpt",
"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_2.ckpt",
"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_3.ckpt",
"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_4.ckpt",
"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_5.ckpt",
"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_6.ckpt"
]
# for framewise localization
heatmap_dir = "/home/Research/heatmap_output"
test_file = "htsat-test-ensemble"
fl_local = False # indicate if we need to use this dataset for the framewise detection
fl_dataset = "/home/Research/desed/desedim_embval.npy"
fl_class_num = [
"Speech", "Frying", "Dishes", "Running_water",
"Blender", "Electric_shaver_toothbrush", "Alarm_bell_ringing",
"Cat", "Dog", "Vacuum_cleaner"
]
# map 527 classes into 10 classes
fl_audioset_mapping = [
[0,1,2,3,4,5,6,7],
[366, 367, 368],
[364],
[288, 289, 290, 291, 292, 293, 294, 295, 296, 297],
[369],
[382],
[310, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402],
[81, 82, 83, 84, 85],
[74, 75, 76, 77, 78, 79],
[377]
]
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple
def do_mixup(x, mixup_lambda):
"""Mixup x of even indexes (0, 2, 4, ...) with x of odd indexes
(1, 3, 5, ...).
Args:
x: (batch_size * 2, ...)
mixup_lambda: (batch_size * 2,)
Returns:
out: (batch_size, ...)
"""
out = (x[0 :: 2].transpose(0, -1) * mixup_lambda[0 :: 2] + \
x[1 :: 2].transpose(0, -1) * mixup_lambda[1 :: 2]).transpose(0, -1)
return out
def interpolate(x, ratio):
"""Interpolate data in time domain. This is used to compensate the
resolution reduction in downsampling of a CNN.
Args:
x: (batch_size, time_steps, classes_num)
ratio: int, ratio to interpolate
Returns:
upsampled: (batch_size, time_steps * ratio, classes_num)
"""
(batch_size, time_steps, classes_num) = x.shape
upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
return upsampled
def drop_path(x, drop_prob: float = 0., training: bool = False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, patch_stride = 16):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patch_stride = to_2tuple(patch_stride)
self.img_size = img_size
self.patch_size = patch_size
self.patch_stride = patch_stride
self.grid_size = (img_size[0] // patch_stride[0], img_size[1] // patch_stride[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.in_chans = in_chans
self.embed_dim = embed_dim
padding = ((patch_size[0] - patch_stride[0]) // 2, (patch_size[1] - patch_stride[1]) // 2)
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride, padding=padding)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x
class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def _no_gradim_audiorunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
# type: (Tensor, float, float, float, float) -> Tensor
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
return _no_gradim_audiorunc_normal_(tensor, mean, std, a, b)
def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
if mode == 'fan_in':
denom = fan_in
elif mode == 'fan_out':
denom = fan_out
elif mode == 'fan_avg':
denom = (fan_in + fan_out) / 2
variance = scale / denom
if distribution == "truncated_normal":
# constant is stddev of standard normal truncated to (-2, 2)
trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
elif distribution == "normal":
tensor.normal_(std=math.sqrt(variance))
elif distribution == "uniform":
bound = math.sqrt(3 * variance)
tensor.uniform_(-bound, bound)
else:
raise ValueError(f"invalid distribution {distribution}")
def lecun_normal_(tensor):
variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
# below codes are based and referred from https://github.com/microsoft/Swin-Transformer
# Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x, attn
def extra_repr(self):
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
# We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model
class SwinTransformerBlock(nn.Module):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_before_mlp='ln'):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
self.norm_before_mlp = norm_before_mlp
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
if self.norm_before_mlp == 'ln':
self.norm2 = nn.LayerNorm(dim)
elif self.norm_before_mlp == 'bn':
self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose(1, 2)
else:
raise NotImplementedError
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def forward(self, x):
# pdb.set_trace()
H, W = self.input_resolution
# print("H: ", H)
# print("W: ", W)
# pdb.set_trace()
B, L, C = x.shape
# assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows, attn = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x, attn
def extra_repr(self):
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
def extra_repr(self):
return f"input_resolution={self.input_resolution}, dim={self.dim}"
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
norm_before_mlp='ln'):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer, norm_before_mlp=norm_before_mlp)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x):
attns = []
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x, attn = blk(x)
if not self.training:
attns.append(attn.unsqueeze(0))
if self.downsample is not None:
x = self.downsample(x)
if not self.training:
attn = torch.cat(attns, dim = 0)
attn = torch.mean(attn, dim = 0)
return x, attn
def extra_repr(self):
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
# The Core of HTSAT
class HTSAT_Swin_Transformer(nn.Module):
r"""HTSAT based on the Swin Transformer
Args:
spec_size (int | tuple(int)): Input Spectrogram size. Default 256
patch_size (int | tuple(int)): Patch size. Default: 4
path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4
in_chans (int): Number of input image channels. Default: 1 (mono)
num_classes (int): Number of classes for classification head. Default: 527
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 8
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
config (module): The configuration Module from config.py (HTSATConfig Class)
"""
def __init__(self, spec_size=256, patch_size=4, patch_stride=(4,4),
in_chans=1, num_classes=527,
embed_dim=96, depths=[2, 2, 6, 2], num_heads=[4, 8, 16, 32],
window_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm,
ape=False, patch_norm=True,
use_checkpoint=False, norm_before_mlp='ln', config = None, **kwargs):
super(HTSAT_Swin_Transformer, self).__init__()
self.config = config
self.spec_size = spec_size
self.patch_stride = patch_stride
self.patch_size = patch_size
self.window_size = window_size
self.embed_dim = embed_dim
self.depths = depths
self.ape = ape
self.in_chans = in_chans
self.num_classes = num_classes
self.num_heads = num_heads
self.num_layers = len(self.depths)
self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1))
self.drop_rate = drop_rate
self.attn_drop_rate = attn_drop_rate
self.drop_path_rate = drop_path_rate
self.qkv_bias = qkv_bias
self.qk_scale = None
self.patch_norm = patch_norm
self.norm_layer = norm_layer if self.patch_norm else None
self.norm_before_mlp = norm_before_mlp
self.mlp_ratio = mlp_ratio
self.use_checkpoint = use_checkpoint
# process mel-spec ; used only once
self.freq_ratio = self.spec_size // self.config.mel_bins
window = 'hann'
center = True
pad_mode = 'reflect'
ref = 1.0
amin = 1e-10
top_db = None
self.interpolate_ratio = 32 # Downsampled ratio
# Spectrogram extractor
self.spectrogram_extractor = Spectrogram(n_fft=config.window_size, hop_length=config.hop_size,
win_length=config.window_size, window=window, center=center, pad_mode=pad_mode,
freeze_parameters=True)
# Logmel feature extractor
self.logmel_extractor = LogmelFilterBank(sr=config.sample_rate, n_fft=config.window_size,
n_mels=config.mel_bins, fmin=config.fmin, fmax=config.fmax, ref=ref, amin=amin, top_db=top_db,
freeze_parameters=True)
# Spec augmenter
self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2,
freq_drop_width=8, freq_stripes_num=2) # 2 2
self.bn0 = nn.BatchNorm2d(self.config.mel_bins)
# split spctrogram into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=self.spec_size, patch_size=self.patch_size, in_chans=self.in_chans,
embed_dim=self.embed_dim, norm_layer=self.norm_layer, patch_stride = patch_stride)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.grid_size
self.patches_resolution = patches_resolution
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=self.drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths))] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(dim=int(self.embed_dim * 2 ** i_layer),
input_resolution=(patches_resolution[0] // (2 ** i_layer),
patches_resolution[1] // (2 ** i_layer)),
depth=self.depths[i_layer],
num_heads=self.num_heads[i_layer],
window_size=self.window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=self.qkv_bias, qk_scale=self.qk_scale,
drop=self.drop_rate, attn_drop=self.attn_drop_rate,
drop_path=dpr[sum(self.depths[:i_layer]):sum(self.depths[:i_layer + 1])],
norm_layer=self.norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint,
norm_before_mlp=self.norm_before_mlp)
self.layers.append(layer)
self.norm = self.norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.maxpool = nn.AdaptiveMaxPool1d(1)
if self.config.enable_tscam:
SF = self.spec_size // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] // self.freq_ratio
self.tscam_conv = nn.Conv2d(
in_channels = self.num_features,
out_channels = self.num_classes,
kernel_size = (SF,3),
padding = (0,1)
)
self.head = nn.Linear(num_classes, num_classes)
else:
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'absolute_pos_embed'}
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'relative_position_bias_table'}
def forward_features(self, x):
frames_num = x.shape[2]
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for i, layer in enumerate(self.layers):
x, attn = layer(x)
if self.config.enable_tscam:
# for x
x = self.norm(x)
B, N, C = x.shape
SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0]
ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1]
x = x.permute(0,2,1).contiguous().reshape(B, C, SF, ST)
B, C, F, T = x.shape
# group 2D CNN
c_freq_bin = F // self.freq_ratio
x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
x = x.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1)
# get latent_output
latent_output = self.avgpool(torch.flatten(x,2))
latent_output = torch.flatten(latent_output, 1)
# display the attention map, if needed
if self.config.htsat_attn_heatmap:
# for attn
attn = torch.mean(attn, dim = 1)
attn = torch.mean(attn, dim = 1)
attn = attn.reshape(B, SF, ST)
c_freq_bin = SF // self.freq_ratio
attn = attn.reshape(B, SF // c_freq_bin, c_freq_bin, ST)
attn = attn.permute(0,2,1,3).contiguous().reshape(B, c_freq_bin, -1)
attn = attn.mean(dim = 1)
attn_max = torch.max(attn, dim = 1, keepdim = True)[0]
attn_min = torch.min(attn, dim = 1, keepdim = True)[0]
attn = ((attn * 0.15) + (attn_max * 0.85 - attn_min)) / (attn_max - attn_min)
attn = attn.unsqueeze(dim = 2)
x = self.tscam_conv(x)
x = torch.flatten(x, 2) # B, C, T
if self.config.htsat_attn_heatmap:
fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous() * attn, 8 * self.patch_stride[1])
else:
fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous(), 8 * self.patch_stride[1])
x = self.avgpool(x)
x = torch.flatten(x, 1)
if self.config.loss_type == "clip_ce":
output_dict = {
'framewise_output': fpx, # already sigmoided
'clipwise_output': x,
'latent_output': latent_output
}
else:
output_dict = {
'framewise_output': fpx, # already sigmoided
'clipwise_output': torch.sigmoid(x),
'latent_output': latent_output
}
else:
x = self.norm(x) # B N C
B, N, C = x.shape
fpx = x.permute(0,2,1).contiguous().reshape(B, C, frames_num // (2 ** (len(self.depths) + 1)), frames_num // (2 ** (len(self.depths) + 1)) )
B, C, F, T = fpx.shape
c_freq_bin = F // self.freq_ratio
fpx = fpx.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
fpx = fpx.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1)
fpx = torch.sum(fpx, dim = 2)
fpx = interpolate(fpx.permute(0,2,1).contiguous(), 8 * self.patch_stride[1])
x = self.avgpool(x.transpose(1, 2)) # B C 1
x = torch.flatten(x, 1)
if self.num_classes > 0:
x = self.head(x)
fpx = self.head(fpx)
output_dict = {'framewise_output': torch.sigmoid(fpx),
'clipwise_output': torch.sigmoid(x)}
return output_dict
def crop_wav(self, x, crop_size, spe_pos = None):
time_steps = x.shape[2]
tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device)
for i in range(len(x)):
if spe_pos is None:
crop_pos = random.randint(0, time_steps - crop_size - 1)
else:
crop_pos = spe_pos
tx[i][0] = x[i, 0, crop_pos:crop_pos + crop_size,:]
return tx
# Reshape the wavform to a img size, if you want to use the pretrained swin transformer model
def reshape_wav2img(self, x):
B, C, T, F = x.shape
target_T = int(self.spec_size * self.freq_ratio)
target_F = self.spec_size // self.freq_ratio
assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size"
# to avoid bicubic zero error
if T < target_T:
x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True)
if F < target_F:
x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True)
x = x.permute(0,1,3,2).contiguous()
x = x.reshape(x.shape[0], x.shape[1], x.shape[2], self.freq_ratio, x.shape[3] // self.freq_ratio)
# print(x.shape)
x = x.permute(0,1,3,2,4).contiguous()
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4])
return x
# Repeat the wavform to a img size, if you want to use the pretrained swin transformer model
def repeat_wat2img(self, x, cur_pos):
B, C, T, F = x.shape
target_T = int(self.spec_size * self.freq_ratio)
target_F = self.spec_size // self.freq_ratio
assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size"
# to avoid bicubic zero error
if T < target_T:
x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True)
if F < target_F:
x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True)
x = x.permute(0,1,3,2).contiguous() # B C F T
x = x[:,:,:,cur_pos:cur_pos + self.spec_size]
x = x.repeat(repeats = (1,1,4,1))
return x
def forward(self, x: torch.Tensor, mixup_lambda = None, infer_mode = False):# out_feat_keys: List[str] = None):
x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins)
x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
x = x.transpose(1, 3)
x = self.bn0(x)
x = x.transpose(1, 3)
if self.training:
x = self.spec_augmenter(x)
if self.training and mixup_lambda is not None:
x = do_mixup(x, mixup_lambda)
if infer_mode:
# in infer mode. we need to handle different length audio input
frame_num = x.shape[2]
target_T = int(self.spec_size * self.freq_ratio)
repeat_ratio = math.floor(target_T / frame_num)
x = x.repeat(repeats=(1,1,repeat_ratio,1))
x = self.reshape_wav2img(x)
output_dict = self.forward_features(x)
elif self.config.enable_repeat_mode:
if self.training:
cur_pos = random.randint(0, (self.freq_ratio - 1) * self.spec_size - 1)
x = self.repeat_wat2img(x, cur_pos)
output_dict = self.forward_features(x)
else:
output_dicts = []
for cur_pos in range(0, (self.freq_ratio - 1) * self.spec_size + 1, self.spec_size):
tx = x.clone()
tx = self.repeat_wat2img(tx, cur_pos)
output_dicts.append(self.forward_features(tx))
clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device)
framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device)
for d in output_dicts:
clipwise_output += d["clipwise_output"]
framewise_output += d["framewise_output"]
clipwise_output = clipwise_output / len(output_dicts)
framewise_output = framewise_output / len(output_dicts)
output_dict = {
'framewise_output': framewise_output,
'clipwise_output': clipwise_output
}
else:
if x.shape[2] > self.freq_ratio * self.spec_size:
if self.training:
x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size)
x = self.reshape_wav2img(x)
output_dict = self.forward_features(x)
else:
# Change: Hard code here
overlap_size = 344 #(x.shape[2] - 1) // 4
output_dicts = []
crop_size = 689 #(x.shape[2] - 1) // 2
for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size):
tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos)
tx = self.reshape_wav2img(tx)
output_dicts.append(self.forward_features(tx))
clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device)
framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device)
latent_output = torch.zeros_like(output_dicts[0]["latent_output"]).float().to(x.device)
for d in output_dicts:
clipwise_output += d["clipwise_output"]
framewise_output += d["framewise_output"]
latent_output += d["latent_output"]
clipwise_output = clipwise_output / len(output_dicts)
framewise_output = framewise_output / len(output_dicts)
latent_output = latent_output / len(output_dicts)
output_dict = {
'framewise_output': framewise_output,
'clipwise_output': clipwise_output,
'latent_output': latent_output,
}
else: # this part is typically used, and most easy one
x = self.reshape_wav2img(x)
output_dict = self.forward_features(x)
# x = self.head(x)
return output_dict
class HTSATWrapper(nn.Module):
def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin,
fmax, classes_num, out_emb):
super().__init__()
# print("parameters are being overidden when using HTSAT")
# print("HTSAT only support loading a pretrained model on AudioSet")
# @TODO later look at what parameters are same and can be merged
self.htsat = HTSAT_Swin_Transformer(config=HTSATConfig())
def forward(self, x):
out_dict = self.htsat(x)
out_dict['embedding'] = out_dict['latent_output']
return out_dict
def get_audio_encoder(name: str):
if name == "HTSAT":
return HTSATWrapper
else:
raise Exception('The audio encoder name {} is incorrect or not supported'.format(name))
class Projection(nn.Module):
def __init__(self, dim_imgn: int, d_out: int, p: float=0.5) -> None:
super().__init__()
self.linear1 = nn.Linear(dim_imgn, d_out, bias=False)
self.linear2 = nn.Linear(d_out, d_out, bias=False)
self.layer_norm = nn.LayerNorm(d_out)
self.drop = nn.Dropout(p)
def forward(self, x: torch.Tensor) -> torch.Tensor:
embed1 = self.linear1(x)
embed2 = self.drop(self.linear2(F.gelu(embed1)))
embeds = self.layer_norm(embed1 + embed2)
return embeds
class AudioEncoder(nn.Module):
def __init__(self, audioenc_name:str, dim_imgn: int, d_out: int, sample_rate: int, window_size: int,
hop_size: int, mel_bins: int, fmin: int, fmax: int, classes_num: int) -> None:
super().__init__()
audio_encoder = get_audio_encoder(audioenc_name)
self.base = audio_encoder(
sample_rate, window_size,
hop_size, mel_bins, fmin, fmax,
classes_num, dim_imgn)
self.projection = Projection(dim_imgn, d_out)
def forward(self, x):
out_dict = self.base(x)
audio_features, audio_classification_output = out_dict['embedding'], out_dict['clipwise_output']
projected_vec = self.projection(audio_features)
return projected_vec, audio_classification_output
class TextEncoder(nn.Module):
def __init__(self, d_out: int, text_model: str, transformer_embed_dim: int) -> None:
super().__init__()
self.text_model = text_model
self.base = AutoModel.from_pretrained(text_model)
if 'clip' in text_model:
self.clip_text_projection = self.base.text_projection
self.base = self.base.text_model
if 'base' in text_model:
transformer_embed_dim = 512
self.projection = Projection(transformer_embed_dim, d_out)
def forward(self, x):
if 'clip' in self.text_model:
pooled_output = self.base(**x)[1] # get pooled output
out = self.clip_text_projection(pooled_output) # get CLS token output
elif 'gpt' in self.text_model:
batch_size = x['input_ids'].shape[0]
hidden_states = self.base(**x)[0] # (batch_size=4, seq_len, 768)
sequence_lengths = torch.ne(x['input_ids'], 0).sum(-1) - 1 # tensor([13, 14, 18, 17])
out = hidden_states[torch.arange(batch_size, device=hidden_states.device), sequence_lengths] # [batch_size, 768] = [4, 768]
else:
out = self.base(**x)[0]
out = out[:, 0, :] # get CLS token output
projected_vec = self.projection(out)
return projected_vec
class CLAP(nn.Module):
def __init__(self,
# audio
audioenc_name: str,
sample_rate: int,
window_size: int,
hop_size: int,
mel_bins: int,
fmin: int,
fmax: int,
classes_num: int,
out_emb: int,
# text
text_model: str,
transformer_embed_dim: int,
# common
d_proj: int,
):
super().__init__()
self.audio_encoder = AudioEncoder(
audioenc_name, out_emb, d_proj,
sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num)
self.caption_encoder = TextEncoder(
d_proj, text_model, transformer_embed_dim
)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
def forward(self, audio, text):
audio_embed, _ = self.audio_encoder(audio)
caption_embed = self.caption_encoder(text)
return caption_embed, audio_embed, self.logit_scale.exp()
# ==================================================================
# A U D I O - P R E - P R O C E S S I N G
# ==================================================================
def read_audio(audio_path, resample=True):
r"""Loads audio file or array and returns a torch tensor"""
# Randomly sample a segment of audio_duration from the clip or pad to match duration
audio_time_series, sample_rate = torchaudio.load(audio_path)
resample_rate = clapConfig.sample_rate
if resample and resample_rate != sample_rate:
resampler = T.Resample(sample_rate, resample_rate)
audio_time_series = resampler(audio_time_series)
return audio_time_series, resample_rate
def load_audio_into_tensor(audio_path, audio_duration, resample=False):
r"""Loads audio file and returns raw audio."""
# Randomly sample a segment of audio_duration from the clip or pad to match duration
audio_time_series, sample_rate = read_audio(audio_path, resample)
audio_time_series = audio_time_series.reshape(-1)
# audio_time_series is shorter than predefined audio duration,
# so audio_time_series is extended
if audio_duration*sample_rate >= audio_time_series.shape[0]:
repeat_factor = int(np.ceil((audio_duration*sample_rate) /
audio_time_series.shape[0]))
# Repeat audio_time_series by repeat_factor to match audio_duration
audio_time_series = audio_time_series.repeat(repeat_factor)
# remove excess part of audio_time_series
audio_time_series = audio_time_series[0:audio_duration*sample_rate]
else:
# audio_time_series is longer than predefined audio duration,
# so audio_time_series is trimmed
start_index = random.randrange(
audio_time_series.shape[0] - audio_duration*sample_rate)
audio_time_series = audio_time_series[start_index:start_index +
audio_duration*sample_rate]
return torch.FloatTensor(audio_time_series)
np_str_obj_array_pattern = re.compile(r'[SaUO]')
default_collate_err_msg_format = (
"default_collate: batch must contain tensors, numpy arrays, numbers, "
"dicts or lists; found {}")
def default_collate(batch):
r"""Puts each data field into a tensor with outer dimension batch size"""
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
# array of string classes and object
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(
default_collate_err_msg_format.format(elem.dtype))
return default_collate([torch.as_tensor(b) for b in batch])
elif elem.shape == (): # scalars
return torch.as_tensor(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int):
return torch.tensor(batch)
elif isinstance(elem, str):
return batch
elif isinstance(elem, collections.abc.Mapping):
return {key: default_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
return elem_type(*(default_collate(samples) for samples in zip(*batch)))
elif isinstance(elem, collections.abc.Sequence):
# check to make sure that the elements in batch have consistent size
it = iter(batch)
elem_size = len(next(it))
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError(
'each element in list of batch should be of equal size')
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]
raise TypeError(default_collate_err_msg_format.format(elem_type))
def preprocess_audio(audio_files, resample):
r"""Load list of audio files and return raw audio"""
audio_tensors = []
for audio_file in audio_files:
audio_tensor = load_audio_into_tensor(
audio_file, clapConfig.duration, resample)
audio_tensor = audio_tensor.reshape(1, -1)
audio_tensors.append(audio_tensor)
return default_collate(audio_tensors)
# ==================================================================
# A U D I O - E M B E D D I N G S - H E L P E R
# ==================================================================
def CLAPAudioProcessor(audio_files: List[str], resample=True):
preprocessed_audio = preprocess_audio(audio_files, resample)
preprocessed_audio = preprocessed_audio.reshape(
preprocessed_audio.shape[0], preprocessed_audio.shape[2])
preprocessed_audio = preprocessed_audio
return preprocessed_audio
def get_audio_embeddings(audio_files: List[str], audio_encoder, resample=True):
"""Load list of audio files and return audio embeddings"""
# preprocessed_audio = preprocess_audio(audio_files, resample)
# with torch.no_grad():
# preprocessed_audio = preprocessed_audio.reshape(
# preprocessed_audio.shape[0], preprocessed_audio.shape[2])
with torch.no_grad():
preprocessed_audio = CLAPAudioProcessor(audio_files, resample)
return audio_encoder(preprocessed_audio)[0]
# ==================================================================
# C L A P
# ==================================================================
class ClapConfig:
# TEXT ENCODER CONFIG
text_model = 'gpt2'
text_len = 77
transformer_embed_dim = 768
freeze_text_encoder_weights = True
# AUDIO ENCODER CONFIG
audioenc_name = 'HTSAT'
out_emb = 768
sample_rate = 44100
duration = 7
fmin = 50
fmax = 8000 # 14000
n_fft = 1024 # 1028
hop_size = 320
mel_bins = 64
window_size = 1024
# PROJECTION SPACE CONFIG
d_proj = 1024
temperature = 0.003
# TRAINING AND EVALUATION CONFIG
num_classes = 527
batch_size = 1024
demo = False
clapConfig = ClapConfig()
clap = CLAP(
audioenc_name=clapConfig.audioenc_name,
sample_rate=clapConfig.sample_rate,
window_size=clapConfig.window_size,
hop_size=clapConfig.hop_size,
mel_bins=clapConfig.mel_bins,
fmin=clapConfig.fmin,
fmax=clapConfig.fmax,
classes_num=clapConfig.num_classes,
out_emb=clapConfig.out_emb,
text_model=clapConfig.text_model,
transformer_embed_dim=clapConfig.transformer_embed_dim,
d_proj=clapConfig.d_proj
)
model_repo = "microsoft/msclap"
model_name = {
'2022': 'CLAP_weights_2022.pth',
'2023': 'CLAP_weights_2023.pth',
'clapcap': 'clapcap_weights_2023.pth'
}
version = '2023'
model_fp = hf_hub_download(model_repo, model_name[version])
model_state_dict = torch.load(model_fp, map_location=torch.device('cpu'))['model']
clap.load_state_dict(model_state_dict, strict=False)
# clap.eval()
clap_audio_encoder = clap.audio_encoder.to(device)
# ENGLISH_AUDIO_DIR = r"/home/IITB/ai-at-ieor/23m1521/datasets/Vaani/Audios/English"
# audio_files = [os.path.join(ENGLISH_AUDIO_DIR, i) for i in os.listdir(ENGLISH_AUDIO_DIR) if i.endswith(".wav")]
# audio_embedding = get_audio_embeddings(audio_files, clap_audio_encoder)
# print("CLAP Audio Encoder Embeddings:", audio_embedding.shape) # [5, 1024]
# ==================================================================
# C L A P - L o R A - M O D E L
# ==================================================================
LoRAconfig = {
"peft_type": "LORA",
"task_type": "FEATURE_EXTRACTION",
"inference_mode": False,
"r": 16,
"target_modules": ["qkv", "fc1", "fc2", "proj", "linear1", "linear2"],
"lora_alpha": 32,
"lora_dropout": 0.05,
"fan_in_fan_out": False,
"bias": "all",
}
peft_config = get_peft_config(LoRAconfig)
peft_model = get_peft_model(clap_audio_encoder, peft_config)
peft_model.print_trainable_parameters()
peft_clap_audio_encoder = peft_model.base_model
# audio_embedding = get_audio_embeddings(audio_files, peft_clap_audio_encoder)
# print("CLAP LoRA Audio Encoder Embeddings:", audio_embedding.shape) # [5, 1024]
# ==================================================================
# O P E N - C L I P - M O D E L
# ==================================================================
# ==================================================================
# I M P O R T S
# ==================================================================
import os
import io
import sys
import math
import random
import collections
import collections.abc
import re
from itertools import repeat
from pathlib import Path
from typing import Optional, Tuple, Union, List, Dict
import csv
import numpy as np
import pandas as pd
from PIL import Image
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import trange, tqdm
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.init import _calculate_fan_in_and_fan_out
import torch.utils.checkpoint as checkpoint
import torchvision
from torchvision.transforms import v2
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(f"Using device: {device}")
import torchaudio
import torchaudio.transforms as T
from torchlibrosa.stft import Spectrogram, LogmelFilterBank
from torchlibrosa.augmentation import SpecAugmentation
from transformers import AutoModel, AutoTokenizer, logging
from huggingface_hub.file_download import hf_hub_download
from huggingface_hub.file_download import hf_hub_download
from peft import get_peft_config, get_peft_model
from typing import Any, Dict, Optional, Tuple, Union
import numbers
import random
import warnings
from dataclasses import dataclass, asdict
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import torch
import torchvision.transforms.functional as F
from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
CenterCrop, ColorJitter, Grayscale
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
INCEPTION_MEAN = (0.5, 0.5, 0.5)
INCEPTION_STD = (0.5, 0.5, 0.5)
# Default name for a weights file hosted on the Huggingface Hub.
HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl
HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version
HF_CONFIG_NAME = 'open_clip_config.json'
import collections.abc
from itertools import repeat
from typing import List, Optional, Tuple, Union
import torch
from torch import nn as nn
from torch import _assert
from torchvision.ops.misc import FrozenBatchNorm2d
def freeze_batch_norm_2d(module, module_match={}, name=''):
"""
Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
returned. Otherwise, the module is walked recursively and submodules are converted in place.
Args:
module (torch.nn.Module): Any PyTorch module.
module_match (dict): Dictionary of full module names to freeze (all if empty)
name (str): Full module name (prefix)
Returns:
torch.nn.Module: Resulting module
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
"""
res = module
is_match = True
if module_match:
is_match = name in module_match
if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
res = FrozenBatchNorm2d(module.num_features)
res.num_features = module.num_features
res.affine = module.affine
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for child_name, child in module.named_children():
full_child_name = '.'.join([name, child_name]) if name else child_name
new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
if new_child is not child:
res.add_module(child_name, new_child)
return res
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = lambda n, x: _ntuple(n)(x)
# Replaces all linear layers with linear_replacement
# TODO: add int8 support for other linear layers including attn and convnets
def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True):
for name, module in model.named_children():
if len(list(module.children())) > 0:
replace_linear(module, linear_replacement, include_modules, copy_weights)
if isinstance(module, torch.nn.Linear) and name in include_modules:
old_module = model._modules[name]
model._modules[name] = linear_replacement(
module.in_features,
module.out_features,
module.bias is not None,
)
if copy_weights:
model._modules[name].weight.data.copy_(old_module.weight.data)
if model._modules[name].bias is not None:
model._modules[name].bias.data.copy_(old_module.bias)
return model
def convert_int8_model_to_inference_mode(model):
for m in model.modules():
if hasattr(m, 'prepare_for_eval'):
int8_original_dtype = m.weight.dtype
m.prepare_for_eval()
m.int8_original_dtype = int8_original_dtype
def feature_take_indices(
num_features: int,
indices: Optional[Union[int, List[int]]] = None,
as_set: bool = False,
) -> Tuple[List[int], int]:
""" Determine the absolute feature indices to 'take' from.
Note: This function can be called in forward() so must be torchscript compatible,
which requires some incomplete typing and workaround hacks.
Args:
num_features: total number of features to select from
indices: indices to select,
None -> select all
int -> select last n
list/tuple of int -> return specified (-ve indices specify from end)
as_set: return as a set
Returns:
List (or set) of absolute (from beginning) indices, Maximum index
"""
if indices is None:
indices = num_features # all features if None
if isinstance(indices, int):
# convert int -> last n indices
_assert(0 < indices <= num_features, f'last-n ({indices}) is out of range (1 to {num_features})')
take_indices = [num_features - indices + i for i in range(indices)]
else:
take_indices: List[int] = []
for i in indices:
idx = num_features + i if i < 0 else i
_assert(0 <= idx < num_features, f'feature index {idx} is out of range (0 to {num_features - 1})')
take_indices.append(idx)
if not torch.jit.is_scripting() and as_set:
return set(take_indices), max(take_indices)
return take_indices, max(take_indices)
def _out_indices_as_tuple(x: Union[int, Tuple[int, ...]]) -> Tuple[int, ...]:
if isinstance(x, int):
# if indices is an int, take last N features
return tuple(range(-x, 0))
return tuple(x)
import copy
import copy
import hashlib
import os
import urllib
import warnings
from functools import partial
from typing import Dict, Iterable, Optional, Union
from tqdm import tqdm
try:
import safetensors.torch
_has_safetensors = True
except ImportError:
_has_safetensors = False
__version__ = '2.32.0'
""" CLIP Model
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
"""
import copy
import logging
import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.checkpoint import checkpoint
from functools import partial
# from .hf_model import HFTextEncoder
# from .modified_resnet import ModifiedResNet
from collections import OrderedDict
import math
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.checkpoint import checkpoint
# from .utils import to_2tuple, feature_take_indices
# from .pos_embed import get_2d_sincos_pos_embed
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# Position embedding utils
# --------------------------------------------------------
import numpy as np
import torch
# --------------------------------------------------------
# 2D sine-cosine position embedding
# References:
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
# MoCo v3: https://github.com/facebookresearch/moco-v3
# --------------------------------------------------------
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=float)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
# --------------------------------------------------------
# Interpolate position embeddings for high-resolution
# References:
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------
def interpolate_pos_embed(model, checkpoint_model):
if 'pos_embed' in checkpoint_model:
pos_embed_checkpoint = checkpoint_model['pos_embed']
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches ** 0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model['pos_embed'] = new_pos_embed
from collections import OrderedDict
from typing import Dict, List, Optional, Union
import torch
from torch import nn
from torch.nn import functional as F
# from .utils import freeze_batch_norm_2d, feature_take_indices
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
super().__init__()
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.act1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.act2 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.act3 = nn.ReLU(inplace=True)
self.downsample = None
self.stride = stride
if stride > 1 or inplanes != planes * Bottleneck.expansion:
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
self.downsample = nn.Sequential(OrderedDict([
("-1", nn.AvgPool2d(stride)),
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
("1", nn.BatchNorm2d(planes * self.expansion))
]))
def forward(self, x: torch.Tensor):
identity = x
out = self.act1(self.bn1(self.conv1(x)))
out = self.act2(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.act3(out)
return out
class AttentionPool2d(nn.Module):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, _ = F.multi_head_attention_forward(
query=x, key=x, value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0.,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False
)
return x[0]
class ModifiedResNet(nn.Module):
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs antialiasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""
def __init__(
self,
layers: List[int],
output_dim: int,
heads: int,
image_size: int = 224,
width: int = 64,
):
super().__init__()
self.output_dim = output_dim
self.image_size = image_size
# the 3-layer stem
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(width // 2)
self.act1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(width // 2)
self.act2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(width)
self.act3 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(2)
# residual layers
self._inplanes = width # this is a *mutable* variable used during construction
self.layer1 = self._make_layer(width, layers[0])
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
self.init_parameters()
def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)]
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))
return nn.Sequential(*layers)
def init_parameters(self):
if self.attnpool is not None:
std = self.attnpool.c_proj.in_features ** -0.5
nn.init.normal_(self.attnpool.q_proj.weight, std=std)
nn.init.normal_(self.attnpool.k_proj.weight, std=std)
nn.init.normal_(self.attnpool.v_proj.weight, std=std)
nn.init.normal_(self.attnpool.c_proj.weight, std=std)
for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
for name, param in resnet_block.named_parameters():
if name.endswith("bn3.weight"):
nn.init.zeros_(param)
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
assert unlocked_groups == 0, 'partial locking not currently supported for this model'
for param in self.parameters():
param.requires_grad = False
if freeze_bn_stats:
freeze_batch_norm_2d(self)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
# FIXME support for non-transformer
pass
def stem(self, x):
x = self.act1(self.bn1(self.conv1(x)))
x = self.act2(self.bn2(self.conv2(x)))
x = self.act3(self.bn3(self.conv3(x)))
x = self.avgpool(x)
return x
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
stop_early: bool = False,
normalize_intermediates: bool = False,
intermediates_only: bool = False,
output_fmt: str = 'NCHW',
output_extra_tokens: bool = False,
) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
stop_early: Stop iterating over blocks when last desired intermediate hit
normalize_intermediates: Apply final norm layer to all intermediates
intermediates_only: Only return intermediate features
output_fmt: Shape of intermediate feature outputs
output_extra_tokens: Return both extra class, eot tokens
Returns:
"""
assert output_fmt in ('NCHW',), 'Output format must be == NCHW.'
# NOTE normalize_intermediates and return_extra_tokens don't apply
take_indices, max_index = feature_take_indices(5, indices)
output = {}
intermediates = []
blocks = [self.stem, self.layer1, self.layer2, self.layer3, self.layer4]
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
blocks = blocks[:max_index + 1]
for i, blk in enumerate(blocks):
x = blk(x)
if i in take_indices:
intermediates.append(x)
output['image_intermediates'] = intermediates
if intermediates_only:
return output
x = self.attnpool(x)
output['image_features'] = x
return output
def forward(self, x):
x = self.stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.attnpool(x)
return x
""" huggingface model adapter
Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
"""
import re
import torch
import torch.nn as nn
from torch import TensorType
try:
import transformers
from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
BaseModelOutputWithPoolingAndCrossAttentions
except ImportError as e:
transformers = None
class BaseModelOutput:
pass
class PretrainedConfig:
pass
# from .hf_configs import arch_dict
# HF architecture dict:
arch_dict = {
# https://huggingface.co/docs/transformers/model_doc/roberta#roberta
"roberta": {
"config_names": {
"context_length": "max_position_embeddings",
"vocab_size": "vocab_size",
"width": "hidden_size",
"heads": "num_attention_heads",
"layers": "num_hidden_layers",
"layer_attr": "layer",
"token_embeddings_attr": "embeddings"
},
"pooler": "mean_pooler",
},
# https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
"xlm-roberta": {
"config_names": {
"context_length": "max_position_embeddings",
"vocab_size": "vocab_size",
"width": "hidden_size",
"heads": "num_attention_heads",
"layers": "num_hidden_layers",
"layer_attr": "layer",
"token_embeddings_attr": "embeddings"
},
"pooler": "mean_pooler",
},
# https://huggingface.co/docs/transformers/model_doc/mt5#mt5
"mt5": {
"config_names": {
# unlimited seqlen
# https://github.com/google-research/text-to-text-transfer-transformer/issues/273
# https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
"context_length": "",
"vocab_size": "vocab_size",
"width": "d_model",
"heads": "num_heads",
"layers": "num_layers",
"layer_attr": "block",
"token_embeddings_attr": "embed_tokens"
},
"pooler": "mean_pooler",
},
# https://huggingface.co/docs/transformers/model_doc/bert
"bert": {
"config_names": {
"context_length": "max_position_embeddings",
"vocab_size": "vocab_size",
"width": "hidden_size",
"heads": "num_attention_heads",
"layers": "num_hidden_layers",
},
"pooler": "cls_pooler",
},
# https://huggingface.co/docs/transformers/model_doc/m2m_100
"m2m_100": {
"config_names": {
"context_length": "max_position_embeddings",
"vocab_size": "vocab_size",
"width": "d_model",
"heads": "encoder_attention_heads",
"layers": "encoder_layers",
},
"pooler": "cls_pooler",
},
}
# utils
def _camel2snake(s):
return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
# TODO: ?last - for gpt-like models
_POOLERS = {}
def register_pooler(cls):
"""Decorator registering pooler class"""
_POOLERS[_camel2snake(cls.__name__)] = cls
return cls
@register_pooler
class MeanPooler(nn.Module):
"""Mean pooling"""
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
@register_pooler
class MaxPooler(nn.Module):
"""Max pooling"""
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
return masked_output.max(1).values
@register_pooler
class ClsPooler(nn.Module):
"""CLS token pooling"""
def __init__(self, use_pooler_output=True):
super().__init__()
self.cls_token_position = 0
self.use_pooler_output = use_pooler_output
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
if (self.use_pooler_output and
isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
(x.pooler_output is not None)
):
return x.pooler_output
return x.last_hidden_state[:, self.cls_token_position, :]
@register_pooler
class ClsLastHiddenStatePooler(nn.Module):
"""CLS token pooling
NOTE: this is equivalent to ClsPooler above with use_pooler_output=False
"""
def __init__(self):
super().__init__()
self.cls_token_position = 0
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
return x.last_hidden_state[:, self.cls_token_position, :]
class HFTextEncoder(nn.Module):
"""HuggingFace model adapter"""
output_tokens: torch.jit.Final[bool]
def __init__(
self,
model_name_or_path: str,
output_dim: int,
config: PretrainedConfig = None,
pooler_type: str = None,
proj_type: str = None,
pretrained: bool = True,
output_tokens: bool = False,
):
super().__init__()
self.output_tokens = output_tokens
self.output_dim = output_dim
# TODO: find better way to get this information
uses_transformer_pooler = (pooler_type == "cls_pooler")
if transformers is None:
raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
if config is None:
self.config = AutoConfig.from_pretrained(model_name_or_path)
create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
AutoModel.from_config, self.config)
# TODO: do all model configs have this attribute? PretrainedConfig does so yes??
if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
self.transformer = create_func(model_args)
self.transformer = self.transformer.encoder
else:
self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
else:
self.config = config
self.transformer = AutoModel.from_config(config)
if pooler_type is None: # get default arch pooler
pooler_type = (arch_dict[self.config.model_type]["pooler"])
# FIXME downstream users of OpenCLIP models use these attr, need to verify valid across all models
self.vocab_size = getattr(self.config, 'vocab_size', 0)
self.context_length = getattr(self.config, 'max_position_embeddings', 0)
self.pooler = _POOLERS[pooler_type]()
d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
if (d_model == output_dim) and (proj_type is None): # do we always need a proj?
self.proj = nn.Identity()
elif proj_type == 'linear':
self.proj = nn.Linear(d_model, output_dim, bias=False)
elif proj_type == 'mlp':
hidden_size = (d_model + output_dim) // 2
self.proj = nn.Sequential(
nn.Linear(d_model, hidden_size, bias=False),
nn.GELU(),
nn.Linear(hidden_size, output_dim, bias=False),
)
def forward(self, x: TensorType):
attn_mask = (x != self.config.pad_token_id).long()
out = self.transformer(input_ids=x, attention_mask=attn_mask)
pooled_out = self.pooler(out, attn_mask)
projected = self.proj(pooled_out)
seq_len = out.last_hidden_state.shape[1]
tokens = (
out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :]
if type(self.pooler) == ClsPooler
else out.last_hidden_state
)
if self.output_tokens:
return projected, tokens
return projected
def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
if not unlocked_layers: # full freezing
for n, p in self.transformer.named_parameters():
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
return
encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
embeddings = getattr(
self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
modules = [embeddings, *layer_list][:-unlocked_layers]
# freeze layers
for module in modules:
for n, p in module.named_parameters():
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.transformer.gradient_checkpointing_enable()
def init_parameters(self):
pass
""" timm model adapter
Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
"""
import logging
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
try:
import timm
from timm.layers import RotAttentionPool2d
from timm.layers import AttentionPool2d as AbsAttentionPool2d
from timm.layers import Mlp, to_2tuple
except ImportError:
timm = None
class TimmModel(nn.Module):
""" timm model adapter
"""
def __init__(
self,
model_name: str,
embed_dim: int,
image_size: Union[int, Tuple[int, int]] = 224,
pool: str = 'avg',
proj: str = 'linear',
proj_bias: bool = False,
drop: float = 0.,
drop_path: Optional[float] = None,
patch_drop: Optional[float] = None,
pretrained: bool = False,
):
super().__init__()
if timm is None:
raise RuntimeError("Please install the latest timm (`pip install timm`) to use timm based models.")
self.image_size = to_2tuple(image_size)
# setup kwargs that may not be common across all models
timm_kwargs = {}
if drop_path is not None:
timm_kwargs['drop_path_rate'] = drop_path
if patch_drop is not None:
timm_kwargs['patch_drop_rate'] = patch_drop
custom_pool = pool in ('abs_attn', 'rot_attn')
if proj:
assert proj in ("linear", "mlp", "none")
extra_proj = proj in ("linear", "mlp")
if not extra_proj and not custom_pool:
# use network classifier head as projection if no proj specified and no custom pooling used
# if projection is explicitly set to "none" will be pass through from network trunk
proj_dim = 0 if proj == 'none' else embed_dim
self.trunk = timm.create_model(
model_name,
num_classes=proj_dim,
global_pool=pool,
pretrained=pretrained,
**timm_kwargs,
)
prev_chs = embed_dim
else:
self.trunk = timm.create_model(
model_name,
pretrained=pretrained,
**timm_kwargs,
)
feat_size = self.trunk.default_cfg.get('pool_size', None)
feature_ndim = 1 if not feat_size else 2
if custom_pool:
assert feature_ndim == 2
# if attn pooling used, remove both classifier and default pool
self.trunk.reset_classifier(0, global_pool='')
else:
# reset global pool if pool config set, otherwise leave as network default
reset_kwargs = dict(global_pool=pool) if pool else {}
self.trunk.reset_classifier(0, **reset_kwargs)
prev_chs = self.trunk.num_features
head_layers = OrderedDict()
# Add custom pooling to head
if pool == 'abs_attn':
head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
prev_chs = embed_dim
elif pool == 'rot_attn':
head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
prev_chs = embed_dim
# NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
if proj == 'linear':
head_layers['drop'] = nn.Dropout(drop)
head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
elif proj == 'mlp':
head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias))
self.head = nn.Sequential(head_layers)
def lock(self, unlocked_groups: int = 0, freeze_bn_stats: bool = False):
""" lock modules
Args:
unlocked_groups (int): leave last n layer groups unlocked (default: 0)
"""
if not unlocked_groups:
# lock full model
for param in self.trunk.parameters():
param.requires_grad = False
if freeze_bn_stats:
freeze_batch_norm_2d(self.trunk)
else:
# NOTE: partial freeze requires latest timm (master) branch and is subject to change
try:
# FIXME import here until API stable and in an official release
from timm.models.helpers import group_parameters, group_modules
except ImportError:
raise RuntimeError(
'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
matcher = self.trunk.group_matcher()
gparams = group_parameters(self.trunk, matcher)
max_layer_id = max(gparams.keys())
max_layer_id = max_layer_id - unlocked_groups
for group_idx in range(max_layer_id + 1):
group = gparams[group_idx]
for param in group:
self.trunk.get_parameter(param).requires_grad = False
if freeze_bn_stats:
gmodules = group_modules(self.trunk, matcher, reverse=True)
gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
freeze_batch_norm_2d(self.trunk, gmodules)
@torch.jit.ignore
def set_grad_checkpointing(self, enable: bool = True):
try:
self.trunk.set_grad_checkpointing(enable)
except Exception as e:
logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
stop_early: bool = False,
normalize_intermediates: bool = False,
intermediates_only: bool = False,
output_fmt: str = 'NCHW',
output_extra_tokens: bool = False,
) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
stop_early: Stop iterating over blocks when last desired intermediate hit
normalize_intermediates: Apply norm layer to all intermediates
intermediates_only: Only return intermediate features
output_fmt: Shape of intermediate feature outputs
output_extra_tokens: Return both prefix and spatial intermediate tokens
Returns:
"""
extra_args = {}
if output_extra_tokens:
extra_args['return_prefix_tokens'] = True
trunk_output = self.trunk.forward_intermediates(
x,
indices=indices,
intermediates_only=intermediates_only,
norm=normalize_intermediates,
stop_early=stop_early,
output_fmt=output_fmt,
**extra_args,
)
return_dict = {}
intermediates = trunk_output if intermediates_only else trunk_output[1]
if output_extra_tokens and intermediates and isinstance(intermediates[0], tuple):
intermediates_prefix = [xi[1] for xi in intermediates]
intermediates = [xi[0] for xi in intermediates]
return_dict['image_intermediates_prefix'] = intermediates_prefix
return_dict['image_intermediates'] = intermediates
if intermediates_only:
return return_dict
image_features = self.trunk.forward_head(trunk_output[0]) # run through timm pooling / projection
image_features = self.head(image_features) # run through adapter pooling / projection
return_dict['image_features'] = image_features
return return_dict
def forward(self, x):
x = self.trunk(x)
x = self.head(x)
return x
class LayerNormFp32(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps)
return x.to(orig_type)
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm (with cast back to input dtype)."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
return x.to(orig_type)
class QuickGELU(nn.Module):
# NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class LayerScale(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class PatchDropout(nn.Module):
"""
https://arxiv.org/abs/2212.00794
"""
def __init__(
self,
prob: float = 0.5,
exclude_first_token: bool = True
):
super().__init__()
assert 0 <= prob < 1.
self.prob = prob
self.exclude_first_token = exclude_first_token # exclude CLS token
def forward(self, x):
if not self.training or self.prob == 0.:
return x
if self.exclude_first_token:
cls_tokens, x = x[:, :1], x[:, 1:]
else:
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
batch = x.size()[0]
num_tokens = x.size()[1]
batch_indices = torch.arange(batch)
batch_indices = batch_indices[..., None]
keep_prob = 1 - self.prob
num_patches_keep = max(1, int(num_tokens * keep_prob))
rand = torch.randn(batch, num_tokens)
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
x = x[batch_indices, patch_indices_keep]
if self.exclude_first_token:
x = torch.cat((cls_tokens, x), dim=1)
return x
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
scaled_cosine: bool = False,
scale_heads: bool = False,
logit_scale_max: float = math.log(1. / 0.01),
batch_first: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.
):
super().__init__()
self.scaled_cosine = scaled_cosine
self.scale_heads = scale_heads
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.logit_scale_max = logit_scale_max
self.batch_first = batch_first
self.use_fsdpa = hasattr(nn.functional, 'scaled_dot_product_attention')
# keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
if qkv_bias:
self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
else:
self.in_proj_bias = None
if self.scaled_cosine:
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
else:
self.logit_scale = None
self.attn_drop = nn.Dropout(attn_drop)
if self.scale_heads:
self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
else:
self.head_scale = None
self.out_proj = nn.Linear(dim, dim)
self.out_drop = nn.Dropout(proj_drop)
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
if self.batch_first:
x = x.transpose(0, 1)
L, N, C = x.shape
q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
q = q.reshape(L, N * self.num_heads, -1).transpose(0, 1)
k = k.reshape(L, N * self.num_heads, -1).transpose(0, 1)
v = v.reshape(L, N * self.num_heads, -1).transpose(0, 1)
if attn_mask is not None and attn_mask.dtype == torch.bool:
new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
attn_mask = new_attn_mask
if self.logit_scale is not None:
attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
attn = attn.view(N, self.num_heads, L, L) * logit_scale
attn = attn.view(-1, L, L)
if attn_mask is not None:
attn = attn + attn_mask
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = torch.bmm(attn, v)
else:
if self.use_fsdpa:
x = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
q = q * self.scale
attn = torch.bmm(q, k.transpose(-1, -2))
if attn_mask is not None:
attn += attn_mask
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = torch.bmm(attn, v)
if self.head_scale is not None:
x = x.view(N, self.num_heads, L, C) * self.head_scale
x = x.view(-1, L, C)
x = x.transpose(0, 1).reshape(L, N, C)
if self.batch_first:
x = x.transpose(0, 1)
x = self.out_proj(x)
x = self.out_drop(x)
return x
class AttentionalPooler(nn.Module):
def __init__(
self,
d_model: int,
context_dim: int,
n_head: int = 8,
n_queries: int = 256,
norm_layer: Callable = LayerNorm,
):
super().__init__()
self.query = nn.Parameter(torch.randn(n_queries, d_model))
self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim, batch_first=True)
self.ln_q = norm_layer(d_model)
self.ln_k = norm_layer(context_dim)
def forward(self, x: torch.Tensor):
N = x.shape[0]
x = self.ln_k(x)
q = self.ln_q(self.query)
out = self.attn(q.unsqueeze(0).expand(N, -1, -1), x, x, need_weights=False)[0]
return out
class ResidualAttentionBlock(nn.Module):
def __init__(
self,
d_model: int,
n_head: int,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
is_cross_attention: bool = False,
batch_first: bool = True,
):
super().__init__()
self.ln_1 = norm_layer(d_model)
self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=batch_first)
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
if is_cross_attention:
self.ln_1_kv = norm_layer(d_model)
self.ln_2 = norm_layer(d_model)
mlp_width = int(d_model * mlp_ratio)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, mlp_width)),
("gelu", act_layer()),
("c_proj", nn.Linear(mlp_width, d_model))
]))
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
def attention(
self,
q_x: torch.Tensor,
k_x: Optional[torch.Tensor] = None,
v_x: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
):
k_x = k_x if k_x is not None else q_x
v_x = v_x if v_x is not None else q_x
attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
return self.attn(
q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask
)[0]
def forward(
self,
q_x: torch.Tensor,
k_x: Optional[torch.Tensor] = None,
v_x: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
):
k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask))
x = x + self.ls_2(self.mlp(self.ln_2(x)))
return x
class CustomResidualAttentionBlock(nn.Module):
def __init__(
self,
d_model: int,
n_head: int,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
scale_cosine_attn: bool = False,
scale_heads: bool = False,
scale_attn: bool = False,
scale_fc: bool = False,
batch_first: bool = True,
):
super().__init__()
self.ln_1 = norm_layer(d_model)
self.attn = Attention(
d_model,
n_head,
scaled_cosine=scale_cosine_attn,
scale_heads=scale_heads,
batch_first=batch_first,
)
self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
self.ln_2 = norm_layer(d_model)
mlp_width = int(d_model * mlp_ratio)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, mlp_width)),
("gelu", act_layer()),
('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()),
("c_proj", nn.Linear(mlp_width, d_model))
]))
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
def get_reference_weight(self):
return self.mlp.c_fc.weight
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask)))
x = x + self.ls_2(self.mlp(self.ln_2(x)))
return x
class CustomTransformer(nn.Module):
""" A custom transformer that can use different block types. """
def __init__(
self,
width: int,
layers: int,
heads: int,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
batch_first: bool = True,
block_types: Union[str, List[str]] = 'CustomResidualAttentionBlock',
):
super().__init__()
self.width = width
self.layers = layers
self.batch_first = batch_first # run transformer stack in batch first (N, L, D)
self.grad_checkpointing = False
if isinstance(block_types, str):
block_types = [block_types] * layers
assert len(block_types) == layers
def _create_block(bt: str):
if bt == 'CustomResidualAttentionBlock':
return CustomResidualAttentionBlock(
width,
heads,
mlp_ratio=mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
batch_first=batch_first,
)
else:
assert False
self.resblocks = nn.ModuleList([
_create_block(bt)
for bt in block_types
])
def get_cast_dtype(self) -> torch.dtype:
weight = self.resblocks[0].get_reference_weight()
if hasattr(weight, 'int8_original_dtype'):
return weight.int8_original_dtype
return weight.dtype
def forward_intermediates(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
indices: Optional[Union[int, List[int]]] = None,
stop_early: bool = False,
):
take_indices, max_index = feature_take_indices(len(self.resblocks), indices)
if not self.batch_first:
x = x.transpose(0, 1).contiguous() # NLD -> LND
intermediates = []
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
blocks = self.resblocks
else:
blocks = self.resblocks[:max_index + 1]
for i, blk in enumerate(blocks):
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(blk, x, None, None, attn_mask, use_reentrant=False)
else:
x = blk(x, attn_mask=attn_mask)
if i in take_indices:
intermediates.append(x.transpose(0, 1) if not self.batch_first else x)
if not self.batch_first:
x = x.transpose(0, 1) # LND -> NLD
return x, intermediates
def prune_intermediate_layers(self, indices: Union[int, List[int]] = 1):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.resblocks), indices)
self.resblocks = self.resblocks[:max_index + 1] # truncate blocks
return take_indices
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
if not self.batch_first:
x = x.transpose(0, 1) # NLD -> LND
for r in self.resblocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False)
else:
x = r(x, attn_mask=attn_mask)
if not self.batch_first:
x = x.transpose(0, 1) # NLD -> LND
return x
class Transformer(nn.Module):
def __init__(
self,
width: int,
layers: int,
heads: int,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
batch_first: bool = True,
):
super().__init__()
self.width = width
self.layers = layers
self.batch_first = batch_first
self.grad_checkpointing = False
self.resblocks = nn.ModuleList([
ResidualAttentionBlock(
width,
heads,
mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
batch_first=batch_first,
)
for _ in range(layers)
])
def get_cast_dtype(self) -> torch.dtype:
if hasattr(self.resblocks[0].mlp.c_fc, 'int8_original_dtype'):
return self.resblocks[0].mlp.c_fc.int8_original_dtype
return self.resblocks[0].mlp.c_fc.weight.dtype
def forward_intermediates(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
indices: Optional[Union[int, List[int]]] = None,
stop_early: bool = False,
):
take_indices, max_index = feature_take_indices(len(self.resblocks), indices)
if not self.batch_first:
x = x.transpose(0, 1).contiguous() # NLD -> LND
intermediates = []
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
blocks = self.resblocks
else:
blocks = self.resblocks[:max_index + 1]
for i, blk in enumerate(blocks):
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(blk, x, None, None, attn_mask, use_reentrant=False)
else:
x = blk(x, attn_mask=attn_mask)
if i in take_indices:
intermediates.append(x.transpose(0, 1) if not self.batch_first else x)
if not self.batch_first:
x = x.transpose(0, 1) # LND -> NLD
return x, intermediates
def prune_intermediate_layers(self, indices: Union[int, List[int]] = 1):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.resblocks), indices)
self.resblocks = self.resblocks[:max_index + 1] # truncate blocks
return take_indices
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
if not self.batch_first:
x = x.transpose(0, 1).contiguous() # NLD -> LND
for r in self.resblocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False)
else:
x = r(x, attn_mask=attn_mask)
if not self.batch_first:
x = x.transpose(0, 1) # LND -> NLD
return x
def _expand_token(token, batch_size: int):
return token.view(1, 1, -1).expand(batch_size, -1, -1)
class VisionTransformer(nn.Module):
output_tokens: torch.jit.Final[bool]
def __init__(
self,
image_size: int,
patch_size: int,
width: int,
layers: int,
heads: int,
mlp_ratio: float,
ls_init_value: float = None,
attentional_pool: bool = False,
attn_pooler_queries: int = 256,
attn_pooler_heads: int = 8,
output_dim: int = 512,
patch_dropout: float = 0.,
no_ln_pre: bool = False,
pos_embed_type: str = 'learnable',
pool_type: str = 'tok',
final_ln_after_pool: bool = False,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
output_tokens: bool = False,
):
super().__init__()
assert pool_type in ('tok', 'avg', 'none')
self.output_tokens = output_tokens
image_height, image_width = self.image_size = to_2tuple(image_size)
patch_height, patch_width = self.patch_size = to_2tuple(patch_size)
self.grid_size = (image_height // patch_height, image_width // patch_width)
self.final_ln_after_pool = final_ln_after_pool # currently ignored w/ attn pool enabled
self.output_dim = output_dim
self.conv1 = nn.Conv2d(
in_channels=3,
out_channels=width,
kernel_size=patch_size,
stride=patch_size,
bias=False,
)
# class embeddings and positional embeddings
scale = width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
if pos_embed_type == 'learnable':
self.positional_embedding = nn.Parameter(
scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
elif pos_embed_type == 'sin_cos_2d':
# fixed sin-cos embedding
assert self.grid_size[0] == self.grid_size[1],\
'currently sin cos 2d pos embedding only supports square input'
self.positional_embedding = nn.Parameter(
torch.zeros(self.grid_size[0] * self.grid_size[1] + 1, width), requires_grad=False)
pos_embed_type = get_2d_sincos_pos_embed(width, self.grid_size[0], cls_token=True)
self.positional_embedding.data.copy_(torch.from_numpy(pos_embed_type).float())
else:
raise ValueError
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
self.ln_pre = nn.Identity() if no_ln_pre else norm_layer(width)
self.transformer = Transformer(
width,
layers,
heads,
mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
)
if attentional_pool:
if isinstance(attentional_pool, str):
self.attn_pool_type = attentional_pool
self.pool_type = 'none'
if attentional_pool in ('parallel', 'cascade'):
self.attn_pool = AttentionalPooler(
output_dim,
width,
n_head=attn_pooler_heads,
n_queries=attn_pooler_queries,
)
self.attn_pool_contrastive = AttentionalPooler(
output_dim,
width,
n_head=attn_pooler_heads,
n_queries=1,
)
else:
assert False
else:
self.attn_pool_type = ''
self.pool_type = pool_type
self.attn_pool = AttentionalPooler(
output_dim,
width,
n_head=attn_pooler_heads,
n_queries=attn_pooler_queries,
)
self.attn_pool_contrastive = None
pool_dim = output_dim
else:
self.attn_pool = None
pool_dim = width
self.pool_type = pool_type
self.ln_post = norm_layer(pool_dim)
self.proj = nn.Parameter(scale * torch.randn(pool_dim, output_dim))
self.init_parameters()
def lock(self, unlocked_groups: int = 0, freeze_bn_stats: bool = False):
for param in self.parameters():
param.requires_grad = False
if unlocked_groups != 0:
groups = [
[
self.conv1,
self.class_embedding,
self.positional_embedding,
self.ln_pre,
],
*self.transformer.resblocks[:-1],
[
self.transformer.resblocks[-1],
self.ln_post,
],
self.proj,
]
def _unlock(x):
if isinstance(x, Sequence):
for g in x:
_unlock(g)
else:
if isinstance(x, torch.nn.Parameter):
x.requires_grad = True
else:
for p in x.parameters():
p.requires_grad = True
_unlock(groups[-unlocked_groups:])
def init_parameters(self):
# FIXME OpenAI CLIP did not define an init for the VisualTransformer
# TODO experiment if default PyTorch init, below, or alternate init is best.
# nn.init.normal_(self.class_embedding, std=self.scale)
# nn.init.normal_(self.positional_embedding, std=self.scale)
#
# proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
# attn_std = self.transformer.width ** -0.5
# fc_std = (2 * self.transformer.width) ** -0.5
# for block in self.transformer.resblocks:
# nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
# nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
# nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
# nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
#
# if self.text_projection is not None:
# nn.init.normal_(self.text_projection, std=self.scale)
pass
@torch.jit.ignore
def set_grad_checkpointing(self, enable: bool = True):
self.transformer.grad_checkpointing = enable
@torch.jit.ignore
def no_weight_decay(self):
# for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default
no_wd = {'positional_embedding', 'class_embedding'}
return no_wd
def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if self.pool_type == 'avg':
pooled, tokens = x[:, 1:].mean(dim=1), x[:, 1:]
elif self.pool_type == 'tok':
pooled, tokens = x[:, 0], x[:, 1:]
else:
pooled = tokens = x
return pooled, tokens
def _embeds(self, x:torch.Tensor) -> torch.Tensor:
x = self.conv1(x) # shape = [*, dim, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
# class embeddings and positional embeddings
x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1)
# shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
# patch dropout (if active)
x = self.patch_dropout(x)
# apply norm before transformer
x = self.ln_pre(x)
return x
def _pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if self.attn_pool is not None:
if self.attn_pool_contrastive is not None:
# This is untested, WIP pooling that should match paper
x = self.ln_post(x) # TBD LN first or separate one after each pool?
tokens = self.attn_pool(x)
if self.attn_pool_type == 'parallel':
pooled = self.attn_pool_contrastive(x)
else:
assert self.attn_pool_type == 'cascade'
pooled = self.attn_pool_contrastive(tokens)
else:
# this is the original OpenCLIP CoCa setup, does not match paper
x = self.attn_pool(x)
x = self.ln_post(x)
pooled, tokens = self._global_pool(x)
elif self.final_ln_after_pool:
pooled, tokens = self._global_pool(x)
pooled = self.ln_post(pooled)
else:
x = self.ln_post(x)
pooled, tokens = self._global_pool(x)
return pooled, tokens
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
stop_early: bool = False,
normalize_intermediates: bool = False,
intermediates_only: bool = False,
output_fmt: str = 'NCHW',
output_extra_tokens: bool = False,
) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
stop_early: Stop iterating over blocks when last desired intermediate hit
intermediates_only: Only return intermediate features
normalize_intermediates: Apply final norm layer to all intermediates
output_fmt: Shape of intermediate feature outputs
output_extra_tokens: Return both extra prefix class tokens
Returns:
"""
assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
reshape = output_fmt == 'NCHW'
# forward pass
B, _, height, width = x.shape
x = self._embeds(x)
x, intermediates = self.transformer.forward_intermediates(
x,
indices=indices,
stop_early=stop_early,
)
# process intermediates
if normalize_intermediates:
# apply final norm to all intermediates
intermediates = [self.ln_post(xi) for xi in intermediates]
num_prefix_tokens = 1 # one class token that's always there (as of now)
if num_prefix_tokens:
# split prefix (e.g. class, distill) and spatial feature tokens
prefix_tokens = [y[:, 0:num_prefix_tokens] for y in intermediates]
intermediates = [y[:, num_prefix_tokens:] for y in intermediates]
else:
prefix_tokens = None
if reshape:
# reshape to BCHW output format
H, W = height // self.patch_size[0], width // self.patch_size[1]
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
output = {'image_intermediates': intermediates}
if prefix_tokens is not None and output_extra_tokens:
output['image_intermediates_prefix'] = prefix_tokens
if intermediates_only:
return output
pooled, _ = self._pool(x)
if self.proj is not None:
pooled = pooled @ self.proj
output['image_features'] = pooled
return output
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices = self.transformer.prune_intermediate_layers(indices)
if prune_norm:
self.ln_post = nn.Identity()
if prune_head:
self.proj = None
return take_indices
def forward(self, x: torch.Tensor):
x = self._embeds(x)
x = self.transformer(x)
pooled, tokens = self._pool(x)
if self.proj is not None:
pooled = pooled @ self.proj
if self.output_tokens:
return pooled, tokens
return pooled
def text_global_pool(
x: torch.Tensor,
text: Optional[torch.Tensor] = None,
pool_type: str = 'argmax',
) -> torch.Tensor:
if pool_type == 'first':
pooled = x[:, 0]
elif pool_type == 'last':
pooled = x[:, -1]
elif pool_type == 'argmax':
# take features from the eot embedding (eot_token is the highest number in each sequence)
assert text is not None
pooled = x[torch.arange(x.shape[0]), text.argmax(dim=-1)]
else:
pooled = x
return pooled
class TextTransformer(nn.Module):
output_tokens: torch.jit.Final[bool]
def __init__(
self,
context_length: int = 77,
vocab_size: int = 49408,
width: int = 512,
heads: int = 8,
layers: int = 12,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
output_dim: Optional[int] = 512,
embed_cls: bool = False,
no_causal_mask: bool = False,
pad_id: int = 0,
pool_type: str = 'argmax',
proj_type: str = 'linear',
proj_bias: bool = False,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
output_tokens: bool = False,
):
super().__init__()
assert pool_type in ('first', 'last', 'argmax', 'none')
self.output_tokens = output_tokens
self.num_pos = self.context_length = context_length
self.vocab_size = vocab_size
self.width = width
self.output_dim = output_dim
self.heads = heads
self.pad_id = pad_id
self.pool_type = pool_type
self.token_embedding = nn.Embedding(vocab_size, width)
if embed_cls:
self.cls_emb = nn.Parameter(torch.empty(width))
self.num_pos += 1
else:
self.cls_emb = None
self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
self.transformer = Transformer(
width=width,
layers=layers,
heads=heads,
mlp_ratio=mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
)
self.ln_final = norm_layer(width)
if no_causal_mask:
self.attn_mask = None
else:
self.register_buffer('attn_mask', self.build_causal_mask(), persistent=False)
if proj_type == 'none' or not output_dim:
self.text_projection = None
else:
if proj_bias:
self.text_projection = nn.Linear(width, output_dim)
else:
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
self.init_parameters()
def init_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
if self.cls_emb is not None:
nn.init.normal_(self.cls_emb, std=0.01)
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
attn_std = self.transformer.width ** -0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text_projection is not None:
if isinstance(self.text_projection, nn.Linear):
nn.init.normal_(self.text_projection.weight, std=self.transformer.width ** -0.5)
if self.text_projection.bias is not None:
nn.init.zeros_(self.text_projection.bias)
else:
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.transformer.grad_checkpointing = enable
@torch.jit.ignore
def no_weight_decay(self):
# for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default
no_wd = {'positional_embedding'}
if self.cls_emb is not None:
no_wd.add('cls_emb')
return no_wd
def build_causal_mask(self):
# lazily create causal attention mask, with full attention between the tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.num_pos, self.num_pos)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
def build_cls_mask(self, text, cast_dtype: torch.dtype):
cls_mask = (text != self.pad_id).unsqueeze(1)
cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True)
additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device)
additive_mask.fill_(0)
additive_mask.masked_fill_(~cls_mask, float("-inf"))
additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)
return additive_mask
def _embeds(self, text) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
cast_dtype = self.transformer.get_cast_dtype()
seq_len = text.shape[1]
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
attn_mask = self.attn_mask
if self.cls_emb is not None:
seq_len += 1
x = torch.cat([x, _expand_token(self.cls_emb, x.shape[0])], dim=1)
cls_mask = self.build_cls_mask(text, cast_dtype)
if attn_mask is not None:
attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len]
x = x + self.positional_embedding[:seq_len].to(cast_dtype)
return x, attn_mask
def forward_intermediates(
self,
text: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
stop_early: bool = False,
normalize_intermediates: bool = False,
intermediates_only: bool = False,
output_fmt: str = 'NCHW',
output_extra_tokens: bool = False,
) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
text: Input text ids
indices: Take last n blocks if int, all if None, select matching indices if sequence
stop_early: Stop iterating over blocks when last desired intermediate hit
normalize_intermediates: Apply norm layer to all intermediates
intermediates_only: Only return intermediate features
output_fmt: Shape of intermediate feature outputs
output_extra_tokens: Return both prefix and intermediate tokens
Returns:
"""
assert output_fmt in ('NLC',), 'Output format must be NLC.'
# forward pass
x, attn_mask = self._embeds(text)
x, intermediates = self.transformer.forward_intermediates(
x,
attn_mask=attn_mask,
indices=indices,
stop_early=stop_early,
)
# process intermediates
if normalize_intermediates:
# apply final norm to all intermediates
intermediates = [self.ln_final(xi) for xi in intermediates]
output = {}
if self.cls_emb is not None:
seq_intermediates = [xi[:, :-1] for xi in intermediates] # separate concat'd class token from sequence
if output_extra_tokens:
# return suffix class tokens separately
cls_intermediates = [xi[:, -1:] for xi in intermediates]
output['text_intermediates_suffix'] = cls_intermediates
intermediates = seq_intermediates
output['text_intermediates'] = intermediates
if intermediates_only:
return output
if self.cls_emb is not None:
# presence of appended cls embed (CoCa) overrides pool_type, always take last token
pooled = text_global_pool(x, pool_type='last')
pooled = self.ln_final(pooled) # final LN applied after pooling in this case
else:
x = self.ln_final(x)
pooled = text_global_pool(x, text, pool_type=self.pool_type)
if self.text_projection is not None:
if isinstance(self.text_projection, nn.Linear):
pooled = self.text_projection(pooled)
else:
pooled = pooled @ self.text_projection
output['text_features'] = pooled
return output
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices = self.transformer.prune_intermediate_layers(indices)
if prune_norm:
self.ln_final = nn.Identity()
if prune_head:
self.text_projection = None
return take_indices
def forward(self, text):
x, attn_mask = self._embeds(text)
x = self.transformer(x, attn_mask=attn_mask)
# x.shape = [batch_size, n_ctx, transformer.width]
if self.cls_emb is not None:
# presence of appended cls embed (CoCa) overrides pool_type, always take last token
pooled = text_global_pool(x, pool_type='last')
pooled = self.ln_final(pooled) # final LN applied after pooling in this case
tokens = x[:, :-1]
else:
x = self.ln_final(x)
pooled = text_global_pool(x, text, pool_type=self.pool_type)
tokens = x
if self.text_projection is not None:
if isinstance(self.text_projection, nn.Linear):
pooled = self.text_projection(pooled)
else:
pooled = pooled @ self.text_projection
if self.output_tokens:
return pooled, tokens
return pooled
class MultimodalTransformer(Transformer):
def __init__(
self,
width: int,
layers: int,
heads: int,
context_length: int = 77,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
output_dim: int = 512,
batch_first: bool = True,
):
super().__init__(
width=width,
layers=layers,
heads=heads,
mlp_ratio=mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
batch_first=batch_first,
)
self.context_length = context_length
self.cross_attn = nn.ModuleList([
ResidualAttentionBlock(
width,
heads,
mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
is_cross_attention=True,
batch_first=batch_first,
)
for _ in range(layers)
])
self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
self.ln_final = norm_layer(width)
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
def init_parameters(self):
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
attn_std = self.transformer.width ** -0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
for block in self.transformer.cross_attn:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text_projection is not None:
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
def forward_intermediates(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
indices: Optional[Union[int, List[int]]] = None,
stop_early: bool = False,
):
assert False, "Not currently implemented for MultimodalTransformer w/ xattn"
def forward(self, image_embs, text_embs):
seq_len = text_embs.shape[1]
if not self.batch_first:
image_embs = image_embs.permute(1, 0, 2) # NLD -> LND
text_embs = text_embs.permute(1, 0, 2) # NLD -> LND
for resblock, cross_attn in zip(self.resblocks, self.cross_attn):
if self.grad_checkpointing and not torch.jit.is_scripting():
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
text_embs = checkpoint(
resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len], use_reentrant=False)
text_embs = checkpoint(
cross_attn, text_embs, image_embs, image_embs, None, use_reentrant=False)
else:
text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len])
text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs)
if not self.batch_first:
text_embs = text_embs.permute(1, 0, 2) # LND -> NLD
out = self.ln_final(text_embs)
if self.text_projection is not None:
out = out @ self.text_projection
return out
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@dataclass
class CLIPVisionCfg:
layers: Union[Tuple[int, int, int, int], int] = 12
width: int = 768
head_width: int = 64
mlp_ratio: float = 4.0
patch_size: int = 16
image_size: Union[Tuple[int, int], int] = 224
ls_init_value: Optional[float] = None # layer scale initial value
patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer (overrides pool_type)
attn_pooler_queries: int = 256 # n_queries for attentional pooler
attn_pooler_heads: int = 8 # n heads for attentional_pooling
no_ln_pre: bool = False # disable pre transformer LayerNorm
pos_embed_type: str = 'learnable'
final_ln_after_pool: bool = False # apply final LayerNorm after pooling
pool_type: str = 'tok'
output_tokens: bool = False
act_kwargs: Optional[dict] = None
norm_kwargs: Optional[dict] = None
timm_model_name: Optional[str] = None # a valid model name overrides layers, width, patch_size
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
timm_proj_bias: bool = False # enable bias final projection
timm_drop: float = 0. # head dropout
timm_drop_path: Optional[float] = None # backbone stochastic depth
@dataclass
class CLIPTextCfg:
context_length: int = 77
vocab_size: int = 49408
hf_tokenizer_name: Optional[str] = None
tokenizer_kwargs: Optional[dict] = None
width: int = 512
heads: int = 8
layers: int = 12
mlp_ratio: float = 4.0
ls_init_value: Optional[float] = None # layer scale initial value
embed_cls: bool = False
pad_id: int = 0
no_causal_mask: bool = False # disable causal masking
final_ln_after_pool: bool = False # apply final LayerNorm after pooling
pool_type: str = 'argmax'
proj_bias: bool = False
proj_type: str = 'linear' # control final text projection, 'none' forces no projection
output_tokens: bool = False
act_kwargs: dict = None
norm_kwargs: dict = None
# HuggingFace specific text tower config
hf_model_name: Optional[str] = None
hf_model_pretrained: bool = True
hf_proj_type: str = 'mlp'
hf_pooler_type: str = 'mean_pooler' # attentional pooling for HF models
def get_cast_dtype(precision: str):
cast_dtype = None
if precision == 'bf16':
cast_dtype = torch.bfloat16
elif precision == 'fp16':
cast_dtype = torch.float16
return cast_dtype
def get_input_dtype(precision: str):
input_dtype = None
if precision in ('bf16', 'pure_bf16'):
input_dtype = torch.bfloat16
elif precision in ('fp16', 'pure_fp16'):
input_dtype = torch.float16
return input_dtype
def _build_vision_tower(
embed_dim: int,
vision_cfg: CLIPVisionCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None
):
if isinstance(vision_cfg, dict):
vision_cfg = CLIPVisionCfg(**vision_cfg)
# OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
# memory efficient in recent PyTorch releases (>= 1.10).
# NOTE: timm models always use native GELU regardless of quick_gelu flag.
act_layer = QuickGELU if quick_gelu else nn.GELU
if vision_cfg.timm_model_name:
visual = TimmModel(
vision_cfg.timm_model_name,
pretrained=vision_cfg.timm_model_pretrained,
pool=vision_cfg.timm_pool,
proj=vision_cfg.timm_proj,
proj_bias=vision_cfg.timm_proj_bias,
drop=vision_cfg.timm_drop,
drop_path=vision_cfg.timm_drop_path,
patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None,
embed_dim=embed_dim,
image_size=vision_cfg.image_size,
)
elif isinstance(vision_cfg.layers, (tuple, list)):
vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
visual = ModifiedResNet(
layers=vision_cfg.layers,
output_dim=embed_dim,
heads=vision_heads,
image_size=vision_cfg.image_size,
width=vision_cfg.width,
)
else:
vision_heads = vision_cfg.width // vision_cfg.head_width
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
if vision_cfg.norm_kwargs:
norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs)
if vision_cfg.act_kwargs is not None:
act_layer = partial(act_layer, **vision_cfg.act_kwargs)
visual = VisionTransformer(
image_size=vision_cfg.image_size,
patch_size=vision_cfg.patch_size,
width=vision_cfg.width,
layers=vision_cfg.layers,
heads=vision_heads,
mlp_ratio=vision_cfg.mlp_ratio,
ls_init_value=vision_cfg.ls_init_value,
patch_dropout=vision_cfg.patch_dropout,
attentional_pool=vision_cfg.attentional_pool,
attn_pooler_queries=vision_cfg.attn_pooler_queries,
attn_pooler_heads=vision_cfg.attn_pooler_heads,
pos_embed_type=vision_cfg.pos_embed_type,
no_ln_pre=vision_cfg.no_ln_pre,
final_ln_after_pool=vision_cfg.final_ln_after_pool,
pool_type=vision_cfg.pool_type,
output_tokens=vision_cfg.output_tokens,
output_dim=embed_dim,
act_layer=act_layer,
norm_layer=norm_layer,
)
return visual
def _build_text_tower(
embed_dim: int,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
):
if isinstance(text_cfg, dict):
text_cfg = CLIPTextCfg(**text_cfg)
if text_cfg.hf_model_name:
text = HFTextEncoder(
text_cfg.hf_model_name,
output_dim=embed_dim,
proj_type=text_cfg.hf_proj_type,
pooler_type=text_cfg.hf_pooler_type,
pretrained=text_cfg.hf_model_pretrained,
output_tokens=text_cfg.output_tokens,
)
else:
act_layer = QuickGELU if quick_gelu else nn.GELU
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
if text_cfg.norm_kwargs:
norm_layer = partial(norm_layer, **text_cfg.norm_kwargs)
if text_cfg.act_kwargs is not None:
act_layer = partial(act_layer, **text_cfg.act_kwargs)
text = TextTransformer(
context_length=text_cfg.context_length,
vocab_size=text_cfg.vocab_size,
width=text_cfg.width,
heads=text_cfg.heads,
layers=text_cfg.layers,
mlp_ratio=text_cfg.mlp_ratio,
ls_init_value=text_cfg.ls_init_value,
output_dim=embed_dim,
embed_cls=text_cfg.embed_cls,
no_causal_mask=text_cfg.no_causal_mask,
pad_id=text_cfg.pad_id,
pool_type=text_cfg.pool_type,
proj_type=text_cfg.proj_type,
proj_bias=text_cfg.proj_bias,
output_tokens=text_cfg.output_tokens,
act_layer=act_layer,
norm_layer=norm_layer,
)
return text
class CLIP(nn.Module):
output_dict: torch.jit.Final[bool]
def __init__(
self,
embed_dim: int,
vision_cfg: CLIPVisionCfg,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
init_logit_scale: float = np.log(1 / 0.07),
init_logit_bias: Optional[float] = None,
nonscalar_logit_scale: bool = False,
cast_dtype: Optional[torch.dtype] = None,
output_dict: bool = False,
):
super().__init__()
self.output_dict = output_dict
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
self.transformer = text.transformer
self.context_length = text.context_length
self.vocab_size = text.vocab_size
self.token_embedding = text.token_embedding
self.positional_embedding = text.positional_embedding
self.ln_final = text.ln_final
self.text_projection = text.text_projection
self.text_pool_type = text.pool_type
self.register_buffer('attn_mask', text.attn_mask, persistent=False)
lshape = [1] if nonscalar_logit_scale else []
self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)
if init_logit_bias is not None:
self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)
else:
self.logit_bias = None
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.transformer.grad_checkpointing = enable
@torch.jit.ignore
def no_weight_decay(self):
# for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default
no_wd = {'positional_embedding'}
if hasattr(self.visual, 'no_weight_decay'):
for n in self.visual.no_weight_decay():
no_wd.add('visual.' + n)
return no_wd
def encode_image(self, image, normalize: bool = False):
features = self.visual(image)
return F.normalize(features, dim=-1) if normalize else features
def encode_text(self, text, normalize: bool = False):
cast_dtype = self.transformer.get_cast_dtype()
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.to(cast_dtype)
x = self.transformer(x, attn_mask=self.attn_mask)
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
x = text_global_pool(x, text, self.text_pool_type)
if self.text_projection is not None:
if isinstance(self.text_projection, nn.Linear):
x = self.text_projection(x)
else:
x = x @ self.text_projection
return F.normalize(x, dim=-1) if normalize else x
def get_logits(self, image, text):
image_features = self.encode_image(image, normalize=True)
text_features = self.encode_text(text, normalize=True)
image_logits = self.logit_scale.exp() * image_features @ text_features.T
if self.logit_bias is not None:
image_logits += self.logit_bias
text_logits = image_logits.T
return image_logits, text_logits
def forward_intermediates(
self,
image: Optional[torch.Tensor] = None,
text: Optional[torch.Tensor] = None,
image_indices: Optional[Union[int, List[int]]] = None,
text_indices: Optional[Union[int, List[int]]] = None,
stop_early: bool = False,
normalize: bool = True,
normalize_intermediates: bool = False,
intermediates_only: bool = False,
image_output_fmt: str = 'NCHW',
image_output_extra_tokens: bool = False,
text_output_fmt: str = 'NLC',
text_output_extra_tokens: bool = False,
output_logits: bool = False,
output_logit_scale_bias: bool = False,
) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
image: Input image tensor
text: Input text tensor
image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence
text_indices: Take last n blocks if int, all if None, select matching indices if sequence
stop_early: Stop iterating over blocks when last desired intermediate hit
normalize_intermediates: Apply final norm layer to all intermediates
normalize: L2 Normalize final features
intermediates_only: Only return intermediate features, do not return final features
image_output_fmt: Shape of intermediate image feature outputs
image_output_extra_tokens: Return both prefix and spatial intermediate tokens
text_output_fmt: Shape of intermediate text feature outputs (ignored for this model)
text_output_extra_tokens: Return both prefix and spatial intermediate tokens (ignored for this model)
output_logits: Include logits in output
output_logit_scale_bias: Include the logit scale bias in the output
Returns:
"""
output = {}
if intermediates_only:
# intermediates only disables final feature normalization, and include logits
normalize = False
output_logits = False
if output_logits:
assert image is not None and text is not None, 'Both image and text inputs are required to compute logits'
if image is not None:
image_output = self.visual.forward_intermediates(
image,
indices=image_indices,
stop_early=stop_early,
normalize_intermediates=normalize_intermediates,
intermediates_only=intermediates_only,
output_fmt=image_output_fmt,
output_extra_tokens=image_output_extra_tokens,
)
if normalize and "image_features" in image_output:
image_output["image_features"] = F.normalize(image_output["image_features"], dim=-1)
output.update(image_output)
if text is not None:
cast_dtype = self.transformer.get_cast_dtype()
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.to(cast_dtype)
x, intermediates = self.transformer.forward_intermediates(
x,
attn_mask=self.attn_mask,
indices=text_indices
)
if normalize_intermediates:
intermediates = [self.ln_final(xi) for xi in intermediates]
# NOTE this model doesn't support cls embed in text transformer, no need for extra intermediate tokens
output["text_intermediates"] = intermediates
if not intermediates_only:
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
x = text_global_pool(x, text, self.text_pool_type)
if self.text_projection is not None:
if isinstance(self.text_projection, nn.Linear):
x = self.text_projection(x)
else:
x = x @ self.text_projection
if normalize:
x = F.normalize(x, dim=-1)
output["text_features"] = x
logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None
if output_logits:
image_logits = logit_scale_exp * output["image_features"] @ output["text_features"].T
if self.logit_bias is not None:
image_logits += self.logit_bias
text_logits = image_logits.T
output["image_logits"] = image_logits
output["text_logits"] = text_logits
if output_logit_scale_bias:
output["logit_scale"] = logit_scale_exp
if self.logit_bias is not None:
output['logit_bias'] = self.logit_bias
return output
def forward(
self,
image: Optional[torch.Tensor] = None,
text: Optional[torch.Tensor] = None,
):
image_features = self.encode_image(image, normalize=True) if image is not None else None
text_features = self.encode_text(text, normalize=True) if text is not None else None
if self.output_dict:
out_dict = {
"image_features": image_features,
"text_features": text_features,
"logit_scale": self.logit_scale.exp()
}
if self.logit_bias is not None:
out_dict['logit_bias'] = self.logit_bias
return out_dict
if self.logit_bias is not None:
return image_features, text_features, self.logit_scale.exp(), self.logit_bias
return image_features, text_features, self.logit_scale.exp()
class CustomTextCLIP(nn.Module):
output_dict: torch.jit.Final[bool]
def __init__(
self,
embed_dim: int,
vision_cfg: CLIPVisionCfg,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
init_logit_scale: float = np.log(1 / 0.07),
init_logit_bias: Optional[float] = None,
nonscalar_logit_scale: bool = False,
cast_dtype: Optional[torch.dtype] = None,
output_dict: bool = False,
):
super().__init__()
self.output_dict = output_dict
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
self.context_length = self.text.context_length
self.vocab_size = self.text.vocab_size
lshape = [1] if nonscalar_logit_scale else []
self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)
if init_logit_bias is not None:
self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)
else:
self.logit_bias = None
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
self.text.lock(unlocked_layers, freeze_layer_norm)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.text.set_grad_checkpointing(enable)
@torch.jit.ignore
def no_weight_decay(self):
# for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default
no_wd = set()
if hasattr(self.visual, 'no_weight_decay'):
for n in self.visual.no_weight_decay():
no_wd.add('visual.' + n)
if hasattr(self.text, 'no_weight_decay'):
for n in self.visual.no_weight_decay():
no_wd.add('text.' + n)
return no_wd
def encode_image(self, image, normalize: bool = False):
features = self.visual(image)
return F.normalize(features, dim=-1) if normalize else features
def encode_text(self, text, normalize: bool = False):
features = self.text(text)
return F.normalize(features, dim=-1) if normalize else features
def get_logits(self, image, text):
image_features = self.encode_image(image, normalize=True)
text_features = self.encode_text(text, normalize=True)
image_logits = self.logit_scale.exp() * image_features @ text_features.T
if self.logit_bias is not None:
image_logits += self.logit_bias
text_logits = image_logits.T
return image_logits, text_logits
def forward_intermediates(
self,
image: Optional[torch.Tensor] = None,
text: Optional[torch.Tensor] = None,
image_indices: Optional[Union[int, List[int]]] = None,
text_indices: Optional[Union[int, List[int]]] = None,
stop_early: bool = False,
normalize: bool = True,
normalize_intermediates: bool = False,
intermediates_only: bool = False,
image_output_fmt: str = 'NCHW',
image_output_extra_tokens: bool = False,
text_output_fmt: str = 'NLC',
text_output_extra_tokens: bool = False,
output_logits: bool = False,
output_logit_scale_bias: bool = False,
) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
image: Input image tensor
text: Input text tensor
image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence
text_indices: Take last n blocks if int, all if None, select matching indices if sequence
stop_early: Stop iterating over blocks when last desired intermediate hit
normalize: L2 Normalize final image and text features (if present)
normalize_intermediates: Apply final encoder norm layer to all intermediates (if possible)
intermediates_only: Only return intermediate features, do not return final features
image_output_fmt: Shape of intermediate image feature outputs
image_output_extra_tokens: Return both prefix and spatial intermediate tokens
text_output_fmt: Shape of intermediate text feature outputs
text_output_extra_tokens: Return both prefix and spatial intermediate tokens
output_logits: Include logits in output
output_logit_scale_bias: Include the logit scale bias in the output
Returns:
"""
output = {}
if intermediates_only:
# intermediates only disables final feature normalization, and include logits
normalize = False
output_logits = False
if output_logits:
assert image is not None and text is not None, 'Both image and text inputs are required to compute logits'
if image is not None:
image_output = self.visual.forward_intermediates(
image,
indices=image_indices,
stop_early=stop_early,
normalize_intermediates=normalize_intermediates,
intermediates_only=intermediates_only,
output_fmt=image_output_fmt,
output_extra_tokens=image_output_extra_tokens,
)
if normalize and "image_features" in image_output:
image_output["image_features"] = F.normalize(image_output["image_features"], dim=-1)
output.update(image_output)
if text is not None:
text_output = self.text.forward_intermediates(
text,
indices=text_indices,
stop_early=stop_early,
normalize_intermediates=normalize_intermediates,
intermediates_only=intermediates_only,
output_fmt=text_output_fmt,
output_extra_tokens=text_output_extra_tokens,
)
if normalize and "text_features" in text_output:
text_output["text_features"] = F.normalize(text_output["text_features"], dim=-1)
output.update(text_output)
logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None
if output_logits:
image_logits = logit_scale_exp * output["image_features"] @ output["text_features"].T
if self.logit_bias is not None:
image_logits += self.logit_bias
text_logits = image_logits.T
output["image_logits"] = image_logits
output["text_logits"] = text_logits
if output_logit_scale_bias:
output["logit_scale"] = logit_scale_exp
if self.logit_bias is not None:
output['logit_bias'] = self.logit_bias
return output
def forward(
self,
image: Optional[torch.Tensor] = None,
text: Optional[torch.Tensor] = None,
):
image_features = self.encode_image(image, normalize=True) if image is not None else None
text_features = self.encode_text(text, normalize=True) if text is not None else None
if self.output_dict:
out_dict = {
"image_features": image_features,
"text_features": text_features,
"logit_scale": self.logit_scale.exp()
}
if self.logit_bias is not None:
out_dict['logit_bias'] = self.logit_bias
return out_dict
if self.logit_bias is not None:
return image_features, text_features, self.logit_scale.exp(), self.logit_bias
return image_features, text_features, self.logit_scale.exp()
def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
"""Convert applicable model parameters to low-precision (bf16 or fp16)"""
def _convert_weights(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
l.weight.data = l.weight.data.to(dtype)
if l.bias is not None:
l.bias.data = l.bias.data.to(dtype)
if isinstance(l, (nn.MultiheadAttention, Attention)):
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
tensor = getattr(l, attr)
if tensor is not None:
tensor.data = tensor.data.to(dtype)
if isinstance(l, (CLIP, TextTransformer)):
# convert text nn.Parameter projections
attr = getattr(l, "text_projection", None)
if attr is not None:
attr.data = attr.data.to(dtype)
if isinstance(l, VisionTransformer):
# convert vision nn.Parameter projections
attr = getattr(l, "proj", None)
if attr is not None:
attr.data = attr.data.to(dtype)
model.apply(_convert_weights)
convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
# used to maintain checkpoint compatibility
def convert_to_custom_text_state_dict(state_dict: dict):
if 'text_projection' in state_dict:
# old format state_dict, move text tower -> .text
new_state_dict = {}
for k, v in state_dict.items():
if any(k.startswith(p) for p in (
'text_projection',
'positional_embedding',
'token_embedding',
'transformer',
'ln_final',
)):
k = 'text.' + k
new_state_dict[k] = v
return new_state_dict
return state_dict
def build_model_from_openai_state_dict(
state_dict: dict,
quick_gelu=True,
cast_dtype=torch.float16,
):
vit = "visual.proj" in state_dict
if vit:
vision_width = state_dict["visual.conv1.weight"].shape[0]
vision_layers = len(
[k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
image_size = vision_patch_size * grid_size
else:
counts: list = [
len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
vision_layers = tuple(counts)
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
vision_patch_size = None
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
image_size = output_width * 32
embed_dim = state_dict["text_projection"].shape[1]
context_length = state_dict["positional_embedding"].shape[0]
vocab_size = state_dict["token_embedding.weight"].shape[0]
transformer_width = state_dict["ln_final.weight"].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
vision_cfg = CLIPVisionCfg(
layers=vision_layers,
width=vision_width,
patch_size=vision_patch_size,
image_size=image_size,
)
text_cfg = CLIPTextCfg(
context_length=context_length,
vocab_size=vocab_size,
width=transformer_width,
heads=transformer_heads,
layers=transformer_layers,
)
model = CLIP(
embed_dim,
vision_cfg=vision_cfg,
text_cfg=text_cfg,
quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
cast_dtype=cast_dtype,
)
for key in ["input_resolution", "context_length", "vocab_size"]:
state_dict.pop(key, None)
convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
model.load_state_dict(state_dict)
return model.eval()
def trace_model(model, batch_size=256, device=torch.device('cpu')):
model.eval()
image_size = model.visual.image_size
example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
model = torch.jit.trace_module(
model,
inputs=dict(
forward=(example_images, example_text),
encode_text=(example_text,),
encode_image=(example_images,)
))
model.visual.image_size = image_size
return model
def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
# Rescale the grid of position embeddings when loading from state_dict
old_pos_embed = state_dict.get('visual.positional_embedding', None)
if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
return
grid_size = to_2tuple(model.visual.grid_size)
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
if new_seq_len == old_pos_embed.shape[0]:
return
if extra_tokens:
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
else:
pos_emb_tok, pos_emb_img = None, old_pos_embed
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
pos_emb_img = F.interpolate(
pos_emb_img,
size=grid_size,
mode=interpolation,
antialias=antialias,
align_corners=False,
)
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
if pos_emb_tok is not None:
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
else:
new_pos_embed = pos_emb_img
state_dict['visual.positional_embedding'] = new_pos_embed
def resize_text_pos_embed(state_dict, model, interpolation: str = 'linear', antialias: bool = False):
old_pos_embed = state_dict.get('positional_embedding', None)
if old_pos_embed is None:
return
# FIXME add support for text cls_token
model_pos_embed = getattr(model, 'positional_embedding', None)
if model_pos_embed is None:
model_pos_embed = getattr(model.text, 'positional_embedding', None)
old_num_pos = old_pos_embed.shape[0]
old_width = old_pos_embed.shape[1]
num_pos = model_pos_embed.shape[0]
width = model_pos_embed.shape[1]
assert old_width == width, 'text pos_embed width changed!'
if old_num_pos == num_pos:
return
logging.info('Resizing text position embedding num_pos from %s to %s', old_num_pos, num_pos)
old_pos_embed = old_pos_embed.reshape(1, old_num_pos, old_width).permute(0, 2, 1)
old_pos_embed = F.interpolate(
old_pos_embed,
size=num_pos,
mode=interpolation,
antialias=antialias,
align_corners=False,
)
old_pos_embed = old_pos_embed.permute(0, 2, 1)[0]
new_pos_embed = old_pos_embed
state_dict['positional_embedding'] = new_pos_embed
def get_model_preprocess_cfg(model):
module = getattr(model, 'visual', model)
preprocess_cfg = getattr(module, 'preprocess_cfg', {})
if not preprocess_cfg:
# use separate legacy attributes if preprocess_cfg dict not found
size = getattr(module, 'image_size')
if size is not None:
preprocess_cfg['size'] = size
mean = getattr(module, 'image_mean', None)
if mean is not None:
preprocess_cfg['mean'] = mean
std = getattr(module, 'image_std', None)
if std is not None:
preprocess_cfg['std'] = std
return preprocess_cfg
def set_model_preprocess_cfg(model, preprocess_cfg: Dict[str, Any]):
module = getattr(model, 'visual', model)
module.image_mean = preprocess_cfg['mean'] # legacy attribute, keeping for bwd compat
module.image_std = preprocess_cfg['std'] # legacy attribute, keeping for bwd compat
module.preprocess_cfg = copy.deepcopy(preprocess_cfg) # new attr, package all pp cfg as dict
def get_model_tokenize_cfg(model):
module = getattr(model, 'text', model)
cfg = {}
context_length = getattr(module, 'context_length', None)
if context_length is not None:
cfg['context_length'] = context_length
vocab_size = getattr(module, 'vocab_size', None)
if vocab_size is not None:
cfg['vocab_size'] = vocab_size
return cfg
try:
from huggingface_hub import hf_hub_download
hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__)
_has_hf_hub = True
except ImportError:
hf_hub_download = None
_has_hf_hub = False
def _pcfg(url='', hf_hub='', **kwargs):
# OpenAI / OpenCLIP defaults
return {
'url': url,
'hf_hub': hf_hub,
'mean': OPENAI_DATASET_MEAN,
'std': OPENAI_DATASET_STD,
'interpolation': 'bicubic',
'resize_mode': 'shortest',
**kwargs,
}
def _slpcfg(url='', hf_hub='', **kwargs):
# SiGLIP defaults
return {
'url': url,
'hf_hub': hf_hub,
'mean': INCEPTION_MEAN,
'std': INCEPTION_STD,
'interpolation': 'bicubic',
'resize_mode': 'squash',
**kwargs,
}
def _apcfg(url='', hf_hub='', **kwargs):
# CLIPA defaults
return {
'url': url,
'hf_hub': hf_hub,
'mean': IMAGENET_MEAN,
'std': IMAGENET_STD,
'interpolation': 'bilinear',
'resize_mode': 'squash',
**kwargs,
}
def _mccfg(url='', hf_hub='', **kwargs):
# MobileCLIP
return {
'url': url,
'hf_hub': hf_hub,
'mean': (0., 0., 0.),
'std': (1., 1., 1.),
'interpolation': 'bilinear',
'resize_mode': 'shortest',
**kwargs,
}
_RN50 = dict(
openai=_pcfg(
url="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
hf_hub="timm/resnet50_clip.openai/",
quick_gelu=True,
),
yfcc15m=_pcfg(
url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt",
hf_hub="timm/resnet50_clip.yfcc15m/",
quick_gelu=True,
),
cc12m=_pcfg(
url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt",
hf_hub="timm/resnet50_clip.cc12m/",
quick_gelu=True,
),
)
_RN101 = dict(
openai=_pcfg(
url="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
hf_hub="timm/resnet101_clip.openai/",
quick_gelu=True,
),
yfcc15m=_pcfg(
url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt",
hf_hub="timm/resnet101_clip.yfcc15m/",
quick_gelu=True,
),
)
_RN50x4 = dict(
openai=_pcfg(
url="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
hf_hub="timm/resnet50x4_clip.openai/",
quick_gelu=True,
),
)
_RN50x16 = dict(
openai=_pcfg(
url="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
hf_hub="timm/resnet50x16_clip.openai/",
quick_gelu=True,
),
)
_RN50x64 = dict(
openai=_pcfg(
url="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
hf_hub="timm/resnet50x64_clip.openai/",
quick_gelu=True,
),
)
_VITB32 = dict(
openai=_pcfg(
url="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
hf_hub="timm/vit_base_patch32_clip_224.openai/",
quick_gelu=True,
),
# LAION 400M (quick gelu)
laion400m_e31=_pcfg(
url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt",
hf_hub="timm/vit_base_patch32_clip_224.laion400m_e31/",
quick_gelu=True,
),
laion400m_e32=_pcfg(
url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt",
hf_hub="timm/vit_base_patch32_clip_224.laion400m_e32/",
quick_gelu=True,
),
# LAION 2B-en
laion2b_e16=_pcfg(
url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth",
hf_hub="timm/vit_base_patch32_clip_224.laion2b_e16/",
),
laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/'),
# DataComp-XL models
datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K/'),
# DataComp-M models
datacomp_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.M-s128M-b4K/'),
commonpool_m_clip_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.clip-s128M-b4K/'),
commonpool_m_laion_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.laion-s128M-b4K/'),
commonpool_m_image_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.image-s128M-b4K/'),
commonpool_m_text_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.text-s128M-b4K/'),
commonpool_m_basic_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.basic-s128M-b4K/'),
commonpool_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M-s128M-b4K/'),
# DataComp-S models
datacomp_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.S-s13M-b4K/'),
commonpool_s_clip_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.clip-s13M-b4K/'),
commonpool_s_laion_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.laion-s13M-b4K/'),
commonpool_s_image_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.image-s13M-b4K/'),
commonpool_s_text_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.text-s13M-b4K/'),
commonpool_s_basic_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.basic-s13M-b4K/'),
commonpool_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S-s13M-b4K/'),
# MetaClip models (NOTE quick-gelu activation used)
metaclip_400m=_pcfg(
url="https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_400m.pt",
hf_hub="timm/vit_base_patch32_clip_224.metaclip_400m/",
quick_gelu=True,
),
metaclip_fullcc=_pcfg(
url="https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_fullcc2.5b.pt",
hf_hub="timm/vit_base_patch32_clip_224.metaclip_2pt5b/",
quick_gelu=True,
),
)
_VITB32_256 = dict(
datacomp_s34b_b86k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-256x256-DataComp-s34B-b86K/'),
)
_VITB16 = dict(
openai=_pcfg(
url="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
hf_hub="timm/vit_base_patch16_clip_224.openai/",
quick_gelu=True,
),
# LAION-400M
laion400m_e31=_pcfg(
url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt",
hf_hub="timm/vit_base_patch16_clip_224.laion400m_e31/",
),
laion400m_e32=_pcfg(
url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt",
hf_hub="timm/vit_base_patch16_clip_224.laion400m_e32/",
),
# LAION-2B
laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'),
# DataComp-XL models
datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.XL-s13B-b90K/'),
# DataComp-L models
datacomp_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.L-s1B-b8K/'),
commonpool_l_clip_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.clip-s1B-b8K/'),
commonpool_l_laion_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.laion-s1B-b8K/'),
commonpool_l_image_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.image-s1B-b8K/'),
commonpool_l_text_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.text-s1B-b8K/'),
commonpool_l_basic_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.basic-s1B-b8K/'),
commonpool_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L-s1B-b8K/'),
# DFN
dfn2b=_pcfg(
hf_hub='apple/DFN2B-CLIP-ViT-B-16/',
quick_gelu=True,
),
# MetaCLIP (these are quick-gelu)
metaclip_400m=_pcfg(
url="https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_400m.pt",
hf_hub="timm/vit_base_patch16_clip_224.metaclip_400m/",
quick_gelu=True,
),
metaclip_fullcc=_pcfg(
url="https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_fullcc2.5b.pt",
hf_hub="timm/vit_base_patch16_clip_224.metaclip_2pt5b/",
quick_gelu=True,
),
)
_VITB16_PLUS_240 = dict(
laion400m_e31=_pcfg(
url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt",
hf_hub="timm/vit_base_patch16_plus_clip_240.laion400m_e31/",
),
laion400m_e32=_pcfg(
url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt",
hf_hub="timm/vit_base_patch16_plus_clip_240.laion400m_e31/",
),
)
_VITL14 = dict(
openai=_pcfg(
url="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
hf_hub="timm/vit_large_patch14_clip_224.openai/",
quick_gelu=True,
),
# LAION-400M
laion400m_e31=_pcfg(
url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt",
hf_hub="timm/vit_large_patch14_clip_224.laion400m_e31/",
),
laion400m_e32=_pcfg(
url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt",
hf_hub="timm/vit_large_patch14_clip_224.laion400m_e32/",
),
# LAION-2B-en
laion2b_s32b_b82k=_pcfg(
hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/',
mean=INCEPTION_MEAN, std=INCEPTION_STD),
# DataComp-XL models
datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K/'),
commonpool_xl_clip_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.clip-s13B-b90K/'),
commonpool_xl_laion_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.laion-s13B-b90K/'),
commonpool_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL-s13B-b90K/'),
# MetaCLIP
metaclip_400m=_pcfg(
url="https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_400m.pt",
hf_hub="timm/vit_large_patch14_clip_224.metaclip_400m/",
quick_gelu=True,
),
metaclip_fullcc=_pcfg(
url="https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_fullcc2.5b.pt",
hf_hub="timm/vit_large_patch14_clip_224.metaclip_2pt5b/",
quick_gelu=True,
),
# DFN-2B (quick-gelu)
dfn2b=_pcfg(
hf_hub='apple/DFN2B-CLIP-ViT-L-14/',
quick_gelu=True,
),
# DFN-2B 39B SS
dfn2b_s39b=_pcfg(
hf_hub='apple/DFN2B-CLIP-ViT-L-14-39B/',
),
)
_VITL14_336 = dict(
openai=_pcfg(
url="https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
hf_hub="timm/vit_large_patch14_clip_336.openai/",
quick_gelu=True,
),
)
_VITH14 = dict(
# LAION-2B-en
laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'),
# MetaCLIP (quick-gelu)
metaclip_fullcc=_pcfg(
url="https://dl.fbaipublicfiles.com/MMPT/metaclip/h14_fullcc2.5b.pt",
hf_hub="timm/vit_huge_patch14_clip_224.metaclip_2pt5b/",
quick_gelu=True,
),
metaclip_altogether=_pcfg(
url="https://dl.fbaipublicfiles.com/MMPT/metaclip/h14_v1.2_altogether.pt",
hf_hub="timm/vit_huge_patch14_clip_224.metaclip_altogether/",
# NOTE unlike other MetaCLIP models, this is not using QuickGELU, yay!
),
# DFN-5B (quick-gelu)
dfn5b=_pcfg(
hf_hub='apple/DFN5B-CLIP-ViT-H-14/',
quick_gelu=True,
interpolation="bicubic",
resize_mode="squash"
),
)
_VITH14_378 = dict(
# DFN-5B (quick-gelu)
dfn5b=_pcfg(
hf_hub='apple/DFN5B-CLIP-ViT-H-14-378/',
quick_gelu=True,
interpolation="bicubic",
resize_mode="squash"
),
)
_VITg14 = dict(
laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'),
laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'),
)
_VITbigG14 = dict(
# LAION-2B-en
laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'),
# MetaCLIP (quick-gelu)
metaclip_fullcc=_pcfg(
url='https://dl.fbaipublicfiles.com/MMPT/metaclip/G14_fullcc2.5b.pt',
hf_hub="timm/vit_gigantic_patch14_clip_224.metaclip_2pt5b/",
quick_gelu=True,
),
)
_robertaViTB32 = dict(
laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'),
)
_xlmRobertaBaseViTB32 = dict(
laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'),
)
_xlmRobertaLargeFrozenViTH14 = dict(
frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'),
)
_convnext_base = dict(
laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'),
)
_convnext_base_w = dict(
laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'),
laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'),
laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'),
)
_convnext_base_w_320 = dict(
laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'),
laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'),
)
_convnext_large_d = dict(
laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'),
)
_convnext_large_d_320 = dict(
laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'),
laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'),
)
_convnext_xxlarge = dict(
laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'),
laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'),
laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'),
)
_coca_VITB32 = dict(
laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'),
mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/')
)
_coca_VITL14 = dict(
laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'),
mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/')
)
_PRETRAINED = {
"RN50": _RN50,
"RN101": _RN101,
"RN50x4": _RN50x4,
"RN50x16": _RN50x16,
"RN50x64": _RN50x64,
"ViT-B-32": _VITB32,
"ViT-B-32-256": _VITB32_256,
"ViT-B-16": _VITB16,
"ViT-B-16-plus-240": _VITB16_PLUS_240,
"ViT-L-14": _VITL14,
"ViT-L-14-336": _VITL14_336,
"ViT-H-14": _VITH14,
"ViT-H-14-378": _VITH14_378,
"ViT-g-14": _VITg14,
"ViT-bigG-14": _VITbigG14,
"roberta-ViT-B-32": _robertaViTB32,
"xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32,
"xlm-roberta-large-ViT-H-14": _xlmRobertaLargeFrozenViTH14,
"convnext_base": _convnext_base,
"convnext_base_w": _convnext_base_w,
"convnext_base_w_320": _convnext_base_w_320,
"convnext_large_d": _convnext_large_d,
"convnext_large_d_320": _convnext_large_d_320,
"convnext_xxlarge": _convnext_xxlarge,
"coca_ViT-B-32": _coca_VITB32,
"coca_ViT-L-14": _coca_VITL14,
"EVA01-g-14": dict(
# from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt
laion400m_s11b_b41k=_pcfg(hf_hub='timm/eva_giant_patch14_clip_224.laion400m_s11b_b41k/'),
),
"EVA01-g-14-plus": dict(
# from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt
merged2b_s11b_b114k=_pcfg(hf_hub='timm/eva_giant_patch14_plus_clip_224.merged2b_s11b_b114k/'),
),
"EVA02-B-16": dict(
# from QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt
merged2b_s8b_b131k=_pcfg(hf_hub='timm/eva02_base_patch16_clip_224.merged2b_s8b_b131k/'),
),
"EVA02-L-14": dict(
# from QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt
merged2b_s4b_b131k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_224.merged2b_s4b_b131k/'),
),
"EVA02-L-14-336": dict(
# from QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt
merged2b_s6b_b61k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_336.merged2b_s6b_b61k/'),
),
"EVA02-E-14": dict(
# from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt
laion2b_s4b_b115k=_pcfg(hf_hub='timm/eva02_enormous_patch14_clip_224.laion2b_s4b_b115k/'),
),
"EVA02-E-14-plus": dict(
# from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt
laion2b_s9b_b144k=_pcfg(hf_hub='timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k/'),
),
"ViT-B-16-SigLIP": dict(
webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP/'),
),
"ViT-B-16-SigLIP-256": dict(
webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-256/'),
),
"ViT-B-16-SigLIP-i18n-256": dict(
webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-i18n-256/'),
),
"ViT-B-16-SigLIP-384": dict(
webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-384/'),
),
"ViT-B-16-SigLIP-512": dict(
webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-512/'),
),
"ViT-L-16-SigLIP-256": dict(
webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP-256/'),
),
"ViT-L-16-SigLIP-384": dict(
webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP-384/'),
),
"ViT-SO400M-14-SigLIP": dict(
webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP/'),
),
"ViT-SO400M-16-SigLIP-i18n-256": dict(
webli=_slpcfg(hf_hub='timm/ViT-SO400M-16-SigLIP-i18n-256/'),
),
"ViT-SO400M-14-SigLIP-378": dict(
webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP-384/'), # NOTE using 384 weights, but diff img_size used
),
"ViT-SO400M-14-SigLIP-384": dict(
webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP-384/'),
),
"ViT-B-32-SigLIP2-256": dict(
webli=_slpcfg(hf_hub='timm/ViT-B-32-SigLIP2-256/'),
),
"ViT-B-16-SigLIP2": dict(
webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP2/'),
),
"ViT-B-16-SigLIP2-256": dict(
webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP2-256/'),
),
"ViT-B-16-SigLIP2-384": dict(
webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP2-384/'),
),
"ViT-B-16-SigLIP2-512": dict(
webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP2-512/'),
),
"ViT-L-16-SigLIP2-256": dict(
webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP2-256/'),
),
"ViT-L-16-SigLIP2-384": dict(
webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP2-384/'),
),
"ViT-L-16-SigLIP2-512": dict(
webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP2-512/'),
),
"ViT-SO400M-14-SigLIP2": dict(
webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP2/'),
),
"ViT-SO400M-14-SigLIP2-378": dict(
webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP2-378/'),
),
"ViT-SO400M-16-SigLIP2-256": dict(
webli=_slpcfg(hf_hub='timm/ViT-SO400M-16-SigLIP2-256/'),
),
"ViT-SO400M-16-SigLIP2-384": dict(
webli=_slpcfg(hf_hub='timm/ViT-SO400M-16-SigLIP2-384/'),
),
"ViT-SO400M-16-SigLIP2-512": dict(
webli=_slpcfg(hf_hub='timm/ViT-SO400M-16-SigLIP2-512/'),
),
"ViT-gopt-16-SigLIP2-256": dict(
webli=_slpcfg(hf_hub='timm/ViT-gopt-16-SigLIP2-256/'),
),
"ViT-gopt-16-SigLIP2-384": dict(
webli=_slpcfg(hf_hub='timm/ViT-gopt-16-SigLIP2-384/'),
),
"ViT-L-14-CLIPA": dict(
datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-L-14-CLIPA-datacomp1B/'),
),
"ViT-L-14-CLIPA-336": dict(
datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-L-14-CLIPA-336-datacomp1B/'),
),
"ViT-H-14-CLIPA": dict(
datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-datacomp1B/'),
),
"ViT-H-14-CLIPA-336": dict(
laion2b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-336-laion2B/'),
datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-336-datacomp1B/'),
),
"ViT-bigG-14-CLIPA": dict(
datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-bigG-14-CLIPA-datacomp1B/'),
),
"ViT-bigG-14-CLIPA-336": dict(
datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-bigG-14-CLIPA-336-datacomp1B/'),
),
"nllb-clip-base": dict(
v1=_pcfg(hf_hub='visheratin/nllb-clip-base-oc/'),
),
"nllb-clip-large": dict(
v1=_pcfg(hf_hub='visheratin/nllb-clip-large-oc/'),
),
"nllb-clip-base-siglip": dict(
v1=_slpcfg(hf_hub='visheratin/nllb-clip-base-siglip/'),
mrl=_slpcfg(hf_hub='visheratin/nllb-siglip-mrl-base/'),
),
"nllb-clip-large-siglip": dict(
v1=_slpcfg(hf_hub='visheratin/nllb-clip-large-siglip/'),
mrl=_slpcfg(hf_hub='visheratin/nllb-siglip-mrl-large/'),
),
"MobileCLIP-S1": dict(
datacompdr=_mccfg(hf_hub='apple/MobileCLIP-S1-OpenCLIP/')),
"MobileCLIP-S2": dict(
datacompdr=_mccfg(hf_hub='apple/MobileCLIP-S2-OpenCLIP/')),
"MobileCLIP-B": dict(
datacompdr=_mccfg(hf_hub='apple/MobileCLIP-B-OpenCLIP/'),
datacompdr_lt=_mccfg(hf_hub='apple/MobileCLIP-B-LT-OpenCLIP/'),
),
"ViTamin-S": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-S/pytorch_model.bin'),
),
"ViTamin-S-LTT": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-S-LTT/pytorch_model.bin'),
),
"ViTamin-B": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-B/pytorch_model.bin'),
),
"ViTamin-B-LTT": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-B-LTT/pytorch_model.bin'),
),
"ViTamin-L": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-224px/pytorch_model.bin'),
),
"ViTamin-L-256": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-256px/pytorch_model.bin'),
),
"ViTamin-L-336": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-336px/pytorch_model.bin'),
),
"ViTamin-L-384": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-384px/pytorch_model.bin'),
),
"ViTamin-L2": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-224px/pytorch_model.bin'),
),
"ViTamin-L2-256": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-256px/pytorch_model.bin'),
),
"ViTamin-L2-336": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-336px/pytorch_model.bin'),
),
"ViTamin-L2-384": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-384px/pytorch_model.bin'),
),
"ViTamin-XL-256": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-256px/pytorch_model.bin'),
),
"ViTamin-XL-336": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-336px/pytorch_model.bin'),
),
"ViTamin-XL-384": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-384px/pytorch_model.bin'),
),
}
_PRETRAINED_quickgelu = {}
for k, v in _PRETRAINED.items():
quick_gelu_tags = {}
for tk, tv in v.items():
if tv.get('quick_gelu', False):
quick_gelu_tags[tk] = copy.deepcopy(tv)
if quick_gelu_tags:
_PRETRAINED_quickgelu[k + '-quickgelu'] = quick_gelu_tags
_PRETRAINED.update(_PRETRAINED_quickgelu)
def _clean_tag(tag: str):
# normalize pretrained tags
return tag.lower().replace('-', '_')
def list_pretrained(as_str: bool = False):
""" returns list of pretrained models
Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
"""
return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]
def list_pretrained_models_by_tag(tag: str):
""" return all models having the specified pretrain tag """
models = []
tag = _clean_tag(tag)
for k in _PRETRAINED.keys():
if tag in _PRETRAINED[k]:
models.append(k)
return models
def list_pretrained_tags_by_model(model: str):
""" return all pretrain tags for the specified model architecture """
tags = []
if model in _PRETRAINED:
tags.extend(_PRETRAINED[model].keys())
return tags
def is_pretrained_cfg(model: str, tag: str):
if model not in _PRETRAINED:
return False
return _clean_tag(tag) in _PRETRAINED[model]
def get_pretrained_cfg(model: str, tag: str):
if model not in _PRETRAINED:
return {}
model_pretrained = _PRETRAINED[model]
return model_pretrained.get(_clean_tag(tag), {})
def get_pretrained_url(model: str, tag: str):
cfg = get_pretrained_cfg(model, _clean_tag(tag))
return cfg.get('url', '')
def download_pretrained_from_url(
url: str,
cache_dir: Optional[str] = None,
):
if not cache_dir:
cache_dir = os.path.expanduser("~/.cache/clip")
os.makedirs(cache_dir, exist_ok=True)
filename = os.path.basename(url)
if 'openaipublic' in url:
expected_sha256 = url.split("/")[-2]
elif 'mlfoundations' in url:
expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
else:
expected_sha256 = ''
download_target = os.path.join(cache_dir, filename)
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
if expected_sha256:
if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
return download_target
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
else:
return download_target
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
return download_target
def has_hf_hub(necessary=False):
if not _has_hf_hub and necessary:
# if no HF Hub module installed, and it is necessary to continue, raise error
raise RuntimeError(
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
return _has_hf_hub
def _get_safe_alternatives(filename: str) -> Iterable[str]:
"""Returns potential safetensors alternatives for a given filename.
Use case:
When downloading a model from the Huggingface Hub, we first look if a .safetensors file exists and if yes, we use it.
"""
if filename == HF_WEIGHTS_NAME:
yield HF_SAFE_WEIGHTS_NAME
if filename not in (HF_WEIGHTS_NAME,) and (filename.endswith(".bin") or filename.endswith(".pth")):
yield filename[:-4] + ".safetensors"
def download_pretrained_from_hf(
model_id: str,
filename: Optional[str] = None,
revision: Optional[str] = None,
cache_dir: Optional[str] = None,
):
has_hf_hub(True)
filename = filename or HF_WEIGHTS_NAME
# Look for .safetensors alternatives and load from it if it exists
if _has_safetensors:
for safe_filename in _get_safe_alternatives(filename):
try:
cached_file = hf_hub_download(
repo_id=model_id,
filename=safe_filename,
revision=revision,
cache_dir=cache_dir,
)
return cached_file
except Exception:
pass
try:
# Attempt to download the file
cached_file = hf_hub_download(
repo_id=model_id,
filename=filename,
revision=revision,
cache_dir=cache_dir,
)
return cached_file # Return the path to the downloaded file if successful
except Exception as e:
raise FileNotFoundError(f"Failed to download file ({filename}) for {model_id}. Last error: {e}")
def download_pretrained(
cfg: Dict,
prefer_hf_hub: bool = True,
cache_dir: Optional[str] = None,
):
target = ''
if not cfg:
return target
if 'file' in cfg:
return cfg['file']
has_hub = has_hf_hub()
download_url = cfg.get('url', '')
download_hf_hub = cfg.get('hf_hub', '')
if has_hub and prefer_hf_hub and download_hf_hub:
# prefer to use HF hub, remove url info
download_url = ''
if download_url:
target = download_pretrained_from_url(download_url, cache_dir=cache_dir)
elif download_hf_hub:
has_hf_hub(True)
# we assume the hf_hub entries in pretrained config combine model_id + filename in
# 'org/model_name/filename.pt' form. To specify just the model id w/o filename and
# use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.
model_id, filename = os.path.split(download_hf_hub)
if filename:
target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir)
else:
target = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
return target
# ==================================================================
def merge_preprocess_dict(
base: Union[PreprocessCfg, Dict],
overlay: Dict,
):
""" Merge overlay key-value pairs on top of base preprocess cfg or dict.
Input dicts are filtered based on PreprocessCfg fields.
"""
if isinstance(base, PreprocessCfg):
base_clean = asdict(base)
else:
base_clean = {k: v for k, v in base.items() if k in _PREPROCESS_KEYS}
if overlay:
overlay_clean = {k: v for k, v in overlay.items() if k in _PREPROCESS_KEYS and v is not None}
base_clean.update(overlay_clean)
return base_clean
def merge_preprocess_kwargs(base: PreprocessCfg, **kwargs):
return merge_preprocess_dict(base, kwargs)
@dataclass
class PreprocessCfg:
size: Union[int, Tuple[int, int]] = 224
mode: str = 'RGB'
mean: Tuple[float, ...] = OPENAI_DATASET_MEAN
std: Tuple[float, ...] = OPENAI_DATASET_STD
interpolation: str = 'bicubic'
resize_mode: str = 'shortest'
fill_color: int = 0
def __post_init__(self):
assert self.mode in ('RGB',)
@property
def num_channels(self):
return 3
@property
def input_size(self):
return (self.num_channels,) + to_2tuple(self.size)
@dataclass
class PreprocessCfg:
size: Union[int, Tuple[int, int]] = 224
mode: str = 'RGB'
mean: Tuple[float, ...] = OPENAI_DATASET_MEAN
std: Tuple[float, ...] = OPENAI_DATASET_STD
interpolation: str = 'bicubic'
resize_mode: str = 'shortest'
fill_color: int = 0
def __post_init__(self):
assert self.mode in ('RGB',)
@property
def num_channels(self):
return 3
@property
def input_size(self):
return (self.num_channels,) + to_2tuple(self.size)
_PREPROCESS_KEYS = set(asdict(PreprocessCfg()).keys())
def merge_preprocess_dict(
base: Union[PreprocessCfg, Dict],
overlay: Dict,
):
""" Merge overlay key-value pairs on top of base preprocess cfg or dict.
Input dicts are filtered based on PreprocessCfg fields.
"""
if isinstance(base, PreprocessCfg):
base_clean = asdict(base)
else:
base_clean = {k: v for k, v in base.items() if k in _PREPROCESS_KEYS}
if overlay:
overlay_clean = {k: v for k, v in overlay.items() if k in _PREPROCESS_KEYS and v is not None}
base_clean.update(overlay_clean)
return base_clean
def merge_preprocess_kwargs(base: PreprocessCfg, **kwargs):
return merge_preprocess_dict(base, kwargs)
@dataclass
class AugmentationCfg:
scale: Tuple[float, float] = (0.9, 1.0)
ratio: Optional[Tuple[float, float]] = None
color_jitter: Optional[Union[float, Tuple[float, float, float], Tuple[float, float, float, float]]] = None
re_prob: Optional[float] = None
re_count: Optional[int] = None
use_timm: bool = False
# params for simclr_jitter_gray
color_jitter_prob: float = None
gray_scale_prob: float = None
def _setup_size(size, error_msg):
if isinstance(size, numbers.Number):
return int(size), int(size)
if isinstance(size, Sequence) and len(size) == 1:
return size[0], size[0]
if len(size) != 2:
raise ValueError(error_msg)
return size
class ResizeKeepRatio:
""" Resize and Keep Ratio
Copy & paste from `timm`
"""
def __init__(
self,
size,
longest=0.,
interpolation=InterpolationMode.BICUBIC,
random_scale_prob=0.,
random_scale_range=(0.85, 1.05),
random_aspect_prob=0.,
random_aspect_range=(0.9, 1.11)
):
if isinstance(size, (list, tuple)):
self.size = tuple(size)
else:
self.size = (size, size)
self.interpolation = interpolation
self.longest = float(longest) # [0, 1] where 0 == shortest edge, 1 == longest
self.random_scale_prob = random_scale_prob
self.random_scale_range = random_scale_range
self.random_aspect_prob = random_aspect_prob
self.random_aspect_range = random_aspect_range
@staticmethod
def get_params(
img,
target_size,
longest,
random_scale_prob=0.,
random_scale_range=(0.85, 1.05),
random_aspect_prob=0.,
random_aspect_range=(0.9, 1.11)
):
"""Get parameters
"""
source_size = img.size[::-1] # h, w
h, w = source_size
target_h, target_w = target_size
ratio_h = h / target_h
ratio_w = w / target_w
ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest)
if random_scale_prob > 0 and random.random() < random_scale_prob:
ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1])
ratio_factor = (ratio_factor, ratio_factor)
else:
ratio_factor = (1., 1.)
if random_aspect_prob > 0 and random.random() < random_aspect_prob:
aspect_factor = random.uniform(random_aspect_range[0], random_aspect_range[1])
ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor)
size = [round(x * f / ratio) for x, f in zip(source_size, ratio_factor)]
return size
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be cropped and resized.
Returns:
PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size
"""
size = self.get_params(
img, self.size, self.longest,
self.random_scale_prob, self.random_scale_range,
self.random_aspect_prob, self.random_aspect_range
)
img = F.resize(img, size, self.interpolation)
return img
def __repr__(self):
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
format_string += f', interpolation={self.interpolation})'
format_string += f', longest={self.longest:.3f})'
return format_string
def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> torch.Tensor:
"""Center crops and/or pads the given image.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
Args:
img (PIL Image or Tensor): Image to be cropped.
output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int,
it is used for both directions.
fill (int, Tuple[int]): Padding color
Returns:
PIL Image or Tensor: Cropped image.
"""
if isinstance(output_size, numbers.Number):
output_size = (int(output_size), int(output_size))
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
output_size = (output_size[0], output_size[0])
_, image_height, image_width = F.get_dimensions(img)
crop_height, crop_width = output_size
if crop_width > image_width or crop_height > image_height:
padding_ltrb = [
(crop_width - image_width) // 2 if crop_width > image_width else 0,
(crop_height - image_height) // 2 if crop_height > image_height else 0,
(crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
]
img = F.pad(img, padding_ltrb, fill=fill)
_, image_height, image_width = F.get_dimensions(img)
if crop_width == image_width and crop_height == image_height:
return img
crop_top = int(round((image_height - crop_height) / 2.0))
crop_left = int(round((image_width - crop_width) / 2.0))
return F.crop(img, crop_top, crop_left, crop_height, crop_width)
class CenterCropOrPad(torch.nn.Module):
"""Crops the given image at the center.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
"""
def __init__(self, size, fill=0):
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.fill = fill
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be cropped.
Returns:
PIL Image or Tensor: Cropped image.
"""
return center_crop_or_pad(img, self.size, fill=self.fill)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size})"
def _convert_to_rgb(image):
return image.convert('RGB')
class color_jitter(object):
"""
Apply Color Jitter to the PIL image with a specified probability.
"""
def __init__(self, brightness=0., contrast=0., saturation=0., hue=0., p=0.8):
assert 0. <= p <= 1.
self.p = p
self.transf = ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)
def __call__(self, img):
if random.random() < self.p:
return self.transf(img)
else:
return img
class gray_scale(object):
"""
Apply Gray Scale to the PIL image with a specified probability.
"""
def __init__(self, p=0.2):
assert 0. <= p <= 1.
self.p = p
self.transf = Grayscale(num_output_channels=3)
def __call__(self, img):
if random.random() < self.p:
return self.transf(img)
else:
return img
def image_transform(
image_size: Union[int, Tuple[int, int]],
is_train: bool,
mean: Optional[Tuple[float, ...]] = None,
std: Optional[Tuple[float, ...]] = None,
resize_mode: Optional[str] = None,
interpolation: Optional[str] = None,
fill_color: int = 0,
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
):
mean = mean or OPENAI_DATASET_MEAN
if not isinstance(mean, (list, tuple)):
mean = (mean,) * 3
std = std or OPENAI_DATASET_STD
if not isinstance(std, (list, tuple)):
std = (std,) * 3
interpolation = interpolation or 'bicubic'
assert interpolation in ['bicubic', 'bilinear', 'random']
# NOTE random is ignored for interpolation_mode, so defaults to BICUBIC for inference if set
interpolation_mode = InterpolationMode.BILINEAR if interpolation == 'bilinear' else InterpolationMode.BICUBIC
resize_mode = resize_mode or 'shortest'
assert resize_mode in ('shortest', 'longest', 'squash')
if isinstance(aug_cfg, dict):
aug_cfg = AugmentationCfg(**aug_cfg)
else:
aug_cfg = aug_cfg or AugmentationCfg()
normalize = Normalize(mean=mean, std=std)
if is_train:
aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}
use_timm = aug_cfg_dict.pop('use_timm', False)
if use_timm:
from timm.data import create_transform # timm can still be optional
if isinstance(image_size, (tuple, list)):
assert len(image_size) >= 2
input_size = (3,) + image_size[-2:]
else:
input_size = (3, image_size, image_size)
aug_cfg_dict.setdefault('color_jitter', None) # disable by default
# drop extra non-timm items
aug_cfg_dict.pop('color_jitter_prob', None)
aug_cfg_dict.pop('gray_scale_prob', None)
train_transform = create_transform(
input_size=input_size,
is_training=True,
hflip=0.,
mean=mean,
std=std,
re_mode='pixel',
interpolation=interpolation,
**aug_cfg_dict,
)
else:
train_transform = [
RandomResizedCrop(
image_size,
scale=aug_cfg_dict.pop('scale'),
interpolation=InterpolationMode.BICUBIC,
),
_convert_to_rgb,
]
if aug_cfg.color_jitter_prob:
assert aug_cfg.color_jitter is not None and len(aug_cfg.color_jitter) == 4
train_transform.extend([
color_jitter(*aug_cfg.color_jitter, p=aug_cfg.color_jitter_prob)
])
if aug_cfg.gray_scale_prob:
train_transform.extend([
gray_scale(aug_cfg.gray_scale_prob)
])
train_transform.extend([
ToTensor(),
normalize,
])
train_transform = Compose(train_transform)
if aug_cfg_dict:
warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).')
return train_transform
else:
if resize_mode == 'longest':
transforms = [
ResizeKeepRatio(image_size, interpolation=interpolation_mode, longest=1),
CenterCropOrPad(image_size, fill=fill_color)
]
elif resize_mode == 'squash':
if isinstance(image_size, int):
image_size = (image_size, image_size)
transforms = [
Resize(image_size, interpolation=interpolation_mode),
]
else:
assert resize_mode == 'shortest'
if not isinstance(image_size, (tuple, list)):
image_size = (image_size, image_size)
if image_size[0] == image_size[1]:
# simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)
transforms = [
Resize(image_size[0], interpolation=interpolation_mode)
]
else:
# resize shortest edge to matching target dim for non-square target
transforms = [ResizeKeepRatio(image_size)]
transforms += [CenterCrop(image_size)]
transforms.extend([
_convert_to_rgb,
ToTensor(),
normalize,
])
return Compose(transforms)
def image_transform_v2(
cfg: PreprocessCfg,
is_train: bool,
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
):
return image_transform(
image_size=cfg.size,
is_train=is_train,
mean=cfg.mean,
std=cfg.std,
interpolation=cfg.interpolation,
resize_mode=cfg.resize_mode,
fill_color=cfg.fill_color,
aug_cfg=aug_cfg,
)
@dataclass
class AugmentationCfg:
scale: Tuple[float, float] = (0.9, 1.0)
ratio: Optional[Tuple[float, float]] = None
color_jitter: Optional[Union[float, Tuple[float, float, float], Tuple[float, float, float, float]]] = None
re_prob: Optional[float] = None
re_count: Optional[int] = None
use_timm: bool = False
# params for simclr_jitter_gray
color_jitter_prob: float = None
gray_scale_prob: float = None
def set_model_preprocess_cfg(model, preprocess_cfg: Dict[str, Any]):
module = getattr(model, 'visual', model)
module.image_mean = preprocess_cfg['mean'] # legacy attribute, keeping for bwd compat
module.image_std = preprocess_cfg['std'] # legacy attribute, keeping for bwd compat
module.preprocess_cfg = copy.deepcopy(preprocess_cfg) # new attr, package all pp cfg as dict
@torch.no_grad()
def convert_mobile_clip_state_dict(model: CustomTextCLIP, state_dict, fastvit = True):
def _convert_timm_img(state_dict):
if fastvit:
from timm.models.fastvit import checkpoint_filter_fn
else:
from timm.models.vision_transformer_hybrid import checkpoint_filter_fn
timm_state_dict = checkpoint_filter_fn(state_dict, model.visual.trunk)
timm_state_dict = {'visual.trunk.' + k: v for k, v in timm_state_dict.items()}
return timm_state_dict
def _convert_openclip_txt(state_dict, prefix='text_encoder.'):
text_dict = {}
for k, v in state_dict.items():
if not k.startswith(prefix):
continue
k = k.replace(prefix, '')
k = k.replace('projection_layer', 'text_projection')
k = k.replace('embedding_layer', 'token_embedding')
if k.startswith('positional_embedding.pos_embed.pos_embed'):
k = k.replace('positional_embedding.pos_embed.pos_embed', 'positional_embedding')
v = v.squeeze()
k = k.replace('final_layer_norm', 'ln_final')
k = k.replace('pre_norm_mha.0', 'ln_1')
k = k.replace('pre_norm_mha.1', 'attn')
k = k.replace('pre_norm_ffn.0', 'ln_2')
k = k.replace('pre_norm_ffn.1', 'mlp.c_fc')
k = k.replace('pre_norm_ffn.4', 'mlp.c_proj')
k = k.replace('qkv_proj.weight', 'in_proj_weight')
k = k.replace('qkv_proj.bias', 'in_proj_bias')
k = k.replace('transformer.', 'transformer.resblocks.')
text_dict['text.' + k] = v
return text_dict
image_dict = _convert_timm_img(state_dict)
text_dict = _convert_openclip_txt(state_dict)
out_dict = {**image_dict, **text_dict}
out_dict['logit_scale'] = state_dict['logit_scale']
return out_dict
def convert_state_dict(model: Union[CustomTextCLIP, CLIP], state_dict):
if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict:
# Apple MobileCLIP s1 & s2 state_dicts (s0 and b not currently supported)
state_dict = convert_mobile_clip_state_dict(model, state_dict)
if 'image_encoder.model.patch_emb.0.block.conv.weight' in state_dict:
# convert b model
state_dict = convert_mobile_clip_state_dict(model, state_dict, fastvit=False)
return state_dict
def load_state_dict(
checkpoint_path: str,
device='cpu',
weights_only=True,
):
# Check if safetensors or not and load weights accordingly
if str(checkpoint_path).endswith(".safetensors"):
from safetensors.torch import load_file
checkpoint = load_file(checkpoint_path, device=device)
else:
try:
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=weights_only)
except TypeError:
checkpoint = torch.load(checkpoint_path, map_location=device)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif isinstance(checkpoint, torch.jit.ScriptModule):
state_dict = checkpoint.state_dict()
for key in ["input_resolution", "context_length", "vocab_size"]:
state_dict.pop(key, None)
else:
state_dict = checkpoint
if next(iter(state_dict.items()))[0].startswith('module'):
state_dict = {k[7:]: v for k, v in state_dict.items()}
return state_dict
def load_checkpoint(
model: Union[CLIP, CustomTextCLIP],
checkpoint_path: str,
strict: bool = True,
weights_only: bool = True,
device='cpu',
):
if Path(checkpoint_path).suffix in ('.npz', '.npy'):
# Separate path loading numpy big_vision (SigLIP) weights
from open_clip.convert import load_big_vision_weights
load_big_vision_weights(model, checkpoint_path)
return {}
state_dict = load_state_dict(checkpoint_path, device=device, weights_only=weights_only)
# Detect & convert 3rd party state_dicts -> open_clip
state_dict = convert_state_dict(model, state_dict)
# Detect old format and make compatible with new format
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
state_dict = convert_to_custom_text_state_dict(state_dict)
# correct if logit_scale differs in being scaler vs 1d param
if 'logit_scale' in state_dict and model.logit_scale.ndim != state_dict['logit_scale'].ndim:
state_dict['logit_scale'] = state_dict['logit_scale'].reshape(model.logit_scale.shape)
# correct if logit_bias differs in being scaler vs 1d param
if 'logit_bias' in state_dict and model.logit_bias.ndim != state_dict['logit_bias'].ndim:
state_dict['logit_bias'] = state_dict['logit_bias'].reshape(model.logit_bias.shape)
# If loading a non-SigLIP model for SigLIP training. See https://github.com/mlfoundations/open_clip/issues/712
if 'logit_bias' not in state_dict and model.logit_bias is not None:
state_dict["logit_bias"] = torch.zeros_like(state_dict["logit_scale"])
# Certain text transformers no longer expect position_ids after transformers==4.31
position_id_key = 'text.transformer.embeddings.position_ids'
if position_id_key in state_dict and not hasattr(model, position_id_key):
del state_dict[position_id_key]
resize_pos_embed(state_dict, model)
resize_text_pos_embed(state_dict, model)
# Finally, load the massaged state_dict into model
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
return incompatible_keys
# /home/IITB/ai-at-ieor/23m1521/.conda/envs/openclip2/lib/python3.11/site-packages/open_clip/factory.py
HF_HUB_PREFIX = 'hf-hub:'
# _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
_MODEL_CONFIG_PATHS = [Path("/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/model_configs")]
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
import json
def _get_hf_config(
model_id: str,
cache_dir: Optional[str] = None,
):
""" Fetch model config from HuggingFace Hub.
"""
config_path = download_pretrained_from_hf(
model_id,
filename='open_clip_config.json',
cache_dir=cache_dir,
)
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
return config
def get_model_config(model_name):
""" Fetch model config from builtin (local library) configs.
"""
if model_name in _MODEL_CONFIGS:
return copy.deepcopy(_MODEL_CONFIGS[model_name])
else:
return None
def _natural_key(string_):
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def _rescan_model_configs():
global _MODEL_CONFIGS
config_ext = ('.json',)
config_files = []
for config_path in _MODEL_CONFIG_PATHS:
if config_path.is_file() and config_path.suffix in config_ext:
config_files.append(config_path)
elif config_path.is_dir():
for ext in config_ext:
config_files.extend(config_path.glob(f'*{ext}'))
for cf in config_files:
with open(cf, 'r') as f:
model_cfg = json.load(f)
if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
_MODEL_CONFIGS[cf.stem] = model_cfg
_MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
_rescan_model_configs() # initial populate of model config registry
def list_models():
""" enumerate available model architectures based on config files """
return list(_MODEL_CONFIGS.keys())
def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
else:
position_ids = None
return {
"text": input_ids,
"images": image_inputs,
"past_key_values": past,
"position_ids": position_ids,
"attention_mask": attention_mask,
}
@dataclass
class MultimodalCfg(CLIPTextCfg):
mlp_ratio: int = 4
dim_head: int = 64
heads: int = 8
n_queries: int = 256
attn_pooler_heads: int = 8
try:
from transformers import (
BeamSearchScorer,
LogitsProcessorList,
TopPLogitsWarper,
TopKLogitsWarper,
RepetitionPenaltyLogitsProcessor,
MinLengthLogitsProcessor,
MaxLengthCriteria,
StopStringCriteria,
EosTokenCriteria,
StoppingCriteriaList
)
GENERATION_TYPES = {
"top_k": TopKLogitsWarper,
"top_p": TopPLogitsWarper,
"beam_search": "beam_search"
}
_has_transformers = True
except ImportError as e:
GENERATION_TYPES = {
"top_k": None,
"top_p": None,
"beam_search": "beam_search"
}
_has_transformers = False
def _token_to_tensor(token_id, device: str = "cpu") -> torch.Tensor:
if not isinstance(token_id, torch.Tensor):
if isinstance(token_id, int):
token_id = [token_id]
token_id = torch.tensor(token_id, device=device)
return token_id
def _build_text_decoder_tower(
embed_dim,
multimodal_cfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
):
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
act_layer = QuickGELU if quick_gelu else nn.GELU
norm_layer = (
LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
)
decoder = MultimodalTransformer(
context_length=multimodal_cfg.context_length,
width=multimodal_cfg.width,
heads=multimodal_cfg.heads,
layers=multimodal_cfg.layers,
ls_init_value=multimodal_cfg.ls_init_value,
output_dim=embed_dim,
act_layer=act_layer,
norm_layer=norm_layer,
)
return decoder
class CoCa(nn.Module):
def __init__(
self,
embed_dim,
multimodal_cfg: MultimodalCfg,
text_cfg: CLIPTextCfg,
vision_cfg: CLIPVisionCfg,
quick_gelu: bool = False,
init_logit_scale: float = np.log(1 / 0.07),
init_logit_bias: Optional[float] = None,
nonscalar_logit_scale: bool = False,
cast_dtype: Optional[torch.dtype] = None,
pad_id: int = 0,
):
super().__init__()
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg
self.text = _build_text_tower(
embed_dim=embed_dim,
text_cfg=text_cfg,
quick_gelu=quick_gelu,
cast_dtype=cast_dtype,
)
vocab_size = (
text_cfg.vocab_size # for hf models
if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
else text_cfg.vocab_size
)
self.visual = _build_vision_tower(
embed_dim=embed_dim,
vision_cfg=vision_cfg,
quick_gelu=quick_gelu,
cast_dtype=cast_dtype,
)
self.text_decoder = _build_text_decoder_tower(
vocab_size,
multimodal_cfg=multimodal_cfg,
quick_gelu=quick_gelu,
cast_dtype=cast_dtype,
)
lshape = [1] if nonscalar_logit_scale else []
self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)
if init_logit_bias is not None:
self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)
else:
self.logit_bias = None
self.pad_id = pad_id
self.context_length = multimodal_cfg.context_length
@torch.jit.ignore
def set_grad_checkpointing(self, enable: bool = True):
self.visual.set_grad_checkpointing(enable)
self.text.set_grad_checkpointing(enable)
self.text_decoder.set_grad_checkpointing(enable)
def _encode_image(self, images, normalize: bool = True):
image_latent, tokens_embs = self.visual(images)
image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
return image_latent, tokens_embs
def _encode_text(self, text, normalize: bool = True):
text_latent, token_emb = self.text(text)
text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
return text_latent, token_emb
def encode_image(self, images, normalize: bool = True):
image_latent, _ = self._encode_image(images, normalize=normalize)
return image_latent
def encode_text(self, text, normalize: bool = True):
text_latent, _ = self._encode_text(text, normalize=normalize)
return text_latent
def forward_intermediates(
self,
image: Optional[torch.Tensor] = None,
text: Optional[torch.Tensor] = None,
image_indices: Optional[Union[int, List[int]]] = None,
text_indices: Optional[Union[int, List[int]]] = None,
stop_early: bool = False,
normalize: bool = True,
normalize_intermediates: bool = False,
intermediates_only: bool = False,
image_output_fmt: str = 'NCHW',
image_output_extra_tokens: bool = False,
text_output_fmt: str = 'NLC',
text_output_extra_tokens: bool = False,
output_logits: bool = False,
output_logit_scale_bias: bool = False,
) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
image: Input image tensor
text: Input text tensor
image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence
text_indices: Take last n blocks if int, all if None, select matching indices if sequence
stop_early: Stop iterating over blocks when last desired intermediate hit
normalize: L2 Normalize final image and text features (if present)
normalize_intermediates: Apply final encoder norm layer to all intermediates (if possible)
intermediates_only: Only return intermediate features, do not return final features
image_output_fmt: Shape of intermediate image feature outputs
image_output_extra_tokens: Return both prefix and spatial intermediate tokens
text_output_fmt: Shape of intermediate text feature outputs
text_output_extra_tokens: Return both prefix and spatial intermediate tokens
output_logits: Include logits in output
output_logit_scale_bias: Include the logit scale bias in the output
Returns:
"""
output = {}
if intermediates_only:
# intermediates only disables final feature normalization, and include logits
normalize = False
output_logits = False
if output_logits:
assert False, 'FIXME, needs implementing'
if image is not None:
image_output = self.visual.forward_intermediates(
image,
indices=image_indices,
stop_early=stop_early,
normalize_intermediates=normalize_intermediates,
intermediates_only=intermediates_only,
output_fmt=image_output_fmt,
output_extra_tokens=image_output_extra_tokens,
)
if normalize and "image_features" in image_output:
image_output["image_features"] = F.normalize(image_output["image_features"], dim=-1)
output.update(image_output)
if text is not None:
text_output = self.text.forward_intermediates(
text,
indices=text_indices,
stop_early=stop_early,
normalize_intermediates=normalize_intermediates,
intermediates_only=intermediates_only,
output_fmt=text_output_fmt,
output_extra_tokens=text_output_extra_tokens,
)
if normalize and "text_features" in text_output:
text_output["text_features"] = F.normalize(text_output["text_features"], dim=-1)
output.update(text_output)
# FIXME text decoder
logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None
if output_logit_scale_bias:
output["logit_scale"] = logit_scale_exp
if self.logit_bias is not None:
output['logit_bias'] = self.logit_bias
return output
def forward(
self,
image,
text: Optional[torch.Tensor] = None,
image_latent: Optional[torch.Tensor] = None,
image_embs: Optional[torch.Tensor] = None,
output_labels: bool = True,
):
if image_latent is None or image_embs is None:
image_latent, image_embs = self._encode_image(image)
if text is None:
return {"image_features": image_latent, "image_embs": image_embs}
text_latent, token_embs = self._encode_text(text)
# FIXME this isn't an ideal solution, would like to improve -RW
labels: Optional[torch.Tensor] = text[:, 1:] if output_labels else None
if output_labels:
# align text_embs and thus logits with labels for teacher-forcing caption loss
token_embs = token_embs[:, :-1]
logits = self.text_decoder(image_embs, token_embs)
out_dict = {
"image_features": image_latent,
"text_features": text_latent,
"logits": logits,
"logit_scale": self.logit_scale.exp()
}
if labels is not None:
out_dict["labels"] = labels
if self.logit_bias is not None:
out_dict["logit_bias"] = self.logit_bias
return out_dict
def generate(
self,
image,
text=None,
seq_len=30,
max_seq_len=77,
temperature=1.,
generation_type="beam_search",
top_p=0.1, # keep tokens in the 1 - top_p quantile
top_k=1, # keeps the top_k most probable tokens
pad_token_id=None,
eos_token_id=None,
sot_token_id=None,
num_beams=6,
num_beam_groups=3,
min_seq_len=5,
stopping_criteria=None,
repetition_penalty=1.0,
fixed_output_length=False # if True output.shape == (batch_size, seq_len)
):
# taking many ideas and components from HuggingFace GenerationMixin
# https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`."
assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len"
device = image.device
with torch.no_grad():
sot_token_id = _token_to_tensor(49406 if sot_token_id is None else sot_token_id, device=device)
eos_token_id = _token_to_tensor(49407 if eos_token_id is None else eos_token_id, device=device)
pad_token_id = self.pad_id if pad_token_id is None else pad_token_id
logit_processor = LogitsProcessorList(
[
MinLengthLogitsProcessor(min_seq_len, eos_token_id),
RepetitionPenaltyLogitsProcessor(repetition_penalty),
]
)
if stopping_criteria is None:
stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]
stopping_criteria = StoppingCriteriaList(stopping_criteria)
if generation_type == "beam_search":
output = self._generate_beamsearch(
image_inputs=image,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
sot_token_id=sot_token_id,
num_beams=num_beams,
num_beam_groups=num_beam_groups,
min_seq_len=min_seq_len,
stopping_criteria=stopping_criteria,
logit_processor=logit_processor,
)
if fixed_output_length and output.shape[1] < seq_len:
pad_len = seq_len - output.shape[1]
return torch.cat((
output,
torch.ones(output.shape[0], pad_len, device=device, dtype=output.dtype) * pad_token_id
),
dim=1
)
return output
elif generation_type == "top_p":
logit_warper = GENERATION_TYPES[generation_type](top_p)
elif generation_type == "top_k":
logit_warper = GENERATION_TYPES[generation_type](top_k)
else:
raise ValueError(
f"generation_type has to be one of "
f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}."
)
image_latent, image_embs = self._encode_image(image)
if text is None:
text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id
was_training = self.training
num_dims = len(text.shape)
if num_dims == 1:
text = text[None, :]
self.eval()
out = text
while True:
x = out[:, -max_seq_len:]
cur_len = x.shape[1]
logits = self(
image,
x,
image_latent=image_latent,
image_embs=image_embs,
output_labels=False,
)["logits"][:, -1]
mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)
sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id
if mask.all():
if not fixed_output_length:
break
else:
logits = logits[~mask, :]
filtered_logits = logit_processor(x[~mask, :], logits)
filtered_logits = logit_warper(x[~mask, :], filtered_logits)
probs = F.softmax(filtered_logits / temperature, dim=-1)
if (cur_len + 1 == seq_len):
sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id
else:
sample[~mask, :] = torch.multinomial(probs, 1)
out = torch.cat((out, sample), dim=-1)
cur_len += 1
if all(stopping_criteria(out, None)):
break
if num_dims == 1:
out = out.squeeze(0)
self.train(was_training)
return out
def _generate_beamsearch(
self,
image_inputs,
pad_token_id=None,
eos_token_id=None,
sot_token_id=None,
num_beams=6,
num_beam_groups=3,
min_seq_len=5,
stopping_criteria=None,
logit_processor=None,
logit_warper=None,
):
device = image_inputs.device
batch_size = image_inputs.shape[0]
image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)
image_latent, image_embs = self._encode_image(image_inputs)
input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)
input_ids = input_ids * sot_token_id
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
device=device,
num_beam_groups=num_beam_groups,
)
# instantiate logits processors
logits_processor = (
LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])
if logit_processor is None
else logit_processor
)
num_beams = beam_scorer.num_beams
num_beam_groups = beam_scorer.num_beam_groups
num_sub_beams = num_beams // num_beam_groups
batch_size = len(beam_scorer._beam_hyps) // num_beam_groups
batch_beam_size, cur_len = input_ids.shape
beam_indices = None
if num_beams * batch_size != batch_beam_size:
raise ValueError(
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
)
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
# the same group don't produce same tokens everytime.
beam_scores[:, ::num_sub_beams] = 0
beam_scores = beam_scores.view((batch_size * num_beams,))
while True:
# predicted tokens in cur_len step
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
# indices which will form the beams in the next time step
reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
# do one decoder step on all beams of all sentences in batch
model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)
outputs = self(
model_inputs['images'],
model_inputs['text'],
image_latent=image_latent,
image_embs=image_embs,
output_labels=False,
)
for beam_group_idx in range(num_beam_groups):
group_start_idx = beam_group_idx * num_sub_beams
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
group_size = group_end_idx - group_start_idx
# indices of beams of current group among all sentences in batch
batch_group_indices = []
for batch_idx in range(batch_size):
batch_group_indices.extend(
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
)
group_input_ids = input_ids[batch_group_indices]
# select outputs of beams of currentg group only
next_token_logits = outputs['logits'][batch_group_indices, -1, :]
vocab_size = next_token_logits.shape[-1]
next_token_scores_processed = logits_processor(
group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
)
next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
# reshape for beam search
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
next_token_scores, next_tokens = torch.topk(
next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
)
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
next_tokens = next_tokens % vocab_size
# stateless
process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
beam_outputs = beam_scorer.process(
group_input_ids,
next_token_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
beam_indices=process_beam_indices,
group_index=beam_group_idx,
)
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]
input_ids[batch_group_indices] = group_input_ids[beam_idx]
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
current_tokens[batch_group_indices] = group_input_ids[:, -1]
# (beam_idx // group_size) -> batch_idx
# (beam_idx % group_size) -> offset of idx inside the group
reordering_indices[batch_group_indices] = (
num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
)
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
# increase cur_len
cur_len = cur_len + 1
if beam_scorer.is_done or all(stopping_criteria(input_ids, None)):
break
final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
sequence_outputs = beam_scorer.finalize(
input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
beam_indices=final_beam_indices,
)
return sequence_outputs['sequences']
def create_model(
model_name: str,
pretrained: Optional[str] = None,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
force_patch_dropout: Optional[float] = None,
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
force_preprocess_cfg: Optional[Dict[str, Any]] = None,
pretrained_image: bool = False,
pretrained_hf: bool = True,
cache_dir: Optional[str] = None,
output_dict: Optional[bool] = None,
require_pretrained: bool = False,
load_weights_only: bool = True,
**model_kwargs,
):
"""Creates and configures a contrastive vision-language model.
Args:
model_name: Name of the model architecture to create. Can be a local model name
or a Hugging Face model ID prefixed with 'hf-hub:'.
pretrained: Tag/path for pretrained model weights. Can be:
- A pretrained tag name (e.g., 'openai')
- A path to local weights
- None to initialize with random weights
precision: Model precision/AMP configuration. Options:
- 'fp32': 32-bit floating point
- 'fp16'/'bf16': Mixed precision with FP32 for certain layers
- 'pure_fp16'/'pure_bf16': Pure 16-bit precision
device: Device to load the model on ('cpu', 'cuda', or torch.device object)
jit: If True, JIT compile the model
force_quick_gelu: Force use of QuickGELU activation
force_custom_text: Force use of custom text encoder
force_patch_dropout: Override default patch dropout value
force_image_size: Override default image size for vision encoder
force_preprocess_cfg: Override default preprocessing configuration
pretrained_image: Load pretrained weights for timm vision models
pretrained_hf: Load pretrained weights for HF text models when not loading CLIP weights
cache_dir: Override default cache directory for downloaded model files
output_dict: If True and model supports it, return dictionary of features
require_pretrained: Raise error if pretrained weights cannot be loaded
load_weights_only: Only deserialize model weights and unpickling torch checkpoints (for safety)
**model_kwargs: Additional keyword arguments passed to model constructor
Returns:
Created and configured model instance
Raises:
RuntimeError: If model config is not found or required pretrained weights
cannot be loaded
Examples:
# Create basic CLIP model
model = create_model('ViT-B/32')
# Create CLIP model with mixed precision on GPU
model = create_model('ViT-B/32', precision='fp16', device='cuda')
# Load pretrained OpenAI weights
model = create_model('ViT-B/32', pretrained='openai')
# Load Hugging Face model
model = create_model('hf-hub:organization/model-name')
"""
force_preprocess_cfg = force_preprocess_cfg or {}
preprocess_cfg = asdict(PreprocessCfg())
has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
if has_hf_hub_prefix:
model_id = model_name[len(HF_HUB_PREFIX):]
checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
config = _get_hf_config(model_id, cache_dir=cache_dir)
preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg'])
model_cfg = config['model_cfg']
pretrained_hf = False # override, no need to load original HF text weights
else:
model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
checkpoint_path = None
model_cfg = None
if isinstance(device, str):
device = torch.device(device)
model_cfg = model_cfg or get_model_config(model_name)
if model_cfg is not None:
logging.info(f'Loaded {model_name} model config.')
else:
logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
raise RuntimeError(f'Model config for {model_name} not found.')
if force_quick_gelu:
# override for use of QuickGELU on non-OpenAI transformer models
model_cfg["quick_gelu"] = True
if force_patch_dropout is not None:
# override the default patch dropout value
model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
if force_image_size is not None:
# override model config's image size
model_cfg["vision_cfg"]["image_size"] = force_image_size
is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {})
if pretrained_image:
if is_timm_model:
# pretrained weight loading for timm models set via vision_cfg
model_cfg['vision_cfg']['timm_model_pretrained'] = True
else:
assert False, 'pretrained image towers currently only supported for timm models'
# cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes
cast_dtype = get_cast_dtype(precision)
is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
if is_hf_model:
# load pretrained weights for HF text model IFF no CLIP weights being loaded
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained
custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model
model_cfg = dict(model_cfg, **model_kwargs) # merge cfg dict w/ kwargs (kwargs overrides cfg)
if custom_text:
if "multimodal_cfg" in model_cfg:
model = CoCa(**model_cfg, cast_dtype=cast_dtype)
else:
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
else:
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
if precision in ("fp16", "bf16"):
dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
# manual mixed precision that matches original OpenAI behaviour
if is_timm_model:
# FIXME this is a bit janky, create timm based model in low-precision and
# then cast only LayerNormFp32 instances back to float32 so they don't break.
# Why? The convert_weights_to_lp fn only works with native models.
model.to(device=device, dtype=dtype)
# from .transformer import LayerNormFp32
def _convert_ln(m):
if isinstance(m, LayerNormFp32):
m.weight.data = m.weight.data.to(torch.float32)
m.bias.data = m.bias.data.to(torch.float32)
model.apply(_convert_ln)
else:
model.to(device=device)
convert_weights_to_lp(model, dtype=dtype)
elif precision in ("pure_fp16", "pure_bf16"):
dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
model.to(device=device, dtype=dtype)
else:
model.to(device=device)
pretrained_loaded = False
if pretrained:
checkpoint_path = ''
pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
if pretrained_cfg:
checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg)
pretrained_quick_gelu = pretrained_cfg.get('quick_gelu', False)
model_quick_gelu = model_cfg.get('quick_gelu', False)
if pretrained_quick_gelu and not model_quick_gelu:
warnings.warn(
f'These pretrained weights were trained with QuickGELU activation but the model config does '
f'not have that enabled. Consider using a model config with a "-quickgelu" suffix or enable with a flag.')
elif not pretrained_quick_gelu and model_quick_gelu:
warnings.warn(
f'The pretrained weights were not trained with QuickGELU but this activation is enabled in the '
f'model config, consider using a model config without QuickGELU or disable override flags.')
elif os.path.exists(pretrained):
checkpoint_path = pretrained
if checkpoint_path:
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
load_checkpoint(model, checkpoint_path, weights_only=load_weights_only)
else:
error_str = (
f'Pretrained weights ({pretrained}) not found for model {model_name}.'
f' Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
logging.warning(error_str)
raise RuntimeError(error_str)
pretrained_loaded = True
elif has_hf_hub_prefix:
logging.info(f'Loading pretrained {model_name} weights ({checkpoint_path}).')
load_checkpoint(model, checkpoint_path, weights_only=load_weights_only)
pretrained_loaded = True
if require_pretrained and not pretrained_loaded:
# callers of create_model_from_pretrained always expect pretrained weights
raise RuntimeError(
f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')
if output_dict and hasattr(model, "output_dict"):
model.output_dict = True
if jit:
model = torch.jit.script(model)
# set image preprocessing configuration in model attributes for convenience
if getattr(model.visual, 'image_size', None) is not None:
# use image_size set on model creation (via config or force_image_size arg)
force_preprocess_cfg['size'] = model.visual.image_size
set_model_preprocess_cfg(model, merge_preprocess_dict(preprocess_cfg, force_preprocess_cfg))
return model
def create_model_and_transforms(
model_name: str,
pretrained: Optional[str] = None,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
force_patch_dropout: Optional[float] = None,
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
image_interpolation: Optional[str] = None,
image_resize_mode: Optional[str] = None, # only effective for inference
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
pretrained_image: bool = False,
pretrained_hf: bool = True,
cache_dir: Optional[str] = None,
output_dict: Optional[bool] = None,
load_weights_only: bool = True,
**model_kwargs,
):
force_preprocess_cfg = merge_preprocess_kwargs(
{},
mean=image_mean,
std=image_std,
interpolation=image_interpolation,
resize_mode=image_resize_mode,
)
model = create_model(
model_name,
pretrained,
precision=precision,
device=device,
jit=jit,
force_quick_gelu=force_quick_gelu,
force_custom_text=force_custom_text,
force_patch_dropout=force_patch_dropout,
force_image_size=force_image_size,
force_preprocess_cfg=force_preprocess_cfg,
pretrained_image=pretrained_image,
pretrained_hf=pretrained_hf,
cache_dir=cache_dir,
output_dict=output_dict,
load_weights_only=load_weights_only,
**model_kwargs,
)
pp_cfg = PreprocessCfg(**model.visual.preprocess_cfg)
preprocess_train = image_transform_v2(
pp_cfg,
is_train=True,
aug_cfg=aug_cfg,
)
preprocess_val = image_transform_v2(
pp_cfg,
is_train=False,
)
return model, preprocess_train, preprocess_val
open_clip_model, open_clip_imgaug, open_clip_preprocess = create_model_and_transforms(
model_name='ViT-H-14', pretrained='laion2b_s32b_b79k', device=device
)
# print("ashish 1")
# exit()
# ==================================================================
# C S I P - M O D U L E
# ==================================================================
class CSIP(nn.Module):
def __init__(self, image_encoder, audio_encoder,
dim_img=None, dim_audio=1024, dim_emb=1024):
super(CSIP, self).__init__()
self.image_encoder = image_encoder # CLIPVisionModel
self.audio_encoder = audio_encoder # CLAP_audio_encoder
for param in self.image_encoder.parameters():
param.requires_grad = False
# self.image_proj = nn.Linear(dim_img, dim_emb)
self.audio_proj = nn.Linear(dim_audio, dim_emb)
# Learnable temperature parameter
self.log_temp = nn.Parameter(torch.tensor(0.07).log())
def forward(self, images, audios):
image_features = self.image_encoder(images) # shape: [n, dim_img]
audio_features = self.audio_encoder(audios)[0] # shape: [n, dim_audio]
# Step 2: Project and normalize
image_embeds = F.normalize(image_features, dim=1) # [n, dim_emb]
audio_embeds = F.normalize(self.audio_proj(audio_features), dim=1) # [n, dim_emb]
# Step 3: Cosine similarity with temperature
logits = torch.matmul(image_embeds, audio_embeds.T) * self.log_temp.exp() # [n, n]
probs = logits.softmax(dim=1)
# Step 4: Symmetric cross-entropy loss
labels = torch.arange(len(images), device=images.device)
loss_i = F.cross_entropy(logits, labels)
loss_a = F.cross_entropy(logits.T, labels)
loss = (loss_i + loss_a) / 2
# Step 5: Similarity metric (average cosine similarity on matched pairs)
similarity_scores = (image_embeds * audio_embeds).sum(dim=1) # Cosine similarity of matching pairs
avg_similarity = similarity_scores.mean()
return loss, loss_i, loss_a, logits, probs, avg_similarity
# ==================================================================
# I M A G E - A U D I O - D A T A S E T
# ==================================================================
class VaaniImageAudioDataset(torch.utils.data.Dataset):
def __init__(self, df):
self.image_paths = df.image_path.tolist()
self.audio_paths = df.audio_path.tolist()
def __len__(self):
return len(self.audio_paths)
def __getitem__(self, idx):
return {
'image_path': self.image_paths[idx],
'audio_path': self.audio_paths[idx]
}
def collate_fn(batch):
image_tensor = [open_clip_imgaug(Image.open(item['image_path'])) for item in batch]
image_paths = [item['image_path'] for item in batch]
audio_paths = [item['audio_path'] for item in batch]
audio_tensor = CLAPAudioProcessor(audio_paths, resample=True)
return {
'image_tensor': torch.stack(image_tensor),
'image_paths': image_paths,
'audio_tensor': audio_tensor,
'audio_paths': audio_paths
}
train_df = pd.read_csv("/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/available_img_audios_TRAIN2.csv")
test_df = pd.read_csv("/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/available_img_audios_TEST2.csv")
train_dataset = VaaniImageAudioDataset(train_df)
test_dataset = VaaniImageAudioDataset(test_df)
print('Train Dataset:', len(train_dataset))
print('Test Dataset:', len(test_dataset))
BATCH_SIZE = int(64)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=48,
collate_fn=collate_fn,
pin_memory=True,
drop_last=False,
persistent_workers=True
)
test_dataloader = torch.utils.data.DataLoader(
test_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=48,
collate_fn=collate_fn,
pin_memory=True,
drop_last=False,
persistent_workers=True
)
batch = next(iter(train_dataloader))
image_tensor_batch = batch['image_tensor'].to(device=device)
audio_tensor_batch = batch['audio_tensor'].to(device=device)
image_paths_batch = batch['image_paths']
audio_paths_batch = batch['audio_paths']
print("Image batch shape:", image_tensor_batch.shape) # [BATCH_SIZE, 3, 224, 224]
print("Audio batch shape:", audio_tensor_batch.shape) # [BATCH_SIZE, 1, 44100]
csip_model = CSIP(open_clip_model.visual, peft_clap_audio_encoder).to(device)
# csip_model = nn.DataParallel(CSIP(open_clip_model.visual, peft_clap_audio_encoder), device_ids=[0, 1]).to(device)
from torchinfo import summary
import subprocess
summary(model=csip_model,
input_data=((image_tensor_batch.to(device)), (audio_tensor_batch.to(device))),
# input_size = (1, 3, config.IMAGE_SIZE, config.IMAGE_SIZE),
dtypes=[torch.long],
col_names = ["input_size", "output_size", "num_params", "trainable", "params_percent"],
col_width=20,
row_settings=["var_names"],
depth = 2,
# verbose=2,
# device=device
)
# loss, logits, probs = csip_model(batch['image_tensor'].to(device), batch['audio_tensor'].to(device))
# loss, logits, probs, logits.shape, probs.shape
def train_batch(model, images, audio, optimizer):
model.train()
optimizer.zero_grad()
loss, loss_i, loss_a, logits, probs, avg_similarity = model(images, audio)
loss.backward()
optimizer.step()
return loss.item(), loss_i.item(), loss_a.item(), logits, probs, avg_similarity.item()
@torch.no_grad()
def evaluate_batch(model, images, audio):
model.eval()
loss, loss_i, loss_a, logits, probs, avg_similarity = model(images, audio)
return loss.item(), loss_i.item(), loss_a.item(), logits, probs, avg_similarity.item()
def save_checkpoint(state, checkpoint_dir, epoch, max_checkpoints=2):
filename = f"csip_best_epoch_{epoch+1}.pt"
path = os.path.join(checkpoint_dir, filename)
torch.save(state, path)
checkpoints = sorted(
[f for f in os.listdir(checkpoint_dir) if f.startswith("csip_best_epoch_")],
key=lambda x: int(x.split("_")[-1].split(".")[0])
)
while len(checkpoints) > max_checkpoints:
to_delete = checkpoints.pop(0)
os.remove(os.path.join(checkpoint_dir, to_delete))
def load_checkpoint(checkpoint_dir, model, optimizer, scheduler):
checkpoints = sorted(
[f for f in os.listdir(checkpoint_dir) if f.startswith("csip_best_epoch_")],
key=lambda x: int(x.split("_")[-1].split(".")[0])
)
if not checkpoints:
print("No checkpoint found to resume from.")
return 0, float("inf")
best_ckpt = checkpoints[-1]
path = os.path.join(checkpoint_dir, best_ckpt)
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state'])
# optimizer.load_state_dict(checkpoint['optimizer_state'])
# scheduler.load_state_dict(checkpoint['scheduler_state'])
start_epoch = checkpoint['epoch']
best_loss = checkpoint['best_loss']
print(f"Resumed training from epoch {start_epoch+1} with best loss {best_loss:.4f}")
return start_epoch, best_loss
def fig_to_tensor(fig):
"""Convert a Matplotlib figure to a tensor suitable for TensorBoard."""
buf = io.BytesIO()
fig.savefig(buf, format='png')
buf.seek(0)
image = Image.open(buf).convert("RGB")
tensor = torchvision.transforms.functional.to_tensor(image)
buf.close()
plt.close(fig)
return tensor
def save_similarity_heatmaps(logits, epoch, loss, save_dir, writer):
os.makedirs(os.path.join(save_dir, 'logits'), exist_ok=True)
os.makedirs(os.path.join(save_dir, 'probs'), exist_ok=True)
# --- Raw logits heatmap ---
logits_np = logits.detach().cpu().numpy()
fig_logits = plt.figure(figsize=(15, 13))
sns.heatmap(logits_np, square=True, cmap="Blues", cbar=True, annot=False)
plt.title(f"Raw Logits Heatmap — Epoch {epoch+1}, Loss {loss:.4f}")
plt.xlabel("Audio Index")
plt.ylabel("Image Index")
raw_path = os.path.join(save_dir, 'logits', f"raw_logits_epoch_{epoch+1}_loss_{loss:.4f}.png")
fig_logits.savefig(raw_path)
writer.add_image("Heatmap/RawLogits", fig_to_tensor(fig_logits), global_step=epoch+1)
# --- Softmax probs heatmap ---
probs_np = logits.softmax(dim=1).cpu().numpy()
fig_probs = plt.figure(figsize=(15, 13))
sns.heatmap(probs_np, square=True, cmap="Blues", cbar=True, annot=False)
plt.title(f"Softmax Probabilities Heatmap — Epoch {epoch+1}, Loss {loss:.4f}")
plt.xlabel("Audio Index")
plt.ylabel("Image Index")
prob_path = os.path.join(save_dir, "probs", f"probs_epoch_{epoch+1}_loss_{loss:.4f}.png")
fig_probs.savefig(prob_path)
writer.add_image("Heatmap/SoftmaxProbs", fig_to_tensor(fig_probs), global_step=epoch+1)
def train_model(model, train_loader, test_loader,
optimizer, scheduler, device, log_dir,
checkpoint_dir, resume=False, epochs=10):
os.makedirs(log_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)
csv_path = os.path.join(log_dir, "training_log.csv")
writer = SummaryWriter(log_dir=log_dir)
start_epoch = 0
best_loss = float("inf")
best_epoch = -1
if resume:
start_epoch, best_loss = load_checkpoint(checkpoint_dir, model, optimizer, scheduler)
# If resuming, don't overwrite the CSV
if not (resume and os.path.exists(csv_path)):
with open(csv_path, mode='w', newline='') as f:
writer_csv = csv.writer(f)
writer_csv.writerow(["Epoch", "Best Epoch", "Train Loss", "Test Loss",
"Best Loss", "Train Sim", "Test Sim", "Learning Rate",
"Train I-Loss", "Test I-Loss", "Train A-Loss", "Test A-Loss"])
for epoch in trange(start_epoch, epochs, colour='yellow', dynamic_ncols=True):
train_losses = []
test_losses = []
train_i_losses = []
test_i_losses = []
train_a_losses = []
test_a_losses = []
train_sim = []
test_sim = []
train_loop = tqdm(train_loader, desc=f"[TrainEp {epoch+1}]", colour='blue', dynamic_ncols=True)
for batch in train_loop:
images = batch['image_tensor'].to(device)
audios = batch['audio_tensor'].to(device)
loss, loss_i, loss_a, logits, probs, avg_similarity = train_batch(model, images, audios, optimizer)
train_losses.append(loss)
train_i_losses.append(loss_i)
train_a_losses.append(loss_a)
train_sim.append(avg_similarity)
train_loop.set_postfix(trainLoss=loss)
test_loop = tqdm(test_loader, desc=f"[TestEp {epoch+1}]", colour='red', dynamic_ncols=True)
for batch in test_loop:
images = batch['image_tensor'].to(device)
audios = batch['audio_tensor'].to(device)
loss, loss_i, loss_a, logits, probs, avg_similarity = evaluate_batch(model, images, audios)
test_losses.append(loss)
test_i_losses.append(loss_i)
test_a_losses.append(loss_a)
test_sim.append(avg_similarity)
test_loop.set_postfix(testLoss=loss)
avg_train_loss = sum(train_losses) / len(train_losses)
avg_test_loss = sum(test_losses) / len(test_losses)
avg_train_i_loss = sum(train_i_losses) / len(train_i_losses)
avg_test_i_loss = sum(test_i_losses) / len(test_i_losses)
avg_train_a_loss = sum(train_a_losses) / len(train_a_losses)
avg_test_a_loss = sum(test_a_losses) / len(test_a_losses)
avg_train_sim = sum(train_sim) / len(train_sim)
avg_test_sim = sum(test_sim) / len(test_sim)
current_lr = optimizer.param_groups[0]['lr']
writer.add_scalar("Loss/Train", avg_train_loss, epoch + 1)
writer.add_scalar("Loss/Test", avg_test_loss, epoch + 1)
writer.add_scalar("Loss/Train/Image", avg_train_i_loss, epoch + 1)
writer.add_scalar("Loss/Test/Image", avg_test_i_loss, epoch + 1)
writer.add_scalar("Loss/Train/Audio", avg_train_a_loss, epoch + 1)
writer.add_scalar("Loss/Test/Audio", avg_test_a_loss, epoch + 1)
writer.add_scalar("Similarity/Train", avg_train_sim, epoch + 1)
writer.add_scalar("Similarity/Test", avg_test_sim, epoch + 1)
writer.add_scalar("Learning Rate", current_lr, epoch + 1)
print(f"\n\n |"
f"Epoch {epoch+1} | Loss: ({avg_train_loss:.4f}, {avg_test_loss:.4f}, {best_loss:.4f}) |"
f"LR: {current_lr:.2e} |"
f"Similarity: ({avg_train_sim:.4f}, {avg_test_sim:.4f})"
f"| I-Loss: ({avg_train_i_loss:.4f}, {avg_test_i_loss:.4f})"
f"| A-Loss: ({avg_train_a_loss:.4f}, {avg_test_a_loss:.4f})"
)
if avg_test_loss < best_loss:
save_similarity_heatmaps(logits, epoch, avg_test_loss, checkpoint_dir, writer)
best_loss = avg_test_loss
best_epoch = epoch + 1
save_checkpoint({
'epoch': epoch,
'model_state': model.state_dict(),
'optimizer_state': optimizer.state_dict(),
'best_loss': best_loss,
'train_loss': avg_train_loss,
'similarity': avg_test_sim,
'train_similarity': avg_train_sim,
'learning_rate': current_lr,
'scheduler_state': scheduler.state_dict() if scheduler else None
}, checkpoint_dir, epoch)
print(f">>> Saved new best model at epoch {epoch+1}")
scheduler.step()
with open(csv_path, mode='a', newline='') as f:
writer_csv = csv.writer(f)
writer_csv.writerow([epoch + 1, best_epoch, avg_train_loss, avg_test_loss, best_loss,
avg_train_sim, avg_test_sim, current_lr,
avg_train_i_loss, avg_test_i_loss, avg_train_a_loss, avg_test_a_loss])
writer.close()
print(f"Training completed. Best epoch: {best_epoch}, Best loss: {best_loss:.4f}, Best similarity: {avg_test_sim:.4f}")
print(f"Training log saved to {csv_path}")
model_name = "csip_model_openClip_CLAP"
epochs = 500
learning_rate = 1e-5
optimizer = torch.optim.AdamW(csip_model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-10)
# learning_rate = 1e-3
# optimizer = torch.optim.AdamW(csip_model.parameters(), lr=learning_rate)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-5)
# subprocess.run([
# "rm",
# "-rf",
# f"/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/{model_name}",
# ])
train_model(
model=csip_model,
train_loader=train_dataloader,
test_loader=test_dataloader,
optimizer=optimizer,
scheduler=scheduler,
device=device,
log_dir=f"{model_name}/runs/csip",
checkpoint_dir=f"{model_name}/checkpoints/csip",
resume=True,
epochs=epochs
)
# tensorboard --logdir=/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/runs --port=6006
# 127.0.0.1:40697
# tensorboard --logdir=/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/runs --port=6006 --host=0.0.0.0