{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using device: cuda:0\n", "Author: Ashish\n", "\n", "Last updated: 2025-05-27T01:20:24.023831+05:30\n", "\n", "Python implementation: CPython\n", "Python version : 3.11.11\n", "IPython version : 9.1.0\n", "\n", "conda environment: clap\n", "\n", "Compiler : GCC 11.2.0\n", "OS : Linux\n", "Release : 4.18.0-513.5.1.el8_9.x86_64\n", "Machine : x86_64\n", "Processor : x86_64\n", "CPU cores : 48\n", "Architecture: 64bit\n", "\n", "Hostname: rmgpu013\n", "\n", "watermark : 2.5.0\n", "tqdm : 4.67.1\n", "matplotlib : 3.10.1\n", "torchvision : 0.16.2\n", "huggingface_hub: 0.31.2\n", "transformers : 4.51.3\n", "PIL : 11.1.0\n", "torch : 2.1.2\n", "pandas : 2.2.3\n", "csv : 1.0\n", "sys : 3.11.11 (main, Dec 11 2024, 16:28:39) [GCC 11.2.0]\n", "re : 2.2.1\n", "numpy : 1.26.0\n", "peft : 0.15.2\n", "seaborn : 0.13.2\n", "torchlibrosa : 0.1.0\n", "torchaudio : 2.1.2\n", "\n", "GPU Info: \n", " GPU 0: NVIDIA A100 80GB PCIe\n", " GPU 1: NVIDIA A100 80GB PCIe\n", "\n" ] } ], "source": [ "# ==================================================================\n", "# L A T E N T D I F F U S I O N M O D E L\n", "# ==================================================================\n", "# Author : Ashish Kumar Uchadiya\n", "# Created : May 11, 2025\n", "# Description: This script implements the training of a VQ-VAE model for\n", "# image reconstruction, integrated with Latent Diffusion Models (LDMs) and\n", "# audio conditioning. The VQ-VAE maps images to a discrete latent space, \n", "# which is then modeled by the LDM for learning a diffusion process over the \n", "# compressed representation. Audio features are used as conditioning inputs \n", "# to guide the generation process. The training minimizes a combination of \n", "# LPIPS (Learned Perceptual Image Patch Similarity) loss for perceptual \n", "# fidelity and PatchGAN loss to enforce local realism. This setup enables \n", "# efficient and semantically-aware generation of high-quality images driven \n", "# by audio cues.\n", "# ==================================================================\n", "# I M P O R T S\n", "# ==================================================================\n", "from __future__ import annotations\n", "import warnings\n", "warnings.filterwarnings(\"ignore\")\n", "\n", "import os\n", "import io\n", "import sys\n", "import math\n", "import random\n", "import collections\n", "import collections.abc\n", "import re\n", "from itertools import repeat\n", "from pathlib import Path\n", "from typing import Optional, Tuple, Union, List, Dict\n", "\n", "import csv\n", "import numpy as np\n", "import pandas as pd\n", "from PIL import Image\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "from tqdm import trange, tqdm\n", "\n", "import torch\n", "import torch.nn.functional as F\n", "from torch import nn\n", "from torch.nn.init import _calculate_fan_in_and_fan_out\n", "import torch.utils.checkpoint as checkpoint\n", "\n", "import torchvision\n", "from torchvision.transforms import v2\n", "from torch.utils.tensorboard import SummaryWriter\n", "# from tensorboardX import SummaryWriter\n", "\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "print(f\"Using device: {device}\")\n", "\n", "import torchaudio\n", "import torchaudio.transforms as T\n", "from torchlibrosa.stft import Spectrogram, LogmelFilterBank\n", "from torchlibrosa.augmentation import SpecAugmentation\n", "\n", "from transformers import AutoModel, AutoTokenizer, logging\n", "from huggingface_hub.file_download import hf_hub_download\n", "from huggingface_hub.file_download import hf_hub_download\n", "from peft import get_peft_config, get_peft_model\n", "from transformers import CLIPVisionModel, AutoProcessor\n", "\n", "from watermark import watermark\n", "print(watermark(\n", " author='Ashish',\n", " # email='ashish@example.com',\n", " current_date=True,\n", " datename=True,\n", " current_time=True,\n", " iso8601=True,\n", " timezone=True,\n", " updated=True,\n", " custom_time=None,\n", " python=True,\n", " # packages=\"torch,torchvision,numpy\",\n", " conda=True,\n", " hostname=True,\n", " machine=True,\n", " watermark=False,\n", " iversions=True,\n", " gpu=True,\n", " globals_=globals()\n", "))\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "trainable params: 1,289,494 || all params: 34,355,927 || trainable%: 3.7533\n" ] } ], "source": [ "# ==================================================================\n", "# H T S - A T\n", "# ==================================================================\n", "class HTSATConfig:\n", " # Ke Chen\n", " # knutchen@ucsd.edu\n", " # HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION\n", " # The configuration for training the model\n", "\n", " exp_name = \"exp_htsat_pretrain\" # the saved ckpt prefix name of the model \n", " workspace = \"/home/kechen/Research/HTSAT\" # the folder of your code\n", " dataset_path = \"/home/Research/audioset\" # the dataset path\n", " desed_folder = \"/home/Research/DESED\" # the desed file\n", "\n", " dataset_type = \"audioset\" # \"audioset\" \"esc-50\" \"scv2\"\n", " index_type = \"full_train\" # only works for audioset\n", " balanced_data = True # only works for audioset\n", "\n", " loss_type = \"clip_bce\" # \n", " # AudioSet & SCV2: \"clip_bce\" | ESC-50: \"clip_ce\" \n", "\n", " # trained from a checkpoint, or evaluate a single model \n", " resume_checkpoint = None \n", " # \"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_1.ckpt\"\n", " \n", " esc_fold = 0 # just for esc dataset, select the fold you need for evaluation and (+1) validation\n", "\n", "\n", " debug = False\n", "\n", " random_seed = 970131 # 19970318 970131 12412 127777 1009 34047\n", " batch_size = 32 * 4 # batch size per GPU x GPU number , default is 32 x 4 = 128\n", " learning_rate = 1e-3 # 1e-4 also workable \n", " max_epoch = 100\n", " num_workers = 3\n", "\n", " lr_scheduler_epoch = [10,20,30]\n", " lr_rate = [0.02, 0.05, 0.1]\n", "\n", " # these data preparation optimizations do not bring many improvements, so deprecated\n", " enable_token_label = False # token label\n", " class_map_path = \"class_hier_map.npy\"\n", " class_filter = None \n", " retrieval_index = [15382, 9202, 130, 17618, 17157, 17516, 16356, 6165, 13992, 9238, 5550, 5733, 1914, 1600, 3450, 13735, 11108, 3762, \n", " 9840, 11318, 8131, 4429, 16748, 4992, 16783, 12691, 4945, 8779, 2805, 9418, 2797, 14357, 5603, 212, 3852, 12666, 1338, 10269, 2388, 8260, 4293, 14454, 7677, 11253, 5060, 14938, 8840, 4542, 2627, 16336, 8992, 15496, 11140, 446, 6126, 10691, 8624, 10127, 9068, 16710, 10155, 14358, 7567, 5695, 2354, 8057, 17635, 133, 16183, 14535, 7248, 4560, 14429, 2463, 10773, 113, 2462, 9223, 4929, 14274, 4716, 17307, 4617, 2132, 11083, 1039, 1403, 9621, 13936, 2229, 2875, 17840, 9359, 13311, 9790, 13288, 4750, 17052, 8260, 14900]\n", " token_label_range = [0.2,0.6]\n", " enable_time_shift = False # shift time\n", " enable_label_enhance = False # enhance hierarchical label\n", " enable_repeat_mode = False # repeat the spectrogram / reshape the spectrogram\n", "\n", "\n", "\n", " # for model's design\n", " enable_tscam = True # enbale the token-semantic layer\n", "\n", " # for signal processing\n", " sample_rate = 32000 # 16000 for scv2, 32000 for audioset and esc-50\n", " clip_samples = sample_rate * 10 # audio_set 10-sec clip\n", " window_size = 1024\n", " hop_size = 320 # 160 for scv2, 320 for audioset and esc-50\n", " mel_bins = 64\n", " fmin = 50\n", " fmax = 14000\n", " shift_max = int(clip_samples * 0.5)\n", "\n", " # for data collection\n", " classes_num = 527 # esc: 50 | audioset: 527 | scv2: 35\n", " patch_size = (25, 4) # deprecated\n", " crop_size = None # int(clip_samples * 0.5) deprecated\n", "\n", " # for htsat hyperparamater\n", " htsat_window_size = 8\n", " htsat_spec_size = 256\n", " htsat_patch_size = 4 \n", " htsat_stride = (4, 4)\n", " htsat_num_head = [4,8,16,32]\n", " htsat_dim = 96 \n", " htsat_depth = [2,2,6,2]\n", "\n", " swin_pretrain_path = None\n", " # \"/home/Research/model_backup/pretrain/swin_tiny_c24_patch4_window8_256.pth\"\n", "\n", " # Some Deprecated Optimization in the model design, check the model code for details\n", " htsat_attn_heatmap = False\n", " htsat_hier_output = False \n", " htsat_use_max = False\n", "\n", "\n", " # for ensemble test \n", "\n", " ensemble_checkpoints = []\n", " ensemble_strides = []\n", "\n", "\n", " # weight average folder\n", " wa_folder = \"/home/version_0/checkpoints/\"\n", " # weight average output filename\n", " wa_model_path = \"HTSAT_AudioSet_Saved_x.ckpt\"\n", "\n", " esm_model_pathes = [\n", " \"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_1.ckpt\",\n", " \"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_2.ckpt\",\n", " \"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_3.ckpt\",\n", " \"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_4.ckpt\",\n", " \"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_5.ckpt\",\n", " \"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_6.ckpt\"\n", " ]\n", "\n", " # for framewise localization\n", " heatmap_dir = \"/home/Research/heatmap_output\"\n", " test_file = \"htsat-test-ensemble\"\n", " fl_local = False # indicate if we need to use this dataset for the framewise detection\n", " fl_dataset = \"/home/Research/desed/desedim_embval.npy\" \n", " fl_class_num = [\n", " \"Speech\", \"Frying\", \"Dishes\", \"Running_water\",\n", " \"Blender\", \"Electric_shaver_toothbrush\", \"Alarm_bell_ringing\",\n", " \"Cat\", \"Dog\", \"Vacuum_cleaner\"\n", " ]\n", "\n", " # map 527 classes into 10 classes\n", " fl_audioset_mapping = [\n", " [0,1,2,3,4,5,6,7],\n", " [366, 367, 368],\n", " [364],\n", " [288, 289, 290, 291, 292, 293, 294, 295, 296, 297],\n", " [369],\n", " [382],\n", " [310, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402],\n", " [81, 82, 83, 84, 85],\n", " [74, 75, 76, 77, 78, 79],\n", " [377]\n", " ]\n", "\n", "\n", "\n", "def _ntuple(n):\n", " def parse(x):\n", " if isinstance(x, collections.abc.Iterable):\n", " return x\n", " return tuple(repeat(x, n))\n", " return parse\n", "\n", "to_1tuple = _ntuple(1)\n", "to_2tuple = _ntuple(2)\n", "to_3tuple = _ntuple(3)\n", "to_4tuple = _ntuple(4)\n", "to_ntuple = _ntuple\n", "\n", "def do_mixup(x, mixup_lambda):\n", " \"\"\"Mixup x of even indexes (0, 2, 4, ...) with x of odd indexes \n", " (1, 3, 5, ...).\n", " Args:\n", " x: (batch_size * 2, ...)\n", " mixup_lambda: (batch_size * 2,)\n", " Returns:\n", " out: (batch_size, ...)\n", " \"\"\"\n", " out = (x[0 :: 2].transpose(0, -1) * mixup_lambda[0 :: 2] + \\\n", " x[1 :: 2].transpose(0, -1) * mixup_lambda[1 :: 2]).transpose(0, -1)\n", " return out\n", "\n", "def interpolate(x, ratio):\n", " \"\"\"Interpolate data in time domain. This is used to compensate the \n", " resolution reduction in downsampling of a CNN.\n", " \n", " Args:\n", " x: (batch_size, time_steps, classes_num)\n", " ratio: int, ratio to interpolate\n", " Returns:\n", " upsampled: (batch_size, time_steps * ratio, classes_num)\n", " \"\"\"\n", " (batch_size, time_steps, classes_num) = x.shape\n", " upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)\n", " upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)\n", " return upsampled\n", "\n", "\n", "def drop_path(x, drop_prob: float = 0., training: bool = False):\n", " \"\"\"Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).\n", " This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,\n", " the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...\n", " See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for\n", " changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use\n", " 'survival rate' as the argument.\n", " \"\"\"\n", " if drop_prob == 0. or not training:\n", " return x\n", " keep_prob = 1 - drop_prob\n", " shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets\n", " random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)\n", " random_tensor.floor_() # binarize\n", " output = x.div(keep_prob) * random_tensor\n", " return output\n", "\n", "\n", "class DropPath(nn.Module):\n", " \"\"\"Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).\n", " \"\"\"\n", " def __init__(self, drop_prob=None):\n", " super(DropPath, self).__init__()\n", " self.drop_prob = drop_prob\n", "\n", " def forward(self, x):\n", " return drop_path(x, self.drop_prob, self.training)\n", "\n", "class PatchEmbed(nn.Module):\n", " \"\"\" 2D Image to Patch Embedding\n", " \"\"\"\n", " def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, patch_stride = 16):\n", " super().__init__()\n", " img_size = to_2tuple(img_size)\n", " patch_size = to_2tuple(patch_size)\n", " patch_stride = to_2tuple(patch_stride)\n", " self.img_size = img_size\n", " self.patch_size = patch_size\n", " self.patch_stride = patch_stride\n", " self.grid_size = (img_size[0] // patch_stride[0], img_size[1] // patch_stride[1])\n", " self.num_patches = self.grid_size[0] * self.grid_size[1]\n", " self.flatten = flatten\n", " self.in_chans = in_chans\n", " self.embed_dim = embed_dim\n", " \n", " padding = ((patch_size[0] - patch_stride[0]) // 2, (patch_size[1] - patch_stride[1]) // 2)\n", "\n", " self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride, padding=padding)\n", " self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()\n", "\n", " def forward(self, x):\n", " B, C, H, W = x.shape\n", " assert H == self.img_size[0] and W == self.img_size[1], \\\n", " f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n", " x = self.proj(x)\n", " if self.flatten:\n", " x = x.flatten(2).transpose(1, 2) # BCHW -> BNC\n", " x = self.norm(x)\n", " return x\n", "\n", "class Mlp(nn.Module):\n", " \"\"\" MLP as used in Vision Transformer, MLP-Mixer and related networks\n", " \"\"\"\n", " def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n", " super().__init__()\n", " out_features = out_features or in_features\n", " hidden_features = hidden_features or in_features\n", " self.fc1 = nn.Linear(in_features, hidden_features)\n", " self.act = act_layer()\n", " self.fc2 = nn.Linear(hidden_features, out_features)\n", " self.drop = nn.Dropout(drop)\n", "\n", " def forward(self, x):\n", " x = self.fc1(x)\n", " x = self.act(x)\n", " x = self.drop(x)\n", " x = self.fc2(x)\n", " x = self.drop(x)\n", " return x\n", "\n", "def _no_gradim_audiorunc_normal_(tensor, mean, std, a, b):\n", " # Cut & paste from PyTorch official master until it's in a few official releases - RW\n", " # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf\n", " def norm_cdf(x):\n", " # Computes standard normal cumulative distribution function\n", " return (1. + math.erf(x / math.sqrt(2.))) / 2.\n", "\n", " if (mean < a - 2 * std) or (mean > b + 2 * std):\n", " warnings.warn(\"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. \"\n", " \"The distribution of values may be incorrect.\",\n", " stacklevel=2)\n", "\n", " with torch.no_grad():\n", " # Values are generated by using a truncated uniform distribution and\n", " # then using the inverse CDF for the normal distribution.\n", " # Get upper and lower cdf values\n", " l = norm_cdf((a - mean) / std)\n", " u = norm_cdf((b - mean) / std)\n", "\n", " # Uniformly fill tensor with values from [l, u], then translate to\n", " # [2l-1, 2u-1].\n", " tensor.uniform_(2 * l - 1, 2 * u - 1)\n", "\n", " # Use inverse cdf transform for normal distribution to get truncated\n", " # standard normal\n", " tensor.erfinv_()\n", "\n", " # Transform to proper mean, std\n", " tensor.mul_(std * math.sqrt(2.))\n", " tensor.add_(mean)\n", "\n", " # Clamp to ensure it's in the proper range\n", " tensor.clamp_(min=a, max=b)\n", " return tensor\n", "\n", "\n", "def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):\n", " # type: (Tensor, float, float, float, float) -> Tensor\n", " r\"\"\"Fills the input Tensor with values drawn from a truncated\n", " normal distribution. The values are effectively drawn from the\n", " normal distribution :math:`\\mathcal{N}(\\text{mean}, \\text{std}^2)`\n", " with values outside :math:`[a, b]` redrawn until they are within\n", " the bounds. The method used for generating the random values works\n", " best when :math:`a \\leq \\text{mean} \\leq b`.\n", " Args:\n", " tensor: an n-dimensional `torch.Tensor`\n", " mean: the mean of the normal distribution\n", " std: the standard deviation of the normal distribution\n", " a: the minimum cutoff value\n", " b: the maximum cutoff value\n", " Examples:\n", " >>> w = torch.empty(3, 5)\n", " >>> nn.init.trunc_normal_(w)\n", " \"\"\"\n", " return _no_gradim_audiorunc_normal_(tensor, mean, std, a, b)\n", "\n", "\n", "def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):\n", " fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)\n", " if mode == 'fan_in':\n", " denom = fan_in\n", " elif mode == 'fan_out':\n", " denom = fan_out\n", " elif mode == 'fan_avg':\n", " denom = (fan_in + fan_out) / 2\n", "\n", " variance = scale / denom\n", "\n", " if distribution == \"truncated_normal\":\n", " # constant is stddev of standard normal truncated to (-2, 2)\n", " trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)\n", " elif distribution == \"normal\":\n", " tensor.normal_(std=math.sqrt(variance))\n", " elif distribution == \"uniform\":\n", " bound = math.sqrt(3 * variance)\n", " tensor.uniform_(-bound, bound)\n", " else:\n", " raise ValueError(f\"invalid distribution {distribution}\")\n", "\n", "\n", "def lecun_normal_(tensor):\n", " variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')\n", "\n", "\n", "# below codes are based and referred from https://github.com/microsoft/Swin-Transformer\n", "# Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf\n", "\n", "def window_partition(x, window_size):\n", " \"\"\"\n", " Args:\n", " x: (B, H, W, C)\n", " window_size (int): window size\n", " Returns:\n", " windows: (num_windows*B, window_size, window_size, C)\n", " \"\"\"\n", " B, H, W, C = x.shape\n", " x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)\n", " windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)\n", " return windows\n", "\n", "\n", "def window_reverse(windows, window_size, H, W):\n", " \"\"\"\n", " Args:\n", " windows: (num_windows*B, window_size, window_size, C)\n", " window_size (int): Window size\n", " H (int): Height of image\n", " W (int): Width of image\n", " Returns:\n", " x: (B, H, W, C)\n", " \"\"\"\n", " B = int(windows.shape[0] / (H * W / window_size / window_size))\n", " x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)\n", " x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)\n", " return x\n", "\n", "\n", "class WindowAttention(nn.Module):\n", " r\"\"\" Window based multi-head self attention (W-MSA) module with relative position bias.\n", " It supports both of shifted and non-shifted window.\n", " Args:\n", " dim (int): Number of input channels.\n", " window_size (tuple[int]): The height and width of the window.\n", " num_heads (int): Number of attention heads.\n", " qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n", " qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set\n", " attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0\n", " proj_drop (float, optional): Dropout ratio of output. Default: 0.0\n", " \"\"\"\n", "\n", " def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):\n", "\n", " super().__init__()\n", " self.dim = dim\n", " self.window_size = window_size # Wh, Ww\n", " self.num_heads = num_heads\n", " head_dim = dim // num_heads\n", " self.scale = qk_scale or head_dim ** -0.5\n", "\n", " # define a parameter table of relative position bias\n", " self.relative_position_bias_table = nn.Parameter(\n", " torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH\n", "\n", " # get pair-wise relative position index for each token inside the window\n", " coords_h = torch.arange(self.window_size[0])\n", " coords_w = torch.arange(self.window_size[1])\n", " coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww\n", " coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww\n", " relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww\n", " relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2\n", " relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0\n", " relative_coords[:, :, 1] += self.window_size[1] - 1\n", " relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1\n", " relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww\n", " self.register_buffer(\"relative_position_index\", relative_position_index)\n", "\n", " self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n", " self.attn_drop = nn.Dropout(attn_drop)\n", " self.proj = nn.Linear(dim, dim)\n", " self.proj_drop = nn.Dropout(proj_drop)\n", "\n", " trunc_normal_(self.relative_position_bias_table, std=.02)\n", " self.softmax = nn.Softmax(dim=-1)\n", "\n", " def forward(self, x, mask=None):\n", " \"\"\"\n", " Args:\n", " x: input features with shape of (num_windows*B, N, C)\n", " mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None\n", " \"\"\"\n", " B_, N, C = x.shape\n", " qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n", " q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)\n", "\n", " q = q * self.scale\n", " attn = (q @ k.transpose(-2, -1))\n", "\n", " relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(\n", " self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH\n", " relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww\n", " attn = attn + relative_position_bias.unsqueeze(0)\n", "\n", " if mask is not None:\n", " nW = mask.shape[0]\n", " attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)\n", " attn = attn.view(-1, self.num_heads, N, N)\n", " attn = self.softmax(attn)\n", " else:\n", " attn = self.softmax(attn)\n", "\n", " attn = self.attn_drop(attn)\n", "\n", " x = (attn @ v).transpose(1, 2).reshape(B_, N, C)\n", " x = self.proj(x)\n", " x = self.proj_drop(x)\n", " return x, attn\n", "\n", " def extra_repr(self):\n", " return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'\n", "\n", "\n", "# We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model\n", "class SwinTransformerBlock(nn.Module):\n", " r\"\"\" Swin Transformer Block.\n", " Args:\n", " dim (int): Number of input channels.\n", " input_resolution (tuple[int]): Input resulotion.\n", " num_heads (int): Number of attention heads.\n", " window_size (int): Window size.\n", " shift_size (int): Shift size for SW-MSA.\n", " mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n", " qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n", " qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n", " drop (float, optional): Dropout rate. Default: 0.0\n", " attn_drop (float, optional): Attention dropout rate. Default: 0.0\n", " drop_path (float, optional): Stochastic depth rate. Default: 0.0\n", " act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\n", " norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n", " \"\"\"\n", "\n", " def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,\n", " mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,\n", " act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_before_mlp='ln'):\n", " super().__init__()\n", " self.dim = dim\n", " self.input_resolution = input_resolution\n", " self.num_heads = num_heads\n", " self.window_size = window_size\n", " self.shift_size = shift_size\n", " self.mlp_ratio = mlp_ratio\n", " self.norm_before_mlp = norm_before_mlp\n", " if min(self.input_resolution) <= self.window_size:\n", " # if window size is larger than input resolution, we don't partition windows\n", " self.shift_size = 0\n", " self.window_size = min(self.input_resolution)\n", " assert 0 <= self.shift_size < self.window_size, \"shift_size must in 0-window_size\"\n", "\n", " self.norm1 = norm_layer(dim)\n", " self.attn = WindowAttention(\n", " dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,\n", " qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\n", "\n", " self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n", " if self.norm_before_mlp == 'ln':\n", " self.norm2 = nn.LayerNorm(dim)\n", " elif self.norm_before_mlp == 'bn':\n", " self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose(1, 2)\n", " else:\n", " raise NotImplementedError\n", " mlp_hidden_dim = int(dim * mlp_ratio)\n", " self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n", "\n", " if self.shift_size > 0:\n", " # calculate attention mask for SW-MSA\n", " H, W = self.input_resolution\n", " img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1\n", " h_slices = (slice(0, -self.window_size),\n", " slice(-self.window_size, -self.shift_size),\n", " slice(-self.shift_size, None))\n", " w_slices = (slice(0, -self.window_size),\n", " slice(-self.window_size, -self.shift_size),\n", " slice(-self.shift_size, None))\n", " cnt = 0\n", " for h in h_slices:\n", " for w in w_slices:\n", " img_mask[:, h, w, :] = cnt\n", " cnt += 1\n", "\n", " mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1\n", " mask_windows = mask_windows.view(-1, self.window_size * self.window_size)\n", " attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)\n", " attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))\n", " else:\n", " attn_mask = None\n", "\n", " self.register_buffer(\"attn_mask\", attn_mask)\n", "\n", " def forward(self, x):\n", " # pdb.set_trace()\n", " H, W = self.input_resolution\n", " # print(\"H: \", H)\n", " # print(\"W: \", W)\n", " # pdb.set_trace()\n", " B, L, C = x.shape\n", " # assert L == H * W, \"input feature has wrong size\"\n", "\n", " shortcut = x\n", " x = self.norm1(x)\n", " x = x.view(B, H, W, C)\n", "\n", " # cyclic shift\n", " if self.shift_size > 0:\n", " shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))\n", " else:\n", " shifted_x = x\n", "\n", " # partition windows\n", " x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C\n", " x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C\n", "\n", " # W-MSA/SW-MSA\n", " attn_windows, attn = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C\n", "\n", " # merge windows\n", " attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)\n", " shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C\n", "\n", " # reverse cyclic shift\n", " if self.shift_size > 0:\n", " x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))\n", " else:\n", " x = shifted_x\n", " x = x.view(B, H * W, C)\n", "\n", " # FFN\n", " x = shortcut + self.drop_path(x)\n", " x = x + self.drop_path(self.mlp(self.norm2(x)))\n", "\n", " return x, attn\n", "\n", " def extra_repr(self):\n", " return f\"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, \" \\\n", " f\"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}\"\n", "\n", "\n", "\n", "class PatchMerging(nn.Module):\n", " r\"\"\" Patch Merging Layer.\n", " Args:\n", " input_resolution (tuple[int]): Resolution of input feature.\n", " dim (int): Number of input channels.\n", " norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n", " \"\"\"\n", "\n", " def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):\n", " super().__init__()\n", " self.input_resolution = input_resolution\n", " self.dim = dim\n", " self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)\n", " self.norm = norm_layer(4 * dim)\n", "\n", " def forward(self, x):\n", " \"\"\"\n", " x: B, H*W, C\n", " \"\"\"\n", " H, W = self.input_resolution\n", " B, L, C = x.shape\n", " assert L == H * W, \"input feature has wrong size\"\n", " assert H % 2 == 0 and W % 2 == 0, f\"x size ({H}*{W}) are not even.\"\n", "\n", " x = x.view(B, H, W, C)\n", "\n", " x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C\n", " x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C\n", " x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C\n", " x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C\n", " x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C\n", " x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C\n", "\n", " x = self.norm(x)\n", " x = self.reduction(x)\n", "\n", " return x\n", "\n", " def extra_repr(self):\n", " return f\"input_resolution={self.input_resolution}, dim={self.dim}\"\n", "\n", "\n", "class BasicLayer(nn.Module):\n", " \"\"\" A basic Swin Transformer layer for one stage.\n", " Args:\n", " dim (int): Number of input channels.\n", " input_resolution (tuple[int]): Input resolution.\n", " depth (int): Number of blocks.\n", " num_heads (int): Number of attention heads.\n", " window_size (int): Local window size.\n", " mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n", " qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n", " qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n", " drop (float, optional): Dropout rate. Default: 0.0\n", " attn_drop (float, optional): Attention dropout rate. Default: 0.0\n", " drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n", " norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n", " downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n", " use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.\n", " \"\"\"\n", "\n", " def __init__(self, dim, input_resolution, depth, num_heads, window_size,\n", " mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,\n", " drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,\n", " norm_before_mlp='ln'):\n", "\n", " super().__init__()\n", " self.dim = dim\n", " self.input_resolution = input_resolution\n", " self.depth = depth\n", " self.use_checkpoint = use_checkpoint\n", "\n", " # build blocks\n", " self.blocks = nn.ModuleList([\n", " SwinTransformerBlock(dim=dim, input_resolution=input_resolution,\n", " num_heads=num_heads, window_size=window_size,\n", " shift_size=0 if (i % 2 == 0) else window_size // 2,\n", " mlp_ratio=mlp_ratio,\n", " qkv_bias=qkv_bias, qk_scale=qk_scale,\n", " drop=drop, attn_drop=attn_drop,\n", " drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,\n", " norm_layer=norm_layer, norm_before_mlp=norm_before_mlp)\n", " for i in range(depth)])\n", "\n", " # patch merging layer\n", " if downsample is not None:\n", " self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)\n", " else:\n", " self.downsample = None\n", "\n", " def forward(self, x):\n", " attns = []\n", " for blk in self.blocks:\n", " if self.use_checkpoint:\n", " x = checkpoint.checkpoint(blk, x)\n", " else:\n", " x, attn = blk(x)\n", " if not self.training:\n", " attns.append(attn.unsqueeze(0))\n", " if self.downsample is not None:\n", " x = self.downsample(x)\n", " if not self.training:\n", " attn = torch.cat(attns, dim = 0)\n", " attn = torch.mean(attn, dim = 0)\n", " return x, attn\n", "\n", " def extra_repr(self):\n", " return f\"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}\"\n", "\n", "\n", "# The Core of HTSAT\n", "class HTSAT_Swin_Transformer(nn.Module):\n", " r\"\"\"HTSAT based on the Swin Transformer\n", " Args:\n", " spec_size (int | tuple(int)): Input Spectrogram size. Default 256\n", " patch_size (int | tuple(int)): Patch size. Default: 4\n", " path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4\n", " in_chans (int): Number of input image channels. Default: 1 (mono)\n", " num_classes (int): Number of classes for classification head. Default: 527\n", " embed_dim (int): Patch embedding dimension. Default: 96\n", " depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer.\n", " num_heads (tuple(int)): Number of attention heads in different layers.\n", " window_size (int): Window size. Default: 8\n", " mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4\n", " qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True\n", " qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None\n", " drop_rate (float): Dropout rate. Default: 0\n", " attn_drop_rate (float): Attention dropout rate. Default: 0\n", " drop_path_rate (float): Stochastic depth rate. Default: 0.1\n", " norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.\n", " ape (bool): If True, add absolute position embedding to the patch embedding. Default: False\n", " patch_norm (bool): If True, add normalization after patch embedding. Default: True\n", " use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False\n", " config (module): The configuration Module from config.py (HTSATConfig Class)\n", " \"\"\"\n", "\n", " def __init__(self, spec_size=256, patch_size=4, patch_stride=(4,4), \n", " in_chans=1, num_classes=527,\n", " embed_dim=96, depths=[2, 2, 6, 2], num_heads=[4, 8, 16, 32],\n", " window_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None,\n", " drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,\n", " norm_layer=nn.LayerNorm, \n", " ape=False, patch_norm=True,\n", " use_checkpoint=False, norm_before_mlp='ln', config = None, **kwargs):\n", " super(HTSAT_Swin_Transformer, self).__init__()\n", "\n", " self.config = config\n", " self.spec_size = spec_size \n", " self.patch_stride = patch_stride\n", " self.patch_size = patch_size\n", " self.window_size = window_size\n", " self.embed_dim = embed_dim\n", " self.depths = depths\n", " self.ape = ape\n", " self.in_chans = in_chans\n", " self.num_classes = num_classes\n", " self.num_heads = num_heads\n", " self.num_layers = len(self.depths)\n", " self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1))\n", " \n", " self.drop_rate = drop_rate\n", " self.attn_drop_rate = attn_drop_rate\n", " self.drop_path_rate = drop_path_rate\n", "\n", " self.qkv_bias = qkv_bias\n", " self.qk_scale = None\n", "\n", " self.patch_norm = patch_norm\n", " self.norm_layer = norm_layer if self.patch_norm else None\n", " self.norm_before_mlp = norm_before_mlp\n", " self.mlp_ratio = mlp_ratio\n", "\n", " self.use_checkpoint = use_checkpoint\n", "\n", " # process mel-spec ; used only once\n", " self.freq_ratio = self.spec_size // self.config.mel_bins\n", " window = 'hann'\n", " center = True\n", " pad_mode = 'reflect'\n", " ref = 1.0\n", " amin = 1e-10\n", " top_db = None\n", " self.interpolate_ratio = 32 # Downsampled ratio\n", " # Spectrogram extractor\n", " self.spectrogram_extractor = Spectrogram(n_fft=config.window_size, hop_length=config.hop_size, \n", " win_length=config.window_size, window=window, center=center, pad_mode=pad_mode, \n", " freeze_parameters=True)\n", " # Logmel feature extractor\n", " self.logmel_extractor = LogmelFilterBank(sr=config.sample_rate, n_fft=config.window_size, \n", " n_mels=config.mel_bins, fmin=config.fmin, fmax=config.fmax, ref=ref, amin=amin, top_db=top_db, \n", " freeze_parameters=True)\n", " # Spec augmenter\n", " self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, \n", " freq_drop_width=8, freq_stripes_num=2) # 2 2\n", " self.bn0 = nn.BatchNorm2d(self.config.mel_bins)\n", "\n", "\n", " # split spctrogram into non-overlapping patches\n", " self.patch_embed = PatchEmbed(\n", " img_size=self.spec_size, patch_size=self.patch_size, in_chans=self.in_chans, \n", " embed_dim=self.embed_dim, norm_layer=self.norm_layer, patch_stride = patch_stride)\n", "\n", " num_patches = self.patch_embed.num_patches\n", " patches_resolution = self.patch_embed.grid_size\n", " self.patches_resolution = patches_resolution\n", "\n", " # absolute position embedding\n", " if self.ape:\n", " self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dim))\n", " trunc_normal_(self.absolute_pos_embed, std=.02)\n", "\n", " self.pos_drop = nn.Dropout(p=self.drop_rate)\n", "\n", " # stochastic depth\n", " dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths))] # stochastic depth decay rule\n", "\n", " # build layers\n", " self.layers = nn.ModuleList()\n", " for i_layer in range(self.num_layers):\n", " layer = BasicLayer(dim=int(self.embed_dim * 2 ** i_layer),\n", " input_resolution=(patches_resolution[0] // (2 ** i_layer),\n", " patches_resolution[1] // (2 ** i_layer)),\n", " depth=self.depths[i_layer],\n", " num_heads=self.num_heads[i_layer],\n", " window_size=self.window_size,\n", " mlp_ratio=self.mlp_ratio,\n", " qkv_bias=self.qkv_bias, qk_scale=self.qk_scale,\n", " drop=self.drop_rate, attn_drop=self.attn_drop_rate,\n", " drop_path=dpr[sum(self.depths[:i_layer]):sum(self.depths[:i_layer + 1])],\n", " norm_layer=self.norm_layer,\n", " downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,\n", " use_checkpoint=use_checkpoint,\n", " norm_before_mlp=self.norm_before_mlp)\n", " self.layers.append(layer)\n", "\n", " self.norm = self.norm_layer(self.num_features)\n", " self.avgpool = nn.AdaptiveAvgPool1d(1)\n", " self.maxpool = nn.AdaptiveMaxPool1d(1)\n", "\n", " if self.config.enable_tscam:\n", " SF = self.spec_size // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] // self.freq_ratio\n", " self.tscam_conv = nn.Conv2d(\n", " in_channels = self.num_features,\n", " out_channels = self.num_classes,\n", " kernel_size = (SF,3),\n", " padding = (0,1)\n", " )\n", " self.head = nn.Linear(num_classes, num_classes)\n", " else:\n", " self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()\n", "\n", " self.apply(self._init_weights)\n", "\n", " def _init_weights(self, m):\n", " if isinstance(m, nn.Linear):\n", " trunc_normal_(m.weight, std=.02)\n", " if isinstance(m, nn.Linear) and m.bias is not None:\n", " nn.init.constant_(m.bias, 0)\n", " elif isinstance(m, nn.LayerNorm):\n", " nn.init.constant_(m.bias, 0)\n", " nn.init.constant_(m.weight, 1.0)\n", "\n", " @torch.jit.ignore\n", " def no_weight_decay(self):\n", " return {'absolute_pos_embed'}\n", "\n", " @torch.jit.ignore\n", " def no_weight_decay_keywords(self):\n", " return {'relative_position_bias_table'}\n", "\n", " def forward_features(self, x):\n", " frames_num = x.shape[2] \n", " x = self.patch_embed(x)\n", " if self.ape:\n", " x = x + self.absolute_pos_embed\n", " x = self.pos_drop(x)\n", " for i, layer in enumerate(self.layers):\n", " x, attn = layer(x)\n", "\n", " if self.config.enable_tscam:\n", " # for x\n", " x = self.norm(x)\n", " B, N, C = x.shape\n", " SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0]\n", " ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1]\n", " x = x.permute(0,2,1).contiguous().reshape(B, C, SF, ST)\n", " B, C, F, T = x.shape\n", " # group 2D CNN\n", " c_freq_bin = F // self.freq_ratio\n", " x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T)\n", " x = x.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1)\n", "\n", " # get latent_output\n", " latent_output = self.avgpool(torch.flatten(x,2))\n", " latent_output = torch.flatten(latent_output, 1)\n", "\n", " # display the attention map, if needed\n", " if self.config.htsat_attn_heatmap:\n", " # for attn\n", " attn = torch.mean(attn, dim = 1)\n", " attn = torch.mean(attn, dim = 1)\n", " attn = attn.reshape(B, SF, ST)\n", " c_freq_bin = SF // self.freq_ratio\n", " attn = attn.reshape(B, SF // c_freq_bin, c_freq_bin, ST) \n", " attn = attn.permute(0,2,1,3).contiguous().reshape(B, c_freq_bin, -1)\n", " attn = attn.mean(dim = 1)\n", " attn_max = torch.max(attn, dim = 1, keepdim = True)[0]\n", " attn_min = torch.min(attn, dim = 1, keepdim = True)[0]\n", " attn = ((attn * 0.15) + (attn_max * 0.85 - attn_min)) / (attn_max - attn_min)\n", " attn = attn.unsqueeze(dim = 2)\n", "\n", " x = self.tscam_conv(x)\n", " x = torch.flatten(x, 2) # B, C, T\n", "\n", " if self.config.htsat_attn_heatmap:\n", " fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous() * attn, 8 * self.patch_stride[1]) \n", " else: \n", " fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous(), 8 * self.patch_stride[1]) \n", " \n", " x = self.avgpool(x)\n", " x = torch.flatten(x, 1)\n", "\n", " if self.config.loss_type == \"clip_ce\":\n", " output_dict = {\n", " 'framewise_output': fpx, # already sigmoided\n", " 'clipwise_output': x,\n", " 'latent_output': latent_output\n", " }\n", " else:\n", " output_dict = {\n", " 'framewise_output': fpx, # already sigmoided\n", " 'clipwise_output': torch.sigmoid(x),\n", " 'latent_output': latent_output\n", " }\n", " \n", " else:\n", " x = self.norm(x) # B N C\n", " B, N, C = x.shape\n", " \n", " fpx = x.permute(0,2,1).contiguous().reshape(B, C, frames_num // (2 ** (len(self.depths) + 1)), frames_num // (2 ** (len(self.depths) + 1)) )\n", " B, C, F, T = fpx.shape\n", " c_freq_bin = F // self.freq_ratio\n", " fpx = fpx.reshape(B, C, F // c_freq_bin, c_freq_bin, T)\n", " fpx = fpx.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1)\n", " fpx = torch.sum(fpx, dim = 2)\n", " fpx = interpolate(fpx.permute(0,2,1).contiguous(), 8 * self.patch_stride[1]) \n", " x = self.avgpool(x.transpose(1, 2)) # B C 1\n", " x = torch.flatten(x, 1)\n", " if self.num_classes > 0:\n", " x = self.head(x)\n", " fpx = self.head(fpx)\n", " output_dict = {'framewise_output': torch.sigmoid(fpx), \n", " 'clipwise_output': torch.sigmoid(x)}\n", " return output_dict\n", "\n", " def crop_wav(self, x, crop_size, spe_pos = None):\n", " time_steps = x.shape[2]\n", " tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device)\n", " for i in range(len(x)):\n", " if spe_pos is None:\n", " crop_pos = random.randint(0, time_steps - crop_size - 1)\n", " else:\n", " crop_pos = spe_pos\n", " tx[i][0] = x[i, 0, crop_pos:crop_pos + crop_size,:]\n", " return tx\n", "\n", " # Reshape the wavform to a img size, if you want to use the pretrained swin transformer model\n", " def reshape_wav2img(self, x):\n", " B, C, T, F = x.shape\n", " target_T = int(self.spec_size * self.freq_ratio)\n", " target_F = self.spec_size // self.freq_ratio\n", " assert T <= target_T and F <= target_F, \"the wav size should less than or equal to the swin input size\"\n", " # to avoid bicubic zero error\n", " if T < target_T:\n", " x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode=\"bicubic\", align_corners=True)\n", " if F < target_F:\n", " x = nn.functional.interpolate(x, (x.shape[2], target_F), mode=\"bicubic\", align_corners=True)\n", " x = x.permute(0,1,3,2).contiguous()\n", " x = x.reshape(x.shape[0], x.shape[1], x.shape[2], self.freq_ratio, x.shape[3] // self.freq_ratio)\n", " # print(x.shape)\n", " x = x.permute(0,1,3,2,4).contiguous()\n", " x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4])\n", " return x\n", " \n", " # Repeat the wavform to a img size, if you want to use the pretrained swin transformer model\n", " def repeat_wat2img(self, x, cur_pos):\n", " B, C, T, F = x.shape\n", " target_T = int(self.spec_size * self.freq_ratio)\n", " target_F = self.spec_size // self.freq_ratio\n", " assert T <= target_T and F <= target_F, \"the wav size should less than or equal to the swin input size\"\n", " # to avoid bicubic zero error\n", " if T < target_T:\n", " x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode=\"bicubic\", align_corners=True)\n", " if F < target_F:\n", " x = nn.functional.interpolate(x, (x.shape[2], target_F), mode=\"bicubic\", align_corners=True) \n", " x = x.permute(0,1,3,2).contiguous() # B C F T\n", " x = x[:,:,:,cur_pos:cur_pos + self.spec_size]\n", " x = x.repeat(repeats = (1,1,4,1))\n", " return x\n", "\n", " def forward(self, x: torch.Tensor, mixup_lambda = None, infer_mode = False):# out_feat_keys: List[str] = None):\n", " x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins)\n", " x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)\n", " \n", " \n", " x = x.transpose(1, 3)\n", " x = self.bn0(x)\n", " x = x.transpose(1, 3)\n", " if self.training:\n", " x = self.spec_augmenter(x)\n", " if self.training and mixup_lambda is not None:\n", " x = do_mixup(x, mixup_lambda)\n", " \n", " if infer_mode:\n", " # in infer mode. we need to handle different length audio input\n", " frame_num = x.shape[2]\n", " target_T = int(self.spec_size * self.freq_ratio)\n", " repeat_ratio = math.floor(target_T / frame_num)\n", " x = x.repeat(repeats=(1,1,repeat_ratio,1))\n", " x = self.reshape_wav2img(x)\n", " output_dict = self.forward_features(x)\n", " elif self.config.enable_repeat_mode:\n", " if self.training:\n", " cur_pos = random.randint(0, (self.freq_ratio - 1) * self.spec_size - 1)\n", " x = self.repeat_wat2img(x, cur_pos)\n", " output_dict = self.forward_features(x)\n", " else:\n", " output_dicts = []\n", " for cur_pos in range(0, (self.freq_ratio - 1) * self.spec_size + 1, self.spec_size):\n", " tx = x.clone()\n", " tx = self.repeat_wat2img(tx, cur_pos)\n", " output_dicts.append(self.forward_features(tx))\n", " clipwise_output = torch.zeros_like(output_dicts[0][\"clipwise_output\"]).float().to(x.device)\n", " framewise_output = torch.zeros_like(output_dicts[0][\"framewise_output\"]).float().to(x.device)\n", " for d in output_dicts:\n", " clipwise_output += d[\"clipwise_output\"]\n", " framewise_output += d[\"framewise_output\"]\n", " clipwise_output = clipwise_output / len(output_dicts)\n", " framewise_output = framewise_output / len(output_dicts)\n", "\n", " output_dict = {\n", " 'framewise_output': framewise_output, \n", " 'clipwise_output': clipwise_output\n", " }\n", " else:\n", " if x.shape[2] > self.freq_ratio * self.spec_size:\n", " if self.training:\n", " x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size)\n", " x = self.reshape_wav2img(x)\n", " output_dict = self.forward_features(x)\n", " else:\n", " # Change: Hard code here\n", " overlap_size = 344 #(x.shape[2] - 1) // 4\n", " output_dicts = []\n", " crop_size = 689 #(x.shape[2] - 1) // 2\n", " for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size):\n", " tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos)\n", " tx = self.reshape_wav2img(tx)\n", " output_dicts.append(self.forward_features(tx))\n", " clipwise_output = torch.zeros_like(output_dicts[0][\"clipwise_output\"]).float().to(x.device)\n", " framewise_output = torch.zeros_like(output_dicts[0][\"framewise_output\"]).float().to(x.device)\n", " latent_output = torch.zeros_like(output_dicts[0][\"latent_output\"]).float().to(x.device)\n", " for d in output_dicts:\n", " clipwise_output += d[\"clipwise_output\"]\n", " framewise_output += d[\"framewise_output\"]\n", " latent_output += d[\"latent_output\"]\n", " clipwise_output = clipwise_output / len(output_dicts)\n", " framewise_output = framewise_output / len(output_dicts)\n", " latent_output = latent_output / len(output_dicts)\n", " output_dict = {\n", " 'framewise_output': framewise_output, \n", " 'clipwise_output': clipwise_output,\n", " 'latent_output': latent_output,\n", " }\n", " else: # this part is typically used, and most easy one\n", " x = self.reshape_wav2img(x)\n", " output_dict = self.forward_features(x)\n", " # x = self.head(x)\n", " return output_dict\n", "\n", "class HTSATWrapper(nn.Module):\n", " def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, \n", " fmax, classes_num, out_emb):\n", " super().__init__()\n", "\n", " # print(\"parameters are being overidden when using HTSAT\")\n", " # print(\"HTSAT only support loading a pretrained model on AudioSet\")\n", " # @TODO later look at what parameters are same and can be merged\n", "\n", " self.htsat = HTSAT_Swin_Transformer(config=HTSATConfig())\n", "\n", " def forward(self, x):\n", " out_dict = self.htsat(x)\n", " out_dict['embedding'] = out_dict['latent_output']\n", " return out_dict\n", "\n", "\n", "def get_audio_encoder(name: str):\n", " if name == \"HTSAT\":\n", " return HTSATWrapper\n", " else:\n", " raise Exception('The audio encoder name {} is incorrect or not supported'.format(name))\n", "\n", "class Projection(nn.Module):\n", " def __init__(self, dim_imgn: int, d_out: int, p: float=0.5) -> None:\n", " super().__init__()\n", " self.linear1 = nn.Linear(dim_imgn, d_out, bias=False)\n", " self.linear2 = nn.Linear(d_out, d_out, bias=False)\n", " self.layer_norm = nn.LayerNorm(d_out)\n", " self.drop = nn.Dropout(p)\n", "\n", " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", " embed1 = self.linear1(x)\n", " embed2 = self.drop(self.linear2(F.gelu(embed1)))\n", " embeds = self.layer_norm(embed1 + embed2)\n", " return embeds\n", "\n", "class AudioEncoder(nn.Module):\n", " def __init__(self, audioenc_name:str, dim_imgn: int, d_out: int, sample_rate: int, window_size: int,\n", " hop_size: int, mel_bins: int, fmin: int, fmax: int, classes_num: int) -> None:\n", " super().__init__()\n", "\n", " audio_encoder = get_audio_encoder(audioenc_name)\n", "\n", " self.base = audio_encoder(\n", " sample_rate, window_size,\n", " hop_size, mel_bins, fmin, fmax,\n", " classes_num, dim_imgn)\n", "\n", " self.projection = Projection(dim_imgn, d_out)\n", "\n", " def forward(self, x):\n", " out_dict = self.base(x)\n", " audio_features, audio_classification_output = out_dict['embedding'], out_dict['clipwise_output']\n", " projected_vec = self.projection(audio_features)\n", " return projected_vec, audio_classification_output\n", "\n", "class TextEncoder(nn.Module):\n", " def __init__(self, d_out: int, text_model: str, transformer_embed_dim: int) -> None:\n", " super().__init__()\n", " self.text_model = text_model\n", " self.base = AutoModel.from_pretrained(text_model)\n", "\n", " if 'clip' in text_model:\n", " self.clip_text_projection = self.base.text_projection\n", " self.base = self.base.text_model\n", " if 'base' in text_model:\n", " transformer_embed_dim = 512\n", " \n", " self.projection = Projection(transformer_embed_dim, d_out)\n", "\n", " def forward(self, x):\n", " if 'clip' in self.text_model:\n", " pooled_output = self.base(**x)[1] # get pooled output\n", " out = self.clip_text_projection(pooled_output) # get CLS token output\n", " elif 'gpt' in self.text_model:\n", " batch_size = x['input_ids'].shape[0]\n", " hidden_states = self.base(**x)[0] # (batch_size=4, seq_len, 768)\n", "\n", " sequence_lengths = torch.ne(x['input_ids'], 0).sum(-1) - 1 # tensor([13, 14, 18, 17])\n", " out = hidden_states[torch.arange(batch_size, device=hidden_states.device), sequence_lengths] # [batch_size, 768] = [4, 768]\n", " else:\n", " out = self.base(**x)[0]\n", " out = out[:, 0, :] # get CLS token output\n", " \n", " projected_vec = self.projection(out)\n", "\n", " return projected_vec\n", "\n", "class CLAP(nn.Module):\n", " def __init__(self,\n", " # audio\n", " audioenc_name: str,\n", " sample_rate: int, \n", " window_size: int, \n", " hop_size: int, \n", " mel_bins: int, \n", " fmin: int, \n", " fmax: int, \n", " classes_num: int, \n", " out_emb: int,\n", " # text\n", " text_model: str,\n", " transformer_embed_dim: int,\n", " # common\n", " d_proj: int,\n", " ):\n", " super().__init__()\n", "\n", " \n", " self.audio_encoder = AudioEncoder(\n", " audioenc_name, out_emb, d_proj,\n", " sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num)\n", "\n", " self.caption_encoder = TextEncoder(\n", " d_proj, text_model, transformer_embed_dim\n", " )\n", "\n", " self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))\n", "\n", " def forward(self, audio, text):\n", " audio_embed, _ = self.audio_encoder(audio)\n", " caption_embed = self.caption_encoder(text)\n", "\n", " return caption_embed, audio_embed, self.logit_scale.exp()\n", " \n", " \n", " \n", "# ==================================================================\n", "# A U D I O - P R E - P R O C E S S I N G\n", "# ==================================================================\n", "def read_audio(audio_path, resample=True):\n", " r\"\"\"Loads audio file or array and returns a torch tensor\"\"\"\n", " # Randomly sample a segment of audio_duration from the clip or pad to match duration\n", " audio_time_series, sample_rate = torchaudio.load(audio_path)\n", "\n", " resample_rate = clapConfig.sample_rate\n", " if resample and resample_rate != sample_rate:\n", " resampler = T.Resample(sample_rate, resample_rate)\n", " audio_time_series = resampler(audio_time_series)\n", " return audio_time_series, resample_rate\n", "\n", "def load_audio_into_tensor(audio_path, audio_duration, resample=False):\n", " r\"\"\"Loads audio file and returns raw audio.\"\"\"\n", " # Randomly sample a segment of audio_duration from the clip or pad to match duration\n", " audio_time_series, sample_rate = read_audio(audio_path, resample)\n", " audio_time_series = audio_time_series.reshape(-1)\n", "\n", " # audio_time_series is shorter than predefined audio duration,\n", " # so audio_time_series is extended\n", " if audio_duration*sample_rate >= audio_time_series.shape[0]:\n", " repeat_factor = int(np.ceil((audio_duration*sample_rate) /\n", " audio_time_series.shape[0]))\n", " # Repeat audio_time_series by repeat_factor to match audio_duration\n", " audio_time_series = audio_time_series.repeat(repeat_factor)\n", " # remove excess part of audio_time_series\n", " audio_time_series = audio_time_series[0:audio_duration*sample_rate]\n", " else:\n", " # audio_time_series is longer than predefined audio duration,\n", " # so audio_time_series is trimmed\n", " start_index = random.randrange(\n", " audio_time_series.shape[0] - audio_duration*sample_rate)\n", " audio_time_series = audio_time_series[start_index:start_index +\n", " audio_duration*sample_rate]\n", " return torch.FloatTensor(audio_time_series)\n", "\n", "np_str_obj_array_pattern = re.compile(r'[SaUO]')\n", "default_collate_err_msg_format = (\n", " \"default_collate: batch must contain tensors, numpy arrays, numbers, \"\n", " \"dicts or lists; found {}\")\n", "\n", "def default_collate(batch):\n", " r\"\"\"Puts each data field into a tensor with outer dimension batch size\"\"\"\n", " elem = batch[0]\n", " elem_type = type(elem)\n", " if isinstance(elem, torch.Tensor):\n", " out = None\n", " if torch.utils.data.get_worker_info() is not None:\n", " # If we're in a background process, concatenate directly into a\n", " # shared memory tensor to avoid an extra copy\n", " numel = sum([x.numel() for x in batch])\n", " storage = elem.storage()._new_shared(numel)\n", " out = elem.new(storage)\n", " return torch.stack(batch, 0, out=out)\n", " elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \\\n", " and elem_type.__name__ != 'string_':\n", " if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':\n", " # array of string classes and object\n", " if np_str_obj_array_pattern.search(elem.dtype.str) is not None:\n", " raise TypeError(\n", " default_collate_err_msg_format.format(elem.dtype))\n", "\n", " return default_collate([torch.as_tensor(b) for b in batch])\n", " elif elem.shape == (): # scalars\n", " return torch.as_tensor(batch)\n", " elif isinstance(elem, float):\n", " return torch.tensor(batch, dtype=torch.float64)\n", " elif isinstance(elem, int):\n", " return torch.tensor(batch)\n", " elif isinstance(elem, str):\n", " return batch\n", " elif isinstance(elem, collections.abc.Mapping):\n", " return {key: default_collate([d[key] for d in batch]) for key in elem}\n", " elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple\n", " return elem_type(*(default_collate(samples) for samples in zip(*batch)))\n", " elif isinstance(elem, collections.abc.Sequence):\n", " # check to make sure that the elements in batch have consistent size\n", " it = iter(batch)\n", " elem_size = len(next(it))\n", " if not all(len(elem) == elem_size for elem in it):\n", " raise RuntimeError(\n", " 'each element in list of batch should be of equal size')\n", " transposed = zip(*batch)\n", " return [default_collate(samples) for samples in transposed]\n", "\n", " raise TypeError(default_collate_err_msg_format.format(elem_type))\n", "\n", "def preprocess_audio(audio_files, resample):\n", " r\"\"\"Load list of audio files and return raw audio\"\"\"\n", " audio_tensors = []\n", " for audio_file in audio_files:\n", " audio_tensor = load_audio_into_tensor(\n", " audio_file, clapConfig.duration, resample)\n", " audio_tensor = audio_tensor.reshape(1, -1)\n", " audio_tensors.append(audio_tensor)\n", " return default_collate(audio_tensors)\n", "\n", "\n", "\n", "# ==================================================================\n", "# A U D I O - E M B E D D I N G S - H E L P E R\n", "# ==================================================================\n", "def CLAPAudioProcessor(audio_files: List[str], resample=True):\n", " preprocessed_audio = preprocess_audio(audio_files, resample)\n", " preprocessed_audio = preprocessed_audio.reshape(\n", " preprocessed_audio.shape[0], preprocessed_audio.shape[2])\n", " preprocessed_audio = preprocessed_audio\n", " return preprocessed_audio\n", "\n", "def get_audio_embeddings(audio_files: List[str], audio_encoder, resample=True):\n", " \"\"\"Load list of audio files and return audio embeddings\"\"\"\n", " # preprocessed_audio = preprocess_audio(audio_files, resample)\n", " # with torch.no_grad():\n", " # preprocessed_audio = preprocessed_audio.reshape(\n", " # preprocessed_audio.shape[0], preprocessed_audio.shape[2])\n", " with torch.no_grad():\n", " preprocessed_audio = CLAPAudioProcessor(audio_files, resample)\n", " return audio_encoder(preprocessed_audio)[0]\n", "\n", "\n", "# ==================================================================\n", "# C L A P\n", "# ==================================================================\n", "class ClapConfig:\n", " # TEXT ENCODER CONFIG\n", " text_model = 'gpt2'\n", " text_len = 77\n", " transformer_embed_dim = 768\n", " freeze_text_encoder_weights = True\n", "\n", " # AUDIO ENCODER CONFIG\n", " audioenc_name = 'HTSAT'\n", " out_emb = 768\n", " sample_rate = 44100\n", " duration = 7\n", " fmin = 50\n", " fmax = 8000 # 14000\n", " n_fft = 1024 # 1028\n", " hop_size = 320\n", " mel_bins = 64\n", " window_size = 1024\n", "\n", " # PROJECTION SPACE CONFIG \n", " d_proj = 1024\n", " temperature = 0.003\n", "\n", " # TRAINING AND EVALUATION CONFIG\n", " num_classes = 527\n", " batch_size = 1024\n", " demo = False\n", " \n", "\n", "clapConfig = ClapConfig()\n", "clap = CLAP(\n", " audioenc_name=clapConfig.audioenc_name,\n", " sample_rate=clapConfig.sample_rate,\n", " window_size=clapConfig.window_size,\n", " hop_size=clapConfig.hop_size,\n", " mel_bins=clapConfig.mel_bins,\n", " fmin=clapConfig.fmin,\n", " fmax=clapConfig.fmax,\n", " classes_num=clapConfig.num_classes,\n", " out_emb=clapConfig.out_emb,\n", " text_model=clapConfig.text_model,\n", " transformer_embed_dim=clapConfig.transformer_embed_dim,\n", " d_proj=clapConfig.d_proj\n", " )\n", "\n", "model_repo = \"microsoft/msclap\"\n", "model_name = {\n", " '2022': 'CLAP_weights_2022.pth',\n", " '2023': 'CLAP_weights_2023.pth',\n", " 'clapcap': 'clapcap_weights_2023.pth'\n", "}\n", "\n", "version = '2023'\n", "model_fp = hf_hub_download(model_repo, model_name[version])\n", "\n", "model_state_dict = torch.load(model_fp, map_location=torch.device('cpu'))['model']\n", "clap.load_state_dict(model_state_dict, strict=False)\n", "# clap.eval()\n", "\n", "clap_audio_encoder = clap.audio_encoder.to(device)\n", "\n", "# ENGLISH_AUDIO_DIR = r\"/home/IITB/ai-at-ieor/23m1521/datasets/Vaani/Audios/English\"\n", "# audio_files = [os.path.join(ENGLISH_AUDIO_DIR, i) for i in os.listdir(ENGLISH_AUDIO_DIR) if i.endswith(\".wav\")]\n", "# audio_embedding = get_audio_embeddings(audio_files, clap_audio_encoder)\n", "# print(\"CLAP Audio Encoder Embeddings:\", audio_embedding.shape) # [5, 1024]\n", "\n", "\n", "# ==================================================================\n", "# C L A P - L o R A - M O D E L\n", "# ==================================================================\n", "LoRAconfig = {\n", " \"peft_type\": \"LORA\",\n", " \"task_type\": \"FEATURE_EXTRACTION\",\n", " \"inference_mode\": False,\n", " \"r\": 16,\n", " \"target_modules\": [\"qkv\", \"fc1\", \"fc2\", \"proj\", \"linear1\", \"linear2\"],\n", " \"lora_alpha\": 32,\n", " \"lora_dropout\": 0.05,\n", " \"fan_in_fan_out\": False,\n", " \"bias\": \"all\",\n", "}\n", "peft_config = get_peft_config(LoRAconfig)\n", "\n", "peft_model = get_peft_model(clap_audio_encoder, peft_config)\n", "\n", "peft_model.print_trainable_parameters()\n", "\n", "peft_clap_audio_encoder = peft_model.base_model\n", "# audio_embedding = get_audio_embeddings(audio_files, peft_clap_audio_encoder)\n", "# print(\"CLAP LoRA Audio Encoder Embeddings:\", audio_embedding.shape) # [5, 1024]\n", "\n", "\n", "\n", "# ==================================================================\n", "# O P E N - C L I P - M O D E L\n", "# ==================================================================\n", "# ==================================================================\n", "# I M P O R T S\n", "# ==================================================================\n", "\n", "\n", "import os\n", "import io\n", "import sys\n", "import math\n", "import random\n", "import collections\n", "import collections.abc\n", "import re\n", "from itertools import repeat\n", "from pathlib import Path\n", "from typing import Optional, Tuple, Union, List, Dict\n", "\n", "import csv\n", "import numpy as np\n", "import pandas as pd\n", "from PIL import Image\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "from tqdm import trange, tqdm\n", "\n", "import torch\n", "import torch.nn.functional as F\n", "from torch import nn\n", "from torch.nn.init import _calculate_fan_in_and_fan_out\n", "import torch.utils.checkpoint as checkpoint\n", "\n", "import torchvision\n", "from torchvision.transforms import v2\n", "\n", "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n", "# device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "# print(f\"Using device: {device}\")\n", "\n", "import torchaudio\n", "import torchaudio.transforms as T\n", "from torchlibrosa.stft import Spectrogram, LogmelFilterBank\n", "from torchlibrosa.augmentation import SpecAugmentation\n", "\n", "from transformers import AutoModel, AutoTokenizer, logging\n", "from huggingface_hub.file_download import hf_hub_download\n", "from huggingface_hub.file_download import hf_hub_download\n", "from peft import get_peft_config, get_peft_model\n", "\n", "from typing import Any, Dict, Optional, Tuple, Union\n", "import numbers\n", "import random\n", "import warnings\n", "from dataclasses import dataclass, asdict\n", "from typing import Any, Dict, List, Optional, Sequence, Tuple, Union\n", "\n", "import torch\n", "import torchvision.transforms.functional as F\n", "from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \\\n", " CenterCrop, ColorJitter, Grayscale\n", "\n", "OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)\n", "OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)\n", "IMAGENET_MEAN = (0.485, 0.456, 0.406)\n", "IMAGENET_STD = (0.229, 0.224, 0.225)\n", "INCEPTION_MEAN = (0.5, 0.5, 0.5)\n", "INCEPTION_STD = (0.5, 0.5, 0.5)\n", "\n", "# Default name for a weights file hosted on the Huggingface Hub.\n", "HF_WEIGHTS_NAME = \"open_clip_pytorch_model.bin\" # default pytorch pkl\n", "HF_SAFE_WEIGHTS_NAME = \"open_clip_model.safetensors\" # safetensors version\n", "HF_CONFIG_NAME = 'open_clip_config.json'\n", "\n", "\n", "import collections.abc\n", "from itertools import repeat\n", "from typing import List, Optional, Tuple, Union\n", "\n", "import torch\n", "from torch import nn as nn\n", "from torch import _assert\n", "from torchvision.ops.misc import FrozenBatchNorm2d\n", "\n", "\n", "def freeze_batch_norm_2d(module, module_match={}, name=''):\n", " \"\"\"\n", " Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is\n", " itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and\n", " returned. Otherwise, the module is walked recursively and submodules are converted in place.\n", "\n", " Args:\n", " module (torch.nn.Module): Any PyTorch module.\n", " module_match (dict): Dictionary of full module names to freeze (all if empty)\n", " name (str): Full module name (prefix)\n", "\n", " Returns:\n", " torch.nn.Module: Resulting module\n", "\n", " Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762\n", " \"\"\"\n", " res = module\n", " is_match = True\n", " if module_match:\n", " is_match = name in module_match\n", " if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):\n", " res = FrozenBatchNorm2d(module.num_features)\n", " res.num_features = module.num_features\n", " res.affine = module.affine\n", " if module.affine:\n", " res.weight.data = module.weight.data.clone().detach()\n", " res.bias.data = module.bias.data.clone().detach()\n", " res.running_mean.data = module.running_mean.data\n", " res.running_var.data = module.running_var.data\n", " res.eps = module.eps\n", " else:\n", " for child_name, child in module.named_children():\n", " full_child_name = '.'.join([name, child_name]) if name else child_name\n", " new_child = freeze_batch_norm_2d(child, module_match, full_child_name)\n", " if new_child is not child:\n", " res.add_module(child_name, new_child)\n", " return res\n", "\n", "\n", "# From PyTorch internals\n", "def _ntuple(n):\n", " def parse(x):\n", " if isinstance(x, collections.abc.Iterable):\n", " return x\n", " return tuple(repeat(x, n))\n", " return parse\n", "\n", "\n", "to_1tuple = _ntuple(1)\n", "to_2tuple = _ntuple(2)\n", "to_3tuple = _ntuple(3)\n", "to_4tuple = _ntuple(4)\n", "to_ntuple = lambda n, x: _ntuple(n)(x)\n", "\n", "# Replaces all linear layers with linear_replacement\n", "# TODO: add int8 support for other linear layers including attn and convnets\n", "def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True):\n", " for name, module in model.named_children():\n", " if len(list(module.children())) > 0:\n", " replace_linear(module, linear_replacement, include_modules, copy_weights)\n", "\n", " if isinstance(module, torch.nn.Linear) and name in include_modules:\n", " old_module = model._modules[name]\n", " model._modules[name] = linear_replacement(\n", " module.in_features,\n", " module.out_features,\n", " module.bias is not None,\n", " )\n", " if copy_weights:\n", " model._modules[name].weight.data.copy_(old_module.weight.data)\n", " if model._modules[name].bias is not None:\n", " model._modules[name].bias.data.copy_(old_module.bias)\n", "\n", " return model\n", "\n", "def convert_int8_model_to_inference_mode(model):\n", " for m in model.modules():\n", " if hasattr(m, 'prepare_for_eval'):\n", " int8_original_dtype = m.weight.dtype\n", " m.prepare_for_eval()\n", " m.int8_original_dtype = int8_original_dtype\n", "\n", "\n", "def feature_take_indices(\n", " num_features: int,\n", " indices: Optional[Union[int, List[int]]] = None,\n", " as_set: bool = False,\n", ") -> Tuple[List[int], int]:\n", " \"\"\" Determine the absolute feature indices to 'take' from.\n", "\n", " Note: This function can be called in forward() so must be torchscript compatible,\n", " which requires some incomplete typing and workaround hacks.\n", "\n", " Args:\n", " num_features: total number of features to select from\n", " indices: indices to select,\n", " None -> select all\n", " int -> select last n\n", " list/tuple of int -> return specified (-ve indices specify from end)\n", " as_set: return as a set\n", "\n", " Returns:\n", " List (or set) of absolute (from beginning) indices, Maximum index\n", " \"\"\"\n", " if indices is None:\n", " indices = num_features # all features if None\n", "\n", " if isinstance(indices, int):\n", " # convert int -> last n indices\n", " _assert(0 < indices <= num_features, f'last-n ({indices}) is out of range (1 to {num_features})')\n", " take_indices = [num_features - indices + i for i in range(indices)]\n", " else:\n", " take_indices: List[int] = []\n", " for i in indices:\n", " idx = num_features + i if i < 0 else i\n", " _assert(0 <= idx < num_features, f'feature index {idx} is out of range (0 to {num_features - 1})')\n", " take_indices.append(idx)\n", "\n", " if not torch.jit.is_scripting() and as_set:\n", " return set(take_indices), max(take_indices)\n", "\n", " return take_indices, max(take_indices)\n", "\n", "\n", "def _out_indices_as_tuple(x: Union[int, Tuple[int, ...]]) -> Tuple[int, ...]:\n", " if isinstance(x, int):\n", " # if indices is an int, take last N features\n", " return tuple(range(-x, 0))\n", " return tuple(x)\n", "\n", "\n", "\n", "import copy\n", "import copy\n", "import hashlib\n", "import os\n", "import urllib\n", "import warnings\n", "from functools import partial\n", "from typing import Dict, Iterable, Optional, Union\n", "\n", "from tqdm import tqdm\n", "\n", "\n", "try:\n", " import safetensors.torch\n", " _has_safetensors = True\n", "except ImportError:\n", " _has_safetensors = False\n", "\n", "__version__ = '2.32.0'\n", "\n", "\n", "\"\"\" CLIP Model\n", "\n", "Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.\n", "\"\"\"\n", "import copy\n", "import logging\n", "import math\n", "from dataclasses import dataclass\n", "from typing import Any, Dict, List, Optional, Tuple, Union\n", "\n", "import numpy as np\n", "import torch\n", "import torch.nn.functional as F\n", "from torch import nn\n", "from torch.utils.checkpoint import checkpoint\n", "from functools import partial\n", "\n", "# from .hf_model import HFTextEncoder\n", "# from .modified_resnet import ModifiedResNet\n", "from collections import OrderedDict\n", "import math\n", "from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union\n", "\n", "import torch\n", "from torch import nn\n", "from torch.nn import functional as F\n", "from torch.utils.checkpoint import checkpoint\n", "\n", "# from .utils import to_2tuple, feature_take_indices\n", "# from .pos_embed import get_2d_sincos_pos_embed\n", "# Copyright (c) Meta Platforms, Inc. and affiliates.\n", "# All rights reserved.\n", "\n", "# This source code is licensed under the license found in the\n", "# LICENSE file in the root directory of this source tree.\n", "# --------------------------------------------------------\n", "# Position embedding utils\n", "# --------------------------------------------------------\n", "\n", "import numpy as np\n", "\n", "import torch\n", "\n", "# --------------------------------------------------------\n", "# 2D sine-cosine position embedding\n", "# References:\n", "# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py\n", "# MoCo v3: https://github.com/facebookresearch/moco-v3\n", "# --------------------------------------------------------\n", "def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):\n", " \"\"\"\n", " grid_size: int of the grid height and width\n", " return:\n", " pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)\n", " \"\"\"\n", " grid_h = np.arange(grid_size, dtype=np.float32)\n", " grid_w = np.arange(grid_size, dtype=np.float32)\n", " grid = np.meshgrid(grid_w, grid_h) # here w goes first\n", " grid = np.stack(grid, axis=0)\n", "\n", " grid = grid.reshape([2, 1, grid_size, grid_size])\n", " pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)\n", " if cls_token:\n", " pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)\n", " return pos_embed\n", "\n", "\n", "def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):\n", " assert embed_dim % 2 == 0\n", "\n", " # use half of dimensions to encode grid_h\n", " emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)\n", " emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)\n", "\n", " emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)\n", " return emb\n", "\n", "\n", "def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):\n", " \"\"\"\n", " embed_dim: output dimension for each position\n", " pos: a list of positions to be encoded: size (M,)\n", " out: (M, D)\n", " \"\"\"\n", " assert embed_dim % 2 == 0\n", " omega = np.arange(embed_dim // 2, dtype=float)\n", " omega /= embed_dim / 2.\n", " omega = 1. / 10000**omega # (D/2,)\n", "\n", " pos = pos.reshape(-1) # (M,)\n", " out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product\n", "\n", " emb_sin = np.sin(out) # (M, D/2)\n", " emb_cos = np.cos(out) # (M, D/2)\n", "\n", " emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)\n", " return emb\n", "\n", "\n", "# --------------------------------------------------------\n", "# Interpolate position embeddings for high-resolution\n", "# References:\n", "# DeiT: https://github.com/facebookresearch/deit\n", "# --------------------------------------------------------\n", "def interpolate_pos_embed(model, checkpoint_model):\n", " if 'pos_embed' in checkpoint_model:\n", " pos_embed_checkpoint = checkpoint_model['pos_embed']\n", " embedding_size = pos_embed_checkpoint.shape[-1]\n", " num_patches = model.patch_embed.num_patches\n", " num_extra_tokens = model.pos_embed.shape[-2] - num_patches\n", " # height (== width) for the checkpoint position embedding\n", " orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)\n", " # height (== width) for the new position embedding\n", " new_size = int(num_patches ** 0.5)\n", " # class_token and dist_token are kept unchanged\n", " if orig_size != new_size:\n", " print(\"Position interpolate from %dx%d to %dx%d\" % (orig_size, orig_size, new_size, new_size))\n", " extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]\n", " # only the position tokens are interpolated\n", " pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]\n", " pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)\n", " pos_tokens = torch.nn.functional.interpolate(\n", " pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)\n", " pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)\n", " new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)\n", " checkpoint_model['pos_embed'] = new_pos_embed\n", "\n", "\n", "\n", "from collections import OrderedDict\n", "from typing import Dict, List, Optional, Union\n", "\n", "import torch\n", "from torch import nn\n", "from torch.nn import functional as F\n", "\n", "# from .utils import freeze_batch_norm_2d, feature_take_indices\n", "\n", "\n", "class Bottleneck(nn.Module):\n", " expansion = 4\n", "\n", " def __init__(self, inplanes, planes, stride=1):\n", " super().__init__()\n", "\n", " # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1\n", " self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)\n", " self.bn1 = nn.BatchNorm2d(planes)\n", " self.act1 = nn.ReLU(inplace=True)\n", "\n", " self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)\n", " self.bn2 = nn.BatchNorm2d(planes)\n", " self.act2 = nn.ReLU(inplace=True)\n", "\n", " self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()\n", "\n", " self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)\n", " self.bn3 = nn.BatchNorm2d(planes * self.expansion)\n", " self.act3 = nn.ReLU(inplace=True)\n", "\n", " self.downsample = None\n", " self.stride = stride\n", "\n", " if stride > 1 or inplanes != planes * Bottleneck.expansion:\n", " # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1\n", " self.downsample = nn.Sequential(OrderedDict([\n", " (\"-1\", nn.AvgPool2d(stride)),\n", " (\"0\", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),\n", " (\"1\", nn.BatchNorm2d(planes * self.expansion))\n", " ]))\n", "\n", " def forward(self, x: torch.Tensor):\n", " identity = x\n", "\n", " out = self.act1(self.bn1(self.conv1(x)))\n", " out = self.act2(self.bn2(self.conv2(out)))\n", " out = self.avgpool(out)\n", " out = self.bn3(self.conv3(out))\n", "\n", " if self.downsample is not None:\n", " identity = self.downsample(x)\n", "\n", " out += identity\n", " out = self.act3(out)\n", " return out\n", "\n", "\n", "class AttentionPool2d(nn.Module):\n", " def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):\n", " super().__init__()\n", " self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)\n", " self.k_proj = nn.Linear(embed_dim, embed_dim)\n", " self.q_proj = nn.Linear(embed_dim, embed_dim)\n", " self.v_proj = nn.Linear(embed_dim, embed_dim)\n", " self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)\n", " self.num_heads = num_heads\n", "\n", " def forward(self, x):\n", " x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC\n", " x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC\n", " x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC\n", " x, _ = F.multi_head_attention_forward(\n", " query=x, key=x, value=x,\n", " embed_dim_to_check=x.shape[-1],\n", " num_heads=self.num_heads,\n", " q_proj_weight=self.q_proj.weight,\n", " k_proj_weight=self.k_proj.weight,\n", " v_proj_weight=self.v_proj.weight,\n", " in_proj_weight=None,\n", " in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),\n", " bias_k=None,\n", " bias_v=None,\n", " add_zero_attn=False,\n", " dropout_p=0.,\n", " out_proj_weight=self.c_proj.weight,\n", " out_proj_bias=self.c_proj.bias,\n", " use_separate_proj_weight=True,\n", " training=self.training,\n", " need_weights=False\n", " )\n", "\n", " return x[0]\n", "\n", "\n", "class ModifiedResNet(nn.Module):\n", " \"\"\"\n", " A ResNet class that is similar to torchvision's but contains the following changes:\n", " - There are now 3 \"stem\" convolutions as opposed to 1, with an average pool instead of a max pool.\n", " - Performs antialiasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1\n", " - The final pooling layer is a QKV attention instead of an average pool\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " layers: List[int],\n", " output_dim: int,\n", " heads: int,\n", " image_size: int = 224,\n", " width: int = 64,\n", " ):\n", " super().__init__()\n", " self.output_dim = output_dim\n", " self.image_size = image_size\n", "\n", " # the 3-layer stem\n", " self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)\n", " self.bn1 = nn.BatchNorm2d(width // 2)\n", " self.act1 = nn.ReLU(inplace=True)\n", " self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)\n", " self.bn2 = nn.BatchNorm2d(width // 2)\n", " self.act2 = nn.ReLU(inplace=True)\n", " self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)\n", " self.bn3 = nn.BatchNorm2d(width)\n", " self.act3 = nn.ReLU(inplace=True)\n", " self.avgpool = nn.AvgPool2d(2)\n", "\n", " # residual layers\n", " self._inplanes = width # this is a *mutable* variable used during construction\n", " self.layer1 = self._make_layer(width, layers[0])\n", " self.layer2 = self._make_layer(width * 2, layers[1], stride=2)\n", " self.layer3 = self._make_layer(width * 4, layers[2], stride=2)\n", " self.layer4 = self._make_layer(width * 8, layers[3], stride=2)\n", "\n", " embed_dim = width * 32 # the ResNet feature dimension\n", " self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)\n", "\n", " self.init_parameters()\n", "\n", " def _make_layer(self, planes, blocks, stride=1):\n", " layers = [Bottleneck(self._inplanes, planes, stride)]\n", "\n", " self._inplanes = planes * Bottleneck.expansion\n", " for _ in range(1, blocks):\n", " layers.append(Bottleneck(self._inplanes, planes))\n", "\n", " return nn.Sequential(*layers)\n", "\n", " def init_parameters(self):\n", " if self.attnpool is not None:\n", " std = self.attnpool.c_proj.in_features ** -0.5\n", " nn.init.normal_(self.attnpool.q_proj.weight, std=std)\n", " nn.init.normal_(self.attnpool.k_proj.weight, std=std)\n", " nn.init.normal_(self.attnpool.v_proj.weight, std=std)\n", " nn.init.normal_(self.attnpool.c_proj.weight, std=std)\n", "\n", " for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:\n", " for name, param in resnet_block.named_parameters():\n", " if name.endswith(\"bn3.weight\"):\n", " nn.init.zeros_(param)\n", "\n", " def lock(self, unlocked_groups=0, freeze_bn_stats=False):\n", " assert unlocked_groups == 0, 'partial locking not currently supported for this model'\n", " for param in self.parameters():\n", " param.requires_grad = False\n", " if freeze_bn_stats:\n", " freeze_batch_norm_2d(self)\n", "\n", " @torch.jit.ignore\n", " def set_grad_checkpointing(self, enable=True):\n", " # FIXME support for non-transformer\n", " pass\n", "\n", " def stem(self, x):\n", " x = self.act1(self.bn1(self.conv1(x)))\n", " x = self.act2(self.bn2(self.conv2(x)))\n", " x = self.act3(self.bn3(self.conv3(x)))\n", " x = self.avgpool(x)\n", " return x\n", "\n", " def forward_intermediates(\n", " self,\n", " x: torch.Tensor,\n", " indices: Optional[Union[int, List[int]]] = None,\n", " stop_early: bool = False,\n", " normalize_intermediates: bool = False,\n", " intermediates_only: bool = False,\n", " output_fmt: str = 'NCHW',\n", " output_extra_tokens: bool = False,\n", " ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:\n", " \"\"\" Forward features that returns intermediates.\n", "\n", " Args:\n", " x: Input image tensor\n", " indices: Take last n blocks if int, all if None, select matching indices if sequence\n", " stop_early: Stop iterating over blocks when last desired intermediate hit\n", " normalize_intermediates: Apply final norm layer to all intermediates\n", " intermediates_only: Only return intermediate features\n", " output_fmt: Shape of intermediate feature outputs\n", " output_extra_tokens: Return both extra class, eot tokens\n", " Returns:\n", "\n", " \"\"\"\n", " assert output_fmt in ('NCHW',), 'Output format must be == NCHW.'\n", " # NOTE normalize_intermediates and return_extra_tokens don't apply\n", " take_indices, max_index = feature_take_indices(5, indices)\n", "\n", " output = {}\n", " intermediates = []\n", " blocks = [self.stem, self.layer1, self.layer2, self.layer3, self.layer4]\n", " if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript\n", " blocks = blocks[:max_index + 1]\n", " for i, blk in enumerate(blocks):\n", " x = blk(x)\n", " if i in take_indices:\n", " intermediates.append(x)\n", "\n", " output['image_intermediates'] = intermediates\n", "\n", " if intermediates_only:\n", " return output\n", "\n", " x = self.attnpool(x)\n", " output['image_features'] = x\n", "\n", " return output\n", "\n", " def forward(self, x):\n", " x = self.stem(x)\n", " x = self.layer1(x)\n", " x = self.layer2(x)\n", " x = self.layer3(x)\n", " x = self.layer4(x)\n", " x = self.attnpool(x)\n", "\n", " return x\n", "\n", "\n", "\"\"\" huggingface model adapter\n", "\n", "Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.\n", "\"\"\"\n", "import re\n", "\n", "import torch\n", "import torch.nn as nn\n", "from torch import TensorType\n", "\n", "try:\n", " import transformers\n", " from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig\n", " from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \\\n", " BaseModelOutputWithPoolingAndCrossAttentions\n", "except ImportError as e:\n", " transformers = None\n", "\n", "\n", " class BaseModelOutput:\n", " pass\n", "\n", "\n", " class PretrainedConfig:\n", " pass\n", "\n", "# from .hf_configs import arch_dict\n", "# HF architecture dict:\n", "arch_dict = {\n", " # https://huggingface.co/docs/transformers/model_doc/roberta#roberta\n", " \"roberta\": {\n", " \"config_names\": {\n", " \"context_length\": \"max_position_embeddings\",\n", " \"vocab_size\": \"vocab_size\",\n", " \"width\": \"hidden_size\",\n", " \"heads\": \"num_attention_heads\",\n", " \"layers\": \"num_hidden_layers\",\n", " \"layer_attr\": \"layer\",\n", " \"token_embeddings_attr\": \"embeddings\"\n", " },\n", " \"pooler\": \"mean_pooler\",\n", " },\n", " # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig\n", " \"xlm-roberta\": {\n", " \"config_names\": {\n", " \"context_length\": \"max_position_embeddings\",\n", " \"vocab_size\": \"vocab_size\",\n", " \"width\": \"hidden_size\",\n", " \"heads\": \"num_attention_heads\",\n", " \"layers\": \"num_hidden_layers\",\n", " \"layer_attr\": \"layer\",\n", " \"token_embeddings_attr\": \"embeddings\"\n", " },\n", " \"pooler\": \"mean_pooler\",\n", " },\n", " # https://huggingface.co/docs/transformers/model_doc/mt5#mt5\n", " \"mt5\": {\n", " \"config_names\": {\n", " # unlimited seqlen\n", " # https://github.com/google-research/text-to-text-transfer-transformer/issues/273\n", " # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374\n", " \"context_length\": \"\",\n", " \"vocab_size\": \"vocab_size\",\n", " \"width\": \"d_model\",\n", " \"heads\": \"num_heads\",\n", " \"layers\": \"num_layers\",\n", " \"layer_attr\": \"block\",\n", " \"token_embeddings_attr\": \"embed_tokens\"\n", " },\n", " \"pooler\": \"mean_pooler\",\n", " },\n", " # https://huggingface.co/docs/transformers/model_doc/bert\n", " \"bert\": {\n", " \"config_names\": {\n", " \"context_length\": \"max_position_embeddings\",\n", " \"vocab_size\": \"vocab_size\",\n", " \"width\": \"hidden_size\",\n", " \"heads\": \"num_attention_heads\",\n", " \"layers\": \"num_hidden_layers\",\n", " },\n", " \"pooler\": \"cls_pooler\",\n", " },\n", " # https://huggingface.co/docs/transformers/model_doc/m2m_100\n", " \"m2m_100\": {\n", " \"config_names\": {\n", " \"context_length\": \"max_position_embeddings\",\n", " \"vocab_size\": \"vocab_size\",\n", " \"width\": \"d_model\",\n", " \"heads\": \"encoder_attention_heads\",\n", " \"layers\": \"encoder_layers\",\n", " },\n", " \"pooler\": \"cls_pooler\",\n", " },\n", "}\n", "\n", "\n", "\n", "# utils\n", "def _camel2snake(s):\n", " return re.sub(r'(? Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:\n", " \"\"\" Forward features that returns intermediates.\n", "\n", " Args:\n", " x: Input image tensor\n", " indices: Take last n blocks if int, all if None, select matching indices if sequence\n", " stop_early: Stop iterating over blocks when last desired intermediate hit\n", " normalize_intermediates: Apply norm layer to all intermediates\n", " intermediates_only: Only return intermediate features\n", " output_fmt: Shape of intermediate feature outputs\n", " output_extra_tokens: Return both prefix and spatial intermediate tokens\n", " Returns:\n", " \"\"\"\n", " extra_args = {}\n", " if output_extra_tokens:\n", " extra_args['return_prefix_tokens'] = True\n", " trunk_output = self.trunk.forward_intermediates(\n", " x,\n", " indices=indices,\n", " intermediates_only=intermediates_only,\n", " norm=normalize_intermediates,\n", " stop_early=stop_early,\n", " output_fmt=output_fmt,\n", " **extra_args,\n", " )\n", "\n", " return_dict = {}\n", " intermediates = trunk_output if intermediates_only else trunk_output[1]\n", " if output_extra_tokens and intermediates and isinstance(intermediates[0], tuple):\n", " intermediates_prefix = [xi[1] for xi in intermediates]\n", " intermediates = [xi[0] for xi in intermediates]\n", " return_dict['image_intermediates_prefix'] = intermediates_prefix\n", "\n", " return_dict['image_intermediates'] = intermediates\n", " if intermediates_only:\n", " return return_dict\n", "\n", " image_features = self.trunk.forward_head(trunk_output[0]) # run through timm pooling / projection\n", " image_features = self.head(image_features) # run through adapter pooling / projection\n", " return_dict['image_features'] = image_features\n", " return return_dict\n", "\n", " def forward(self, x):\n", " x = self.trunk(x)\n", " x = self.head(x)\n", " return x\n", "\n", "\n", "class LayerNormFp32(nn.LayerNorm):\n", " \"\"\"Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).\"\"\"\n", "\n", " def forward(self, x: torch.Tensor):\n", " orig_type = x.dtype\n", " x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps)\n", " return x.to(orig_type)\n", "\n", "\n", "class LayerNorm(nn.LayerNorm):\n", " \"\"\"Subclass torch's LayerNorm (with cast back to input dtype).\"\"\"\n", "\n", " def forward(self, x: torch.Tensor):\n", " orig_type = x.dtype\n", " x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)\n", " return x.to(orig_type)\n", "\n", "\n", "class QuickGELU(nn.Module):\n", " # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory\n", " def forward(self, x: torch.Tensor):\n", " return x * torch.sigmoid(1.702 * x)\n", "\n", "\n", "class LayerScale(nn.Module):\n", " def __init__(self, dim, init_values=1e-5, inplace=False):\n", " super().__init__()\n", " self.inplace = inplace\n", " self.gamma = nn.Parameter(init_values * torch.ones(dim))\n", "\n", " def forward(self, x):\n", " return x.mul_(self.gamma) if self.inplace else x * self.gamma\n", "\n", "\n", "class PatchDropout(nn.Module):\n", " \"\"\"\n", " https://arxiv.org/abs/2212.00794\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " prob: float = 0.5,\n", " exclude_first_token: bool = True\n", " ):\n", " super().__init__()\n", " assert 0 <= prob < 1.\n", " self.prob = prob\n", " self.exclude_first_token = exclude_first_token # exclude CLS token\n", "\n", " def forward(self, x):\n", " if not self.training or self.prob == 0.:\n", " return x\n", "\n", " if self.exclude_first_token:\n", " cls_tokens, x = x[:, :1], x[:, 1:]\n", " else:\n", " cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])\n", "\n", " batch = x.size()[0]\n", " num_tokens = x.size()[1]\n", "\n", " batch_indices = torch.arange(batch)\n", " batch_indices = batch_indices[..., None]\n", "\n", " keep_prob = 1 - self.prob\n", " num_patches_keep = max(1, int(num_tokens * keep_prob))\n", "\n", " rand = torch.randn(batch, num_tokens)\n", " patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices\n", "\n", " x = x[batch_indices, patch_indices_keep]\n", "\n", " if self.exclude_first_token:\n", " x = torch.cat((cls_tokens, x), dim=1)\n", "\n", " return x\n", "\n", "\n", "class Attention(nn.Module):\n", " def __init__(\n", " self,\n", " dim: int,\n", " num_heads: int = 8,\n", " qkv_bias: bool = True,\n", " scaled_cosine: bool = False,\n", " scale_heads: bool = False,\n", " logit_scale_max: float = math.log(1. / 0.01),\n", " batch_first: bool = True,\n", " attn_drop: float = 0.,\n", " proj_drop: float = 0.\n", " ):\n", " super().__init__()\n", " self.scaled_cosine = scaled_cosine\n", " self.scale_heads = scale_heads\n", " assert dim % num_heads == 0, 'dim should be divisible by num_heads'\n", " self.num_heads = num_heads\n", " self.head_dim = dim // num_heads\n", " self.scale = self.head_dim ** -0.5\n", " self.logit_scale_max = logit_scale_max\n", " self.batch_first = batch_first\n", " self.use_fsdpa = hasattr(nn.functional, 'scaled_dot_product_attention')\n", "\n", " # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original\n", " self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)\n", " if qkv_bias:\n", " self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))\n", " else:\n", " self.in_proj_bias = None\n", "\n", " if self.scaled_cosine:\n", " self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))\n", " else:\n", " self.logit_scale = None\n", " self.attn_drop = nn.Dropout(attn_drop)\n", " if self.scale_heads:\n", " self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))\n", " else:\n", " self.head_scale = None\n", " self.out_proj = nn.Linear(dim, dim)\n", " self.out_drop = nn.Dropout(proj_drop)\n", "\n", " def forward(self, x, attn_mask: Optional[torch.Tensor] = None):\n", " if self.batch_first:\n", " x = x.transpose(0, 1)\n", "\n", " L, N, C = x.shape\n", " q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)\n", " q = q.reshape(L, N * self.num_heads, -1).transpose(0, 1)\n", " k = k.reshape(L, N * self.num_heads, -1).transpose(0, 1)\n", " v = v.reshape(L, N * self.num_heads, -1).transpose(0, 1)\n", "\n", " if attn_mask is not None and attn_mask.dtype == torch.bool:\n", " new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)\n", " new_attn_mask.masked_fill_(attn_mask, float(\"-inf\"))\n", " attn_mask = new_attn_mask\n", "\n", " if self.logit_scale is not None:\n", " attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))\n", " logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()\n", " attn = attn.view(N, self.num_heads, L, L) * logit_scale\n", " attn = attn.view(-1, L, L)\n", " if attn_mask is not None:\n", " attn = attn + attn_mask\n", " attn = attn.softmax(dim=-1)\n", " attn = self.attn_drop(attn)\n", " x = torch.bmm(attn, v)\n", " else:\n", " if self.use_fsdpa:\n", " x = F.scaled_dot_product_attention(\n", " q, k, v,\n", " attn_mask=attn_mask,\n", " dropout_p=self.attn_drop.p if self.training else 0.,\n", " )\n", " else:\n", " q = q * self.scale\n", " attn = torch.bmm(q, k.transpose(-1, -2))\n", " if attn_mask is not None:\n", " attn += attn_mask\n", " attn = attn.softmax(dim=-1)\n", " attn = self.attn_drop(attn)\n", " x = torch.bmm(attn, v)\n", "\n", " if self.head_scale is not None:\n", " x = x.view(N, self.num_heads, L, C) * self.head_scale\n", " x = x.view(-1, L, C)\n", "\n", " x = x.transpose(0, 1).reshape(L, N, C)\n", "\n", " if self.batch_first:\n", " x = x.transpose(0, 1)\n", "\n", " x = self.out_proj(x)\n", " x = self.out_drop(x)\n", " return x\n", "\n", "\n", "class AttentionalPooler(nn.Module):\n", " def __init__(\n", " self,\n", " d_model: int,\n", " context_dim: int,\n", " n_head: int = 8,\n", " n_queries: int = 256,\n", " norm_layer: Callable = LayerNorm,\n", " ):\n", " super().__init__()\n", " self.query = nn.Parameter(torch.randn(n_queries, d_model))\n", " self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim, batch_first=True)\n", " self.ln_q = norm_layer(d_model)\n", " self.ln_k = norm_layer(context_dim)\n", "\n", " def forward(self, x: torch.Tensor):\n", " N = x.shape[0]\n", " x = self.ln_k(x)\n", " q = self.ln_q(self.query)\n", " out = self.attn(q.unsqueeze(0).expand(N, -1, -1), x, x, need_weights=False)[0]\n", " return out\n", "\n", "\n", "class ResidualAttentionBlock(nn.Module):\n", " def __init__(\n", " self,\n", " d_model: int,\n", " n_head: int,\n", " mlp_ratio: float = 4.0,\n", " ls_init_value: float = None,\n", " act_layer: Callable = nn.GELU,\n", " norm_layer: Callable = LayerNorm,\n", " is_cross_attention: bool = False,\n", " batch_first: bool = True,\n", " ):\n", " super().__init__()\n", "\n", " self.ln_1 = norm_layer(d_model)\n", " self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=batch_first)\n", " self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()\n", " if is_cross_attention:\n", " self.ln_1_kv = norm_layer(d_model)\n", "\n", " self.ln_2 = norm_layer(d_model)\n", " mlp_width = int(d_model * mlp_ratio)\n", " self.mlp = nn.Sequential(OrderedDict([\n", " (\"c_fc\", nn.Linear(d_model, mlp_width)),\n", " (\"gelu\", act_layer()),\n", " (\"c_proj\", nn.Linear(mlp_width, d_model))\n", " ]))\n", " self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()\n", "\n", " def attention(\n", " self,\n", " q_x: torch.Tensor,\n", " k_x: Optional[torch.Tensor] = None,\n", " v_x: Optional[torch.Tensor] = None,\n", " attn_mask: Optional[torch.Tensor] = None,\n", " ):\n", " k_x = k_x if k_x is not None else q_x\n", " v_x = v_x if v_x is not None else q_x\n", "\n", " attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None\n", " return self.attn(\n", " q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask\n", " )[0]\n", "\n", " def forward(\n", " self,\n", " q_x: torch.Tensor,\n", " k_x: Optional[torch.Tensor] = None,\n", " v_x: Optional[torch.Tensor] = None,\n", " attn_mask: Optional[torch.Tensor] = None,\n", " ):\n", " k_x = self.ln_1_kv(k_x) if hasattr(self, \"ln_1_kv\") and k_x is not None else None\n", " v_x = self.ln_1_kv(v_x) if hasattr(self, \"ln_1_kv\") and v_x is not None else None\n", " x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask))\n", " x = x + self.ls_2(self.mlp(self.ln_2(x)))\n", " return x\n", "\n", "\n", "class CustomResidualAttentionBlock(nn.Module):\n", " def __init__(\n", " self,\n", " d_model: int,\n", " n_head: int,\n", " mlp_ratio: float = 4.0,\n", " ls_init_value: float = None,\n", " act_layer: Callable = nn.GELU,\n", " norm_layer: Callable = LayerNorm,\n", " scale_cosine_attn: bool = False,\n", " scale_heads: bool = False,\n", " scale_attn: bool = False,\n", " scale_fc: bool = False,\n", " batch_first: bool = True,\n", " ):\n", " super().__init__()\n", "\n", " self.ln_1 = norm_layer(d_model)\n", " self.attn = Attention(\n", " d_model,\n", " n_head,\n", " scaled_cosine=scale_cosine_attn,\n", " scale_heads=scale_heads,\n", " batch_first=batch_first,\n", " )\n", " self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()\n", " self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()\n", "\n", " self.ln_2 = norm_layer(d_model)\n", " mlp_width = int(d_model * mlp_ratio)\n", " self.mlp = nn.Sequential(OrderedDict([\n", " (\"c_fc\", nn.Linear(d_model, mlp_width)),\n", " (\"gelu\", act_layer()),\n", " ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()),\n", " (\"c_proj\", nn.Linear(mlp_width, d_model))\n", " ]))\n", " self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()\n", "\n", " def get_reference_weight(self):\n", " return self.mlp.c_fc.weight\n", "\n", " def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):\n", " x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask)))\n", " x = x + self.ls_2(self.mlp(self.ln_2(x)))\n", " return x\n", "\n", "\n", "class CustomTransformer(nn.Module):\n", " \"\"\" A custom transformer that can use different block types. \"\"\"\n", " def __init__(\n", " self,\n", " width: int,\n", " layers: int,\n", " heads: int,\n", " mlp_ratio: float = 4.0,\n", " ls_init_value: float = None,\n", " act_layer: Callable = nn.GELU,\n", " norm_layer: Callable = LayerNorm,\n", " batch_first: bool = True,\n", " block_types: Union[str, List[str]] = 'CustomResidualAttentionBlock',\n", " ):\n", " super().__init__()\n", " self.width = width\n", " self.layers = layers\n", " self.batch_first = batch_first # run transformer stack in batch first (N, L, D)\n", " self.grad_checkpointing = False\n", "\n", " if isinstance(block_types, str):\n", " block_types = [block_types] * layers\n", " assert len(block_types) == layers\n", "\n", " def _create_block(bt: str):\n", " if bt == 'CustomResidualAttentionBlock':\n", " return CustomResidualAttentionBlock(\n", " width,\n", " heads,\n", " mlp_ratio=mlp_ratio,\n", " ls_init_value=ls_init_value,\n", " act_layer=act_layer,\n", " norm_layer=norm_layer,\n", " batch_first=batch_first,\n", " )\n", " else:\n", " assert False\n", "\n", " self.resblocks = nn.ModuleList([\n", " _create_block(bt)\n", " for bt in block_types\n", " ])\n", "\n", " def get_cast_dtype(self) -> torch.dtype:\n", " weight = self.resblocks[0].get_reference_weight()\n", " if hasattr(weight, 'int8_original_dtype'):\n", " return weight.int8_original_dtype\n", " return weight.dtype\n", "\n", " def forward_intermediates(\n", " self,\n", " x: torch.Tensor,\n", " attn_mask: Optional[torch.Tensor] = None,\n", " indices: Optional[Union[int, List[int]]] = None,\n", " stop_early: bool = False,\n", " ):\n", " take_indices, max_index = feature_take_indices(len(self.resblocks), indices)\n", "\n", " if not self.batch_first:\n", " x = x.transpose(0, 1).contiguous() # NLD -> LND\n", "\n", " intermediates = []\n", " if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript\n", " blocks = self.resblocks\n", " else:\n", " blocks = self.resblocks[:max_index + 1]\n", " for i, blk in enumerate(blocks):\n", " if self.grad_checkpointing and not torch.jit.is_scripting():\n", " x = checkpoint(blk, x, None, None, attn_mask, use_reentrant=False)\n", " else:\n", " x = blk(x, attn_mask=attn_mask)\n", "\n", " if i in take_indices:\n", " intermediates.append(x.transpose(0, 1) if not self.batch_first else x)\n", "\n", " if not self.batch_first:\n", " x = x.transpose(0, 1) # LND -> NLD\n", "\n", " return x, intermediates\n", "\n", " def prune_intermediate_layers(self, indices: Union[int, List[int]] = 1):\n", " \"\"\" Prune layers not required for specified intermediates.\n", " \"\"\"\n", " take_indices, max_index = feature_take_indices(len(self.resblocks), indices)\n", " self.resblocks = self.resblocks[:max_index + 1] # truncate blocks\n", " return take_indices\n", "\n", " def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):\n", " if not self.batch_first:\n", " x = x.transpose(0, 1) # NLD -> LND\n", "\n", " for r in self.resblocks:\n", " if self.grad_checkpointing and not torch.jit.is_scripting():\n", " # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372\n", " x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False)\n", " else:\n", " x = r(x, attn_mask=attn_mask)\n", "\n", " if not self.batch_first:\n", " x = x.transpose(0, 1) # NLD -> LND\n", " return x\n", "\n", "\n", "class Transformer(nn.Module):\n", " def __init__(\n", " self,\n", " width: int,\n", " layers: int,\n", " heads: int,\n", " mlp_ratio: float = 4.0,\n", " ls_init_value: float = None,\n", " act_layer: Callable = nn.GELU,\n", " norm_layer: Callable = LayerNorm,\n", " batch_first: bool = True,\n", " ):\n", " super().__init__()\n", " self.width = width\n", " self.layers = layers\n", " self.batch_first = batch_first\n", " self.grad_checkpointing = False\n", "\n", " self.resblocks = nn.ModuleList([\n", " ResidualAttentionBlock(\n", " width,\n", " heads,\n", " mlp_ratio,\n", " ls_init_value=ls_init_value,\n", " act_layer=act_layer,\n", " norm_layer=norm_layer,\n", " batch_first=batch_first,\n", " )\n", " for _ in range(layers)\n", " ])\n", "\n", " def get_cast_dtype(self) -> torch.dtype:\n", " if hasattr(self.resblocks[0].mlp.c_fc, 'int8_original_dtype'):\n", " return self.resblocks[0].mlp.c_fc.int8_original_dtype\n", " return self.resblocks[0].mlp.c_fc.weight.dtype\n", "\n", " def forward_intermediates(\n", " self,\n", " x: torch.Tensor,\n", " attn_mask: Optional[torch.Tensor] = None,\n", " indices: Optional[Union[int, List[int]]] = None,\n", " stop_early: bool = False,\n", " ):\n", " take_indices, max_index = feature_take_indices(len(self.resblocks), indices)\n", "\n", " if not self.batch_first:\n", " x = x.transpose(0, 1).contiguous() # NLD -> LND\n", "\n", " intermediates = []\n", " if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript\n", " blocks = self.resblocks\n", " else:\n", " blocks = self.resblocks[:max_index + 1]\n", " for i, blk in enumerate(blocks):\n", " if self.grad_checkpointing and not torch.jit.is_scripting():\n", " x = checkpoint(blk, x, None, None, attn_mask, use_reentrant=False)\n", " else:\n", " x = blk(x, attn_mask=attn_mask)\n", "\n", " if i in take_indices:\n", " intermediates.append(x.transpose(0, 1) if not self.batch_first else x)\n", "\n", " if not self.batch_first:\n", " x = x.transpose(0, 1) # LND -> NLD\n", "\n", " return x, intermediates\n", "\n", " def prune_intermediate_layers(self, indices: Union[int, List[int]] = 1):\n", " \"\"\" Prune layers not required for specified intermediates.\n", " \"\"\"\n", " take_indices, max_index = feature_take_indices(len(self.resblocks), indices)\n", " self.resblocks = self.resblocks[:max_index + 1] # truncate blocks\n", " return take_indices\n", "\n", " def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):\n", " if not self.batch_first:\n", " x = x.transpose(0, 1).contiguous() # NLD -> LND\n", "\n", " for r in self.resblocks:\n", " if self.grad_checkpointing and not torch.jit.is_scripting():\n", " # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372\n", " x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False)\n", " else:\n", " x = r(x, attn_mask=attn_mask)\n", "\n", " if not self.batch_first:\n", " x = x.transpose(0, 1) # LND -> NLD\n", " return x\n", "\n", "\n", "def _expand_token(token, batch_size: int):\n", " return token.view(1, 1, -1).expand(batch_size, -1, -1)\n", "\n", "\n", "class VisionTransformer(nn.Module):\n", " output_tokens: torch.jit.Final[bool]\n", "\n", " def __init__(\n", " self,\n", " image_size: int,\n", " patch_size: int,\n", " width: int,\n", " layers: int,\n", " heads: int,\n", " mlp_ratio: float,\n", " ls_init_value: float = None,\n", " attentional_pool: bool = False,\n", " attn_pooler_queries: int = 256,\n", " attn_pooler_heads: int = 8,\n", " output_dim: int = 512,\n", " patch_dropout: float = 0.,\n", " no_ln_pre: bool = False,\n", " pos_embed_type: str = 'learnable',\n", " pool_type: str = 'tok',\n", " final_ln_after_pool: bool = False,\n", " act_layer: Callable = nn.GELU,\n", " norm_layer: Callable = LayerNorm,\n", " output_tokens: bool = False,\n", " ):\n", " super().__init__()\n", " assert pool_type in ('tok', 'avg', 'none')\n", " self.output_tokens = output_tokens\n", " image_height, image_width = self.image_size = to_2tuple(image_size)\n", " patch_height, patch_width = self.patch_size = to_2tuple(patch_size)\n", " self.grid_size = (image_height // patch_height, image_width // patch_width)\n", " self.final_ln_after_pool = final_ln_after_pool # currently ignored w/ attn pool enabled\n", " self.output_dim = output_dim\n", "\n", " self.conv1 = nn.Conv2d(\n", " in_channels=3,\n", " out_channels=width,\n", " kernel_size=patch_size,\n", " stride=patch_size,\n", " bias=False,\n", " )\n", "\n", " # class embeddings and positional embeddings\n", " scale = width ** -0.5\n", " self.class_embedding = nn.Parameter(scale * torch.randn(width))\n", " if pos_embed_type == 'learnable':\n", " self.positional_embedding = nn.Parameter(\n", " scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))\n", " elif pos_embed_type == 'sin_cos_2d':\n", " # fixed sin-cos embedding\n", " assert self.grid_size[0] == self.grid_size[1],\\\n", " 'currently sin cos 2d pos embedding only supports square input'\n", " self.positional_embedding = nn.Parameter(\n", " torch.zeros(self.grid_size[0] * self.grid_size[1] + 1, width), requires_grad=False)\n", " pos_embed_type = get_2d_sincos_pos_embed(width, self.grid_size[0], cls_token=True)\n", " self.positional_embedding.data.copy_(torch.from_numpy(pos_embed_type).float())\n", " else:\n", " raise ValueError\n", "\n", " # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn\n", " self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()\n", "\n", " self.ln_pre = nn.Identity() if no_ln_pre else norm_layer(width)\n", " self.transformer = Transformer(\n", " width,\n", " layers,\n", " heads,\n", " mlp_ratio,\n", " ls_init_value=ls_init_value,\n", " act_layer=act_layer,\n", " norm_layer=norm_layer,\n", " )\n", "\n", " if attentional_pool:\n", " if isinstance(attentional_pool, str):\n", " self.attn_pool_type = attentional_pool\n", " self.pool_type = 'none'\n", " if attentional_pool in ('parallel', 'cascade'):\n", " self.attn_pool = AttentionalPooler(\n", " output_dim,\n", " width,\n", " n_head=attn_pooler_heads,\n", " n_queries=attn_pooler_queries,\n", " )\n", " self.attn_pool_contrastive = AttentionalPooler(\n", " output_dim,\n", " width,\n", " n_head=attn_pooler_heads,\n", " n_queries=1,\n", " )\n", " else:\n", " assert False\n", " else:\n", " self.attn_pool_type = ''\n", " self.pool_type = pool_type\n", " self.attn_pool = AttentionalPooler(\n", " output_dim,\n", " width,\n", " n_head=attn_pooler_heads,\n", " n_queries=attn_pooler_queries,\n", " )\n", " self.attn_pool_contrastive = None\n", " pool_dim = output_dim\n", " else:\n", " self.attn_pool = None\n", " pool_dim = width\n", " self.pool_type = pool_type\n", "\n", " self.ln_post = norm_layer(pool_dim)\n", " self.proj = nn.Parameter(scale * torch.randn(pool_dim, output_dim))\n", "\n", " self.init_parameters()\n", "\n", " def lock(self, unlocked_groups: int = 0, freeze_bn_stats: bool = False):\n", " for param in self.parameters():\n", " param.requires_grad = False\n", "\n", " if unlocked_groups != 0:\n", " groups = [\n", " [\n", " self.conv1,\n", " self.class_embedding,\n", " self.positional_embedding,\n", " self.ln_pre,\n", " ],\n", " *self.transformer.resblocks[:-1],\n", " [\n", " self.transformer.resblocks[-1],\n", " self.ln_post,\n", " ],\n", " self.proj,\n", " ]\n", "\n", " def _unlock(x):\n", " if isinstance(x, Sequence):\n", " for g in x:\n", " _unlock(g)\n", " else:\n", " if isinstance(x, torch.nn.Parameter):\n", " x.requires_grad = True\n", " else:\n", " for p in x.parameters():\n", " p.requires_grad = True\n", "\n", " _unlock(groups[-unlocked_groups:])\n", "\n", " def init_parameters(self):\n", " # FIXME OpenAI CLIP did not define an init for the VisualTransformer\n", " # TODO experiment if default PyTorch init, below, or alternate init is best.\n", "\n", " # nn.init.normal_(self.class_embedding, std=self.scale)\n", " # nn.init.normal_(self.positional_embedding, std=self.scale)\n", " #\n", " # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)\n", " # attn_std = self.transformer.width ** -0.5\n", " # fc_std = (2 * self.transformer.width) ** -0.5\n", " # for block in self.transformer.resblocks:\n", " # nn.init.normal_(block.attn.in_proj_weight, std=attn_std)\n", " # nn.init.normal_(block.attn.out_proj.weight, std=proj_std)\n", " # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)\n", " # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)\n", " #\n", " # if self.text_projection is not None:\n", " # nn.init.normal_(self.text_projection, std=self.scale)\n", " pass\n", "\n", " @torch.jit.ignore\n", " def set_grad_checkpointing(self, enable: bool = True):\n", " self.transformer.grad_checkpointing = enable\n", "\n", " @torch.jit.ignore\n", " def no_weight_decay(self):\n", " # for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default\n", " no_wd = {'positional_embedding', 'class_embedding'}\n", " return no_wd\n", "\n", " def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n", " if self.pool_type == 'avg':\n", " pooled, tokens = x[:, 1:].mean(dim=1), x[:, 1:]\n", " elif self.pool_type == 'tok':\n", " pooled, tokens = x[:, 0], x[:, 1:]\n", " else:\n", " pooled = tokens = x\n", "\n", " return pooled, tokens\n", "\n", " def _embeds(self, x:torch.Tensor) -> torch.Tensor:\n", " x = self.conv1(x) # shape = [*, dim, grid, grid]\n", " x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]\n", " x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]\n", "\n", " # class embeddings and positional embeddings\n", " x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1)\n", " # shape = [*, grid ** 2 + 1, width]\n", " x = x + self.positional_embedding.to(x.dtype)\n", "\n", " # patch dropout (if active)\n", " x = self.patch_dropout(x)\n", "\n", " # apply norm before transformer\n", " x = self.ln_pre(x)\n", " return x\n", "\n", " def _pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n", " if self.attn_pool is not None:\n", " if self.attn_pool_contrastive is not None:\n", " # This is untested, WIP pooling that should match paper\n", " x = self.ln_post(x) # TBD LN first or separate one after each pool?\n", " tokens = self.attn_pool(x)\n", " if self.attn_pool_type == 'parallel':\n", " pooled = self.attn_pool_contrastive(x)\n", " else:\n", " assert self.attn_pool_type == 'cascade'\n", " pooled = self.attn_pool_contrastive(tokens)\n", " else:\n", " # this is the original OpenCLIP CoCa setup, does not match paper\n", " x = self.attn_pool(x)\n", " x = self.ln_post(x)\n", " pooled, tokens = self._global_pool(x)\n", " elif self.final_ln_after_pool:\n", " pooled, tokens = self._global_pool(x)\n", " pooled = self.ln_post(pooled)\n", " else:\n", " x = self.ln_post(x)\n", " pooled, tokens = self._global_pool(x)\n", "\n", " return pooled, tokens\n", "\n", " def forward_intermediates(\n", " self,\n", " x: torch.Tensor,\n", " indices: Optional[Union[int, List[int]]] = None,\n", " stop_early: bool = False,\n", " normalize_intermediates: bool = False,\n", " intermediates_only: bool = False,\n", " output_fmt: str = 'NCHW',\n", " output_extra_tokens: bool = False,\n", " ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:\n", " \"\"\" Forward features that returns intermediates.\n", "\n", " Args:\n", " x: Input image tensor\n", " indices: Take last n blocks if int, all if None, select matching indices if sequence\n", " stop_early: Stop iterating over blocks when last desired intermediate hit\n", " intermediates_only: Only return intermediate features\n", " normalize_intermediates: Apply final norm layer to all intermediates\n", " output_fmt: Shape of intermediate feature outputs\n", " output_extra_tokens: Return both extra prefix class tokens\n", " Returns:\n", "\n", " \"\"\"\n", " assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'\n", " reshape = output_fmt == 'NCHW'\n", "\n", " # forward pass\n", " B, _, height, width = x.shape\n", " x = self._embeds(x)\n", " x, intermediates = self.transformer.forward_intermediates(\n", " x,\n", " indices=indices,\n", " stop_early=stop_early,\n", " )\n", "\n", " # process intermediates\n", " if normalize_intermediates:\n", " # apply final norm to all intermediates\n", " intermediates = [self.ln_post(xi) for xi in intermediates]\n", " num_prefix_tokens = 1 # one class token that's always there (as of now)\n", " if num_prefix_tokens:\n", " # split prefix (e.g. class, distill) and spatial feature tokens\n", " prefix_tokens = [y[:, 0:num_prefix_tokens] for y in intermediates]\n", " intermediates = [y[:, num_prefix_tokens:] for y in intermediates]\n", " else:\n", " prefix_tokens = None\n", " if reshape:\n", " # reshape to BCHW output format\n", " H, W = height // self.patch_size[0], width // self.patch_size[1]\n", " intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]\n", "\n", " output = {'image_intermediates': intermediates}\n", " if prefix_tokens is not None and output_extra_tokens:\n", " output['image_intermediates_prefix'] = prefix_tokens\n", "\n", " if intermediates_only:\n", " return output\n", "\n", " pooled, _ = self._pool(x)\n", "\n", " if self.proj is not None:\n", " pooled = pooled @ self.proj\n", "\n", " output['image_features'] = pooled\n", "\n", " return output\n", "\n", " def prune_intermediate_layers(\n", " self,\n", " indices: Union[int, List[int]] = 1,\n", " prune_norm: bool = False,\n", " prune_head: bool = True,\n", " ):\n", " \"\"\" Prune layers not required for specified intermediates.\n", " \"\"\"\n", " take_indices = self.transformer.prune_intermediate_layers(indices)\n", " if prune_norm:\n", " self.ln_post = nn.Identity()\n", " if prune_head:\n", " self.proj = None\n", " return take_indices\n", "\n", " def forward(self, x: torch.Tensor):\n", " x = self._embeds(x)\n", " x = self.transformer(x)\n", " pooled, tokens = self._pool(x)\n", "\n", " if self.proj is not None:\n", " pooled = pooled @ self.proj\n", "\n", " if self.output_tokens:\n", " return pooled, tokens\n", " \n", " return pooled\n", "\n", "\n", "def text_global_pool(\n", " x: torch.Tensor,\n", " text: Optional[torch.Tensor] = None,\n", " pool_type: str = 'argmax',\n", ") -> torch.Tensor:\n", " if pool_type == 'first':\n", " pooled = x[:, 0]\n", " elif pool_type == 'last':\n", " pooled = x[:, -1]\n", " elif pool_type == 'argmax':\n", " # take features from the eot embedding (eot_token is the highest number in each sequence)\n", " assert text is not None\n", " pooled = x[torch.arange(x.shape[0]), text.argmax(dim=-1)]\n", " else:\n", " pooled = x\n", "\n", " return pooled\n", "\n", "\n", "class TextTransformer(nn.Module):\n", " output_tokens: torch.jit.Final[bool]\n", "\n", " def __init__(\n", " self,\n", " context_length: int = 77,\n", " vocab_size: int = 49408,\n", " width: int = 512,\n", " heads: int = 8,\n", " layers: int = 12,\n", " mlp_ratio: float = 4.0,\n", " ls_init_value: float = None,\n", " output_dim: Optional[int] = 512,\n", " embed_cls: bool = False,\n", " no_causal_mask: bool = False,\n", " pad_id: int = 0,\n", " pool_type: str = 'argmax',\n", " proj_type: str = 'linear',\n", " proj_bias: bool = False,\n", " act_layer: Callable = nn.GELU,\n", " norm_layer: Callable = LayerNorm,\n", " output_tokens: bool = False,\n", " ):\n", " super().__init__()\n", " assert pool_type in ('first', 'last', 'argmax', 'none')\n", " self.output_tokens = output_tokens\n", " self.num_pos = self.context_length = context_length\n", " self.vocab_size = vocab_size\n", " self.width = width\n", " self.output_dim = output_dim\n", " self.heads = heads\n", " self.pad_id = pad_id\n", " self.pool_type = pool_type\n", "\n", " self.token_embedding = nn.Embedding(vocab_size, width)\n", " if embed_cls:\n", " self.cls_emb = nn.Parameter(torch.empty(width))\n", " self.num_pos += 1\n", " else:\n", " self.cls_emb = None\n", " self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))\n", " self.transformer = Transformer(\n", " width=width,\n", " layers=layers,\n", " heads=heads,\n", " mlp_ratio=mlp_ratio,\n", " ls_init_value=ls_init_value,\n", " act_layer=act_layer,\n", " norm_layer=norm_layer,\n", " )\n", " self.ln_final = norm_layer(width)\n", "\n", " if no_causal_mask:\n", " self.attn_mask = None\n", " else:\n", " self.register_buffer('attn_mask', self.build_causal_mask(), persistent=False)\n", "\n", " if proj_type == 'none' or not output_dim:\n", " self.text_projection = None\n", " else:\n", " if proj_bias:\n", " self.text_projection = nn.Linear(width, output_dim)\n", " else:\n", " self.text_projection = nn.Parameter(torch.empty(width, output_dim))\n", "\n", " self.init_parameters()\n", "\n", " def init_parameters(self):\n", " nn.init.normal_(self.token_embedding.weight, std=0.02)\n", " nn.init.normal_(self.positional_embedding, std=0.01)\n", " if self.cls_emb is not None:\n", " nn.init.normal_(self.cls_emb, std=0.01)\n", "\n", " proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)\n", " attn_std = self.transformer.width ** -0.5\n", " fc_std = (2 * self.transformer.width) ** -0.5\n", " for block in self.transformer.resblocks:\n", " nn.init.normal_(block.attn.in_proj_weight, std=attn_std)\n", " nn.init.normal_(block.attn.out_proj.weight, std=proj_std)\n", " nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)\n", " nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)\n", "\n", " if self.text_projection is not None:\n", " if isinstance(self.text_projection, nn.Linear):\n", " nn.init.normal_(self.text_projection.weight, std=self.transformer.width ** -0.5)\n", " if self.text_projection.bias is not None:\n", " nn.init.zeros_(self.text_projection.bias)\n", " else:\n", " nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)\n", "\n", " @torch.jit.ignore\n", " def set_grad_checkpointing(self, enable=True):\n", " self.transformer.grad_checkpointing = enable\n", "\n", " @torch.jit.ignore\n", " def no_weight_decay(self):\n", " # for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default\n", " no_wd = {'positional_embedding'}\n", " if self.cls_emb is not None:\n", " no_wd.add('cls_emb')\n", " return no_wd\n", "\n", " def build_causal_mask(self):\n", " # lazily create causal attention mask, with full attention between the tokens\n", " # pytorch uses additive attention mask; fill with -inf\n", " mask = torch.empty(self.num_pos, self.num_pos)\n", " mask.fill_(float(\"-inf\"))\n", " mask.triu_(1) # zero out the lower diagonal\n", " return mask\n", "\n", " def build_cls_mask(self, text, cast_dtype: torch.dtype):\n", " cls_mask = (text != self.pad_id).unsqueeze(1)\n", " cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True)\n", " additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device)\n", " additive_mask.fill_(0)\n", " additive_mask.masked_fill_(~cls_mask, float(\"-inf\"))\n", " additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)\n", " return additive_mask\n", "\n", " def _embeds(self, text) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n", " cast_dtype = self.transformer.get_cast_dtype()\n", " seq_len = text.shape[1]\n", " x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]\n", " attn_mask = self.attn_mask\n", " if self.cls_emb is not None:\n", " seq_len += 1\n", " x = torch.cat([x, _expand_token(self.cls_emb, x.shape[0])], dim=1)\n", " cls_mask = self.build_cls_mask(text, cast_dtype)\n", " if attn_mask is not None:\n", " attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len]\n", " x = x + self.positional_embedding[:seq_len].to(cast_dtype)\n", " return x, attn_mask\n", "\n", " def forward_intermediates(\n", " self,\n", " text: torch.Tensor,\n", " indices: Optional[Union[int, List[int]]] = None,\n", " stop_early: bool = False,\n", " normalize_intermediates: bool = False,\n", " intermediates_only: bool = False,\n", " output_fmt: str = 'NCHW',\n", " output_extra_tokens: bool = False,\n", " ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:\n", " \"\"\" Forward features that returns intermediates.\n", "\n", " Args:\n", " text: Input text ids\n", " indices: Take last n blocks if int, all if None, select matching indices if sequence\n", " stop_early: Stop iterating over blocks when last desired intermediate hit\n", " normalize_intermediates: Apply norm layer to all intermediates\n", " intermediates_only: Only return intermediate features\n", " output_fmt: Shape of intermediate feature outputs\n", " output_extra_tokens: Return both prefix and intermediate tokens\n", " Returns:\n", "\n", " \"\"\"\n", " assert output_fmt in ('NLC',), 'Output format must be NLC.'\n", " # forward pass\n", " x, attn_mask = self._embeds(text)\n", " x, intermediates = self.transformer.forward_intermediates(\n", " x,\n", " attn_mask=attn_mask,\n", " indices=indices,\n", " stop_early=stop_early,\n", " )\n", "\n", " # process intermediates\n", " if normalize_intermediates:\n", " # apply final norm to all intermediates\n", " intermediates = [self.ln_final(xi) for xi in intermediates]\n", "\n", " output = {}\n", "\n", " if self.cls_emb is not None:\n", " seq_intermediates = [xi[:, :-1] for xi in intermediates] # separate concat'd class token from sequence\n", " if output_extra_tokens:\n", " # return suffix class tokens separately\n", " cls_intermediates = [xi[:, -1:] for xi in intermediates]\n", " output['text_intermediates_suffix'] = cls_intermediates\n", " intermediates = seq_intermediates\n", " output['text_intermediates'] = intermediates\n", "\n", " if intermediates_only:\n", " return output\n", "\n", " if self.cls_emb is not None:\n", " # presence of appended cls embed (CoCa) overrides pool_type, always take last token\n", " pooled = text_global_pool(x, pool_type='last')\n", " pooled = self.ln_final(pooled) # final LN applied after pooling in this case\n", " else:\n", " x = self.ln_final(x)\n", " pooled = text_global_pool(x, text, pool_type=self.pool_type)\n", "\n", " if self.text_projection is not None:\n", " if isinstance(self.text_projection, nn.Linear):\n", " pooled = self.text_projection(pooled)\n", " else:\n", " pooled = pooled @ self.text_projection\n", "\n", " output['text_features'] = pooled\n", "\n", " return output\n", "\n", " def prune_intermediate_layers(\n", " self,\n", " indices: Union[int, List[int]] = 1,\n", " prune_norm: bool = False,\n", " prune_head: bool = True,\n", " ):\n", " \"\"\" Prune layers not required for specified intermediates.\n", " \"\"\"\n", " take_indices = self.transformer.prune_intermediate_layers(indices)\n", " if prune_norm:\n", " self.ln_final = nn.Identity()\n", " if prune_head:\n", " self.text_projection = None\n", " return take_indices\n", "\n", " def forward(self, text):\n", " x, attn_mask = self._embeds(text)\n", "\n", " x = self.transformer(x, attn_mask=attn_mask)\n", "\n", " # x.shape = [batch_size, n_ctx, transformer.width]\n", " if self.cls_emb is not None:\n", " # presence of appended cls embed (CoCa) overrides pool_type, always take last token\n", " pooled = text_global_pool(x, pool_type='last')\n", " pooled = self.ln_final(pooled) # final LN applied after pooling in this case\n", " tokens = x[:, :-1]\n", " else:\n", " x = self.ln_final(x)\n", " pooled = text_global_pool(x, text, pool_type=self.pool_type)\n", " tokens = x\n", "\n", " if self.text_projection is not None:\n", " if isinstance(self.text_projection, nn.Linear):\n", " pooled = self.text_projection(pooled)\n", " else:\n", " pooled = pooled @ self.text_projection\n", "\n", " if self.output_tokens:\n", " return pooled, tokens\n", "\n", " return pooled\n", "\n", "\n", "class MultimodalTransformer(Transformer):\n", " def __init__(\n", " self,\n", " width: int,\n", " layers: int,\n", " heads: int,\n", " context_length: int = 77,\n", " mlp_ratio: float = 4.0,\n", " ls_init_value: float = None,\n", " act_layer: Callable = nn.GELU,\n", " norm_layer: Callable = LayerNorm,\n", " output_dim: int = 512,\n", " batch_first: bool = True,\n", " ):\n", " super().__init__(\n", " width=width,\n", " layers=layers,\n", " heads=heads,\n", " mlp_ratio=mlp_ratio,\n", " ls_init_value=ls_init_value,\n", " act_layer=act_layer,\n", " norm_layer=norm_layer,\n", " batch_first=batch_first,\n", " )\n", " self.context_length = context_length\n", " self.cross_attn = nn.ModuleList([\n", " ResidualAttentionBlock(\n", " width,\n", " heads,\n", " mlp_ratio,\n", " ls_init_value=ls_init_value,\n", " act_layer=act_layer,\n", " norm_layer=norm_layer,\n", " is_cross_attention=True,\n", " batch_first=batch_first,\n", " )\n", " for _ in range(layers)\n", " ])\n", "\n", " self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)\n", "\n", " self.ln_final = norm_layer(width)\n", " self.text_projection = nn.Parameter(torch.empty(width, output_dim))\n", "\n", " def init_parameters(self):\n", " proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)\n", " attn_std = self.transformer.width ** -0.5\n", " fc_std = (2 * self.transformer.width) ** -0.5\n", " for block in self.transformer.resblocks:\n", " nn.init.normal_(block.attn.in_proj_weight, std=attn_std)\n", " nn.init.normal_(block.attn.out_proj.weight, std=proj_std)\n", " nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)\n", " nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)\n", " for block in self.transformer.cross_attn:\n", " nn.init.normal_(block.attn.in_proj_weight, std=attn_std)\n", " nn.init.normal_(block.attn.out_proj.weight, std=proj_std)\n", " nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)\n", " nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)\n", "\n", " if self.text_projection is not None:\n", " nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)\n", "\n", " def build_attention_mask(self):\n", " # lazily create causal attention mask, with full attention between the tokens\n", " # pytorch uses additive attention mask; fill with -inf\n", " mask = torch.empty(self.context_length, self.context_length)\n", " mask.fill_(float(\"-inf\"))\n", " mask.triu_(1) # zero out the lower diagonal\n", " return mask\n", "\n", " def forward_intermediates(\n", " self,\n", " x: torch.Tensor,\n", " attn_mask: Optional[torch.Tensor] = None,\n", " indices: Optional[Union[int, List[int]]] = None,\n", " stop_early: bool = False,\n", " ):\n", " assert False, \"Not currently implemented for MultimodalTransformer w/ xattn\"\n", "\n", " def forward(self, image_embs, text_embs):\n", " seq_len = text_embs.shape[1]\n", " if not self.batch_first:\n", " image_embs = image_embs.permute(1, 0, 2) # NLD -> LND\n", " text_embs = text_embs.permute(1, 0, 2) # NLD -> LND\n", "\n", " for resblock, cross_attn in zip(self.resblocks, self.cross_attn):\n", " if self.grad_checkpointing and not torch.jit.is_scripting():\n", " # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372\n", " text_embs = checkpoint(\n", " resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len], use_reentrant=False)\n", " text_embs = checkpoint(\n", " cross_attn, text_embs, image_embs, image_embs, None, use_reentrant=False)\n", " else:\n", " text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len])\n", " text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs)\n", "\n", " if not self.batch_first:\n", " text_embs = text_embs.permute(1, 0, 2) # LND -> NLD\n", "\n", " out = self.ln_final(text_embs)\n", " if self.text_projection is not None:\n", " out = out @ self.text_projection\n", "\n", " return out\n", "\n", " @torch.jit.ignore\n", " def set_grad_checkpointing(self, enable=True):\n", " self.grad_checkpointing = enable\n", "\n", "\n", "\n", "@dataclass\n", "class CLIPVisionCfg:\n", " layers: Union[Tuple[int, int, int, int], int] = 12\n", " width: int = 768\n", " head_width: int = 64\n", " mlp_ratio: float = 4.0\n", " patch_size: int = 16\n", " image_size: Union[Tuple[int, int], int] = 224\n", "\n", " ls_init_value: Optional[float] = None # layer scale initial value\n", " patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results\n", " attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer (overrides pool_type)\n", " attn_pooler_queries: int = 256 # n_queries for attentional pooler\n", " attn_pooler_heads: int = 8 # n heads for attentional_pooling\n", " no_ln_pre: bool = False # disable pre transformer LayerNorm\n", " pos_embed_type: str = 'learnable'\n", " final_ln_after_pool: bool = False # apply final LayerNorm after pooling\n", " pool_type: str = 'tok'\n", " output_tokens: bool = False\n", " act_kwargs: Optional[dict] = None\n", " norm_kwargs: Optional[dict] = None\n", "\n", " timm_model_name: Optional[str] = None # a valid model name overrides layers, width, patch_size\n", " timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model\n", " timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')\n", " timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')\n", " timm_proj_bias: bool = False # enable bias final projection\n", " timm_drop: float = 0. # head dropout\n", " timm_drop_path: Optional[float] = None # backbone stochastic depth\n", "\n", "\n", "@dataclass\n", "class CLIPTextCfg:\n", " context_length: int = 77\n", " vocab_size: int = 49408\n", " hf_tokenizer_name: Optional[str] = None\n", " tokenizer_kwargs: Optional[dict] = None\n", "\n", " width: int = 512\n", " heads: int = 8\n", " layers: int = 12\n", " mlp_ratio: float = 4.0\n", " ls_init_value: Optional[float] = None # layer scale initial value\n", " embed_cls: bool = False\n", " pad_id: int = 0\n", " no_causal_mask: bool = False # disable causal masking\n", " final_ln_after_pool: bool = False # apply final LayerNorm after pooling\n", " pool_type: str = 'argmax'\n", " proj_bias: bool = False\n", " proj_type: str = 'linear' # control final text projection, 'none' forces no projection\n", " output_tokens: bool = False\n", " act_kwargs: dict = None\n", " norm_kwargs: dict = None\n", "\n", " # HuggingFace specific text tower config\n", " hf_model_name: Optional[str] = None\n", " hf_model_pretrained: bool = True\n", " hf_proj_type: str = 'mlp'\n", " hf_pooler_type: str = 'mean_pooler' # attentional pooling for HF models\n", "\n", "\n", "def get_cast_dtype(precision: str):\n", " cast_dtype = None\n", " if precision == 'bf16':\n", " cast_dtype = torch.bfloat16\n", " elif precision == 'fp16':\n", " cast_dtype = torch.float16\n", " return cast_dtype\n", "\n", "\n", "def get_input_dtype(precision: str):\n", " input_dtype = None\n", " if precision in ('bf16', 'pure_bf16'):\n", " input_dtype = torch.bfloat16\n", " elif precision in ('fp16', 'pure_fp16'):\n", " input_dtype = torch.float16\n", " return input_dtype\n", "\n", "\n", "def _build_vision_tower(\n", " embed_dim: int,\n", " vision_cfg: CLIPVisionCfg,\n", " quick_gelu: bool = False,\n", " cast_dtype: Optional[torch.dtype] = None\n", "):\n", " if isinstance(vision_cfg, dict):\n", " vision_cfg = CLIPVisionCfg(**vision_cfg)\n", "\n", " # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more\n", " # memory efficient in recent PyTorch releases (>= 1.10).\n", " # NOTE: timm models always use native GELU regardless of quick_gelu flag.\n", " act_layer = QuickGELU if quick_gelu else nn.GELU\n", "\n", " if vision_cfg.timm_model_name:\n", " visual = TimmModel(\n", " vision_cfg.timm_model_name,\n", " pretrained=vision_cfg.timm_model_pretrained,\n", " pool=vision_cfg.timm_pool,\n", " proj=vision_cfg.timm_proj,\n", " proj_bias=vision_cfg.timm_proj_bias,\n", " drop=vision_cfg.timm_drop,\n", " drop_path=vision_cfg.timm_drop_path,\n", " patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None,\n", " embed_dim=embed_dim,\n", " image_size=vision_cfg.image_size,\n", " )\n", " elif isinstance(vision_cfg.layers, (tuple, list)):\n", " vision_heads = vision_cfg.width * 32 // vision_cfg.head_width\n", " visual = ModifiedResNet(\n", " layers=vision_cfg.layers,\n", " output_dim=embed_dim,\n", " heads=vision_heads,\n", " image_size=vision_cfg.image_size,\n", " width=vision_cfg.width,\n", " )\n", " else:\n", " vision_heads = vision_cfg.width // vision_cfg.head_width\n", " norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm\n", " if vision_cfg.norm_kwargs:\n", " norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs)\n", " if vision_cfg.act_kwargs is not None:\n", " act_layer = partial(act_layer, **vision_cfg.act_kwargs)\n", "\n", " visual = VisionTransformer(\n", " image_size=vision_cfg.image_size,\n", " patch_size=vision_cfg.patch_size,\n", " width=vision_cfg.width,\n", " layers=vision_cfg.layers,\n", " heads=vision_heads,\n", " mlp_ratio=vision_cfg.mlp_ratio,\n", " ls_init_value=vision_cfg.ls_init_value,\n", " patch_dropout=vision_cfg.patch_dropout,\n", " attentional_pool=vision_cfg.attentional_pool,\n", " attn_pooler_queries=vision_cfg.attn_pooler_queries,\n", " attn_pooler_heads=vision_cfg.attn_pooler_heads,\n", " pos_embed_type=vision_cfg.pos_embed_type,\n", " no_ln_pre=vision_cfg.no_ln_pre,\n", " final_ln_after_pool=vision_cfg.final_ln_after_pool,\n", " pool_type=vision_cfg.pool_type,\n", " output_tokens=vision_cfg.output_tokens,\n", " output_dim=embed_dim,\n", " act_layer=act_layer,\n", " norm_layer=norm_layer,\n", " )\n", "\n", " return visual\n", "\n", "\n", "def _build_text_tower(\n", " embed_dim: int,\n", " text_cfg: CLIPTextCfg,\n", " quick_gelu: bool = False,\n", " cast_dtype: Optional[torch.dtype] = None,\n", "):\n", " if isinstance(text_cfg, dict):\n", " text_cfg = CLIPTextCfg(**text_cfg)\n", "\n", " if text_cfg.hf_model_name:\n", " text = HFTextEncoder(\n", " text_cfg.hf_model_name,\n", " output_dim=embed_dim,\n", " proj_type=text_cfg.hf_proj_type,\n", " pooler_type=text_cfg.hf_pooler_type,\n", " pretrained=text_cfg.hf_model_pretrained,\n", " output_tokens=text_cfg.output_tokens,\n", " )\n", " else:\n", " act_layer = QuickGELU if quick_gelu else nn.GELU\n", " norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm\n", " if text_cfg.norm_kwargs:\n", " norm_layer = partial(norm_layer, **text_cfg.norm_kwargs)\n", " if text_cfg.act_kwargs is not None:\n", " act_layer = partial(act_layer, **text_cfg.act_kwargs)\n", "\n", " text = TextTransformer(\n", " context_length=text_cfg.context_length,\n", " vocab_size=text_cfg.vocab_size,\n", " width=text_cfg.width,\n", " heads=text_cfg.heads,\n", " layers=text_cfg.layers,\n", " mlp_ratio=text_cfg.mlp_ratio,\n", " ls_init_value=text_cfg.ls_init_value,\n", " output_dim=embed_dim,\n", " embed_cls=text_cfg.embed_cls,\n", " no_causal_mask=text_cfg.no_causal_mask,\n", " pad_id=text_cfg.pad_id,\n", " pool_type=text_cfg.pool_type,\n", " proj_type=text_cfg.proj_type,\n", " proj_bias=text_cfg.proj_bias,\n", " output_tokens=text_cfg.output_tokens,\n", " act_layer=act_layer,\n", " norm_layer=norm_layer,\n", " )\n", " return text\n", "\n", "\n", "class CLIP(nn.Module):\n", " output_dict: torch.jit.Final[bool]\n", "\n", " def __init__(\n", " self,\n", " embed_dim: int,\n", " vision_cfg: CLIPVisionCfg,\n", " text_cfg: CLIPTextCfg,\n", " quick_gelu: bool = False,\n", " init_logit_scale: float = np.log(1 / 0.07),\n", " init_logit_bias: Optional[float] = None,\n", " nonscalar_logit_scale: bool = False,\n", " cast_dtype: Optional[torch.dtype] = None,\n", " output_dict: bool = False,\n", " ):\n", " super().__init__()\n", " self.output_dict = output_dict\n", "\n", " self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)\n", "\n", " text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)\n", " self.transformer = text.transformer\n", " self.context_length = text.context_length\n", " self.vocab_size = text.vocab_size\n", " self.token_embedding = text.token_embedding\n", " self.positional_embedding = text.positional_embedding\n", " self.ln_final = text.ln_final\n", " self.text_projection = text.text_projection\n", " self.text_pool_type = text.pool_type\n", " self.register_buffer('attn_mask', text.attn_mask, persistent=False)\n", "\n", " lshape = [1] if nonscalar_logit_scale else []\n", " self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)\n", " if init_logit_bias is not None:\n", " self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)\n", " else:\n", " self.logit_bias = None\n", "\n", " def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):\n", " # lock image tower as per LiT - https://arxiv.org/abs/2111.07991\n", " self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)\n", "\n", " @torch.jit.ignore\n", " def set_grad_checkpointing(self, enable=True):\n", " self.visual.set_grad_checkpointing(enable)\n", " self.transformer.grad_checkpointing = enable\n", "\n", " @torch.jit.ignore\n", " def no_weight_decay(self):\n", " # for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default\n", " no_wd = {'positional_embedding'}\n", " if hasattr(self.visual, 'no_weight_decay'):\n", " for n in self.visual.no_weight_decay():\n", " no_wd.add('visual.' + n)\n", " return no_wd\n", "\n", " def encode_image(self, image, normalize: bool = False):\n", " features = self.visual(image)\n", " return F.normalize(features, dim=-1) if normalize else features\n", "\n", " def encode_text(self, text, normalize: bool = False):\n", " cast_dtype = self.transformer.get_cast_dtype()\n", "\n", " x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]\n", "\n", " x = x + self.positional_embedding.to(cast_dtype)\n", " x = self.transformer(x, attn_mask=self.attn_mask)\n", " x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]\n", " x = text_global_pool(x, text, self.text_pool_type)\n", " if self.text_projection is not None:\n", " if isinstance(self.text_projection, nn.Linear):\n", " x = self.text_projection(x)\n", " else:\n", " x = x @ self.text_projection\n", "\n", " return F.normalize(x, dim=-1) if normalize else x\n", "\n", " def get_logits(self, image, text):\n", " image_features = self.encode_image(image, normalize=True)\n", " text_features = self.encode_text(text, normalize=True)\n", " image_logits = self.logit_scale.exp() * image_features @ text_features.T\n", " if self.logit_bias is not None:\n", " image_logits += self.logit_bias\n", " text_logits = image_logits.T\n", " return image_logits, text_logits\n", "\n", " def forward_intermediates(\n", " self,\n", " image: Optional[torch.Tensor] = None,\n", " text: Optional[torch.Tensor] = None,\n", " image_indices: Optional[Union[int, List[int]]] = None,\n", " text_indices: Optional[Union[int, List[int]]] = None,\n", " stop_early: bool = False,\n", " normalize: bool = True,\n", " normalize_intermediates: bool = False,\n", " intermediates_only: bool = False,\n", " image_output_fmt: str = 'NCHW',\n", " image_output_extra_tokens: bool = False,\n", " text_output_fmt: str = 'NLC',\n", " text_output_extra_tokens: bool = False,\n", " output_logits: bool = False,\n", " output_logit_scale_bias: bool = False,\n", " ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:\n", " \"\"\" Forward features that returns intermediates.\n", "\n", " Args:\n", " image: Input image tensor\n", " text: Input text tensor\n", " image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence\n", " text_indices: Take last n blocks if int, all if None, select matching indices if sequence\n", " stop_early: Stop iterating over blocks when last desired intermediate hit\n", " normalize_intermediates: Apply final norm layer to all intermediates\n", " normalize: L2 Normalize final features\n", " intermediates_only: Only return intermediate features, do not return final features\n", " image_output_fmt: Shape of intermediate image feature outputs\n", " image_output_extra_tokens: Return both prefix and spatial intermediate tokens\n", " text_output_fmt: Shape of intermediate text feature outputs (ignored for this model)\n", " text_output_extra_tokens: Return both prefix and spatial intermediate tokens (ignored for this model)\n", " output_logits: Include logits in output\n", " output_logit_scale_bias: Include the logit scale bias in the output\n", " Returns:\n", "\n", " \"\"\"\n", " output = {}\n", " if intermediates_only:\n", " # intermediates only disables final feature normalization, and include logits\n", " normalize = False\n", " output_logits = False\n", " if output_logits:\n", " assert image is not None and text is not None, 'Both image and text inputs are required to compute logits'\n", "\n", " if image is not None:\n", " image_output = self.visual.forward_intermediates(\n", " image,\n", " indices=image_indices,\n", " stop_early=stop_early,\n", " normalize_intermediates=normalize_intermediates,\n", " intermediates_only=intermediates_only,\n", " output_fmt=image_output_fmt,\n", " output_extra_tokens=image_output_extra_tokens,\n", " )\n", " if normalize and \"image_features\" in image_output:\n", " image_output[\"image_features\"] = F.normalize(image_output[\"image_features\"], dim=-1)\n", " output.update(image_output)\n", "\n", " if text is not None:\n", " cast_dtype = self.transformer.get_cast_dtype()\n", " x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]\n", " x = x + self.positional_embedding.to(cast_dtype)\n", " x, intermediates = self.transformer.forward_intermediates(\n", " x,\n", " attn_mask=self.attn_mask,\n", " indices=text_indices\n", " )\n", " if normalize_intermediates:\n", " intermediates = [self.ln_final(xi) for xi in intermediates]\n", "\n", " # NOTE this model doesn't support cls embed in text transformer, no need for extra intermediate tokens\n", " output[\"text_intermediates\"] = intermediates\n", "\n", " if not intermediates_only:\n", " x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]\n", " x = text_global_pool(x, text, self.text_pool_type)\n", " if self.text_projection is not None:\n", " if isinstance(self.text_projection, nn.Linear):\n", " x = self.text_projection(x)\n", " else:\n", " x = x @ self.text_projection\n", " if normalize:\n", " x = F.normalize(x, dim=-1)\n", " output[\"text_features\"] = x\n", "\n", " logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None\n", "\n", " if output_logits:\n", " image_logits = logit_scale_exp * output[\"image_features\"] @ output[\"text_features\"].T\n", " if self.logit_bias is not None:\n", " image_logits += self.logit_bias\n", " text_logits = image_logits.T\n", " output[\"image_logits\"] = image_logits\n", " output[\"text_logits\"] = text_logits\n", "\n", " if output_logit_scale_bias:\n", " output[\"logit_scale\"] = logit_scale_exp\n", " if self.logit_bias is not None:\n", " output['logit_bias'] = self.logit_bias\n", "\n", " return output\n", "\n", " def forward(\n", " self,\n", " image: Optional[torch.Tensor] = None,\n", " text: Optional[torch.Tensor] = None,\n", " ):\n", " image_features = self.encode_image(image, normalize=True) if image is not None else None\n", " text_features = self.encode_text(text, normalize=True) if text is not None else None\n", "\n", " if self.output_dict:\n", " out_dict = {\n", " \"image_features\": image_features,\n", " \"text_features\": text_features,\n", " \"logit_scale\": self.logit_scale.exp()\n", " }\n", " if self.logit_bias is not None:\n", " out_dict['logit_bias'] = self.logit_bias\n", " return out_dict\n", "\n", " if self.logit_bias is not None:\n", " return image_features, text_features, self.logit_scale.exp(), self.logit_bias\n", " return image_features, text_features, self.logit_scale.exp()\n", "\n", "\n", "class CustomTextCLIP(nn.Module):\n", " output_dict: torch.jit.Final[bool]\n", "\n", " def __init__(\n", " self,\n", " embed_dim: int,\n", " vision_cfg: CLIPVisionCfg,\n", " text_cfg: CLIPTextCfg,\n", " quick_gelu: bool = False,\n", " init_logit_scale: float = np.log(1 / 0.07),\n", " init_logit_bias: Optional[float] = None,\n", " nonscalar_logit_scale: bool = False,\n", " cast_dtype: Optional[torch.dtype] = None,\n", " output_dict: bool = False,\n", " ):\n", " super().__init__()\n", " self.output_dict = output_dict\n", " self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)\n", " self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)\n", " self.context_length = self.text.context_length\n", " self.vocab_size = self.text.vocab_size\n", "\n", " lshape = [1] if nonscalar_logit_scale else []\n", " self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)\n", " if init_logit_bias is not None:\n", " self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)\n", " else:\n", " self.logit_bias = None\n", "\n", " def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):\n", " # lock image tower as per LiT - https://arxiv.org/abs/2111.07991\n", " self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)\n", "\n", " def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):\n", " self.text.lock(unlocked_layers, freeze_layer_norm)\n", "\n", " @torch.jit.ignore\n", " def set_grad_checkpointing(self, enable=True):\n", " self.visual.set_grad_checkpointing(enable)\n", " self.text.set_grad_checkpointing(enable)\n", "\n", " @torch.jit.ignore\n", " def no_weight_decay(self):\n", " # for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default\n", " no_wd = set()\n", " if hasattr(self.visual, 'no_weight_decay'):\n", " for n in self.visual.no_weight_decay():\n", " no_wd.add('visual.' + n)\n", " if hasattr(self.text, 'no_weight_decay'):\n", " for n in self.visual.no_weight_decay():\n", " no_wd.add('text.' + n)\n", " return no_wd\n", "\n", " def encode_image(self, image, normalize: bool = False):\n", " features = self.visual(image)\n", " return F.normalize(features, dim=-1) if normalize else features\n", "\n", " def encode_text(self, text, normalize: bool = False):\n", " features = self.text(text)\n", " return F.normalize(features, dim=-1) if normalize else features\n", "\n", " def get_logits(self, image, text):\n", " image_features = self.encode_image(image, normalize=True)\n", " text_features = self.encode_text(text, normalize=True)\n", " image_logits = self.logit_scale.exp() * image_features @ text_features.T\n", " if self.logit_bias is not None:\n", " image_logits += self.logit_bias\n", " text_logits = image_logits.T\n", " return image_logits, text_logits\n", "\n", " def forward_intermediates(\n", " self,\n", " image: Optional[torch.Tensor] = None,\n", " text: Optional[torch.Tensor] = None,\n", " image_indices: Optional[Union[int, List[int]]] = None,\n", " text_indices: Optional[Union[int, List[int]]] = None,\n", " stop_early: bool = False,\n", " normalize: bool = True,\n", " normalize_intermediates: bool = False,\n", " intermediates_only: bool = False,\n", " image_output_fmt: str = 'NCHW',\n", " image_output_extra_tokens: bool = False,\n", " text_output_fmt: str = 'NLC',\n", " text_output_extra_tokens: bool = False,\n", " output_logits: bool = False,\n", " output_logit_scale_bias: bool = False,\n", " ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:\n", " \"\"\" Forward features that returns intermediates.\n", "\n", " Args:\n", " image: Input image tensor\n", " text: Input text tensor\n", " image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence\n", " text_indices: Take last n blocks if int, all if None, select matching indices if sequence\n", " stop_early: Stop iterating over blocks when last desired intermediate hit\n", " normalize: L2 Normalize final image and text features (if present)\n", " normalize_intermediates: Apply final encoder norm layer to all intermediates (if possible)\n", " intermediates_only: Only return intermediate features, do not return final features\n", " image_output_fmt: Shape of intermediate image feature outputs\n", " image_output_extra_tokens: Return both prefix and spatial intermediate tokens\n", " text_output_fmt: Shape of intermediate text feature outputs\n", " text_output_extra_tokens: Return both prefix and spatial intermediate tokens\n", " output_logits: Include logits in output\n", " output_logit_scale_bias: Include the logit scale bias in the output\n", " Returns:\n", "\n", " \"\"\"\n", " output = {}\n", " if intermediates_only:\n", " # intermediates only disables final feature normalization, and include logits\n", " normalize = False\n", " output_logits = False\n", " if output_logits:\n", " assert image is not None and text is not None, 'Both image and text inputs are required to compute logits'\n", "\n", " if image is not None:\n", " image_output = self.visual.forward_intermediates(\n", " image,\n", " indices=image_indices,\n", " stop_early=stop_early,\n", " normalize_intermediates=normalize_intermediates,\n", " intermediates_only=intermediates_only,\n", " output_fmt=image_output_fmt,\n", " output_extra_tokens=image_output_extra_tokens,\n", " )\n", " if normalize and \"image_features\" in image_output:\n", " image_output[\"image_features\"] = F.normalize(image_output[\"image_features\"], dim=-1)\n", " output.update(image_output)\n", "\n", " if text is not None:\n", " text_output = self.text.forward_intermediates(\n", " text,\n", " indices=text_indices,\n", " stop_early=stop_early,\n", " normalize_intermediates=normalize_intermediates,\n", " intermediates_only=intermediates_only,\n", " output_fmt=text_output_fmt,\n", " output_extra_tokens=text_output_extra_tokens,\n", " )\n", " if normalize and \"text_features\" in text_output:\n", " text_output[\"text_features\"] = F.normalize(text_output[\"text_features\"], dim=-1)\n", " output.update(text_output)\n", "\n", " logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None\n", "\n", " if output_logits:\n", " image_logits = logit_scale_exp * output[\"image_features\"] @ output[\"text_features\"].T\n", " if self.logit_bias is not None:\n", " image_logits += self.logit_bias\n", " text_logits = image_logits.T\n", " output[\"image_logits\"] = image_logits\n", " output[\"text_logits\"] = text_logits\n", "\n", " if output_logit_scale_bias:\n", " output[\"logit_scale\"] = logit_scale_exp\n", " if self.logit_bias is not None:\n", " output['logit_bias'] = self.logit_bias\n", "\n", " return output\n", "\n", " def forward(\n", " self,\n", " image: Optional[torch.Tensor] = None,\n", " text: Optional[torch.Tensor] = None,\n", " ):\n", " image_features = self.encode_image(image, normalize=True) if image is not None else None\n", " text_features = self.encode_text(text, normalize=True) if text is not None else None\n", "\n", " if self.output_dict:\n", " out_dict = {\n", " \"image_features\": image_features,\n", " \"text_features\": text_features,\n", " \"logit_scale\": self.logit_scale.exp()\n", " }\n", " if self.logit_bias is not None:\n", " out_dict['logit_bias'] = self.logit_bias\n", " return out_dict\n", "\n", " if self.logit_bias is not None:\n", " return image_features, text_features, self.logit_scale.exp(), self.logit_bias\n", " return image_features, text_features, self.logit_scale.exp()\n", "\n", "\n", "def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):\n", " \"\"\"Convert applicable model parameters to low-precision (bf16 or fp16)\"\"\"\n", "\n", " def _convert_weights(l):\n", " if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):\n", " l.weight.data = l.weight.data.to(dtype)\n", " if l.bias is not None:\n", " l.bias.data = l.bias.data.to(dtype)\n", "\n", " if isinstance(l, (nn.MultiheadAttention, Attention)):\n", " for attr in [*[f\"{s}_proj_weight\" for s in [\"in\", \"q\", \"k\", \"v\"]], \"in_proj_bias\", \"bias_k\", \"bias_v\"]:\n", " tensor = getattr(l, attr)\n", " if tensor is not None:\n", " tensor.data = tensor.data.to(dtype)\n", "\n", " if isinstance(l, (CLIP, TextTransformer)):\n", " # convert text nn.Parameter projections\n", " attr = getattr(l, \"text_projection\", None)\n", " if attr is not None:\n", " attr.data = attr.data.to(dtype)\n", "\n", " if isinstance(l, VisionTransformer):\n", " # convert vision nn.Parameter projections\n", " attr = getattr(l, \"proj\", None)\n", " if attr is not None:\n", " attr.data = attr.data.to(dtype)\n", "\n", " model.apply(_convert_weights)\n", "\n", "\n", "convert_weights_to_fp16 = convert_weights_to_lp # backwards compat\n", "\n", "\n", "# used to maintain checkpoint compatibility\n", "def convert_to_custom_text_state_dict(state_dict: dict):\n", " if 'text_projection' in state_dict:\n", " # old format state_dict, move text tower -> .text\n", " new_state_dict = {}\n", " for k, v in state_dict.items():\n", " if any(k.startswith(p) for p in (\n", " 'text_projection',\n", " 'positional_embedding',\n", " 'token_embedding',\n", " 'transformer',\n", " 'ln_final',\n", " )):\n", " k = 'text.' + k\n", " new_state_dict[k] = v\n", " return new_state_dict\n", " return state_dict\n", "\n", "\n", "def build_model_from_openai_state_dict(\n", " state_dict: dict,\n", " quick_gelu=True,\n", " cast_dtype=torch.float16,\n", "):\n", " vit = \"visual.proj\" in state_dict\n", "\n", " if vit:\n", " vision_width = state_dict[\"visual.conv1.weight\"].shape[0]\n", " vision_layers = len(\n", " [k for k in state_dict.keys() if k.startswith(\"visual.\") and k.endswith(\".attn.in_proj_weight\")])\n", " vision_patch_size = state_dict[\"visual.conv1.weight\"].shape[-1]\n", " grid_size = round((state_dict[\"visual.positional_embedding\"].shape[0] - 1) ** 0.5)\n", " image_size = vision_patch_size * grid_size\n", " else:\n", " counts: list = [\n", " len(set(k.split(\".\")[2] for k in state_dict if k.startswith(f\"visual.layer{b}\"))) for b in [1, 2, 3, 4]]\n", " vision_layers = tuple(counts)\n", " vision_width = state_dict[\"visual.layer1.0.conv1.weight\"].shape[0]\n", " output_width = round((state_dict[\"visual.attnpool.positional_embedding\"].shape[0] - 1) ** 0.5)\n", " vision_patch_size = None\n", " assert output_width ** 2 + 1 == state_dict[\"visual.attnpool.positional_embedding\"].shape[0]\n", " image_size = output_width * 32\n", "\n", " embed_dim = state_dict[\"text_projection\"].shape[1]\n", " context_length = state_dict[\"positional_embedding\"].shape[0]\n", " vocab_size = state_dict[\"token_embedding.weight\"].shape[0]\n", " transformer_width = state_dict[\"ln_final.weight\"].shape[0]\n", " transformer_heads = transformer_width // 64\n", " transformer_layers = len(set(k.split(\".\")[2] for k in state_dict if k.startswith(f\"transformer.resblocks\")))\n", "\n", " vision_cfg = CLIPVisionCfg(\n", " layers=vision_layers,\n", " width=vision_width,\n", " patch_size=vision_patch_size,\n", " image_size=image_size,\n", " )\n", " text_cfg = CLIPTextCfg(\n", " context_length=context_length,\n", " vocab_size=vocab_size,\n", " width=transformer_width,\n", " heads=transformer_heads,\n", " layers=transformer_layers,\n", " )\n", " model = CLIP(\n", " embed_dim,\n", " vision_cfg=vision_cfg,\n", " text_cfg=text_cfg,\n", " quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU\n", " cast_dtype=cast_dtype,\n", " )\n", "\n", " for key in [\"input_resolution\", \"context_length\", \"vocab_size\"]:\n", " state_dict.pop(key, None)\n", " convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16\n", " model.load_state_dict(state_dict)\n", " return model.eval()\n", "\n", "\n", "def trace_model(model, batch_size=256, device=torch.device('cpu')):\n", " model.eval()\n", " image_size = model.visual.image_size\n", " example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)\n", " example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)\n", " model = torch.jit.trace_module(\n", " model,\n", " inputs=dict(\n", " forward=(example_images, example_text),\n", " encode_text=(example_text,),\n", " encode_image=(example_images,)\n", " ))\n", " model.visual.image_size = image_size\n", " return model\n", "\n", "\n", "def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):\n", " # Rescale the grid of position embeddings when loading from state_dict\n", " old_pos_embed = state_dict.get('visual.positional_embedding', None)\n", " if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):\n", " return\n", " grid_size = to_2tuple(model.visual.grid_size)\n", " extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)\n", " new_seq_len = grid_size[0] * grid_size[1] + extra_tokens\n", " if new_seq_len == old_pos_embed.shape[0]:\n", " return\n", "\n", " if extra_tokens:\n", " pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]\n", " else:\n", " pos_emb_tok, pos_emb_img = None, old_pos_embed\n", " old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))\n", "\n", " logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)\n", " pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)\n", " pos_emb_img = F.interpolate(\n", " pos_emb_img,\n", " size=grid_size,\n", " mode=interpolation,\n", " antialias=antialias,\n", " align_corners=False,\n", " )\n", " pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]\n", " if pos_emb_tok is not None:\n", " new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)\n", " else:\n", " new_pos_embed = pos_emb_img\n", " state_dict['visual.positional_embedding'] = new_pos_embed\n", "\n", "\n", "def resize_text_pos_embed(state_dict, model, interpolation: str = 'linear', antialias: bool = False):\n", " old_pos_embed = state_dict.get('positional_embedding', None)\n", " if old_pos_embed is None:\n", " return\n", " # FIXME add support for text cls_token\n", " model_pos_embed = getattr(model, 'positional_embedding', None)\n", " if model_pos_embed is None:\n", " model_pos_embed = getattr(model.text, 'positional_embedding', None)\n", "\n", " old_num_pos = old_pos_embed.shape[0]\n", " old_width = old_pos_embed.shape[1]\n", " num_pos = model_pos_embed.shape[0]\n", " width = model_pos_embed.shape[1]\n", " assert old_width == width, 'text pos_embed width changed!'\n", " if old_num_pos == num_pos:\n", " return\n", "\n", " logging.info('Resizing text position embedding num_pos from %s to %s', old_num_pos, num_pos)\n", " old_pos_embed = old_pos_embed.reshape(1, old_num_pos, old_width).permute(0, 2, 1)\n", " old_pos_embed = F.interpolate(\n", " old_pos_embed,\n", " size=num_pos,\n", " mode=interpolation,\n", " antialias=antialias,\n", " align_corners=False,\n", " )\n", " old_pos_embed = old_pos_embed.permute(0, 2, 1)[0]\n", " new_pos_embed = old_pos_embed\n", "\n", " state_dict['positional_embedding'] = new_pos_embed\n", "\n", "\n", "def get_model_preprocess_cfg(model):\n", " module = getattr(model, 'visual', model)\n", " preprocess_cfg = getattr(module, 'preprocess_cfg', {})\n", " if not preprocess_cfg:\n", " # use separate legacy attributes if preprocess_cfg dict not found\n", " size = getattr(module, 'image_size')\n", " if size is not None:\n", " preprocess_cfg['size'] = size\n", " mean = getattr(module, 'image_mean', None)\n", " if mean is not None:\n", " preprocess_cfg['mean'] = mean\n", " std = getattr(module, 'image_std', None)\n", " if std is not None:\n", " preprocess_cfg['std'] = std\n", " return preprocess_cfg\n", "\n", "\n", "def set_model_preprocess_cfg(model, preprocess_cfg: Dict[str, Any]):\n", " module = getattr(model, 'visual', model)\n", " module.image_mean = preprocess_cfg['mean'] # legacy attribute, keeping for bwd compat\n", " module.image_std = preprocess_cfg['std'] # legacy attribute, keeping for bwd compat\n", " module.preprocess_cfg = copy.deepcopy(preprocess_cfg) # new attr, package all pp cfg as dict\n", "\n", "\n", "def get_model_tokenize_cfg(model):\n", " module = getattr(model, 'text', model)\n", " cfg = {}\n", " context_length = getattr(module, 'context_length', None)\n", " if context_length is not None:\n", " cfg['context_length'] = context_length\n", " vocab_size = getattr(module, 'vocab_size', None)\n", " if vocab_size is not None:\n", " cfg['vocab_size'] = vocab_size\n", " return cfg\n", "\n", "\n", "\n", "try:\n", " from huggingface_hub import hf_hub_download\n", " hf_hub_download = partial(hf_hub_download, library_name=\"open_clip\", library_version=__version__)\n", " _has_hf_hub = True\n", "except ImportError:\n", " hf_hub_download = None\n", " _has_hf_hub = False\n", "\n", "\n", "def _pcfg(url='', hf_hub='', **kwargs):\n", " # OpenAI / OpenCLIP defaults\n", " return {\n", " 'url': url,\n", " 'hf_hub': hf_hub,\n", " 'mean': OPENAI_DATASET_MEAN,\n", " 'std': OPENAI_DATASET_STD,\n", " 'interpolation': 'bicubic',\n", " 'resize_mode': 'shortest',\n", " **kwargs,\n", " }\n", "\n", "\n", "def _slpcfg(url='', hf_hub='', **kwargs):\n", " # SiGLIP defaults\n", " return {\n", " 'url': url,\n", " 'hf_hub': hf_hub,\n", " 'mean': INCEPTION_MEAN,\n", " 'std': INCEPTION_STD,\n", " 'interpolation': 'bicubic',\n", " 'resize_mode': 'squash',\n", " **kwargs,\n", " }\n", "\n", "\n", "def _apcfg(url='', hf_hub='', **kwargs):\n", " # CLIPA defaults\n", " return {\n", " 'url': url,\n", " 'hf_hub': hf_hub,\n", " 'mean': IMAGENET_MEAN,\n", " 'std': IMAGENET_STD,\n", " 'interpolation': 'bilinear',\n", " 'resize_mode': 'squash',\n", " **kwargs,\n", " }\n", "\n", "\n", "def _mccfg(url='', hf_hub='', **kwargs):\n", " # MobileCLIP\n", " return {\n", " 'url': url,\n", " 'hf_hub': hf_hub,\n", " 'mean': (0., 0., 0.),\n", " 'std': (1., 1., 1.),\n", " 'interpolation': 'bilinear',\n", " 'resize_mode': 'shortest',\n", " **kwargs,\n", " }\n", "\n", "\n", "\n", "_RN50 = dict(\n", " openai=_pcfg(\n", " url=\"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt\",\n", " hf_hub=\"timm/resnet50_clip.openai/\",\n", " quick_gelu=True,\n", " ),\n", " yfcc15m=_pcfg(\n", " url=\"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt\",\n", " hf_hub=\"timm/resnet50_clip.yfcc15m/\",\n", " quick_gelu=True,\n", " ),\n", " cc12m=_pcfg(\n", " url=\"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt\",\n", " hf_hub=\"timm/resnet50_clip.cc12m/\",\n", " quick_gelu=True,\n", " ),\n", ")\n", "\n", "_RN101 = dict(\n", " openai=_pcfg(\n", " url=\"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt\",\n", " hf_hub=\"timm/resnet101_clip.openai/\",\n", " quick_gelu=True,\n", " ),\n", " yfcc15m=_pcfg(\n", " url=\"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt\",\n", " hf_hub=\"timm/resnet101_clip.yfcc15m/\",\n", " quick_gelu=True,\n", " ),\n", ")\n", "\n", "_RN50x4 = dict(\n", " openai=_pcfg(\n", " url=\"https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt\",\n", " hf_hub=\"timm/resnet50x4_clip.openai/\",\n", " quick_gelu=True,\n", " ),\n", ")\n", "\n", "_RN50x16 = dict(\n", " openai=_pcfg(\n", " url=\"https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt\",\n", " hf_hub=\"timm/resnet50x16_clip.openai/\",\n", " quick_gelu=True,\n", " ),\n", ")\n", "\n", "_RN50x64 = dict(\n", " openai=_pcfg(\n", " url=\"https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt\",\n", " hf_hub=\"timm/resnet50x64_clip.openai/\",\n", " quick_gelu=True,\n", " ),\n", ")\n", "\n", "_VITB32 = dict(\n", " openai=_pcfg(\n", " url=\"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt\",\n", " hf_hub=\"timm/vit_base_patch32_clip_224.openai/\",\n", " quick_gelu=True,\n", " ),\n", " # LAION 400M (quick gelu)\n", " laion400m_e31=_pcfg(\n", " url=\"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt\",\n", " hf_hub=\"timm/vit_base_patch32_clip_224.laion400m_e31/\",\n", " quick_gelu=True,\n", " ),\n", " laion400m_e32=_pcfg(\n", " url=\"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt\",\n", " hf_hub=\"timm/vit_base_patch32_clip_224.laion400m_e32/\",\n", " quick_gelu=True,\n", " ),\n", " # LAION 2B-en\n", " laion2b_e16=_pcfg(\n", " url=\"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth\",\n", " hf_hub=\"timm/vit_base_patch32_clip_224.laion2b_e16/\",\n", " ),\n", " laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/'),\n", " # DataComp-XL models\n", " datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K/'),\n", " # DataComp-M models\n", " datacomp_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.M-s128M-b4K/'),\n", " commonpool_m_clip_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.clip-s128M-b4K/'),\n", " commonpool_m_laion_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.laion-s128M-b4K/'),\n", " commonpool_m_image_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.image-s128M-b4K/'),\n", " commonpool_m_text_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.text-s128M-b4K/'),\n", " commonpool_m_basic_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.basic-s128M-b4K/'),\n", " commonpool_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M-s128M-b4K/'),\n", " # DataComp-S models\n", " datacomp_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.S-s13M-b4K/'),\n", " commonpool_s_clip_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.clip-s13M-b4K/'),\n", " commonpool_s_laion_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.laion-s13M-b4K/'),\n", " commonpool_s_image_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.image-s13M-b4K/'),\n", " commonpool_s_text_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.text-s13M-b4K/'),\n", " commonpool_s_basic_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.basic-s13M-b4K/'),\n", " commonpool_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S-s13M-b4K/'),\n", " # MetaClip models (NOTE quick-gelu activation used)\n", " metaclip_400m=_pcfg(\n", " url=\"https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_400m.pt\",\n", " hf_hub=\"timm/vit_base_patch32_clip_224.metaclip_400m/\",\n", " quick_gelu=True,\n", " ),\n", " metaclip_fullcc=_pcfg(\n", " url=\"https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_fullcc2.5b.pt\",\n", " hf_hub=\"timm/vit_base_patch32_clip_224.metaclip_2pt5b/\",\n", " quick_gelu=True,\n", " ),\n", ")\n", "\n", "_VITB32_256 = dict(\n", " datacomp_s34b_b86k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-256x256-DataComp-s34B-b86K/'),\n", ")\n", "\n", "_VITB16 = dict(\n", " openai=_pcfg(\n", " url=\"https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt\",\n", " hf_hub=\"timm/vit_base_patch16_clip_224.openai/\",\n", " quick_gelu=True,\n", " ),\n", " # LAION-400M\n", " laion400m_e31=_pcfg(\n", " url=\"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt\",\n", " hf_hub=\"timm/vit_base_patch16_clip_224.laion400m_e31/\",\n", " ),\n", " laion400m_e32=_pcfg(\n", " url=\"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt\",\n", " hf_hub=\"timm/vit_base_patch16_clip_224.laion400m_e32/\",\n", " ),\n", " # LAION-2B\n", " laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'),\n", " # DataComp-XL models\n", " datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.XL-s13B-b90K/'),\n", " # DataComp-L models\n", " datacomp_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.L-s1B-b8K/'),\n", " commonpool_l_clip_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.clip-s1B-b8K/'),\n", " commonpool_l_laion_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.laion-s1B-b8K/'),\n", " commonpool_l_image_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.image-s1B-b8K/'),\n", " commonpool_l_text_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.text-s1B-b8K/'),\n", " commonpool_l_basic_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.basic-s1B-b8K/'),\n", " commonpool_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L-s1B-b8K/'),\n", " # DFN\n", " dfn2b=_pcfg(\n", " hf_hub='apple/DFN2B-CLIP-ViT-B-16/',\n", " quick_gelu=True,\n", " ),\n", " # MetaCLIP (these are quick-gelu)\n", " metaclip_400m=_pcfg(\n", " url=\"https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_400m.pt\",\n", " hf_hub=\"timm/vit_base_patch16_clip_224.metaclip_400m/\",\n", " quick_gelu=True,\n", " ),\n", " metaclip_fullcc=_pcfg(\n", " url=\"https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_fullcc2.5b.pt\",\n", " hf_hub=\"timm/vit_base_patch16_clip_224.metaclip_2pt5b/\",\n", " quick_gelu=True,\n", " ),\n", ")\n", "\n", "_VITB16_PLUS_240 = dict(\n", " laion400m_e31=_pcfg(\n", " url=\"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt\",\n", " hf_hub=\"timm/vit_base_patch16_plus_clip_240.laion400m_e31/\",\n", " ),\n", " laion400m_e32=_pcfg(\n", " url=\"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt\",\n", " hf_hub=\"timm/vit_base_patch16_plus_clip_240.laion400m_e31/\",\n", " ),\n", ")\n", "\n", "_VITL14 = dict(\n", " openai=_pcfg(\n", " url=\"https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt\",\n", " hf_hub=\"timm/vit_large_patch14_clip_224.openai/\",\n", " quick_gelu=True,\n", " ),\n", " # LAION-400M\n", " laion400m_e31=_pcfg(\n", " url=\"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt\",\n", " hf_hub=\"timm/vit_large_patch14_clip_224.laion400m_e31/\",\n", " ),\n", " laion400m_e32=_pcfg(\n", " url=\"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt\",\n", " hf_hub=\"timm/vit_large_patch14_clip_224.laion400m_e32/\",\n", " ),\n", " # LAION-2B-en\n", " laion2b_s32b_b82k=_pcfg(\n", " hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/',\n", " mean=INCEPTION_MEAN, std=INCEPTION_STD),\n", " # DataComp-XL models\n", " datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K/'),\n", " commonpool_xl_clip_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.clip-s13B-b90K/'),\n", " commonpool_xl_laion_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.laion-s13B-b90K/'),\n", " commonpool_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL-s13B-b90K/'),\n", " # MetaCLIP\n", " metaclip_400m=_pcfg(\n", " url=\"https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_400m.pt\",\n", " hf_hub=\"timm/vit_large_patch14_clip_224.metaclip_400m/\",\n", " quick_gelu=True,\n", " ),\n", " metaclip_fullcc=_pcfg(\n", " url=\"https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_fullcc2.5b.pt\",\n", " hf_hub=\"timm/vit_large_patch14_clip_224.metaclip_2pt5b/\",\n", " quick_gelu=True,\n", " ),\n", " # DFN-2B (quick-gelu)\n", " dfn2b=_pcfg(\n", " hf_hub='apple/DFN2B-CLIP-ViT-L-14/',\n", " quick_gelu=True,\n", " ),\n", " # DFN-2B 39B SS\n", " dfn2b_s39b=_pcfg(\n", " hf_hub='apple/DFN2B-CLIP-ViT-L-14-39B/',\n", " ),\n", ")\n", "\n", "_VITL14_336 = dict(\n", " openai=_pcfg(\n", " url=\"https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt\",\n", " hf_hub=\"timm/vit_large_patch14_clip_336.openai/\",\n", " quick_gelu=True,\n", " ),\n", ")\n", "\n", "_VITH14 = dict(\n", " # LAION-2B-en\n", " laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'),\n", " # MetaCLIP (quick-gelu)\n", " metaclip_fullcc=_pcfg(\n", " url=\"https://dl.fbaipublicfiles.com/MMPT/metaclip/h14_fullcc2.5b.pt\",\n", " hf_hub=\"timm/vit_huge_patch14_clip_224.metaclip_2pt5b/\",\n", " quick_gelu=True,\n", " ),\n", " metaclip_altogether=_pcfg(\n", " url=\"https://dl.fbaipublicfiles.com/MMPT/metaclip/h14_v1.2_altogether.pt\",\n", " hf_hub=\"timm/vit_huge_patch14_clip_224.metaclip_altogether/\",\n", " # NOTE unlike other MetaCLIP models, this is not using QuickGELU, yay!\n", " ),\n", " # DFN-5B (quick-gelu)\n", " dfn5b=_pcfg(\n", " hf_hub='apple/DFN5B-CLIP-ViT-H-14/',\n", " quick_gelu=True,\n", " interpolation=\"bicubic\",\n", " resize_mode=\"squash\"\n", " ),\n", ")\n", "\n", "_VITH14_378 = dict(\n", " # DFN-5B (quick-gelu)\n", " dfn5b=_pcfg(\n", " hf_hub='apple/DFN5B-CLIP-ViT-H-14-378/',\n", " quick_gelu=True,\n", " interpolation=\"bicubic\",\n", " resize_mode=\"squash\"\n", " ),\n", ")\n", "\n", "_VITg14 = dict(\n", " laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'),\n", " laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'),\n", ")\n", "\n", "_VITbigG14 = dict(\n", " # LAION-2B-en\n", " laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'),\n", " # MetaCLIP (quick-gelu)\n", " metaclip_fullcc=_pcfg(\n", " url='https://dl.fbaipublicfiles.com/MMPT/metaclip/G14_fullcc2.5b.pt',\n", " hf_hub=\"timm/vit_gigantic_patch14_clip_224.metaclip_2pt5b/\",\n", " quick_gelu=True,\n", " ),\n", ")\n", "\n", "_robertaViTB32 = dict(\n", " laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'),\n", ")\n", "\n", "_xlmRobertaBaseViTB32 = dict(\n", " laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'),\n", ")\n", "\n", "_xlmRobertaLargeFrozenViTH14 = dict(\n", " frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'),\n", ")\n", "\n", "_convnext_base = dict(\n", " laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'),\n", ")\n", "\n", "_convnext_base_w = dict(\n", " laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'),\n", " laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'),\n", " laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'),\n", ")\n", "\n", "_convnext_base_w_320 = dict(\n", " laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'),\n", " laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'),\n", ")\n", "\n", "_convnext_large_d = dict(\n", " laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'),\n", ")\n", "\n", "_convnext_large_d_320 = dict(\n", " laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'),\n", " laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'),\n", ")\n", "\n", "_convnext_xxlarge = dict(\n", " laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'),\n", " laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'),\n", " laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'),\n", ")\n", "\n", "_coca_VITB32 = dict(\n", " laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'),\n", " mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/')\n", ")\n", "\n", "_coca_VITL14 = dict(\n", " laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'),\n", " mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/')\n", ")\n", "\n", "\n", "_PRETRAINED = {\n", " \"RN50\": _RN50,\n", " \"RN101\": _RN101,\n", " \"RN50x4\": _RN50x4,\n", " \"RN50x16\": _RN50x16,\n", " \"RN50x64\": _RN50x64,\n", "\n", " \"ViT-B-32\": _VITB32,\n", " \"ViT-B-32-256\": _VITB32_256,\n", " \"ViT-B-16\": _VITB16,\n", " \"ViT-B-16-plus-240\": _VITB16_PLUS_240,\n", " \"ViT-L-14\": _VITL14,\n", " \"ViT-L-14-336\": _VITL14_336,\n", " \"ViT-H-14\": _VITH14,\n", " \"ViT-H-14-378\": _VITH14_378,\n", " \"ViT-g-14\": _VITg14,\n", " \"ViT-bigG-14\": _VITbigG14,\n", "\n", " \"roberta-ViT-B-32\": _robertaViTB32,\n", " \"xlm-roberta-base-ViT-B-32\": _xlmRobertaBaseViTB32,\n", " \"xlm-roberta-large-ViT-H-14\": _xlmRobertaLargeFrozenViTH14,\n", "\n", " \"convnext_base\": _convnext_base,\n", " \"convnext_base_w\": _convnext_base_w,\n", " \"convnext_base_w_320\": _convnext_base_w_320,\n", " \"convnext_large_d\": _convnext_large_d,\n", " \"convnext_large_d_320\": _convnext_large_d_320,\n", " \"convnext_xxlarge\": _convnext_xxlarge,\n", "\n", " \"coca_ViT-B-32\": _coca_VITB32,\n", " \"coca_ViT-L-14\": _coca_VITL14,\n", "\n", " \"EVA01-g-14\": dict(\n", " # from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt\n", " laion400m_s11b_b41k=_pcfg(hf_hub='timm/eva_giant_patch14_clip_224.laion400m_s11b_b41k/'),\n", " ),\n", " \"EVA01-g-14-plus\": dict(\n", " # from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt\n", " merged2b_s11b_b114k=_pcfg(hf_hub='timm/eva_giant_patch14_plus_clip_224.merged2b_s11b_b114k/'),\n", " ),\n", " \"EVA02-B-16\": dict(\n", " # from QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt\n", " merged2b_s8b_b131k=_pcfg(hf_hub='timm/eva02_base_patch16_clip_224.merged2b_s8b_b131k/'),\n", " ),\n", " \"EVA02-L-14\": dict(\n", " # from QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt\n", " merged2b_s4b_b131k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_224.merged2b_s4b_b131k/'),\n", " ),\n", " \"EVA02-L-14-336\": dict(\n", " # from QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt\n", " merged2b_s6b_b61k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_336.merged2b_s6b_b61k/'),\n", " ),\n", " \"EVA02-E-14\": dict(\n", " # from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt\n", " laion2b_s4b_b115k=_pcfg(hf_hub='timm/eva02_enormous_patch14_clip_224.laion2b_s4b_b115k/'),\n", " ),\n", " \"EVA02-E-14-plus\": dict(\n", " # from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt\n", " laion2b_s9b_b144k=_pcfg(hf_hub='timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k/'),\n", " ),\n", "\n", " \"ViT-B-16-SigLIP\": dict(\n", " webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP/'),\n", " ),\n", " \"ViT-B-16-SigLIP-256\": dict(\n", " webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-256/'),\n", " ),\n", " \"ViT-B-16-SigLIP-i18n-256\": dict(\n", " webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-i18n-256/'),\n", " ),\n", " \"ViT-B-16-SigLIP-384\": dict(\n", " webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-384/'),\n", " ),\n", " \"ViT-B-16-SigLIP-512\": dict(\n", " webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-512/'),\n", " ),\n", " \"ViT-L-16-SigLIP-256\": dict(\n", " webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP-256/'),\n", " ),\n", " \"ViT-L-16-SigLIP-384\": dict(\n", " webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP-384/'),\n", " ),\n", " \"ViT-SO400M-14-SigLIP\": dict(\n", " webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP/'),\n", " ),\n", " \"ViT-SO400M-16-SigLIP-i18n-256\": dict(\n", " webli=_slpcfg(hf_hub='timm/ViT-SO400M-16-SigLIP-i18n-256/'),\n", " ),\n", " \"ViT-SO400M-14-SigLIP-378\": dict(\n", " webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP-384/'), # NOTE using 384 weights, but diff img_size used\n", " ),\n", " \"ViT-SO400M-14-SigLIP-384\": dict(\n", " webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP-384/'),\n", " ),\n", "\n", " \"ViT-B-32-SigLIP2-256\": dict(\n", " webli=_slpcfg(hf_hub='timm/ViT-B-32-SigLIP2-256/'),\n", " ),\n", " \"ViT-B-16-SigLIP2\": dict(\n", " webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP2/'),\n", " ),\n", " \"ViT-B-16-SigLIP2-256\": dict(\n", " webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP2-256/'),\n", " ),\n", " \"ViT-B-16-SigLIP2-384\": dict(\n", " webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP2-384/'),\n", " ),\n", " \"ViT-B-16-SigLIP2-512\": dict(\n", " webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP2-512/'),\n", " ),\n", " \"ViT-L-16-SigLIP2-256\": dict(\n", " webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP2-256/'),\n", " ),\n", " \"ViT-L-16-SigLIP2-384\": dict(\n", " webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP2-384/'),\n", " ),\n", " \"ViT-L-16-SigLIP2-512\": dict(\n", " webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP2-512/'),\n", " ),\n", " \"ViT-SO400M-14-SigLIP2\": dict(\n", " webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP2/'),\n", " ),\n", " \"ViT-SO400M-14-SigLIP2-378\": dict(\n", " webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP2-378/'),\n", " ),\n", " \"ViT-SO400M-16-SigLIP2-256\": dict(\n", " webli=_slpcfg(hf_hub='timm/ViT-SO400M-16-SigLIP2-256/'),\n", " ),\n", " \"ViT-SO400M-16-SigLIP2-384\": dict(\n", " webli=_slpcfg(hf_hub='timm/ViT-SO400M-16-SigLIP2-384/'),\n", " ),\n", " \"ViT-SO400M-16-SigLIP2-512\": dict(\n", " webli=_slpcfg(hf_hub='timm/ViT-SO400M-16-SigLIP2-512/'),\n", " ),\n", " \"ViT-gopt-16-SigLIP2-256\": dict(\n", " webli=_slpcfg(hf_hub='timm/ViT-gopt-16-SigLIP2-256/'),\n", " ),\n", " \"ViT-gopt-16-SigLIP2-384\": dict(\n", " webli=_slpcfg(hf_hub='timm/ViT-gopt-16-SigLIP2-384/'),\n", " ),\n", "\n", " \"ViT-L-14-CLIPA\": dict(\n", " datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-L-14-CLIPA-datacomp1B/'),\n", " ),\n", " \"ViT-L-14-CLIPA-336\": dict(\n", " datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-L-14-CLIPA-336-datacomp1B/'),\n", " ),\n", " \"ViT-H-14-CLIPA\": dict(\n", " datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-datacomp1B/'),\n", " ),\n", " \"ViT-H-14-CLIPA-336\": dict(\n", " laion2b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-336-laion2B/'),\n", " datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-336-datacomp1B/'),\n", " ),\n", " \"ViT-bigG-14-CLIPA\": dict(\n", " datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-bigG-14-CLIPA-datacomp1B/'),\n", " ),\n", " \"ViT-bigG-14-CLIPA-336\": dict(\n", " datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-bigG-14-CLIPA-336-datacomp1B/'),\n", " ),\n", "\n", " \"nllb-clip-base\": dict(\n", " v1=_pcfg(hf_hub='visheratin/nllb-clip-base-oc/'),\n", " ),\n", " \"nllb-clip-large\": dict(\n", " v1=_pcfg(hf_hub='visheratin/nllb-clip-large-oc/'),\n", " ),\n", "\n", " \"nllb-clip-base-siglip\": dict(\n", " v1=_slpcfg(hf_hub='visheratin/nllb-clip-base-siglip/'),\n", " mrl=_slpcfg(hf_hub='visheratin/nllb-siglip-mrl-base/'),\n", " ),\n", " \"nllb-clip-large-siglip\": dict(\n", " v1=_slpcfg(hf_hub='visheratin/nllb-clip-large-siglip/'),\n", " mrl=_slpcfg(hf_hub='visheratin/nllb-siglip-mrl-large/'),\n", " ),\n", "\n", " \"MobileCLIP-S1\": dict(\n", " datacompdr=_mccfg(hf_hub='apple/MobileCLIP-S1-OpenCLIP/')),\n", " \"MobileCLIP-S2\": dict(\n", " datacompdr=_mccfg(hf_hub='apple/MobileCLIP-S2-OpenCLIP/')),\n", " \"MobileCLIP-B\": dict(\n", " datacompdr=_mccfg(hf_hub='apple/MobileCLIP-B-OpenCLIP/'),\n", " datacompdr_lt=_mccfg(hf_hub='apple/MobileCLIP-B-LT-OpenCLIP/'),\n", " ),\n", "\n", " \"ViTamin-S\": dict(\n", " datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-S/pytorch_model.bin'),\n", " ),\n", " \"ViTamin-S-LTT\": dict(\n", " datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-S-LTT/pytorch_model.bin'),\n", " ),\n", " \"ViTamin-B\": dict(\n", " datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-B/pytorch_model.bin'),\n", " ),\n", " \"ViTamin-B-LTT\": dict(\n", " datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-B-LTT/pytorch_model.bin'),\n", " ),\n", " \"ViTamin-L\": dict(\n", " datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-224px/pytorch_model.bin'),\n", " ),\n", " \"ViTamin-L-256\": dict(\n", " datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-256px/pytorch_model.bin'),\n", " ),\n", " \"ViTamin-L-336\": dict(\n", " datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-336px/pytorch_model.bin'),\n", " ),\n", " \"ViTamin-L-384\": dict(\n", " datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-384px/pytorch_model.bin'),\n", " ),\n", " \"ViTamin-L2\": dict(\n", " datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-224px/pytorch_model.bin'),\n", " ),\n", " \"ViTamin-L2-256\": dict(\n", " datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-256px/pytorch_model.bin'),\n", " ),\n", " \"ViTamin-L2-336\": dict(\n", " datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-336px/pytorch_model.bin'),\n", " ),\n", " \"ViTamin-L2-384\": dict(\n", " datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-384px/pytorch_model.bin'),\n", " ),\n", " \"ViTamin-XL-256\": dict(\n", " datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-256px/pytorch_model.bin'),\n", " ),\n", " \"ViTamin-XL-336\": dict(\n", " datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-336px/pytorch_model.bin'),\n", " ),\n", " \"ViTamin-XL-384\": dict(\n", " datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-384px/pytorch_model.bin'),\n", " ),\n", "}\n", "\n", "_PRETRAINED_quickgelu = {}\n", "for k, v in _PRETRAINED.items():\n", " quick_gelu_tags = {}\n", " for tk, tv in v.items():\n", " if tv.get('quick_gelu', False):\n", " quick_gelu_tags[tk] = copy.deepcopy(tv)\n", " if quick_gelu_tags:\n", " _PRETRAINED_quickgelu[k + '-quickgelu'] = quick_gelu_tags\n", "_PRETRAINED.update(_PRETRAINED_quickgelu)\n", "\n", "def _clean_tag(tag: str):\n", " # normalize pretrained tags\n", " return tag.lower().replace('-', '_')\n", "\n", "\n", "def list_pretrained(as_str: bool = False):\n", " \"\"\" returns list of pretrained models\n", " Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True\n", " \"\"\"\n", " return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]\n", "\n", "\n", "def list_pretrained_models_by_tag(tag: str):\n", " \"\"\" return all models having the specified pretrain tag \"\"\"\n", " models = []\n", " tag = _clean_tag(tag)\n", " for k in _PRETRAINED.keys():\n", " if tag in _PRETRAINED[k]:\n", " models.append(k)\n", " return models\n", "\n", "\n", "def list_pretrained_tags_by_model(model: str):\n", " \"\"\" return all pretrain tags for the specified model architecture \"\"\"\n", " tags = []\n", " if model in _PRETRAINED:\n", " tags.extend(_PRETRAINED[model].keys())\n", " return tags\n", "\n", "\n", "def is_pretrained_cfg(model: str, tag: str):\n", " if model not in _PRETRAINED:\n", " return False\n", " return _clean_tag(tag) in _PRETRAINED[model]\n", "\n", "\n", "def get_pretrained_cfg(model: str, tag: str):\n", " if model not in _PRETRAINED:\n", " return {}\n", " model_pretrained = _PRETRAINED[model]\n", " return model_pretrained.get(_clean_tag(tag), {})\n", "\n", "\n", "def get_pretrained_url(model: str, tag: str):\n", " cfg = get_pretrained_cfg(model, _clean_tag(tag))\n", " return cfg.get('url', '')\n", "\n", "\n", "def download_pretrained_from_url(\n", " url: str,\n", " cache_dir: Optional[str] = None,\n", "):\n", " if not cache_dir:\n", " cache_dir = os.path.expanduser(\"~/.cache/clip\")\n", " os.makedirs(cache_dir, exist_ok=True)\n", " filename = os.path.basename(url)\n", "\n", " if 'openaipublic' in url:\n", " expected_sha256 = url.split(\"/\")[-2]\n", " elif 'mlfoundations' in url:\n", " expected_sha256 = os.path.splitext(filename)[0].split(\"-\")[-1]\n", " else:\n", " expected_sha256 = ''\n", "\n", " download_target = os.path.join(cache_dir, filename)\n", "\n", " if os.path.exists(download_target) and not os.path.isfile(download_target):\n", " raise RuntimeError(f\"{download_target} exists and is not a regular file\")\n", "\n", " if os.path.isfile(download_target):\n", " if expected_sha256:\n", " if hashlib.sha256(open(download_target, \"rb\").read()).hexdigest().startswith(expected_sha256):\n", " return download_target\n", " else:\n", " warnings.warn(f\"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file\")\n", " else:\n", " return download_target\n", "\n", " with urllib.request.urlopen(url) as source, open(download_target, \"wb\") as output:\n", " with tqdm(total=int(source.headers.get(\"Content-Length\")), ncols=80, unit='iB', unit_scale=True) as loop:\n", " while True:\n", " buffer = source.read(8192)\n", " if not buffer:\n", " break\n", "\n", " output.write(buffer)\n", " loop.update(len(buffer))\n", "\n", " if expected_sha256 and not hashlib.sha256(open(download_target, \"rb\").read()).hexdigest().startswith(expected_sha256):\n", " raise RuntimeError(f\"Model has been downloaded but the SHA256 checksum does not not match\")\n", "\n", " return download_target\n", "\n", "\n", "def has_hf_hub(necessary=False):\n", " if not _has_hf_hub and necessary:\n", " # if no HF Hub module installed, and it is necessary to continue, raise error\n", " raise RuntimeError(\n", " 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')\n", " return _has_hf_hub\n", "\n", "\n", "def _get_safe_alternatives(filename: str) -> Iterable[str]:\n", " \"\"\"Returns potential safetensors alternatives for a given filename.\n", "\n", " Use case:\n", " When downloading a model from the Huggingface Hub, we first look if a .safetensors file exists and if yes, we use it.\n", " \"\"\"\n", " if filename == HF_WEIGHTS_NAME:\n", " yield HF_SAFE_WEIGHTS_NAME\n", "\n", " if filename not in (HF_WEIGHTS_NAME,) and (filename.endswith(\".bin\") or filename.endswith(\".pth\")):\n", " yield filename[:-4] + \".safetensors\"\n", "\n", "\n", "def download_pretrained_from_hf(\n", " model_id: str,\n", " filename: Optional[str] = None,\n", " revision: Optional[str] = None,\n", " cache_dir: Optional[str] = None,\n", "):\n", " has_hf_hub(True)\n", "\n", " filename = filename or HF_WEIGHTS_NAME\n", "\n", " # Look for .safetensors alternatives and load from it if it exists\n", " if _has_safetensors:\n", " for safe_filename in _get_safe_alternatives(filename):\n", " try:\n", " cached_file = hf_hub_download(\n", " repo_id=model_id,\n", " filename=safe_filename,\n", " revision=revision,\n", " cache_dir=cache_dir,\n", " )\n", " return cached_file\n", " except Exception:\n", " pass\n", "\n", " try:\n", " # Attempt to download the file\n", " cached_file = hf_hub_download(\n", " repo_id=model_id,\n", " filename=filename,\n", " revision=revision,\n", " cache_dir=cache_dir,\n", " )\n", " return cached_file # Return the path to the downloaded file if successful\n", " except Exception as e:\n", " raise FileNotFoundError(f\"Failed to download file ({filename}) for {model_id}. Last error: {e}\")\n", "\n", "\n", "def download_pretrained(\n", " cfg: Dict,\n", " prefer_hf_hub: bool = True,\n", " cache_dir: Optional[str] = None,\n", "):\n", " target = ''\n", " if not cfg:\n", " return target\n", "\n", " if 'file' in cfg:\n", " return cfg['file']\n", "\n", " has_hub = has_hf_hub()\n", " download_url = cfg.get('url', '')\n", " download_hf_hub = cfg.get('hf_hub', '')\n", " if has_hub and prefer_hf_hub and download_hf_hub:\n", " # prefer to use HF hub, remove url info\n", " download_url = ''\n", "\n", " if download_url:\n", " target = download_pretrained_from_url(download_url, cache_dir=cache_dir)\n", " elif download_hf_hub:\n", " has_hf_hub(True)\n", " # we assume the hf_hub entries in pretrained config combine model_id + filename in\n", " # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and\n", " # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.\n", " model_id, filename = os.path.split(download_hf_hub)\n", " if filename:\n", " target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir)\n", " else:\n", " target = download_pretrained_from_hf(model_id, cache_dir=cache_dir)\n", "\n", " return target\n", "\n", "# ==================================================================\n", "def merge_preprocess_dict(\n", " base: Union[PreprocessCfg, Dict],\n", " overlay: Dict,\n", "):\n", " \"\"\" Merge overlay key-value pairs on top of base preprocess cfg or dict.\n", " Input dicts are filtered based on PreprocessCfg fields.\n", " \"\"\"\n", " if isinstance(base, PreprocessCfg):\n", " base_clean = asdict(base)\n", " else:\n", " base_clean = {k: v for k, v in base.items() if k in _PREPROCESS_KEYS}\n", " if overlay:\n", " overlay_clean = {k: v for k, v in overlay.items() if k in _PREPROCESS_KEYS and v is not None}\n", " base_clean.update(overlay_clean)\n", " return base_clean\n", "\n", "\n", "def merge_preprocess_kwargs(base: PreprocessCfg, **kwargs):\n", " return merge_preprocess_dict(base, kwargs)\n", "\n", "\n", "@dataclass\n", "class PreprocessCfg:\n", " size: Union[int, Tuple[int, int]] = 224\n", " mode: str = 'RGB'\n", " mean: Tuple[float, ...] = OPENAI_DATASET_MEAN\n", " std: Tuple[float, ...] = OPENAI_DATASET_STD\n", " interpolation: str = 'bicubic'\n", " resize_mode: str = 'shortest'\n", " fill_color: int = 0\n", "\n", " def __post_init__(self):\n", " assert self.mode in ('RGB',)\n", "\n", " @property\n", " def num_channels(self):\n", " return 3\n", "\n", " @property\n", " def input_size(self):\n", " return (self.num_channels,) + to_2tuple(self.size)\n", "\n", "\n", "\n", "\n", "@dataclass\n", "class PreprocessCfg:\n", " size: Union[int, Tuple[int, int]] = 224\n", " mode: str = 'RGB'\n", " mean: Tuple[float, ...] = OPENAI_DATASET_MEAN\n", " std: Tuple[float, ...] = OPENAI_DATASET_STD\n", " interpolation: str = 'bicubic'\n", " resize_mode: str = 'shortest'\n", " fill_color: int = 0\n", "\n", " def __post_init__(self):\n", " assert self.mode in ('RGB',)\n", "\n", " @property\n", " def num_channels(self):\n", " return 3\n", "\n", " @property\n", " def input_size(self):\n", " return (self.num_channels,) + to_2tuple(self.size)\n", "\n", "_PREPROCESS_KEYS = set(asdict(PreprocessCfg()).keys())\n", "\n", "\n", "def merge_preprocess_dict(\n", " base: Union[PreprocessCfg, Dict],\n", " overlay: Dict,\n", "):\n", " \"\"\" Merge overlay key-value pairs on top of base preprocess cfg or dict.\n", " Input dicts are filtered based on PreprocessCfg fields.\n", " \"\"\"\n", " if isinstance(base, PreprocessCfg):\n", " base_clean = asdict(base)\n", " else:\n", " base_clean = {k: v for k, v in base.items() if k in _PREPROCESS_KEYS}\n", " if overlay:\n", " overlay_clean = {k: v for k, v in overlay.items() if k in _PREPROCESS_KEYS and v is not None}\n", " base_clean.update(overlay_clean)\n", " return base_clean\n", "\n", "\n", "def merge_preprocess_kwargs(base: PreprocessCfg, **kwargs):\n", " return merge_preprocess_dict(base, kwargs)\n", "\n", "\n", "@dataclass\n", "class AugmentationCfg:\n", " scale: Tuple[float, float] = (0.9, 1.0)\n", " ratio: Optional[Tuple[float, float]] = None\n", " color_jitter: Optional[Union[float, Tuple[float, float, float], Tuple[float, float, float, float]]] = None\n", " re_prob: Optional[float] = None\n", " re_count: Optional[int] = None\n", " use_timm: bool = False\n", "\n", " # params for simclr_jitter_gray\n", " color_jitter_prob: float = None\n", " gray_scale_prob: float = None\n", "\n", "\n", "def _setup_size(size, error_msg):\n", " if isinstance(size, numbers.Number):\n", " return int(size), int(size)\n", "\n", " if isinstance(size, Sequence) and len(size) == 1:\n", " return size[0], size[0]\n", "\n", " if len(size) != 2:\n", " raise ValueError(error_msg)\n", "\n", " return size\n", "\n", "\n", "class ResizeKeepRatio:\n", " \"\"\" Resize and Keep Ratio\n", "\n", " Copy & paste from `timm`\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " size,\n", " longest=0.,\n", " interpolation=InterpolationMode.BICUBIC,\n", " random_scale_prob=0.,\n", " random_scale_range=(0.85, 1.05),\n", " random_aspect_prob=0.,\n", " random_aspect_range=(0.9, 1.11)\n", " ):\n", " if isinstance(size, (list, tuple)):\n", " self.size = tuple(size)\n", " else:\n", " self.size = (size, size)\n", " self.interpolation = interpolation\n", " self.longest = float(longest) # [0, 1] where 0 == shortest edge, 1 == longest\n", " self.random_scale_prob = random_scale_prob\n", " self.random_scale_range = random_scale_range\n", " self.random_aspect_prob = random_aspect_prob\n", " self.random_aspect_range = random_aspect_range\n", "\n", " @staticmethod\n", " def get_params(\n", " img,\n", " target_size,\n", " longest,\n", " random_scale_prob=0.,\n", " random_scale_range=(0.85, 1.05),\n", " random_aspect_prob=0.,\n", " random_aspect_range=(0.9, 1.11)\n", " ):\n", " \"\"\"Get parameters\n", " \"\"\"\n", " source_size = img.size[::-1] # h, w\n", " h, w = source_size\n", " target_h, target_w = target_size\n", " ratio_h = h / target_h\n", " ratio_w = w / target_w\n", " ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest)\n", " if random_scale_prob > 0 and random.random() < random_scale_prob:\n", " ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1])\n", " ratio_factor = (ratio_factor, ratio_factor)\n", " else:\n", " ratio_factor = (1., 1.)\n", " if random_aspect_prob > 0 and random.random() < random_aspect_prob:\n", " aspect_factor = random.uniform(random_aspect_range[0], random_aspect_range[1])\n", " ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor)\n", " size = [round(x * f / ratio) for x, f in zip(source_size, ratio_factor)]\n", " return size\n", "\n", " def __call__(self, img):\n", " \"\"\"\n", " Args:\n", " img (PIL Image): Image to be cropped and resized.\n", "\n", " Returns:\n", " PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size\n", " \"\"\"\n", " size = self.get_params(\n", " img, self.size, self.longest,\n", " self.random_scale_prob, self.random_scale_range,\n", " self.random_aspect_prob, self.random_aspect_range\n", " )\n", " img = F.resize(img, size, self.interpolation)\n", " return img\n", "\n", " def __repr__(self):\n", " format_string = self.__class__.__name__ + '(size={0}'.format(self.size)\n", " format_string += f', interpolation={self.interpolation})'\n", " format_string += f', longest={self.longest:.3f})'\n", " return format_string\n", "\n", "\n", "def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> torch.Tensor:\n", " \"\"\"Center crops and/or pads the given image.\n", " If the image is torch Tensor, it is expected\n", " to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.\n", " If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.\n", "\n", " Args:\n", " img (PIL Image or Tensor): Image to be cropped.\n", " output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int,\n", " it is used for both directions.\n", " fill (int, Tuple[int]): Padding color\n", "\n", " Returns:\n", " PIL Image or Tensor: Cropped image.\n", " \"\"\"\n", " if isinstance(output_size, numbers.Number):\n", " output_size = (int(output_size), int(output_size))\n", " elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:\n", " output_size = (output_size[0], output_size[0])\n", "\n", " _, image_height, image_width = F.get_dimensions(img)\n", " crop_height, crop_width = output_size\n", "\n", " if crop_width > image_width or crop_height > image_height:\n", " padding_ltrb = [\n", " (crop_width - image_width) // 2 if crop_width > image_width else 0,\n", " (crop_height - image_height) // 2 if crop_height > image_height else 0,\n", " (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,\n", " (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,\n", " ]\n", " img = F.pad(img, padding_ltrb, fill=fill)\n", " _, image_height, image_width = F.get_dimensions(img)\n", " if crop_width == image_width and crop_height == image_height:\n", " return img\n", "\n", " crop_top = int(round((image_height - crop_height) / 2.0))\n", " crop_left = int(round((image_width - crop_width) / 2.0))\n", " return F.crop(img, crop_top, crop_left, crop_height, crop_width)\n", "\n", "\n", "class CenterCropOrPad(torch.nn.Module):\n", " \"\"\"Crops the given image at the center.\n", " If the image is torch Tensor, it is expected\n", " to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.\n", " If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.\n", "\n", " Args:\n", " size (sequence or int): Desired output size of the crop. If size is an\n", " int instead of sequence like (h, w), a square crop (size, size) is\n", " made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).\n", " \"\"\"\n", "\n", " def __init__(self, size, fill=0):\n", " super().__init__()\n", " self.size = _setup_size(size, error_msg=\"Please provide only two dimensions (h, w) for size.\")\n", " self.fill = fill\n", "\n", " def forward(self, img):\n", " \"\"\"\n", " Args:\n", " img (PIL Image or Tensor): Image to be cropped.\n", "\n", " Returns:\n", " PIL Image or Tensor: Cropped image.\n", " \"\"\"\n", " return center_crop_or_pad(img, self.size, fill=self.fill)\n", "\n", " def __repr__(self) -> str:\n", " return f\"{self.__class__.__name__}(size={self.size})\"\n", "\n", "\n", "def _convert_to_rgb(image):\n", " return image.convert('RGB')\n", "\n", "\n", "class color_jitter(object):\n", " \"\"\"\n", " Apply Color Jitter to the PIL image with a specified probability.\n", " \"\"\"\n", " def __init__(self, brightness=0., contrast=0., saturation=0., hue=0., p=0.8):\n", " assert 0. <= p <= 1.\n", " self.p = p\n", " self.transf = ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)\n", "\n", " def __call__(self, img):\n", " if random.random() < self.p:\n", " return self.transf(img)\n", " else:\n", " return img\n", "\n", "\n", "class gray_scale(object):\n", " \"\"\"\n", " Apply Gray Scale to the PIL image with a specified probability.\n", " \"\"\"\n", " def __init__(self, p=0.2):\n", " assert 0. <= p <= 1.\n", " self.p = p\n", " self.transf = Grayscale(num_output_channels=3)\n", "\n", " def __call__(self, img):\n", " if random.random() < self.p:\n", " return self.transf(img)\n", " else:\n", " return img\n", "\n", "\n", "def image_transform(\n", " image_size: Union[int, Tuple[int, int]],\n", " is_train: bool,\n", " mean: Optional[Tuple[float, ...]] = None,\n", " std: Optional[Tuple[float, ...]] = None,\n", " resize_mode: Optional[str] = None,\n", " interpolation: Optional[str] = None,\n", " fill_color: int = 0,\n", " aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,\n", "):\n", " mean = mean or OPENAI_DATASET_MEAN\n", " if not isinstance(mean, (list, tuple)):\n", " mean = (mean,) * 3\n", "\n", " std = std or OPENAI_DATASET_STD\n", " if not isinstance(std, (list, tuple)):\n", " std = (std,) * 3\n", "\n", " interpolation = interpolation or 'bicubic'\n", " assert interpolation in ['bicubic', 'bilinear', 'random']\n", " # NOTE random is ignored for interpolation_mode, so defaults to BICUBIC for inference if set\n", " interpolation_mode = InterpolationMode.BILINEAR if interpolation == 'bilinear' else InterpolationMode.BICUBIC\n", "\n", " resize_mode = resize_mode or 'shortest'\n", " assert resize_mode in ('shortest', 'longest', 'squash')\n", "\n", " if isinstance(aug_cfg, dict):\n", " aug_cfg = AugmentationCfg(**aug_cfg)\n", " else:\n", " aug_cfg = aug_cfg or AugmentationCfg()\n", "\n", " normalize = Normalize(mean=mean, std=std)\n", "\n", " if is_train:\n", " aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}\n", " use_timm = aug_cfg_dict.pop('use_timm', False)\n", " if use_timm:\n", " from timm.data import create_transform # timm can still be optional\n", " if isinstance(image_size, (tuple, list)):\n", " assert len(image_size) >= 2\n", " input_size = (3,) + image_size[-2:]\n", " else:\n", " input_size = (3, image_size, image_size)\n", "\n", " aug_cfg_dict.setdefault('color_jitter', None) # disable by default\n", " # drop extra non-timm items\n", " aug_cfg_dict.pop('color_jitter_prob', None)\n", " aug_cfg_dict.pop('gray_scale_prob', None)\n", "\n", " train_transform = create_transform(\n", " input_size=input_size,\n", " is_training=True,\n", " hflip=0.,\n", " mean=mean,\n", " std=std,\n", " re_mode='pixel',\n", " interpolation=interpolation,\n", " **aug_cfg_dict,\n", " )\n", " else:\n", " train_transform = [\n", " RandomResizedCrop(\n", " image_size,\n", " scale=aug_cfg_dict.pop('scale'),\n", " interpolation=InterpolationMode.BICUBIC,\n", " ),\n", " _convert_to_rgb,\n", " ]\n", " if aug_cfg.color_jitter_prob:\n", " assert aug_cfg.color_jitter is not None and len(aug_cfg.color_jitter) == 4\n", " train_transform.extend([\n", " color_jitter(*aug_cfg.color_jitter, p=aug_cfg.color_jitter_prob)\n", " ])\n", " if aug_cfg.gray_scale_prob:\n", " train_transform.extend([\n", " gray_scale(aug_cfg.gray_scale_prob)\n", " ])\n", " train_transform.extend([\n", " ToTensor(),\n", " normalize,\n", " ])\n", " train_transform = Compose(train_transform)\n", " if aug_cfg_dict:\n", " warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).')\n", " return train_transform\n", " else:\n", " if resize_mode == 'longest':\n", " transforms = [\n", " ResizeKeepRatio(image_size, interpolation=interpolation_mode, longest=1),\n", " CenterCropOrPad(image_size, fill=fill_color)\n", " ]\n", " elif resize_mode == 'squash':\n", " if isinstance(image_size, int):\n", " image_size = (image_size, image_size)\n", " transforms = [\n", " Resize(image_size, interpolation=interpolation_mode),\n", " ]\n", " else:\n", " assert resize_mode == 'shortest'\n", " if not isinstance(image_size, (tuple, list)):\n", " image_size = (image_size, image_size)\n", " if image_size[0] == image_size[1]:\n", " # simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)\n", " transforms = [\n", " Resize(image_size[0], interpolation=interpolation_mode)\n", " ]\n", " else:\n", " # resize shortest edge to matching target dim for non-square target\n", " transforms = [ResizeKeepRatio(image_size)]\n", " transforms += [CenterCrop(image_size)]\n", "\n", " transforms.extend([\n", " _convert_to_rgb,\n", " ToTensor(),\n", " normalize,\n", " ])\n", " return Compose(transforms)\n", " \n", " \n", "def image_transform_v2(\n", " cfg: PreprocessCfg,\n", " is_train: bool,\n", " aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,\n", "):\n", " return image_transform(\n", " image_size=cfg.size,\n", " is_train=is_train,\n", " mean=cfg.mean,\n", " std=cfg.std,\n", " interpolation=cfg.interpolation,\n", " resize_mode=cfg.resize_mode,\n", " fill_color=cfg.fill_color,\n", " aug_cfg=aug_cfg,\n", " )\n", "\n", "@dataclass\n", "class AugmentationCfg:\n", " scale: Tuple[float, float] = (0.9, 1.0)\n", " ratio: Optional[Tuple[float, float]] = None\n", " color_jitter: Optional[Union[float, Tuple[float, float, float], Tuple[float, float, float, float]]] = None\n", " re_prob: Optional[float] = None\n", " re_count: Optional[int] = None\n", " use_timm: bool = False\n", "\n", " # params for simclr_jitter_gray\n", " color_jitter_prob: float = None\n", " gray_scale_prob: float = None\n", "\n", "def set_model_preprocess_cfg(model, preprocess_cfg: Dict[str, Any]):\n", " module = getattr(model, 'visual', model)\n", " module.image_mean = preprocess_cfg['mean'] # legacy attribute, keeping for bwd compat\n", " module.image_std = preprocess_cfg['std'] # legacy attribute, keeping for bwd compat\n", " module.preprocess_cfg = copy.deepcopy(preprocess_cfg) # new attr, package all pp cfg as dict\n", "\n", "\n", "@torch.no_grad()\n", "def convert_mobile_clip_state_dict(model: CustomTextCLIP, state_dict, fastvit = True):\n", "\n", " def _convert_timm_img(state_dict):\n", " if fastvit:\n", " from timm.models.fastvit import checkpoint_filter_fn\n", " else:\n", " from timm.models.vision_transformer_hybrid import checkpoint_filter_fn\n", " timm_state_dict = checkpoint_filter_fn(state_dict, model.visual.trunk)\n", " timm_state_dict = {'visual.trunk.' + k: v for k, v in timm_state_dict.items()}\n", " return timm_state_dict\n", "\n", " def _convert_openclip_txt(state_dict, prefix='text_encoder.'):\n", " text_dict = {}\n", " for k, v in state_dict.items():\n", " if not k.startswith(prefix):\n", " continue\n", " k = k.replace(prefix, '')\n", " k = k.replace('projection_layer', 'text_projection')\n", " k = k.replace('embedding_layer', 'token_embedding')\n", " if k.startswith('positional_embedding.pos_embed.pos_embed'):\n", " k = k.replace('positional_embedding.pos_embed.pos_embed', 'positional_embedding')\n", " v = v.squeeze()\n", " k = k.replace('final_layer_norm', 'ln_final')\n", " k = k.replace('pre_norm_mha.0', 'ln_1')\n", " k = k.replace('pre_norm_mha.1', 'attn')\n", " k = k.replace('pre_norm_ffn.0', 'ln_2')\n", " k = k.replace('pre_norm_ffn.1', 'mlp.c_fc')\n", " k = k.replace('pre_norm_ffn.4', 'mlp.c_proj')\n", " k = k.replace('qkv_proj.weight', 'in_proj_weight')\n", " k = k.replace('qkv_proj.bias', 'in_proj_bias')\n", " k = k.replace('transformer.', 'transformer.resblocks.')\n", " text_dict['text.' + k] = v\n", " return text_dict\n", "\n", " image_dict = _convert_timm_img(state_dict)\n", " text_dict = _convert_openclip_txt(state_dict)\n", " out_dict = {**image_dict, **text_dict}\n", " out_dict['logit_scale'] = state_dict['logit_scale']\n", " return out_dict\n", "\n", "\n", "def convert_state_dict(model: Union[CustomTextCLIP, CLIP], state_dict):\n", " if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict:\n", " # Apple MobileCLIP s1 & s2 state_dicts (s0 and b not currently supported)\n", " state_dict = convert_mobile_clip_state_dict(model, state_dict)\n", " if 'image_encoder.model.patch_emb.0.block.conv.weight' in state_dict:\n", " # convert b model\n", " state_dict = convert_mobile_clip_state_dict(model, state_dict, fastvit=False)\n", " return state_dict\n", "\n", "def load_state_dict(\n", " checkpoint_path: str,\n", " device='cpu',\n", " weights_only=True,\n", "):\n", " # Check if safetensors or not and load weights accordingly\n", " if str(checkpoint_path).endswith(\".safetensors\"):\n", " from safetensors.torch import load_file\n", " checkpoint = load_file(checkpoint_path, device=device)\n", " else:\n", " try:\n", " checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=weights_only)\n", " except TypeError:\n", " checkpoint = torch.load(checkpoint_path, map_location=device)\n", "\n", " if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:\n", " state_dict = checkpoint['state_dict']\n", " elif isinstance(checkpoint, torch.jit.ScriptModule):\n", " state_dict = checkpoint.state_dict()\n", " for key in [\"input_resolution\", \"context_length\", \"vocab_size\"]:\n", " state_dict.pop(key, None)\n", " else:\n", " state_dict = checkpoint\n", " if next(iter(state_dict.items()))[0].startswith('module'):\n", " state_dict = {k[7:]: v for k, v in state_dict.items()}\n", " return state_dict\n", "\n", "def load_checkpoint(\n", " model: Union[CLIP, CustomTextCLIP],\n", " checkpoint_path: str,\n", " strict: bool = True,\n", " weights_only: bool = True,\n", " device='cpu',\n", "):\n", " if Path(checkpoint_path).suffix in ('.npz', '.npy'):\n", " # Separate path loading numpy big_vision (SigLIP) weights\n", " from open_clip.convert import load_big_vision_weights\n", " load_big_vision_weights(model, checkpoint_path)\n", " return {}\n", "\n", " state_dict = load_state_dict(checkpoint_path, device=device, weights_only=weights_only)\n", "\n", " # Detect & convert 3rd party state_dicts -> open_clip\n", " state_dict = convert_state_dict(model, state_dict)\n", "\n", " # Detect old format and make compatible with new format\n", " if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):\n", " state_dict = convert_to_custom_text_state_dict(state_dict)\n", "\n", " # correct if logit_scale differs in being scaler vs 1d param\n", " if 'logit_scale' in state_dict and model.logit_scale.ndim != state_dict['logit_scale'].ndim:\n", " state_dict['logit_scale'] = state_dict['logit_scale'].reshape(model.logit_scale.shape)\n", "\n", " # correct if logit_bias differs in being scaler vs 1d param\n", " if 'logit_bias' in state_dict and model.logit_bias.ndim != state_dict['logit_bias'].ndim:\n", " state_dict['logit_bias'] = state_dict['logit_bias'].reshape(model.logit_bias.shape)\n", "\n", " # If loading a non-SigLIP model for SigLIP training. See https://github.com/mlfoundations/open_clip/issues/712\n", " if 'logit_bias' not in state_dict and model.logit_bias is not None:\n", " state_dict[\"logit_bias\"] = torch.zeros_like(state_dict[\"logit_scale\"])\n", "\n", " # Certain text transformers no longer expect position_ids after transformers==4.31\n", " position_id_key = 'text.transformer.embeddings.position_ids'\n", " if position_id_key in state_dict and not hasattr(model, position_id_key):\n", " del state_dict[position_id_key]\n", "\n", " resize_pos_embed(state_dict, model)\n", " resize_text_pos_embed(state_dict, model)\n", "\n", " # Finally, load the massaged state_dict into model\n", " incompatible_keys = model.load_state_dict(state_dict, strict=strict)\n", " return incompatible_keys\n", "\n", "# /home/IITB/ai-at-ieor/23m1521/.conda/envs/openclip2/lib/python3.11/site-packages/open_clip/factory.py\n", "HF_HUB_PREFIX = 'hf-hub:'\n", "# _MODEL_CONFIG_PATHS = [Path(__file__).parent / f\"model_configs/\"]\n", "_MODEL_CONFIG_PATHS = [Path(\"/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/model_configs\")]\n", "_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs\n", "\n", "import json\n", "\n", "def _get_hf_config(\n", " model_id: str,\n", " cache_dir: Optional[str] = None,\n", "):\n", " \"\"\" Fetch model config from HuggingFace Hub.\n", " \"\"\"\n", " config_path = download_pretrained_from_hf(\n", " model_id,\n", " filename='open_clip_config.json',\n", " cache_dir=cache_dir,\n", " )\n", " with open(config_path, 'r', encoding='utf-8') as f:\n", " config = json.load(f)\n", " return config\n", "\n", "def get_model_config(model_name):\n", " \"\"\" Fetch model config from builtin (local library) configs.\n", " \"\"\"\n", " if model_name in _MODEL_CONFIGS:\n", " return copy.deepcopy(_MODEL_CONFIGS[model_name])\n", " else:\n", " return None\n", "\n", "def _natural_key(string_):\n", " return [int(s) if s.isdigit() else s for s in re.split(r'(\\d+)', string_.lower())]\n", "\n", "\n", "def _rescan_model_configs():\n", " global _MODEL_CONFIGS\n", "\n", " config_ext = ('.json',)\n", " config_files = []\n", " for config_path in _MODEL_CONFIG_PATHS:\n", " if config_path.is_file() and config_path.suffix in config_ext:\n", " config_files.append(config_path)\n", " elif config_path.is_dir():\n", " for ext in config_ext:\n", " config_files.extend(config_path.glob(f'*{ext}'))\n", "\n", " for cf in config_files:\n", " with open(cf, 'r') as f:\n", " model_cfg = json.load(f)\n", " if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):\n", " _MODEL_CONFIGS[cf.stem] = model_cfg\n", "\n", " _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}\n", "\n", "\n", "_rescan_model_configs() # initial populate of model config registry\n", "\n", "def list_models():\n", " \"\"\" enumerate available model architectures based on config files \"\"\"\n", " return list(_MODEL_CONFIGS.keys())\n", "\n", "\n", "def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):\n", " if past:\n", " input_ids = input_ids[:, -1].unsqueeze(-1)\n", "\n", " attention_mask = kwargs.get(\"attention_mask\", None)\n", " position_ids = kwargs.get(\"position_ids\", None)\n", "\n", " if attention_mask is not None and position_ids is None:\n", " # create position_ids on the fly for batch generation\n", " position_ids = attention_mask.long().cumsum(-1) - 1\n", " position_ids.masked_fill_(attention_mask == 0, 1)\n", " else:\n", " position_ids = None\n", " return {\n", " \"text\": input_ids,\n", " \"images\": image_inputs,\n", " \"past_key_values\": past,\n", " \"position_ids\": position_ids,\n", " \"attention_mask\": attention_mask,\n", " }\n", "\n", "@dataclass\n", "class MultimodalCfg(CLIPTextCfg):\n", " mlp_ratio: int = 4\n", " dim_head: int = 64\n", " heads: int = 8\n", " n_queries: int = 256\n", " attn_pooler_heads: int = 8\n", "\n", "try:\n", " from transformers import (\n", " BeamSearchScorer,\n", " LogitsProcessorList,\n", " TopPLogitsWarper,\n", " TopKLogitsWarper,\n", " RepetitionPenaltyLogitsProcessor,\n", " MinLengthLogitsProcessor,\n", " MaxLengthCriteria,\n", " StopStringCriteria,\n", " EosTokenCriteria,\n", " StoppingCriteriaList\n", " )\n", "\n", " GENERATION_TYPES = {\n", " \"top_k\": TopKLogitsWarper,\n", " \"top_p\": TopPLogitsWarper,\n", " \"beam_search\": \"beam_search\"\n", " }\n", " _has_transformers = True\n", "except ImportError as e:\n", " GENERATION_TYPES = {\n", " \"top_k\": None,\n", " \"top_p\": None,\n", " \"beam_search\": \"beam_search\"\n", " }\n", " _has_transformers = False\n", "\n", "def _token_to_tensor(token_id, device: str = \"cpu\") -> torch.Tensor:\n", " if not isinstance(token_id, torch.Tensor):\n", " if isinstance(token_id, int):\n", " token_id = [token_id]\n", " token_id = torch.tensor(token_id, device=device)\n", " return token_id\n", "\n", "\n", "def _build_text_decoder_tower(\n", " embed_dim,\n", " multimodal_cfg,\n", " quick_gelu: bool = False,\n", " cast_dtype: Optional[torch.dtype] = None,\n", "):\n", " multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg\n", " act_layer = QuickGELU if quick_gelu else nn.GELU\n", " norm_layer = (\n", " LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm\n", " )\n", "\n", " decoder = MultimodalTransformer(\n", " context_length=multimodal_cfg.context_length,\n", " width=multimodal_cfg.width,\n", " heads=multimodal_cfg.heads,\n", " layers=multimodal_cfg.layers,\n", " ls_init_value=multimodal_cfg.ls_init_value,\n", " output_dim=embed_dim,\n", " act_layer=act_layer,\n", " norm_layer=norm_layer,\n", " )\n", "\n", " return decoder\n", "\n", "class CoCa(nn.Module):\n", " def __init__(\n", " self,\n", " embed_dim,\n", " multimodal_cfg: MultimodalCfg,\n", " text_cfg: CLIPTextCfg,\n", " vision_cfg: CLIPVisionCfg,\n", " quick_gelu: bool = False,\n", " init_logit_scale: float = np.log(1 / 0.07),\n", " init_logit_bias: Optional[float] = None,\n", " nonscalar_logit_scale: bool = False,\n", " cast_dtype: Optional[torch.dtype] = None,\n", " pad_id: int = 0,\n", " ):\n", " super().__init__()\n", " multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg\n", " text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg\n", " vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg\n", "\n", " self.text = _build_text_tower(\n", " embed_dim=embed_dim,\n", " text_cfg=text_cfg,\n", " quick_gelu=quick_gelu,\n", " cast_dtype=cast_dtype,\n", " )\n", "\n", " vocab_size = (\n", " text_cfg.vocab_size # for hf models\n", " if hasattr(text_cfg, \"hf_model_name\") and text_cfg.hf_model_name is not None\n", " else text_cfg.vocab_size\n", " )\n", "\n", " self.visual = _build_vision_tower(\n", " embed_dim=embed_dim,\n", " vision_cfg=vision_cfg,\n", " quick_gelu=quick_gelu,\n", " cast_dtype=cast_dtype,\n", " )\n", "\n", " self.text_decoder = _build_text_decoder_tower(\n", " vocab_size,\n", " multimodal_cfg=multimodal_cfg,\n", " quick_gelu=quick_gelu,\n", " cast_dtype=cast_dtype,\n", " )\n", "\n", " lshape = [1] if nonscalar_logit_scale else []\n", " self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)\n", " if init_logit_bias is not None:\n", " self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)\n", " else:\n", " self.logit_bias = None\n", " self.pad_id = pad_id\n", "\n", " self.context_length = multimodal_cfg.context_length\n", "\n", " @torch.jit.ignore\n", " def set_grad_checkpointing(self, enable: bool = True):\n", " self.visual.set_grad_checkpointing(enable)\n", " self.text.set_grad_checkpointing(enable)\n", " self.text_decoder.set_grad_checkpointing(enable)\n", "\n", " def _encode_image(self, images, normalize: bool = True):\n", " image_latent, tokens_embs = self.visual(images)\n", " image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent\n", " return image_latent, tokens_embs\n", "\n", " def _encode_text(self, text, normalize: bool = True):\n", " text_latent, token_emb = self.text(text)\n", " text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent\n", " return text_latent, token_emb\n", "\n", " def encode_image(self, images, normalize: bool = True):\n", " image_latent, _ = self._encode_image(images, normalize=normalize)\n", " return image_latent\n", "\n", " def encode_text(self, text, normalize: bool = True):\n", " text_latent, _ = self._encode_text(text, normalize=normalize)\n", " return text_latent\n", "\n", " def forward_intermediates(\n", " self,\n", " image: Optional[torch.Tensor] = None,\n", " text: Optional[torch.Tensor] = None,\n", " image_indices: Optional[Union[int, List[int]]] = None,\n", " text_indices: Optional[Union[int, List[int]]] = None,\n", " stop_early: bool = False,\n", " normalize: bool = True,\n", " normalize_intermediates: bool = False,\n", " intermediates_only: bool = False,\n", " image_output_fmt: str = 'NCHW',\n", " image_output_extra_tokens: bool = False,\n", " text_output_fmt: str = 'NLC',\n", " text_output_extra_tokens: bool = False,\n", " output_logits: bool = False,\n", " output_logit_scale_bias: bool = False,\n", " ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:\n", " \"\"\" Forward features that returns intermediates.\n", "\n", " Args:\n", " image: Input image tensor\n", " text: Input text tensor\n", " image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence\n", " text_indices: Take last n blocks if int, all if None, select matching indices if sequence\n", " stop_early: Stop iterating over blocks when last desired intermediate hit\n", " normalize: L2 Normalize final image and text features (if present)\n", " normalize_intermediates: Apply final encoder norm layer to all intermediates (if possible)\n", " intermediates_only: Only return intermediate features, do not return final features\n", " image_output_fmt: Shape of intermediate image feature outputs\n", " image_output_extra_tokens: Return both prefix and spatial intermediate tokens\n", " text_output_fmt: Shape of intermediate text feature outputs\n", " text_output_extra_tokens: Return both prefix and spatial intermediate tokens\n", " output_logits: Include logits in output\n", " output_logit_scale_bias: Include the logit scale bias in the output\n", " Returns:\n", "\n", " \"\"\"\n", " output = {}\n", " if intermediates_only:\n", " # intermediates only disables final feature normalization, and include logits\n", " normalize = False\n", " output_logits = False\n", " if output_logits:\n", " assert False, 'FIXME, needs implementing'\n", "\n", " if image is not None:\n", " image_output = self.visual.forward_intermediates(\n", " image,\n", " indices=image_indices,\n", " stop_early=stop_early,\n", " normalize_intermediates=normalize_intermediates,\n", " intermediates_only=intermediates_only,\n", " output_fmt=image_output_fmt,\n", " output_extra_tokens=image_output_extra_tokens,\n", " )\n", " if normalize and \"image_features\" in image_output:\n", " image_output[\"image_features\"] = F.normalize(image_output[\"image_features\"], dim=-1)\n", " output.update(image_output)\n", "\n", " if text is not None:\n", " text_output = self.text.forward_intermediates(\n", " text,\n", " indices=text_indices,\n", " stop_early=stop_early,\n", " normalize_intermediates=normalize_intermediates,\n", " intermediates_only=intermediates_only,\n", " output_fmt=text_output_fmt,\n", " output_extra_tokens=text_output_extra_tokens,\n", " )\n", " if normalize and \"text_features\" in text_output:\n", " text_output[\"text_features\"] = F.normalize(text_output[\"text_features\"], dim=-1)\n", " output.update(text_output)\n", "\n", " # FIXME text decoder\n", " logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None\n", " if output_logit_scale_bias:\n", " output[\"logit_scale\"] = logit_scale_exp\n", " if self.logit_bias is not None:\n", " output['logit_bias'] = self.logit_bias\n", "\n", " return output\n", "\n", " def forward(\n", " self,\n", " image,\n", " text: Optional[torch.Tensor] = None,\n", " image_latent: Optional[torch.Tensor] = None,\n", " image_embs: Optional[torch.Tensor] = None,\n", " output_labels: bool = True,\n", " ):\n", " if image_latent is None or image_embs is None:\n", " image_latent, image_embs = self._encode_image(image)\n", "\n", " if text is None:\n", " return {\"image_features\": image_latent, \"image_embs\": image_embs}\n", "\n", " text_latent, token_embs = self._encode_text(text)\n", "\n", " # FIXME this isn't an ideal solution, would like to improve -RW\n", " labels: Optional[torch.Tensor] = text[:, 1:] if output_labels else None\n", " if output_labels:\n", " # align text_embs and thus logits with labels for teacher-forcing caption loss\n", " token_embs = token_embs[:, :-1]\n", "\n", " logits = self.text_decoder(image_embs, token_embs)\n", " out_dict = {\n", " \"image_features\": image_latent,\n", " \"text_features\": text_latent,\n", " \"logits\": logits,\n", " \"logit_scale\": self.logit_scale.exp()\n", " }\n", " if labels is not None:\n", " out_dict[\"labels\"] = labels\n", " if self.logit_bias is not None:\n", " out_dict[\"logit_bias\"] = self.logit_bias\n", " return out_dict\n", "\n", " def generate(\n", " self,\n", " image,\n", " text=None,\n", " seq_len=30,\n", " max_seq_len=77,\n", " temperature=1.,\n", " generation_type=\"beam_search\",\n", " top_p=0.1, # keep tokens in the 1 - top_p quantile\n", " top_k=1, # keeps the top_k most probable tokens\n", " pad_token_id=None,\n", " eos_token_id=None,\n", " sot_token_id=None,\n", " num_beams=6,\n", " num_beam_groups=3,\n", " min_seq_len=5,\n", " stopping_criteria=None,\n", " repetition_penalty=1.0,\n", " fixed_output_length=False # if True output.shape == (batch_size, seq_len)\n", " ):\n", " # taking many ideas and components from HuggingFace GenerationMixin\n", " # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation\n", " assert _has_transformers, \"Please install transformers for generate functionality. `pip install transformers`.\"\n", " assert seq_len > min_seq_len, \"seq_len must be larger than min_seq_len\"\n", " device = image.device\n", "\n", " with torch.no_grad():\n", " sot_token_id = _token_to_tensor(49406 if sot_token_id is None else sot_token_id, device=device)\n", " eos_token_id = _token_to_tensor(49407 if eos_token_id is None else eos_token_id, device=device)\n", " pad_token_id = self.pad_id if pad_token_id is None else pad_token_id\n", " logit_processor = LogitsProcessorList(\n", " [\n", " MinLengthLogitsProcessor(min_seq_len, eos_token_id),\n", " RepetitionPenaltyLogitsProcessor(repetition_penalty),\n", " ]\n", " )\n", "\n", " if stopping_criteria is None:\n", " stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]\n", " stopping_criteria = StoppingCriteriaList(stopping_criteria)\n", "\n", " if generation_type == \"beam_search\":\n", " output = self._generate_beamsearch(\n", " image_inputs=image,\n", " pad_token_id=pad_token_id,\n", " eos_token_id=eos_token_id,\n", " sot_token_id=sot_token_id,\n", " num_beams=num_beams,\n", " num_beam_groups=num_beam_groups,\n", " min_seq_len=min_seq_len,\n", " stopping_criteria=stopping_criteria,\n", " logit_processor=logit_processor,\n", " )\n", " if fixed_output_length and output.shape[1] < seq_len:\n", " pad_len = seq_len - output.shape[1]\n", " return torch.cat((\n", " output,\n", " torch.ones(output.shape[0], pad_len, device=device, dtype=output.dtype) * pad_token_id\n", " ),\n", " dim=1\n", " )\n", " return output\n", "\n", " elif generation_type == \"top_p\":\n", " logit_warper = GENERATION_TYPES[generation_type](top_p)\n", " elif generation_type == \"top_k\":\n", " logit_warper = GENERATION_TYPES[generation_type](top_k)\n", " else:\n", " raise ValueError(\n", " f\"generation_type has to be one of \"\n", " f\"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}.\"\n", " )\n", "\n", " image_latent, image_embs = self._encode_image(image)\n", "\n", " if text is None:\n", " text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id\n", "\n", " was_training = self.training\n", " num_dims = len(text.shape)\n", "\n", " if num_dims == 1:\n", " text = text[None, :]\n", "\n", " self.eval()\n", " out = text\n", "\n", " while True:\n", " x = out[:, -max_seq_len:]\n", " cur_len = x.shape[1]\n", " logits = self(\n", " image,\n", " x,\n", " image_latent=image_latent,\n", " image_embs=image_embs,\n", " output_labels=False,\n", " )[\"logits\"][:, -1]\n", " mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)\n", " sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id\n", "\n", " if mask.all():\n", " if not fixed_output_length:\n", " break\n", " else:\n", " logits = logits[~mask, :]\n", " filtered_logits = logit_processor(x[~mask, :], logits)\n", " filtered_logits = logit_warper(x[~mask, :], filtered_logits)\n", " probs = F.softmax(filtered_logits / temperature, dim=-1)\n", "\n", " if (cur_len + 1 == seq_len):\n", " sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id\n", " else:\n", " sample[~mask, :] = torch.multinomial(probs, 1)\n", "\n", " out = torch.cat((out, sample), dim=-1)\n", "\n", " cur_len += 1\n", "\n", " if all(stopping_criteria(out, None)):\n", " break\n", "\n", " if num_dims == 1:\n", " out = out.squeeze(0)\n", "\n", " self.train(was_training)\n", " return out\n", "\n", " def _generate_beamsearch(\n", " self,\n", " image_inputs,\n", " pad_token_id=None,\n", " eos_token_id=None,\n", " sot_token_id=None,\n", " num_beams=6,\n", " num_beam_groups=3,\n", " min_seq_len=5,\n", " stopping_criteria=None,\n", " logit_processor=None,\n", " logit_warper=None,\n", " ):\n", " device = image_inputs.device\n", " batch_size = image_inputs.shape[0]\n", " image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)\n", " image_latent, image_embs = self._encode_image(image_inputs)\n", "\n", " input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)\n", " input_ids = input_ids * sot_token_id\n", " beam_scorer = BeamSearchScorer(\n", " batch_size=batch_size,\n", " num_beams=num_beams,\n", " device=device,\n", " num_beam_groups=num_beam_groups,\n", " )\n", " # instantiate logits processors\n", " logits_processor = (\n", " LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])\n", " if logit_processor is None\n", " else logit_processor\n", " )\n", "\n", " num_beams = beam_scorer.num_beams\n", " num_beam_groups = beam_scorer.num_beam_groups\n", " num_sub_beams = num_beams // num_beam_groups\n", " batch_size = len(beam_scorer._beam_hyps) // num_beam_groups\n", " batch_beam_size, cur_len = input_ids.shape\n", " beam_indices = None\n", "\n", " if num_beams * batch_size != batch_beam_size:\n", " raise ValueError(\n", " f\"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}.\"\n", " )\n", "\n", " beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)\n", " # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in\n", " # the same group don't produce same tokens everytime.\n", " beam_scores[:, ::num_sub_beams] = 0\n", " beam_scores = beam_scores.view((batch_size * num_beams,))\n", "\n", " while True:\n", "\n", " # predicted tokens in cur_len step\n", " current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)\n", "\n", " # indices which will form the beams in the next time step\n", " reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)\n", "\n", " # do one decoder step on all beams of all sentences in batch\n", " model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)\n", " outputs = self(\n", " model_inputs['images'],\n", " model_inputs['text'],\n", " image_latent=image_latent,\n", " image_embs=image_embs,\n", " output_labels=False,\n", " )\n", "\n", " for beam_group_idx in range(num_beam_groups):\n", " group_start_idx = beam_group_idx * num_sub_beams\n", " group_end_idx = min(group_start_idx + num_sub_beams, num_beams)\n", " group_size = group_end_idx - group_start_idx\n", "\n", " # indices of beams of current group among all sentences in batch\n", " batch_group_indices = []\n", "\n", " for batch_idx in range(batch_size):\n", " batch_group_indices.extend(\n", " [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]\n", " )\n", " group_input_ids = input_ids[batch_group_indices]\n", "\n", " # select outputs of beams of currentg group only\n", " next_token_logits = outputs['logits'][batch_group_indices, -1, :]\n", " vocab_size = next_token_logits.shape[-1]\n", "\n", " next_token_scores_processed = logits_processor(\n", " group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx\n", " )\n", " next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)\n", " next_token_scores = next_token_scores.expand_as(next_token_scores_processed)\n", "\n", " # reshape for beam search\n", " next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)\n", "\n", " next_token_scores, next_tokens = torch.topk(\n", " next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True\n", " )\n", "\n", " next_indices = torch.div(next_tokens, vocab_size, rounding_mode=\"floor\")\n", " next_tokens = next_tokens % vocab_size\n", "\n", " # stateless\n", " process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None\n", " beam_outputs = beam_scorer.process(\n", " group_input_ids,\n", " next_token_scores,\n", " next_tokens,\n", " next_indices,\n", " pad_token_id=pad_token_id,\n", " eos_token_id=eos_token_id,\n", " beam_indices=process_beam_indices,\n", " group_index=beam_group_idx,\n", " )\n", " beam_scores[batch_group_indices] = beam_outputs[\"next_beam_scores\"]\n", " beam_next_tokens = beam_outputs[\"next_beam_tokens\"]\n", " beam_idx = beam_outputs[\"next_beam_indices\"]\n", "\n", " input_ids[batch_group_indices] = group_input_ids[beam_idx]\n", " group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)\n", " current_tokens[batch_group_indices] = group_input_ids[:, -1]\n", "\n", " # (beam_idx // group_size) -> batch_idx\n", " # (beam_idx % group_size) -> offset of idx inside the group\n", " reordering_indices[batch_group_indices] = (\n", " num_beams * torch.div(beam_idx, group_size, rounding_mode=\"floor\") + group_start_idx + (beam_idx % group_size)\n", " )\n", "\n", " input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)\n", "\n", " # increase cur_len\n", " cur_len = cur_len + 1\n", " if beam_scorer.is_done or all(stopping_criteria(input_ids, None)):\n", " break\n", "\n", " final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None\n", " sequence_outputs = beam_scorer.finalize(\n", " input_ids,\n", " beam_scores,\n", " next_tokens,\n", " next_indices,\n", " pad_token_id=pad_token_id,\n", " eos_token_id=eos_token_id,\n", " max_length=stopping_criteria.max_length,\n", " beam_indices=final_beam_indices,\n", " )\n", " return sequence_outputs['sequences']\n", "\n", "\n", "def create_model(\n", " model_name: str,\n", " pretrained: Optional[str] = None,\n", " precision: str = 'fp32',\n", " device: Union[str, torch.device] = 'cpu',\n", " jit: bool = False,\n", " force_quick_gelu: bool = False,\n", " force_custom_text: bool = False,\n", " force_patch_dropout: Optional[float] = None,\n", " force_image_size: Optional[Union[int, Tuple[int, int]]] = None,\n", " force_preprocess_cfg: Optional[Dict[str, Any]] = None,\n", " pretrained_image: bool = False,\n", " pretrained_hf: bool = True,\n", " cache_dir: Optional[str] = None,\n", " output_dict: Optional[bool] = None,\n", " require_pretrained: bool = False,\n", " load_weights_only: bool = True,\n", " **model_kwargs,\n", "):\n", " \"\"\"Creates and configures a contrastive vision-language model.\n", "\n", " Args:\n", " model_name: Name of the model architecture to create. Can be a local model name\n", " or a Hugging Face model ID prefixed with 'hf-hub:'.\n", " pretrained: Tag/path for pretrained model weights. Can be:\n", " - A pretrained tag name (e.g., 'openai')\n", " - A path to local weights\n", " - None to initialize with random weights\n", " precision: Model precision/AMP configuration. Options:\n", " - 'fp32': 32-bit floating point\n", " - 'fp16'/'bf16': Mixed precision with FP32 for certain layers\n", " - 'pure_fp16'/'pure_bf16': Pure 16-bit precision\n", " device: Device to load the model on ('cpu', 'cuda', or torch.device object)\n", " jit: If True, JIT compile the model\n", " force_quick_gelu: Force use of QuickGELU activation\n", " force_custom_text: Force use of custom text encoder\n", " force_patch_dropout: Override default patch dropout value\n", " force_image_size: Override default image size for vision encoder\n", " force_preprocess_cfg: Override default preprocessing configuration\n", " pretrained_image: Load pretrained weights for timm vision models\n", " pretrained_hf: Load pretrained weights for HF text models when not loading CLIP weights\n", " cache_dir: Override default cache directory for downloaded model files\n", " output_dict: If True and model supports it, return dictionary of features\n", " require_pretrained: Raise error if pretrained weights cannot be loaded\n", " load_weights_only: Only deserialize model weights and unpickling torch checkpoints (for safety)\n", " **model_kwargs: Additional keyword arguments passed to model constructor\n", "\n", " Returns:\n", " Created and configured model instance\n", "\n", " Raises:\n", " RuntimeError: If model config is not found or required pretrained weights\n", " cannot be loaded\n", "\n", " Examples:\n", " # Create basic CLIP model\n", " model = create_model('ViT-B/32')\n", "\n", " # Create CLIP model with mixed precision on GPU\n", " model = create_model('ViT-B/32', precision='fp16', device='cuda')\n", "\n", " # Load pretrained OpenAI weights\n", " model = create_model('ViT-B/32', pretrained='openai')\n", "\n", " # Load Hugging Face model\n", " model = create_model('hf-hub:organization/model-name')\n", " \"\"\"\n", "\n", " force_preprocess_cfg = force_preprocess_cfg or {}\n", " preprocess_cfg = asdict(PreprocessCfg())\n", " has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)\n", " if has_hf_hub_prefix:\n", " model_id = model_name[len(HF_HUB_PREFIX):]\n", " checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)\n", " config = _get_hf_config(model_id, cache_dir=cache_dir)\n", " preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg'])\n", " model_cfg = config['model_cfg']\n", " pretrained_hf = False # override, no need to load original HF text weights\n", " else:\n", " model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names\n", " checkpoint_path = None\n", " model_cfg = None\n", "\n", " if isinstance(device, str):\n", " device = torch.device(device)\n", "\n", " model_cfg = model_cfg or get_model_config(model_name)\n", " if model_cfg is not None:\n", " logging.info(f'Loaded {model_name} model config.')\n", " else:\n", " logging.error(f'Model config for {model_name} not found; available models {list_models()}.')\n", " raise RuntimeError(f'Model config for {model_name} not found.')\n", "\n", " if force_quick_gelu:\n", " # override for use of QuickGELU on non-OpenAI transformer models\n", " model_cfg[\"quick_gelu\"] = True\n", "\n", " if force_patch_dropout is not None:\n", " # override the default patch dropout value\n", " model_cfg[\"vision_cfg\"][\"patch_dropout\"] = force_patch_dropout\n", "\n", " if force_image_size is not None:\n", " # override model config's image size\n", " model_cfg[\"vision_cfg\"][\"image_size\"] = force_image_size\n", "\n", " is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {})\n", " if pretrained_image:\n", " if is_timm_model:\n", " # pretrained weight loading for timm models set via vision_cfg\n", " model_cfg['vision_cfg']['timm_model_pretrained'] = True\n", " else:\n", " assert False, 'pretrained image towers currently only supported for timm models'\n", "\n", " # cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes\n", " cast_dtype = get_cast_dtype(precision)\n", " is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})\n", " if is_hf_model:\n", " # load pretrained weights for HF text model IFF no CLIP weights being loaded\n", " model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained\n", " custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model\n", "\n", " model_cfg = dict(model_cfg, **model_kwargs) # merge cfg dict w/ kwargs (kwargs overrides cfg)\n", " if custom_text:\n", " if \"multimodal_cfg\" in model_cfg:\n", " model = CoCa(**model_cfg, cast_dtype=cast_dtype)\n", " else:\n", " model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)\n", " else:\n", " model = CLIP(**model_cfg, cast_dtype=cast_dtype)\n", "\n", " if precision in (\"fp16\", \"bf16\"):\n", " dtype = torch.float16 if 'fp16' in precision else torch.bfloat16\n", " # manual mixed precision that matches original OpenAI behaviour\n", " if is_timm_model:\n", " # FIXME this is a bit janky, create timm based model in low-precision and\n", " # then cast only LayerNormFp32 instances back to float32 so they don't break.\n", " # Why? The convert_weights_to_lp fn only works with native models.\n", " model.to(device=device, dtype=dtype)\n", " # from .transformer import LayerNormFp32\n", "\n", " def _convert_ln(m):\n", " if isinstance(m, LayerNormFp32):\n", " m.weight.data = m.weight.data.to(torch.float32)\n", " m.bias.data = m.bias.data.to(torch.float32)\n", " model.apply(_convert_ln)\n", " else:\n", " model.to(device=device)\n", " convert_weights_to_lp(model, dtype=dtype)\n", " elif precision in (\"pure_fp16\", \"pure_bf16\"):\n", " dtype = torch.float16 if 'fp16' in precision else torch.bfloat16\n", " model.to(device=device, dtype=dtype)\n", " else:\n", " model.to(device=device)\n", "\n", " pretrained_loaded = False\n", " if pretrained:\n", " checkpoint_path = ''\n", " pretrained_cfg = get_pretrained_cfg(model_name, pretrained)\n", " if pretrained_cfg:\n", " checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)\n", " preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg)\n", " pretrained_quick_gelu = pretrained_cfg.get('quick_gelu', False)\n", " model_quick_gelu = model_cfg.get('quick_gelu', False)\n", " if pretrained_quick_gelu and not model_quick_gelu:\n", " warnings.warn(\n", " f'These pretrained weights were trained with QuickGELU activation but the model config does '\n", " f'not have that enabled. Consider using a model config with a \"-quickgelu\" suffix or enable with a flag.')\n", " elif not pretrained_quick_gelu and model_quick_gelu:\n", " warnings.warn(\n", " f'The pretrained weights were not trained with QuickGELU but this activation is enabled in the '\n", " f'model config, consider using a model config without QuickGELU or disable override flags.')\n", " elif os.path.exists(pretrained):\n", " checkpoint_path = pretrained\n", "\n", " if checkpoint_path:\n", " logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')\n", " load_checkpoint(model, checkpoint_path, weights_only=load_weights_only)\n", " else:\n", " error_str = (\n", " f'Pretrained weights ({pretrained}) not found for model {model_name}.'\n", " f' Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')\n", " logging.warning(error_str)\n", " raise RuntimeError(error_str)\n", " pretrained_loaded = True\n", " elif has_hf_hub_prefix:\n", " logging.info(f'Loading pretrained {model_name} weights ({checkpoint_path}).')\n", " load_checkpoint(model, checkpoint_path, weights_only=load_weights_only)\n", " pretrained_loaded = True\n", "\n", " if require_pretrained and not pretrained_loaded:\n", " # callers of create_model_from_pretrained always expect pretrained weights\n", " raise RuntimeError(\n", " f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')\n", "\n", " if output_dict and hasattr(model, \"output_dict\"):\n", " model.output_dict = True\n", "\n", " if jit:\n", " model = torch.jit.script(model)\n", "\n", " # set image preprocessing configuration in model attributes for convenience\n", " if getattr(model.visual, 'image_size', None) is not None:\n", " # use image_size set on model creation (via config or force_image_size arg)\n", " force_preprocess_cfg['size'] = model.visual.image_size\n", " set_model_preprocess_cfg(model, merge_preprocess_dict(preprocess_cfg, force_preprocess_cfg))\n", "\n", " return model\n", "\n", "def create_model_and_transforms(\n", " model_name: str,\n", " pretrained: Optional[str] = None,\n", " precision: str = 'fp32',\n", " device: Union[str, torch.device] = 'cpu',\n", " jit: bool = False,\n", " force_quick_gelu: bool = False,\n", " force_custom_text: bool = False,\n", " force_patch_dropout: Optional[float] = None,\n", " force_image_size: Optional[Union[int, Tuple[int, int]]] = None,\n", " image_mean: Optional[Tuple[float, ...]] = None,\n", " image_std: Optional[Tuple[float, ...]] = None,\n", " image_interpolation: Optional[str] = None,\n", " image_resize_mode: Optional[str] = None, # only effective for inference\n", " aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,\n", " pretrained_image: bool = False,\n", " pretrained_hf: bool = True,\n", " cache_dir: Optional[str] = None,\n", " output_dict: Optional[bool] = None,\n", " load_weights_only: bool = True,\n", " **model_kwargs,\n", "):\n", " force_preprocess_cfg = merge_preprocess_kwargs(\n", " {},\n", " mean=image_mean,\n", " std=image_std,\n", " interpolation=image_interpolation,\n", " resize_mode=image_resize_mode,\n", " )\n", "\n", " model = create_model(\n", " model_name,\n", " pretrained,\n", " precision=precision,\n", " device=device,\n", " jit=jit,\n", " force_quick_gelu=force_quick_gelu,\n", " force_custom_text=force_custom_text,\n", " force_patch_dropout=force_patch_dropout,\n", " force_image_size=force_image_size,\n", " force_preprocess_cfg=force_preprocess_cfg,\n", " pretrained_image=pretrained_image,\n", " pretrained_hf=pretrained_hf,\n", " cache_dir=cache_dir,\n", " output_dict=output_dict,\n", " load_weights_only=load_weights_only,\n", " **model_kwargs,\n", " )\n", "\n", " pp_cfg = PreprocessCfg(**model.visual.preprocess_cfg)\n", "\n", " preprocess_train = image_transform_v2(\n", " pp_cfg,\n", " is_train=True,\n", " aug_cfg=aug_cfg,\n", " )\n", " preprocess_val = image_transform_v2(\n", " pp_cfg,\n", " is_train=False,\n", " )\n", "\n", " return model, preprocess_train, preprocess_val" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "open_clip_model, open_clip_imgaug, open_clip_preprocess = create_model_and_transforms(\n", " model_name='ViT-H-14', pretrained='laion2b_s32b_b79k', device=device\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Old Dataset Class" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train Dataset: 62691\n", "Test Dataset: 11064\n" ] } ], "source": [ "# ==================================================================\n", "# I M A G E - A U D I O - D A T A S E T\n", "# ==================================================================\n", "class VaaniImageAudioDataset(torch.utils.data.Dataset):\n", " def __init__(self, df):\n", " self.image_paths = df.image_path.tolist()\n", " self.audio_paths = df.audio_path.tolist()\n", "\n", " def __len__(self):\n", " return len(self.audio_paths)\n", "\n", " def __getitem__(self, idx):\n", " return {\n", " 'image_path': self.image_paths[idx], \n", " 'audio_path': self.audio_paths[idx]\n", " }\n", "\n", "\n", "# def collate_fn(batch):\n", "# image_tensor = [item['image_path'] for item in batch]\n", "# audio_tensor = CLAPAudioProcessor([item['audio_path'] for item in batch], resample=True)\n", "# return {'image_tensor': torch.stack(image_tensor), 'audio_tensor': audio_tensor}\n", "\n", "def collate_fn(batch):\n", " image_tensor = [open_clip_imgaug(Image.open(item['image_path'])) for item in batch]\n", " image_paths = [item['image_path'] for item in batch]\n", " \n", " audio_paths = [item['audio_path'] for item in batch]\n", " audio_tensor = CLAPAudioProcessor(audio_paths, resample=True)\n", " \n", " return {\n", " 'image_tensor': torch.stack(image_tensor),\n", " 'image_paths': image_paths,\n", " 'audio_tensor': audio_tensor,\n", " 'audio_paths': audio_paths\n", " }\n", "\n", "\n", "\n", "train_df = pd.read_csv(\"/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/available_img_audios_TRAIN3.csv\")\n", "test_df = pd.read_csv(\"/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/available_img_audios_TEST3.csv\")\n", "train_dataset = VaaniImageAudioDataset(train_df)\n", "test_dataset = VaaniImageAudioDataset(test_df)\n", "\n", "print('Train Dataset:', len(train_dataset))\n", "print('Test Dataset:', len(test_dataset))\n", "\n", "BATCH_SIZE = int(128)\n", "train_dataloader = torch.utils.data.DataLoader(\n", " train_dataset,\n", " batch_size=BATCH_SIZE, \n", " shuffle=True, \n", " num_workers=48,\n", " collate_fn=collate_fn,\n", " pin_memory=True,\n", " drop_last=False,\n", " persistent_workers=True\n", ")\n", "\n", "test_dataloader = torch.utils.data.DataLoader(\n", " test_dataset,\n", " batch_size=BATCH_SIZE, \n", " shuffle=False, \n", " num_workers=48,\n", " collate_fn=collate_fn,\n", " pin_memory=True,\n", " drop_last=False,\n", " persistent_workers=True\n", ")\n", "\n", "# batch = next(iter(train_dataloader))\n", "# image_tensor_batch = batch['image_tensor'].to(device=device)\n", "# audio_tensor_batch = batch['audio_tensor'].to(device=device)\n", "# image_paths_batch = batch['image_paths']\n", "# audio_paths_batch = batch['audio_paths']\n", "# print(\"Image batch shape:\", image_tensor_batch.shape) # [BATCH_SIZE, 3, 224, 224]\n", "# print(\"Audio batch shape:\", audio_tensor_batch.shape) # [BATCH_SIZE, 1, 44100]" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "image_encoder = open_clip_model.visual.to(device=device)\n", "for param in image_encoder.parameters():\n", " param.requires_grad = False" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "from torchinfo import summary\n", "summary(model=image_encoder,\n", " input_data=(image_tensor_batch.to(device)),\n", " # input_size = (1, 3, config.IMAGE_SIZE, config.IMAGE_SIZE),\n", " dtypes=[torch.long],\n", " col_names = [\"input_size\", \"output_size\", \"num_params\", \"trainable\", \"params_percent\"],\n", " col_width=20,\n", " row_settings=[\"var_names\"],\n", " depth = 2,\n", " # verbose=2,\n", " # device=device\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train Image Features" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cuda:0\n" ] } ], "source": [ "print(device)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Processing Images]: 100%|\u001b[34m██████████\u001b[0m| 490/490 [24:02<00:00, 2.94s/it]\u001b[0m\n" ] } ], "source": [ "import gc\n", "import torch\n", "def force_gc():\n", " gc.collect()\n", " torch.cuda.empty_cache()\n", " torch.cuda.ipc_collect() # Optional: cleans up interprocess caches\n", "\n", "\n", "savedir = '/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Hindi_Image_features/'\n", "os.makedirs(savedir, exist_ok=True)\n", "\n", "train_loop = tqdm(train_dataloader, desc=f\"[Processing Images]\", colour='blue', ncols=70)\n", "for i, batch in enumerate(train_loop):\n", " # if i == 1:break\n", " \n", " image_tensor_batch = batch['image_tensor'].to(device=device)\n", " image_paths_batch = batch['image_paths']\n", " image_features = image_encoder(image_tensor_batch)\n", "\n", " for i in range(len(image_paths_batch)):\n", " torch.save({\n", " 'image_path': image_paths_batch[i],\n", " 'image_features': image_features[i].detach().cpu(),\n", " }, os.path.join(savedir, f\"{os.path.basename(image_paths_batch[i])}.pt\")\n", " )\n", " \n", " del image_features\n", " force_gc" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test Image Features" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Processing Images]: 100%|\u001b[34m████████████\u001b[0m| 87/87 [04:22<00:00, 3.02s/it]\u001b[0m\n" ] } ], "source": [ "import gc\n", "import torch\n", "def force_gc():\n", " gc.collect()\n", " torch.cuda.empty_cache()\n", " torch.cuda.ipc_collect() # Optional: cleans up interprocess caches\n", "\n", "\n", "savedir = '/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Hindi_Image_features/'\n", "os.makedirs(savedir, exist_ok=True)\n", "\n", "train_loop = tqdm(test_dataloader, desc=f\"[Processing Images]\", colour='blue', ncols=70)\n", "for i, batch in enumerate(train_loop):\n", " # if i == 1:break\n", " \n", " image_tensor_batch = batch['image_tensor'].to(device=device)\n", " image_paths_batch = batch['image_paths']\n", " image_features = image_encoder(image_tensor_batch)\n", "\n", " for i in range(len(image_paths_batch)):\n", " torch.save({\n", " 'image_path': image_paths_batch[i],\n", " 'image_features': image_features[i].detach().cpu(),\n", " }, os.path.join(savedir, f\"{os.path.basename(image_paths_batch[i])}.pt\")\n", " )\n", " \n", " del image_features\n", " force_gc" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "73755" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import os\n", "savedir = '/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Hindi_Image_features/'\n", "\n", "len(os.listdir(savedir))" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "584M\t/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Hindi_Image_features/\n" ] } ], "source": [ "!du -sh /scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Hindi_Image_features/" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train Audio Tensors" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Processing Audios]: 100%|\u001b[34m██████████\u001b[0m| 490/490 [04:26<00:00, 1.84it/s]\u001b[0m\n" ] } ], "source": [ "import gc\n", "import torch\n", "def force_gc():\n", " gc.collect()\n", " torch.cuda.empty_cache()\n", " torch.cuda.ipc_collect() # Optional: cleans up interprocess caches\n", "\n", "\n", "savedir = '/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Hindi_Audio_tensors/'\n", "os.makedirs(savedir, exist_ok=True)\n", "\n", "train_loop = tqdm(train_dataloader, desc=f\"[Processing Audios]\", colour='blue', ncols=70)\n", "for i, batch in enumerate(train_loop):\n", " # if i == 1:break\n", " \n", " audio_tensor_batch = batch['audio_tensor'].to(device=device)\n", " audio_paths_batch = batch['audio_paths']\n", "\n", " for i in range(len(audio_paths_batch)):\n", " torch.save({\n", " 'audio_paths': audio_paths_batch[i],\n", " 'audio_tensor': audio_tensor_batch[i].detach().cpu(),\n", " }, os.path.join(savedir, f\"{os.path.basename(audio_paths_batch[i])}.pt\")\n", " )\n", " \n", " del audio_tensor_batch\n", " force_gc" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "69G\t/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Hindi_Audio_tensors/\n" ] } ], "source": [ "!du -sh /scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Hindi_Audio_tensors/" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test Audio Tensors" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Processing Audios]: 100%|\u001b[34m████████████\u001b[0m| 87/87 [00:55<00:00, 1.58it/s]\u001b[0m\n" ] } ], "source": [ "import gc\n", "import torch\n", "def force_gc():\n", " gc.collect()\n", " torch.cuda.empty_cache()\n", " torch.cuda.ipc_collect() # Optional: cleans up interprocess caches\n", "\n", "\n", "savedir = '/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Hindi_Audio_tensors/'\n", "os.makedirs(savedir, exist_ok=True)\n", "\n", "test_loop = tqdm(test_dataloader, desc=f\"[Processing Audios]\", colour='blue', ncols=70)\n", "for i, batch in enumerate(test_loop):\n", " # if i == 1:break\n", " \n", " audio_tensor_batch = batch['audio_tensor'].to(device=device)\n", " audio_paths_batch = batch['audio_paths']\n", "\n", " for i in range(len(audio_paths_batch)):\n", " torch.save({\n", " 'audio_paths': audio_paths_batch[i],\n", " 'audio_tensor': audio_tensor_batch[i].detach().cpu(),\n", " }, os.path.join(savedir, f\"{os.path.basename(audio_paths_batch[i])}.pt\")\n", " )\n", " \n", " del audio_tensor_batch\n", " force_gc" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "83G\t/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Hindi_Audio_tensors/\n" ] } ], "source": [ "!du -sh /scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Hindi_Audio_tensors/" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "73755" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import os\n", "savedir = '/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Hindi_Audio_tensors/'\n", "\n", "len(os.listdir(savedir))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## New Dataset Class" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train Dataset: 26810\n", "Test Dataset: 11490\n" ] }, { "data": { "text/plain": [ "{'image_path': '/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Images/Folder3/IISc_VaaniProject_Lucknow-SPECIFIC_00826.jpg',\n", " 'image_feature': tensor([-0.1034, 0.4547, -0.3613, ..., -0.4897, -0.0025, 0.6462]),\n", " 'audio_path': '/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Audios/Hindi/UttarPradesh_Lucknow/IISc_VaaniProject_K_UttarPradesh_Lucknow_Lucknow844425030382473_010_Lucknow-SPECIFIC_00826_4706_6462.wav',\n", " 'audio_tensor': tensor([-0.0131, -0.0133, -0.0105, ..., -0.0070, -0.0086, -0.0096])}" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# ==================================================================\n", "# I M A G E - A U D I O - D A T A S E T\n", "# ==================================================================\n", "class VaaniImageAudioDataset(torch.utils.data.Dataset):\n", " def __init__(self, df, image_features_savedir, audio_tensors_savedir):\n", " self.image_paths = df.image_path.tolist()\n", " self.audio_paths = df.audio_path.tolist()\n", " self.image_features_savedir = image_features_savedir\n", " self.audio_tensors_savedir = audio_tensors_savedir\n", "\n", " def __len__(self):\n", " return len(self.audio_paths)\n", "\n", " def __getitem__(self, idx):\n", " return {\n", " 'image_path': self.image_paths[idx],\n", " 'image_feature': torch.load(os.path.join(\n", " self.image_features_savedir, \n", " f\"{os.path.basename(self.image_paths[idx])}.pt\"))['image_features'],\n", " 'audio_path': self.audio_paths[idx],\n", " 'audio_tensor': torch.load(os.path.join(\n", " audio_tensors_savedir, \n", " f\"{os.path.basename(self.audio_paths[idx])}.pt\"))['audio_tensor']\n", " }\n", " \n", "\n", "train_df = pd.read_csv(\"/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/available_img_audios_TRAIN2.csv\")\n", "test_df = pd.read_csv(\"/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/available_img_audios_TEST2.csv\")\n", "image_features_savedir = '/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Hindi_Image_features/'\n", "audio_tensors_savedir = '/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Hindi_Audio_tensors/'\n", "train_dataset = VaaniImageAudioDataset(train_df, image_features_savedir, audio_tensors_savedir)\n", "test_dataset = VaaniImageAudioDataset(test_df, image_features_savedir, audio_tensors_savedir)\n", "\n", "print('Train Dataset:', len(train_dataset))\n", "print('Test Dataset:', len(test_dataset))\n", "train_dataset[0]" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Image batch shape: torch.Size([64, 1024])\n", "Audio batch shape: torch.Size([64, 308700])\n" ] } ], "source": [ "BATCH_SIZE = int(64)\n", "train_dataloader = torch.utils.data.DataLoader(\n", " train_dataset,\n", " batch_size=BATCH_SIZE, \n", " shuffle=True, \n", " num_workers=48,\n", " pin_memory=True,\n", " drop_last=False,\n", " persistent_workers=True\n", ")\n", "\n", "test_dataloader = torch.utils.data.DataLoader(\n", " test_dataset,\n", " batch_size=BATCH_SIZE, \n", " shuffle=False, \n", " num_workers=48,\n", " pin_memory=True,\n", " drop_last=False,\n", " persistent_workers=True\n", ")\n", "\n", "batch = next(iter(train_dataloader))\n", "image_features_batch = batch['image_feature'].to(device=device)\n", "audio_tensor_batch = batch['audio_tensor'].to(device=device)\n", "image_paths_batch = batch['image_path']\n", "audio_paths_batch = batch['audio_path']\n", "print(\"Image batch shape:\", image_features_batch.shape) # [BATCH_SIZE, 3, 224, 224]\n", "print(\"Audio batch shape:\", audio_tensor_batch.shape) # [BATCH_SIZE, 1, 44100]\n" ] } ], "metadata": { "kernelspec": { "display_name": "clap", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.11" } }, "nbformat": 4, "nbformat_minor": 2 }