| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | from __future__ import annotations |
| | import warnings |
| | warnings.filterwarnings("ignore") |
| |
|
| |
|
| | import os |
| | import io |
| | import sys |
| | import math |
| | import random |
| | import collections |
| | import collections.abc |
| | import re |
| | from itertools import repeat |
| | from pathlib import Path |
| | from typing import Optional, Tuple, Union, List, Dict |
| |
|
| | import csv |
| | import numpy as np |
| | import pandas as pd |
| | from PIL import Image |
| | import seaborn as sns |
| | import matplotlib.pyplot as plt |
| | from tqdm import trange, tqdm |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from torch import nn |
| | from torch.nn.init import _calculate_fan_in_and_fan_out |
| | import torch.utils.checkpoint as checkpoint |
| |
|
| | import torchvision as tv |
| | from torchvision.transforms import v2 |
| | from torch.utils.tensorboard import SummaryWriter |
| | |
| |
|
| | os.environ["CUDA_VISIBLE_DEVICES"] = "1" |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | print(f"Using device: {device}") |
| |
|
| | import torchaudio |
| | import torchaudio.transforms as T |
| | from torchlibrosa.stft import Spectrogram, LogmelFilterBank |
| | from torchlibrosa.augmentation import SpecAugmentation |
| |
|
| | from transformers import AutoModel, AutoTokenizer, logging |
| | from huggingface_hub.file_download import hf_hub_download |
| | from huggingface_hub.file_download import hf_hub_download |
| | from peft import get_peft_config, get_peft_model |
| | from transformers import CLIPVisionModel, AutoProcessor |
| |
|
| | from watermark import watermark |
| | print(watermark( |
| | author='Ashish', |
| | |
| | current_date=True, |
| | datename=True, |
| | current_time=True, |
| | iso8601=True, |
| | timezone=True, |
| | updated=True, |
| | custom_time=None, |
| | python=True, |
| | |
| | conda=True, |
| | hostname=True, |
| | machine=True, |
| | watermark=False, |
| | iversions=True, |
| | gpu=True, |
| | globals_=globals() |
| | )) |
| |
|
| | from typing import Any, Dict, Optional, Tuple, Union |
| | import numbers |
| | import random |
| | import warnings |
| | from dataclasses import dataclass, asdict |
| | from typing import Any, Dict, List, Optional, Sequence, Tuple, Union |
| |
|
| | import torch |
| | import torchvision.transforms.functional as F |
| | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ |
| | CenterCrop, ColorJitter, Grayscale |
| |
|
| | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) |
| | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) |
| | IMAGENET_MEAN = (0.485, 0.456, 0.406) |
| | IMAGENET_STD = (0.229, 0.224, 0.225) |
| | INCEPTION_MEAN = (0.5, 0.5, 0.5) |
| | INCEPTION_STD = (0.5, 0.5, 0.5) |
| |
|
| | |
| | HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin" |
| | HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" |
| | HF_CONFIG_NAME = 'open_clip_config.json' |
| |
|
| |
|
| | import collections.abc |
| | from itertools import repeat |
| | from typing import List, Optional, Tuple, Union |
| |
|
| | import torch |
| | from torch import nn as nn |
| | from torch import _assert |
| | from torchvision.ops.misc import FrozenBatchNorm2d |
| |
|
| |
|
| | def freeze_batch_norm_2d(module, module_match={}, name=''): |
| | """ |
| | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is |
| | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and |
| | returned. Otherwise, the module is walked recursively and submodules are converted in place. |
| | |
| | Args: |
| | module (torch.nn.Module): Any PyTorch module. |
| | module_match (dict): Dictionary of full module names to freeze (all if empty) |
| | name (str): Full module name (prefix) |
| | |
| | Returns: |
| | torch.nn.Module: Resulting module |
| | |
| | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 |
| | """ |
| | res = module |
| | is_match = True |
| | if module_match: |
| | is_match = name in module_match |
| | if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): |
| | res = FrozenBatchNorm2d(module.num_features) |
| | res.num_features = module.num_features |
| | res.affine = module.affine |
| | if module.affine: |
| | res.weight.data = module.weight.data.clone().detach() |
| | res.bias.data = module.bias.data.clone().detach() |
| | res.running_mean.data = module.running_mean.data |
| | res.running_var.data = module.running_var.data |
| | res.eps = module.eps |
| | else: |
| | for child_name, child in module.named_children(): |
| | full_child_name = '.'.join([name, child_name]) if name else child_name |
| | new_child = freeze_batch_norm_2d(child, module_match, full_child_name) |
| | if new_child is not child: |
| | res.add_module(child_name, new_child) |
| | return res |
| |
|
| |
|
| | |
| | def _ntuple(n): |
| | def parse(x): |
| | if isinstance(x, collections.abc.Iterable): |
| | return x |
| | return tuple(repeat(x, n)) |
| | return parse |
| |
|
| |
|
| | to_1tuple = _ntuple(1) |
| | to_2tuple = _ntuple(2) |
| | to_3tuple = _ntuple(3) |
| | to_4tuple = _ntuple(4) |
| | to_ntuple = lambda n, x: _ntuple(n)(x) |
| |
|
| | |
| | |
| | def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True): |
| | for name, module in model.named_children(): |
| | if len(list(module.children())) > 0: |
| | replace_linear(module, linear_replacement, include_modules, copy_weights) |
| |
|
| | if isinstance(module, torch.nn.Linear) and name in include_modules: |
| | old_module = model._modules[name] |
| | model._modules[name] = linear_replacement( |
| | module.in_features, |
| | module.out_features, |
| | module.bias is not None, |
| | ) |
| | if copy_weights: |
| | model._modules[name].weight.data.copy_(old_module.weight.data) |
| | if model._modules[name].bias is not None: |
| | model._modules[name].bias.data.copy_(old_module.bias) |
| |
|
| | return model |
| |
|
| | def convert_int8_model_to_inference_mode(model): |
| | for m in model.modules(): |
| | if hasattr(m, 'prepare_for_eval'): |
| | int8_original_dtype = m.weight.dtype |
| | m.prepare_for_eval() |
| | m.int8_original_dtype = int8_original_dtype |
| |
|
| |
|
| | def feature_take_indices( |
| | num_features: int, |
| | indices: Optional[Union[int, List[int]]] = None, |
| | as_set: bool = False, |
| | ) -> Tuple[List[int], int]: |
| | """ Determine the absolute feature indices to 'take' from. |
| | |
| | Note: This function can be called in forward() so must be torchscript compatible, |
| | which requires some incomplete typing and workaround hacks. |
| | |
| | Args: |
| | num_features: total number of features to select from |
| | indices: indices to select, |
| | None -> select all |
| | int -> select last n |
| | list/tuple of int -> return specified (-ve indices specify from end) |
| | as_set: return as a set |
| | |
| | Returns: |
| | List (or set) of absolute (from beginning) indices, Maximum index |
| | """ |
| | if indices is None: |
| | indices = num_features |
| |
|
| | if isinstance(indices, int): |
| | |
| | _assert(0 < indices <= num_features, f'last-n ({indices}) is out of range (1 to {num_features})') |
| | take_indices = [num_features - indices + i for i in range(indices)] |
| | else: |
| | take_indices: List[int] = [] |
| | for i in indices: |
| | idx = num_features + i if i < 0 else i |
| | _assert(0 <= idx < num_features, f'feature index {idx} is out of range (0 to {num_features - 1})') |
| | take_indices.append(idx) |
| |
|
| | if not torch.jit.is_scripting() and as_set: |
| | return set(take_indices), max(take_indices) |
| |
|
| | return take_indices, max(take_indices) |
| |
|
| |
|
| | def _out_indices_as_tuple(x: Union[int, Tuple[int, ...]]) -> Tuple[int, ...]: |
| | if isinstance(x, int): |
| | |
| | return tuple(range(-x, 0)) |
| | return tuple(x) |
| |
|
| |
|
| |
|
| | import copy |
| | import copy |
| | import hashlib |
| | import os |
| | import urllib |
| | import warnings |
| | from functools import partial |
| | from typing import Dict, Iterable, Optional, Union |
| |
|
| | from tqdm import tqdm |
| |
|
| |
|
| | try: |
| | import safetensors.torch |
| | _has_safetensors = True |
| | except ImportError: |
| | _has_safetensors = False |
| |
|
| | __version__ = '2.32.0' |
| |
|
| |
|
| | """ CLIP Model |
| | |
| | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. |
| | """ |
| | import copy |
| | import logging |
| | import math |
| | from dataclasses import dataclass |
| | from typing import Any, Dict, List, Optional, Tuple, Union |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | from torch import nn |
| | from torch.utils.checkpoint import checkpoint |
| | from functools import partial |
| |
|
| | |
| | |
| | from collections import OrderedDict |
| | import math |
| | from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union |
| |
|
| | import torch |
| | from torch import nn |
| | from torch.nn import functional as F |
| | from torch.utils.checkpoint import checkpoint |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | import numpy as np |
| |
|
| | import torch |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): |
| | """ |
| | grid_size: int of the grid height and width |
| | return: |
| | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) |
| | """ |
| | grid_h = np.arange(grid_size, dtype=np.float32) |
| | grid_w = np.arange(grid_size, dtype=np.float32) |
| | grid = np.meshgrid(grid_w, grid_h) |
| | grid = np.stack(grid, axis=0) |
| |
|
| | grid = grid.reshape([2, 1, grid_size, grid_size]) |
| | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) |
| | if cls_token: |
| | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) |
| | return pos_embed |
| |
|
| |
|
| | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): |
| | assert embed_dim % 2 == 0 |
| |
|
| | |
| | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) |
| | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) |
| |
|
| | emb = np.concatenate([emb_h, emb_w], axis=1) |
| | return emb |
| |
|
| |
|
| | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): |
| | """ |
| | embed_dim: output dimension for each position |
| | pos: a list of positions to be encoded: size (M,) |
| | out: (M, D) |
| | """ |
| | assert embed_dim % 2 == 0 |
| | omega = np.arange(embed_dim // 2, dtype=float) |
| | omega /= embed_dim / 2. |
| | omega = 1. / 10000**omega |
| |
|
| | pos = pos.reshape(-1) |
| | out = np.einsum('m,d->md', pos, omega) |
| |
|
| | emb_sin = np.sin(out) |
| | emb_cos = np.cos(out) |
| |
|
| | emb = np.concatenate([emb_sin, emb_cos], axis=1) |
| | return emb |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | def interpolate_pos_embed(model, checkpoint_model): |
| | if 'pos_embed' in checkpoint_model: |
| | pos_embed_checkpoint = checkpoint_model['pos_embed'] |
| | embedding_size = pos_embed_checkpoint.shape[-1] |
| | num_patches = model.patch_embed.num_patches |
| | num_extra_tokens = model.pos_embed.shape[-2] - num_patches |
| | |
| | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) |
| | |
| | new_size = int(num_patches ** 0.5) |
| | |
| | if orig_size != new_size: |
| | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) |
| | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] |
| | |
| | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] |
| | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) |
| | pos_tokens = torch.nn.functional.interpolate( |
| | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) |
| | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) |
| | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) |
| | checkpoint_model['pos_embed'] = new_pos_embed |
| |
|
| |
|
| |
|
| | from collections import OrderedDict |
| | from typing import Dict, List, Optional, Union |
| |
|
| | import torch |
| | from torch import nn |
| | from torch.nn import functional as F |
| |
|
| | |
| |
|
| |
|
| | class Bottleneck(nn.Module): |
| | expansion = 4 |
| |
|
| | def __init__(self, inplanes, planes, stride=1): |
| | super().__init__() |
| |
|
| | |
| | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) |
| | self.bn1 = nn.BatchNorm2d(planes) |
| | self.act1 = nn.ReLU(inplace=True) |
| |
|
| | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) |
| | self.bn2 = nn.BatchNorm2d(planes) |
| | self.act2 = nn.ReLU(inplace=True) |
| |
|
| | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() |
| |
|
| | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) |
| | self.bn3 = nn.BatchNorm2d(planes * self.expansion) |
| | self.act3 = nn.ReLU(inplace=True) |
| |
|
| | self.downsample = None |
| | self.stride = stride |
| |
|
| | if stride > 1 or inplanes != planes * Bottleneck.expansion: |
| | |
| | self.downsample = nn.Sequential(OrderedDict([ |
| | ("-1", nn.AvgPool2d(stride)), |
| | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), |
| | ("1", nn.BatchNorm2d(planes * self.expansion)) |
| | ])) |
| |
|
| | def forward(self, x: torch.Tensor): |
| | identity = x |
| |
|
| | out = self.act1(self.bn1(self.conv1(x))) |
| | out = self.act2(self.bn2(self.conv2(out))) |
| | out = self.avgpool(out) |
| | out = self.bn3(self.conv3(out)) |
| |
|
| | if self.downsample is not None: |
| | identity = self.downsample(x) |
| |
|
| | out += identity |
| | out = self.act3(out) |
| | return out |
| |
|
| |
|
| | class AttentionPool2d(nn.Module): |
| | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): |
| | super().__init__() |
| | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) |
| | self.k_proj = nn.Linear(embed_dim, embed_dim) |
| | self.q_proj = nn.Linear(embed_dim, embed_dim) |
| | self.v_proj = nn.Linear(embed_dim, embed_dim) |
| | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) |
| | self.num_heads = num_heads |
| |
|
| | def forward(self, x): |
| | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) |
| | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) |
| | x = x + self.positional_embedding[:, None, :].to(x.dtype) |
| | x, _ = F.multi_head_attention_forward( |
| | query=x, key=x, value=x, |
| | embed_dim_to_check=x.shape[-1], |
| | num_heads=self.num_heads, |
| | q_proj_weight=self.q_proj.weight, |
| | k_proj_weight=self.k_proj.weight, |
| | v_proj_weight=self.v_proj.weight, |
| | in_proj_weight=None, |
| | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), |
| | bias_k=None, |
| | bias_v=None, |
| | add_zero_attn=False, |
| | dropout_p=0., |
| | out_proj_weight=self.c_proj.weight, |
| | out_proj_bias=self.c_proj.bias, |
| | use_separate_proj_weight=True, |
| | training=self.training, |
| | need_weights=False |
| | ) |
| |
|
| | return x[0] |
| |
|
| |
|
| | class ModifiedResNet(nn.Module): |
| | """ |
| | A ResNet class that is similar to torchvision's but contains the following changes: |
| | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. |
| | - Performs antialiasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 |
| | - The final pooling layer is a QKV attention instead of an average pool |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | layers: List[int], |
| | output_dim: int, |
| | heads: int, |
| | image_size: int = 224, |
| | width: int = 64, |
| | ): |
| | super().__init__() |
| | self.output_dim = output_dim |
| | self.image_size = image_size |
| |
|
| | |
| | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) |
| | self.bn1 = nn.BatchNorm2d(width // 2) |
| | self.act1 = nn.ReLU(inplace=True) |
| | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) |
| | self.bn2 = nn.BatchNorm2d(width // 2) |
| | self.act2 = nn.ReLU(inplace=True) |
| | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) |
| | self.bn3 = nn.BatchNorm2d(width) |
| | self.act3 = nn.ReLU(inplace=True) |
| | self.avgpool = nn.AvgPool2d(2) |
| |
|
| | |
| | self._inplanes = width |
| | self.layer1 = self._make_layer(width, layers[0]) |
| | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) |
| | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) |
| | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) |
| |
|
| | embed_dim = width * 32 |
| | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) |
| |
|
| | self.init_parameters() |
| |
|
| | def _make_layer(self, planes, blocks, stride=1): |
| | layers = [Bottleneck(self._inplanes, planes, stride)] |
| |
|
| | self._inplanes = planes * Bottleneck.expansion |
| | for _ in range(1, blocks): |
| | layers.append(Bottleneck(self._inplanes, planes)) |
| |
|
| | return nn.Sequential(*layers) |
| |
|
| | def init_parameters(self): |
| | if self.attnpool is not None: |
| | std = self.attnpool.c_proj.in_features ** -0.5 |
| | nn.init.normal_(self.attnpool.q_proj.weight, std=std) |
| | nn.init.normal_(self.attnpool.k_proj.weight, std=std) |
| | nn.init.normal_(self.attnpool.v_proj.weight, std=std) |
| | nn.init.normal_(self.attnpool.c_proj.weight, std=std) |
| |
|
| | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: |
| | for name, param in resnet_block.named_parameters(): |
| | if name.endswith("bn3.weight"): |
| | nn.init.zeros_(param) |
| |
|
| | def lock(self, unlocked_groups=0, freeze_bn_stats=False): |
| | assert unlocked_groups == 0, 'partial locking not currently supported for this model' |
| | for param in self.parameters(): |
| | param.requires_grad = False |
| | if freeze_bn_stats: |
| | freeze_batch_norm_2d(self) |
| |
|
| | @torch.jit.ignore |
| | def set_grad_checkpointing(self, enable=True): |
| | |
| | pass |
| |
|
| | def stem(self, x): |
| | x = self.act1(self.bn1(self.conv1(x))) |
| | x = self.act2(self.bn2(self.conv2(x))) |
| | x = self.act3(self.bn3(self.conv3(x))) |
| | x = self.avgpool(x) |
| | return x |
| |
|
| | def forward_intermediates( |
| | self, |
| | x: torch.Tensor, |
| | indices: Optional[Union[int, List[int]]] = None, |
| | stop_early: bool = False, |
| | normalize_intermediates: bool = False, |
| | intermediates_only: bool = False, |
| | output_fmt: str = 'NCHW', |
| | output_extra_tokens: bool = False, |
| | ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: |
| | """ Forward features that returns intermediates. |
| | |
| | Args: |
| | x: Input image tensor |
| | indices: Take last n blocks if int, all if None, select matching indices if sequence |
| | stop_early: Stop iterating over blocks when last desired intermediate hit |
| | normalize_intermediates: Apply final norm layer to all intermediates |
| | intermediates_only: Only return intermediate features |
| | output_fmt: Shape of intermediate feature outputs |
| | output_extra_tokens: Return both extra class, eot tokens |
| | Returns: |
| | |
| | """ |
| | assert output_fmt in ('NCHW',), 'Output format must be == NCHW.' |
| | |
| | take_indices, max_index = feature_take_indices(5, indices) |
| |
|
| | output = {} |
| | intermediates = [] |
| | blocks = [self.stem, self.layer1, self.layer2, self.layer3, self.layer4] |
| | if torch.jit.is_scripting() or not stop_early: |
| | blocks = blocks[:max_index + 1] |
| | for i, blk in enumerate(blocks): |
| | x = blk(x) |
| | if i in take_indices: |
| | intermediates.append(x) |
| |
|
| | output['image_intermediates'] = intermediates |
| |
|
| | if intermediates_only: |
| | return output |
| |
|
| | x = self.attnpool(x) |
| | output['image_features'] = x |
| |
|
| | return output |
| |
|
| | def forward(self, x): |
| | x = self.stem(x) |
| | x = self.layer1(x) |
| | x = self.layer2(x) |
| | x = self.layer3(x) |
| | x = self.layer4(x) |
| | x = self.attnpool(x) |
| |
|
| | return x |
| |
|
| |
|
| | """ huggingface model adapter |
| | |
| | Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model. |
| | """ |
| | import re |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch import TensorType |
| |
|
| | try: |
| | import transformers |
| | from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig |
| | from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \ |
| | BaseModelOutputWithPoolingAndCrossAttentions |
| | except ImportError as e: |
| | transformers = None |
| |
|
| |
|
| | class BaseModelOutput: |
| | pass |
| |
|
| |
|
| | class PretrainedConfig: |
| | pass |
| |
|
| | |
| | |
| | arch_dict = { |
| | |
| | "roberta": { |
| | "config_names": { |
| | "context_length": "max_position_embeddings", |
| | "vocab_size": "vocab_size", |
| | "width": "hidden_size", |
| | "heads": "num_attention_heads", |
| | "layers": "num_hidden_layers", |
| | "layer_attr": "layer", |
| | "token_embeddings_attr": "embeddings" |
| | }, |
| | "pooler": "mean_pooler", |
| | }, |
| | |
| | "xlm-roberta": { |
| | "config_names": { |
| | "context_length": "max_position_embeddings", |
| | "vocab_size": "vocab_size", |
| | "width": "hidden_size", |
| | "heads": "num_attention_heads", |
| | "layers": "num_hidden_layers", |
| | "layer_attr": "layer", |
| | "token_embeddings_attr": "embeddings" |
| | }, |
| | "pooler": "mean_pooler", |
| | }, |
| | |
| | "mt5": { |
| | "config_names": { |
| | |
| | |
| | |
| | "context_length": "", |
| | "vocab_size": "vocab_size", |
| | "width": "d_model", |
| | "heads": "num_heads", |
| | "layers": "num_layers", |
| | "layer_attr": "block", |
| | "token_embeddings_attr": "embed_tokens" |
| | }, |
| | "pooler": "mean_pooler", |
| | }, |
| | |
| | "bert": { |
| | "config_names": { |
| | "context_length": "max_position_embeddings", |
| | "vocab_size": "vocab_size", |
| | "width": "hidden_size", |
| | "heads": "num_attention_heads", |
| | "layers": "num_hidden_layers", |
| | }, |
| | "pooler": "cls_pooler", |
| | }, |
| | |
| | "m2m_100": { |
| | "config_names": { |
| | "context_length": "max_position_embeddings", |
| | "vocab_size": "vocab_size", |
| | "width": "d_model", |
| | "heads": "encoder_attention_heads", |
| | "layers": "encoder_layers", |
| | }, |
| | "pooler": "cls_pooler", |
| | }, |
| | } |
| |
|
| |
|
| |
|
| | |
| | def _camel2snake(s): |
| | return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower() |
| |
|
| |
|
| | |
| | _POOLERS = {} |
| |
|
| |
|
| | def register_pooler(cls): |
| | """Decorator registering pooler class""" |
| | _POOLERS[_camel2snake(cls.__name__)] = cls |
| | return cls |
| |
|
| |
|
| | @register_pooler |
| | class MeanPooler(nn.Module): |
| | """Mean pooling""" |
| |
|
| | def forward(self, x: BaseModelOutput, attention_mask: TensorType): |
| | masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1) |
| | return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True) |
| |
|
| |
|
| | @register_pooler |
| | class MaxPooler(nn.Module): |
| | """Max pooling""" |
| |
|
| | def forward(self, x: BaseModelOutput, attention_mask: TensorType): |
| | masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf) |
| | return masked_output.max(1).values |
| |
|
| |
|
| | @register_pooler |
| | class ClsPooler(nn.Module): |
| | """CLS token pooling""" |
| |
|
| | def __init__(self, use_pooler_output=True): |
| | super().__init__() |
| | self.cls_token_position = 0 |
| | self.use_pooler_output = use_pooler_output |
| |
|
| | def forward(self, x: BaseModelOutput, attention_mask: TensorType): |
| | if (self.use_pooler_output and |
| | isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and |
| | (x.pooler_output is not None) |
| | ): |
| | return x.pooler_output |
| |
|
| | return x.last_hidden_state[:, self.cls_token_position, :] |
| |
|
| |
|
| | @register_pooler |
| | class ClsLastHiddenStatePooler(nn.Module): |
| | """CLS token pooling |
| | NOTE: this is equivalent to ClsPooler above with use_pooler_output=False |
| | """ |
| |
|
| | def __init__(self): |
| | super().__init__() |
| | self.cls_token_position = 0 |
| |
|
| | def forward(self, x: BaseModelOutput, attention_mask: TensorType): |
| | return x.last_hidden_state[:, self.cls_token_position, :] |
| |
|
| |
|
| | class HFTextEncoder(nn.Module): |
| | """HuggingFace model adapter""" |
| | output_tokens: torch.jit.Final[bool] |
| |
|
| | def __init__( |
| | self, |
| | model_name_or_path: str, |
| | output_dim: int, |
| | config: PretrainedConfig = None, |
| | pooler_type: str = None, |
| | proj_type: str = None, |
| | pretrained: bool = True, |
| | output_tokens: bool = False, |
| | ): |
| | super().__init__() |
| | self.output_tokens = output_tokens |
| | self.output_dim = output_dim |
| |
|
| | |
| | uses_transformer_pooler = (pooler_type == "cls_pooler") |
| |
|
| | if transformers is None: |
| | raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models") |
| | if config is None: |
| | self.config = AutoConfig.from_pretrained(model_name_or_path) |
| | create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else ( |
| | AutoModel.from_config, self.config) |
| | |
| | if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder: |
| | self.transformer = create_func(model_args) |
| | self.transformer = self.transformer.encoder |
| | else: |
| | self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler) |
| | else: |
| | self.config = config |
| | self.transformer = AutoModel.from_config(config) |
| | if pooler_type is None: |
| | pooler_type = (arch_dict[self.config.model_type]["pooler"]) |
| |
|
| | |
| | self.vocab_size = getattr(self.config, 'vocab_size', 0) |
| | self.context_length = getattr(self.config, 'max_position_embeddings', 0) |
| |
|
| | self.pooler = _POOLERS[pooler_type]() |
| |
|
| | d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"]) |
| | if (d_model == output_dim) and (proj_type is None): |
| | self.proj = nn.Identity() |
| | elif proj_type == 'linear': |
| | self.proj = nn.Linear(d_model, output_dim, bias=False) |
| | elif proj_type == 'mlp': |
| | hidden_size = (d_model + output_dim) // 2 |
| | self.proj = nn.Sequential( |
| | nn.Linear(d_model, hidden_size, bias=False), |
| | nn.GELU(), |
| | nn.Linear(hidden_size, output_dim, bias=False), |
| | ) |
| |
|
| | def forward(self, x: TensorType): |
| | attn_mask = (x != self.config.pad_token_id).long() |
| | out = self.transformer(input_ids=x, attention_mask=attn_mask) |
| | pooled_out = self.pooler(out, attn_mask) |
| | projected = self.proj(pooled_out) |
| |
|
| | seq_len = out.last_hidden_state.shape[1] |
| | tokens = ( |
| | out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :] |
| | if type(self.pooler) == ClsPooler |
| | else out.last_hidden_state |
| | ) |
| | |
| | if self.output_tokens: |
| | return projected, tokens |
| | return projected |
| |
|
| | def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): |
| | if not unlocked_layers: |
| | for n, p in self.transformer.named_parameters(): |
| | p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False |
| | return |
| |
|
| | encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer |
| | layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"]) |
| | print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model") |
| | embeddings = getattr( |
| | self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"]) |
| | modules = [embeddings, *layer_list][:-unlocked_layers] |
| | |
| | for module in modules: |
| | for n, p in module.named_parameters(): |
| | p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False |
| |
|
| | @torch.jit.ignore |
| | def set_grad_checkpointing(self, enable=True): |
| | self.transformer.gradient_checkpointing_enable() |
| |
|
| | def init_parameters(self): |
| | pass |
| |
|
| |
|
| | """ timm model adapter |
| | |
| | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. |
| | """ |
| | import logging |
| | from collections import OrderedDict |
| | from typing import Dict, List, Optional, Tuple, Union |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | try: |
| | import timm |
| | from timm.layers import RotAttentionPool2d |
| | from timm.layers import AttentionPool2d as AbsAttentionPool2d |
| | from timm.layers import Mlp, to_2tuple |
| | except ImportError: |
| | timm = None |
| |
|
| |
|
| |
|
| | class TimmModel(nn.Module): |
| | """ timm model adapter |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | model_name: str, |
| | embed_dim: int, |
| | image_size: Union[int, Tuple[int, int]] = 224, |
| | pool: str = 'avg', |
| | proj: str = 'linear', |
| | proj_bias: bool = False, |
| | drop: float = 0., |
| | drop_path: Optional[float] = None, |
| | patch_drop: Optional[float] = None, |
| | pretrained: bool = False, |
| | ): |
| | super().__init__() |
| | if timm is None: |
| | raise RuntimeError("Please install the latest timm (`pip install timm`) to use timm based models.") |
| | self.image_size = to_2tuple(image_size) |
| |
|
| | |
| | timm_kwargs = {} |
| | if drop_path is not None: |
| | timm_kwargs['drop_path_rate'] = drop_path |
| | if patch_drop is not None: |
| | timm_kwargs['patch_drop_rate'] = patch_drop |
| |
|
| | custom_pool = pool in ('abs_attn', 'rot_attn') |
| | if proj: |
| | assert proj in ("linear", "mlp", "none") |
| | extra_proj = proj in ("linear", "mlp") |
| | if not extra_proj and not custom_pool: |
| | |
| | |
| | proj_dim = 0 if proj == 'none' else embed_dim |
| | self.trunk = timm.create_model( |
| | model_name, |
| | num_classes=proj_dim, |
| | global_pool=pool, |
| | pretrained=pretrained, |
| | **timm_kwargs, |
| | ) |
| | prev_chs = embed_dim |
| | else: |
| | self.trunk = timm.create_model( |
| | model_name, |
| | pretrained=pretrained, |
| | **timm_kwargs, |
| | ) |
| | feat_size = self.trunk.default_cfg.get('pool_size', None) |
| | feature_ndim = 1 if not feat_size else 2 |
| | if custom_pool: |
| | assert feature_ndim == 2 |
| | |
| | self.trunk.reset_classifier(0, global_pool='') |
| | else: |
| | |
| | reset_kwargs = dict(global_pool=pool) if pool else {} |
| | self.trunk.reset_classifier(0, **reset_kwargs) |
| | prev_chs = self.trunk.num_features |
| |
|
| | head_layers = OrderedDict() |
| |
|
| | |
| | if pool == 'abs_attn': |
| | head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) |
| | prev_chs = embed_dim |
| | elif pool == 'rot_attn': |
| | head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) |
| | prev_chs = embed_dim |
| |
|
| | |
| | if proj == 'linear': |
| | head_layers['drop'] = nn.Dropout(drop) |
| | head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) |
| | elif proj == 'mlp': |
| | head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias)) |
| |
|
| | self.head = nn.Sequential(head_layers) |
| |
|
| | def lock(self, unlocked_groups: int = 0, freeze_bn_stats: bool = False): |
| | """ lock modules |
| | Args: |
| | unlocked_groups (int): leave last n layer groups unlocked (default: 0) |
| | """ |
| | if not unlocked_groups: |
| | |
| | for param in self.trunk.parameters(): |
| | param.requires_grad = False |
| | if freeze_bn_stats: |
| | freeze_batch_norm_2d(self.trunk) |
| | else: |
| | |
| | try: |
| | |
| | from timm.models.helpers import group_parameters, group_modules |
| | except ImportError: |
| | raise RuntimeError( |
| | 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') |
| | matcher = self.trunk.group_matcher() |
| | gparams = group_parameters(self.trunk, matcher) |
| | max_layer_id = max(gparams.keys()) |
| | max_layer_id = max_layer_id - unlocked_groups |
| | for group_idx in range(max_layer_id + 1): |
| | group = gparams[group_idx] |
| | for param in group: |
| | self.trunk.get_parameter(param).requires_grad = False |
| | if freeze_bn_stats: |
| | gmodules = group_modules(self.trunk, matcher, reverse=True) |
| | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} |
| | freeze_batch_norm_2d(self.trunk, gmodules) |
| |
|
| | @torch.jit.ignore |
| | def set_grad_checkpointing(self, enable: bool = True): |
| | try: |
| | self.trunk.set_grad_checkpointing(enable) |
| | except Exception as e: |
| | logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') |
| |
|
| | def forward_intermediates( |
| | self, |
| | x: torch.Tensor, |
| | indices: Optional[Union[int, List[int]]] = None, |
| | stop_early: bool = False, |
| | normalize_intermediates: bool = False, |
| | intermediates_only: bool = False, |
| | output_fmt: str = 'NCHW', |
| | output_extra_tokens: bool = False, |
| | ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: |
| | """ Forward features that returns intermediates. |
| | |
| | Args: |
| | x: Input image tensor |
| | indices: Take last n blocks if int, all if None, select matching indices if sequence |
| | stop_early: Stop iterating over blocks when last desired intermediate hit |
| | normalize_intermediates: Apply norm layer to all intermediates |
| | intermediates_only: Only return intermediate features |
| | output_fmt: Shape of intermediate feature outputs |
| | output_extra_tokens: Return both prefix and spatial intermediate tokens |
| | Returns: |
| | """ |
| | extra_args = {} |
| | if output_extra_tokens: |
| | extra_args['return_prefix_tokens'] = True |
| | trunk_output = self.trunk.forward_intermediates( |
| | x, |
| | indices=indices, |
| | intermediates_only=intermediates_only, |
| | norm=normalize_intermediates, |
| | stop_early=stop_early, |
| | output_fmt=output_fmt, |
| | **extra_args, |
| | ) |
| |
|
| | return_dict = {} |
| | intermediates = trunk_output if intermediates_only else trunk_output[1] |
| | if output_extra_tokens and intermediates and isinstance(intermediates[0], tuple): |
| | intermediates_prefix = [xi[1] for xi in intermediates] |
| | intermediates = [xi[0] for xi in intermediates] |
| | return_dict['image_intermediates_prefix'] = intermediates_prefix |
| |
|
| | return_dict['image_intermediates'] = intermediates |
| | if intermediates_only: |
| | return return_dict |
| |
|
| | image_features = self.trunk.forward_head(trunk_output[0]) |
| | image_features = self.head(image_features) |
| | return_dict['image_features'] = image_features |
| | return return_dict |
| |
|
| | def forward(self, x): |
| | x = self.trunk(x) |
| | x = self.head(x) |
| | return x |
| |
|
| |
|
| | class LayerNormFp32(nn.LayerNorm): |
| | """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" |
| |
|
| | def forward(self, x: torch.Tensor): |
| | orig_type = x.dtype |
| | x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps) |
| | return x.to(orig_type) |
| |
|
| |
|
| | class LayerNorm(nn.LayerNorm): |
| | """Subclass torch's LayerNorm (with cast back to input dtype).""" |
| |
|
| | def forward(self, x: torch.Tensor): |
| | orig_type = x.dtype |
| | x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) |
| | return x.to(orig_type) |
| |
|
| |
|
| | class QuickGELU(nn.Module): |
| | |
| | def forward(self, x: torch.Tensor): |
| | return x * torch.sigmoid(1.702 * x) |
| |
|
| |
|
| | class LayerScale(nn.Module): |
| | def __init__(self, dim, init_values=1e-5, inplace=False): |
| | super().__init__() |
| | self.inplace = inplace |
| | self.gamma = nn.Parameter(init_values * torch.ones(dim)) |
| |
|
| | def forward(self, x): |
| | return x.mul_(self.gamma) if self.inplace else x * self.gamma |
| |
|
| |
|
| | class PatchDropout(nn.Module): |
| | """ |
| | https://arxiv.org/abs/2212.00794 |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | prob: float = 0.5, |
| | exclude_first_token: bool = True |
| | ): |
| | super().__init__() |
| | assert 0 <= prob < 1. |
| | self.prob = prob |
| | self.exclude_first_token = exclude_first_token |
| |
|
| | def forward(self, x): |
| | if not self.training or self.prob == 0.: |
| | return x |
| |
|
| | if self.exclude_first_token: |
| | cls_tokens, x = x[:, :1], x[:, 1:] |
| | else: |
| | cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) |
| |
|
| | batch = x.size()[0] |
| | num_tokens = x.size()[1] |
| |
|
| | batch_indices = torch.arange(batch) |
| | batch_indices = batch_indices[..., None] |
| |
|
| | keep_prob = 1 - self.prob |
| | num_patches_keep = max(1, int(num_tokens * keep_prob)) |
| |
|
| | rand = torch.randn(batch, num_tokens) |
| | patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices |
| |
|
| | x = x[batch_indices, patch_indices_keep] |
| |
|
| | if self.exclude_first_token: |
| | x = torch.cat((cls_tokens, x), dim=1) |
| |
|
| | return x |
| |
|
| |
|
| | class Attention(nn.Module): |
| | def __init__( |
| | self, |
| | dim: int, |
| | num_heads: int = 8, |
| | qkv_bias: bool = True, |
| | scaled_cosine: bool = False, |
| | scale_heads: bool = False, |
| | logit_scale_max: float = math.log(1. / 0.01), |
| | batch_first: bool = True, |
| | attn_drop: float = 0., |
| | proj_drop: float = 0. |
| | ): |
| | super().__init__() |
| | self.scaled_cosine = scaled_cosine |
| | self.scale_heads = scale_heads |
| | assert dim % num_heads == 0, 'dim should be divisible by num_heads' |
| | self.num_heads = num_heads |
| | self.head_dim = dim // num_heads |
| | self.scale = self.head_dim ** -0.5 |
| | self.logit_scale_max = logit_scale_max |
| | self.batch_first = batch_first |
| | self.use_fsdpa = hasattr(nn.functional, 'scaled_dot_product_attention') |
| |
|
| | |
| | self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) |
| | if qkv_bias: |
| | self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) |
| | else: |
| | self.in_proj_bias = None |
| |
|
| | if self.scaled_cosine: |
| | self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) |
| | else: |
| | self.logit_scale = None |
| | self.attn_drop = nn.Dropout(attn_drop) |
| | if self.scale_heads: |
| | self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) |
| | else: |
| | self.head_scale = None |
| | self.out_proj = nn.Linear(dim, dim) |
| | self.out_drop = nn.Dropout(proj_drop) |
| |
|
| | def forward(self, x, attn_mask: Optional[torch.Tensor] = None): |
| | if self.batch_first: |
| | x = x.transpose(0, 1) |
| |
|
| | L, N, C = x.shape |
| | q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) |
| | q = q.reshape(L, N * self.num_heads, -1).transpose(0, 1) |
| | k = k.reshape(L, N * self.num_heads, -1).transpose(0, 1) |
| | v = v.reshape(L, N * self.num_heads, -1).transpose(0, 1) |
| |
|
| | if attn_mask is not None and attn_mask.dtype == torch.bool: |
| | new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) |
| | new_attn_mask.masked_fill_(attn_mask, float("-inf")) |
| | attn_mask = new_attn_mask |
| |
|
| | if self.logit_scale is not None: |
| | attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) |
| | logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() |
| | attn = attn.view(N, self.num_heads, L, L) * logit_scale |
| | attn = attn.view(-1, L, L) |
| | if attn_mask is not None: |
| | attn = attn + attn_mask |
| | attn = attn.softmax(dim=-1) |
| | attn = self.attn_drop(attn) |
| | x = torch.bmm(attn, v) |
| | else: |
| | if self.use_fsdpa: |
| | x = F.scaled_dot_product_attention( |
| | q, k, v, |
| | attn_mask=attn_mask, |
| | dropout_p=self.attn_drop.p if self.training else 0., |
| | ) |
| | else: |
| | q = q * self.scale |
| | attn = torch.bmm(q, k.transpose(-1, -2)) |
| | if attn_mask is not None: |
| | attn += attn_mask |
| | attn = attn.softmax(dim=-1) |
| | attn = self.attn_drop(attn) |
| | x = torch.bmm(attn, v) |
| |
|
| | if self.head_scale is not None: |
| | x = x.view(N, self.num_heads, L, C) * self.head_scale |
| | x = x.view(-1, L, C) |
| |
|
| | x = x.transpose(0, 1).reshape(L, N, C) |
| |
|
| | if self.batch_first: |
| | x = x.transpose(0, 1) |
| |
|
| | x = self.out_proj(x) |
| | x = self.out_drop(x) |
| | return x |
| |
|
| |
|
| | class AttentionalPooler(nn.Module): |
| | def __init__( |
| | self, |
| | d_model: int, |
| | context_dim: int, |
| | n_head: int = 8, |
| | n_queries: int = 256, |
| | norm_layer: Callable = LayerNorm, |
| | ): |
| | super().__init__() |
| | self.query = nn.Parameter(torch.randn(n_queries, d_model)) |
| | self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim, batch_first=True) |
| | self.ln_q = norm_layer(d_model) |
| | self.ln_k = norm_layer(context_dim) |
| |
|
| | def forward(self, x: torch.Tensor): |
| | N = x.shape[0] |
| | x = self.ln_k(x) |
| | q = self.ln_q(self.query) |
| | out = self.attn(q.unsqueeze(0).expand(N, -1, -1), x, x, need_weights=False)[0] |
| | return out |
| |
|
| |
|
| | class ResidualAttentionBlock(nn.Module): |
| | def __init__( |
| | self, |
| | d_model: int, |
| | n_head: int, |
| | mlp_ratio: float = 4.0, |
| | ls_init_value: float = None, |
| | act_layer: Callable = nn.GELU, |
| | norm_layer: Callable = LayerNorm, |
| | is_cross_attention: bool = False, |
| | batch_first: bool = True, |
| | ): |
| | super().__init__() |
| |
|
| | self.ln_1 = norm_layer(d_model) |
| | self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=batch_first) |
| | self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() |
| | if is_cross_attention: |
| | self.ln_1_kv = norm_layer(d_model) |
| |
|
| | self.ln_2 = norm_layer(d_model) |
| | mlp_width = int(d_model * mlp_ratio) |
| | self.mlp = nn.Sequential(OrderedDict([ |
| | ("c_fc", nn.Linear(d_model, mlp_width)), |
| | ("gelu", act_layer()), |
| | ("c_proj", nn.Linear(mlp_width, d_model)) |
| | ])) |
| | self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() |
| |
|
| | def attention( |
| | self, |
| | q_x: torch.Tensor, |
| | k_x: Optional[torch.Tensor] = None, |
| | v_x: Optional[torch.Tensor] = None, |
| | attn_mask: Optional[torch.Tensor] = None, |
| | ): |
| | k_x = k_x if k_x is not None else q_x |
| | v_x = v_x if v_x is not None else q_x |
| |
|
| | attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None |
| | return self.attn( |
| | q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask |
| | )[0] |
| |
|
| | def forward( |
| | self, |
| | q_x: torch.Tensor, |
| | k_x: Optional[torch.Tensor] = None, |
| | v_x: Optional[torch.Tensor] = None, |
| | attn_mask: Optional[torch.Tensor] = None, |
| | ): |
| | k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None |
| | v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None |
| | x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)) |
| | x = x + self.ls_2(self.mlp(self.ln_2(x))) |
| | return x |
| |
|
| |
|
| | class CustomResidualAttentionBlock(nn.Module): |
| | def __init__( |
| | self, |
| | d_model: int, |
| | n_head: int, |
| | mlp_ratio: float = 4.0, |
| | ls_init_value: float = None, |
| | act_layer: Callable = nn.GELU, |
| | norm_layer: Callable = LayerNorm, |
| | scale_cosine_attn: bool = False, |
| | scale_heads: bool = False, |
| | scale_attn: bool = False, |
| | scale_fc: bool = False, |
| | batch_first: bool = True, |
| | ): |
| | super().__init__() |
| |
|
| | self.ln_1 = norm_layer(d_model) |
| | self.attn = Attention( |
| | d_model, |
| | n_head, |
| | scaled_cosine=scale_cosine_attn, |
| | scale_heads=scale_heads, |
| | batch_first=batch_first, |
| | ) |
| | self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity() |
| | self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() |
| |
|
| | self.ln_2 = norm_layer(d_model) |
| | mlp_width = int(d_model * mlp_ratio) |
| | self.mlp = nn.Sequential(OrderedDict([ |
| | ("c_fc", nn.Linear(d_model, mlp_width)), |
| | ("gelu", act_layer()), |
| | ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()), |
| | ("c_proj", nn.Linear(mlp_width, d_model)) |
| | ])) |
| | self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() |
| |
|
| | def get_reference_weight(self): |
| | return self.mlp.c_fc.weight |
| |
|
| | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): |
| | x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask))) |
| | x = x + self.ls_2(self.mlp(self.ln_2(x))) |
| | return x |
| |
|
| |
|
| | class CustomTransformer(nn.Module): |
| | """ A custom transformer that can use different block types. """ |
| | def __init__( |
| | self, |
| | width: int, |
| | layers: int, |
| | heads: int, |
| | mlp_ratio: float = 4.0, |
| | ls_init_value: float = None, |
| | act_layer: Callable = nn.GELU, |
| | norm_layer: Callable = LayerNorm, |
| | batch_first: bool = True, |
| | block_types: Union[str, List[str]] = 'CustomResidualAttentionBlock', |
| | ): |
| | super().__init__() |
| | self.width = width |
| | self.layers = layers |
| | self.batch_first = batch_first |
| | self.grad_checkpointing = False |
| |
|
| | if isinstance(block_types, str): |
| | block_types = [block_types] * layers |
| | assert len(block_types) == layers |
| |
|
| | def _create_block(bt: str): |
| | if bt == 'CustomResidualAttentionBlock': |
| | return CustomResidualAttentionBlock( |
| | width, |
| | heads, |
| | mlp_ratio=mlp_ratio, |
| | ls_init_value=ls_init_value, |
| | act_layer=act_layer, |
| | norm_layer=norm_layer, |
| | batch_first=batch_first, |
| | ) |
| | else: |
| | assert False |
| |
|
| | self.resblocks = nn.ModuleList([ |
| | _create_block(bt) |
| | for bt in block_types |
| | ]) |
| |
|
| | def get_cast_dtype(self) -> torch.dtype: |
| | weight = self.resblocks[0].get_reference_weight() |
| | if hasattr(weight, 'int8_original_dtype'): |
| | return weight.int8_original_dtype |
| | return weight.dtype |
| |
|
| | def forward_intermediates( |
| | self, |
| | x: torch.Tensor, |
| | attn_mask: Optional[torch.Tensor] = None, |
| | indices: Optional[Union[int, List[int]]] = None, |
| | stop_early: bool = False, |
| | ): |
| | take_indices, max_index = feature_take_indices(len(self.resblocks), indices) |
| |
|
| | if not self.batch_first: |
| | x = x.transpose(0, 1).contiguous() |
| |
|
| | intermediates = [] |
| | if torch.jit.is_scripting() or not stop_early: |
| | blocks = self.resblocks |
| | else: |
| | blocks = self.resblocks[:max_index + 1] |
| | for i, blk in enumerate(blocks): |
| | if self.grad_checkpointing and not torch.jit.is_scripting(): |
| | x = checkpoint(blk, x, None, None, attn_mask, use_reentrant=False) |
| | else: |
| | x = blk(x, attn_mask=attn_mask) |
| |
|
| | if i in take_indices: |
| | intermediates.append(x.transpose(0, 1) if not self.batch_first else x) |
| |
|
| | if not self.batch_first: |
| | x = x.transpose(0, 1) |
| |
|
| | return x, intermediates |
| |
|
| | def prune_intermediate_layers(self, indices: Union[int, List[int]] = 1): |
| | """ Prune layers not required for specified intermediates. |
| | """ |
| | take_indices, max_index = feature_take_indices(len(self.resblocks), indices) |
| | self.resblocks = self.resblocks[:max_index + 1] |
| | return take_indices |
| |
|
| | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): |
| | if not self.batch_first: |
| | x = x.transpose(0, 1) |
| |
|
| | for r in self.resblocks: |
| | if self.grad_checkpointing and not torch.jit.is_scripting(): |
| | |
| | x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False) |
| | else: |
| | x = r(x, attn_mask=attn_mask) |
| |
|
| | if not self.batch_first: |
| | x = x.transpose(0, 1) |
| | return x |
| |
|
| |
|
| | class Transformer(nn.Module): |
| | def __init__( |
| | self, |
| | width: int, |
| | layers: int, |
| | heads: int, |
| | mlp_ratio: float = 4.0, |
| | ls_init_value: float = None, |
| | act_layer: Callable = nn.GELU, |
| | norm_layer: Callable = LayerNorm, |
| | batch_first: bool = True, |
| | ): |
| | super().__init__() |
| | self.width = width |
| | self.layers = layers |
| | self.batch_first = batch_first |
| | self.grad_checkpointing = False |
| |
|
| | self.resblocks = nn.ModuleList([ |
| | ResidualAttentionBlock( |
| | width, |
| | heads, |
| | mlp_ratio, |
| | ls_init_value=ls_init_value, |
| | act_layer=act_layer, |
| | norm_layer=norm_layer, |
| | batch_first=batch_first, |
| | ) |
| | for _ in range(layers) |
| | ]) |
| |
|
| | def get_cast_dtype(self) -> torch.dtype: |
| | if hasattr(self.resblocks[0].mlp.c_fc, 'int8_original_dtype'): |
| | return self.resblocks[0].mlp.c_fc.int8_original_dtype |
| | return self.resblocks[0].mlp.c_fc.weight.dtype |
| |
|
| | def forward_intermediates( |
| | self, |
| | x: torch.Tensor, |
| | attn_mask: Optional[torch.Tensor] = None, |
| | indices: Optional[Union[int, List[int]]] = None, |
| | stop_early: bool = False, |
| | ): |
| | take_indices, max_index = feature_take_indices(len(self.resblocks), indices) |
| |
|
| | if not self.batch_first: |
| | x = x.transpose(0, 1).contiguous() |
| |
|
| | intermediates = [] |
| | if torch.jit.is_scripting() or not stop_early: |
| | blocks = self.resblocks |
| | else: |
| | blocks = self.resblocks[:max_index + 1] |
| | for i, blk in enumerate(blocks): |
| | if self.grad_checkpointing and not torch.jit.is_scripting(): |
| | x = checkpoint(blk, x, None, None, attn_mask, use_reentrant=False) |
| | else: |
| | x = blk(x, attn_mask=attn_mask) |
| |
|
| | if i in take_indices: |
| | intermediates.append(x.transpose(0, 1) if not self.batch_first else x) |
| |
|
| | if not self.batch_first: |
| | x = x.transpose(0, 1) |
| |
|
| | return x, intermediates |
| |
|
| | def prune_intermediate_layers(self, indices: Union[int, List[int]] = 1): |
| | """ Prune layers not required for specified intermediates. |
| | """ |
| | take_indices, max_index = feature_take_indices(len(self.resblocks), indices) |
| | self.resblocks = self.resblocks[:max_index + 1] |
| | return take_indices |
| |
|
| | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): |
| | if not self.batch_first: |
| | x = x.transpose(0, 1).contiguous() |
| |
|
| | for r in self.resblocks: |
| | if self.grad_checkpointing and not torch.jit.is_scripting(): |
| | |
| | x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False) |
| | else: |
| | x = r(x, attn_mask=attn_mask) |
| |
|
| | if not self.batch_first: |
| | x = x.transpose(0, 1) |
| | return x |
| |
|
| |
|
| | def _expand_token(token, batch_size: int): |
| | return token.view(1, 1, -1).expand(batch_size, -1, -1) |
| |
|
| |
|
| | class VisionTransformer(nn.Module): |
| | output_tokens: torch.jit.Final[bool] |
| |
|
| | def __init__( |
| | self, |
| | image_size: int, |
| | patch_size: int, |
| | width: int, |
| | layers: int, |
| | heads: int, |
| | mlp_ratio: float, |
| | ls_init_value: float = None, |
| | attentional_pool: bool = False, |
| | attn_pooler_queries: int = 256, |
| | attn_pooler_heads: int = 8, |
| | output_dim: int = 512, |
| | patch_dropout: float = 0., |
| | no_ln_pre: bool = False, |
| | pos_embed_type: str = 'learnable', |
| | pool_type: str = 'tok', |
| | final_ln_after_pool: bool = False, |
| | act_layer: Callable = nn.GELU, |
| | norm_layer: Callable = LayerNorm, |
| | output_tokens: bool = False, |
| | ): |
| | super().__init__() |
| | assert pool_type in ('tok', 'avg', 'none') |
| | self.output_tokens = output_tokens |
| | image_height, image_width = self.image_size = to_2tuple(image_size) |
| | patch_height, patch_width = self.patch_size = to_2tuple(patch_size) |
| | self.grid_size = (image_height // patch_height, image_width // patch_width) |
| | self.final_ln_after_pool = final_ln_after_pool |
| | self.output_dim = output_dim |
| |
|
| | self.conv1 = nn.Conv2d( |
| | in_channels=3, |
| | out_channels=width, |
| | kernel_size=patch_size, |
| | stride=patch_size, |
| | bias=False, |
| | ) |
| |
|
| | |
| | scale = width ** -0.5 |
| | self.class_embedding = nn.Parameter(scale * torch.randn(width)) |
| | if pos_embed_type == 'learnable': |
| | self.positional_embedding = nn.Parameter( |
| | scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) |
| | elif pos_embed_type == 'sin_cos_2d': |
| | |
| | assert self.grid_size[0] == self.grid_size[1],\ |
| | 'currently sin cos 2d pos embedding only supports square input' |
| | self.positional_embedding = nn.Parameter( |
| | torch.zeros(self.grid_size[0] * self.grid_size[1] + 1, width), requires_grad=False) |
| | pos_embed_type = get_2d_sincos_pos_embed(width, self.grid_size[0], cls_token=True) |
| | self.positional_embedding.data.copy_(torch.from_numpy(pos_embed_type).float()) |
| | else: |
| | raise ValueError |
| |
|
| | |
| | self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() |
| |
|
| | self.ln_pre = nn.Identity() if no_ln_pre else norm_layer(width) |
| | self.transformer = Transformer( |
| | width, |
| | layers, |
| | heads, |
| | mlp_ratio, |
| | ls_init_value=ls_init_value, |
| | act_layer=act_layer, |
| | norm_layer=norm_layer, |
| | ) |
| |
|
| | if attentional_pool: |
| | if isinstance(attentional_pool, str): |
| | self.attn_pool_type = attentional_pool |
| | self.pool_type = 'none' |
| | if attentional_pool in ('parallel', 'cascade'): |
| | self.attn_pool = AttentionalPooler( |
| | output_dim, |
| | width, |
| | n_head=attn_pooler_heads, |
| | n_queries=attn_pooler_queries, |
| | ) |
| | self.attn_pool_contrastive = AttentionalPooler( |
| | output_dim, |
| | width, |
| | n_head=attn_pooler_heads, |
| | n_queries=1, |
| | ) |
| | else: |
| | assert False |
| | else: |
| | self.attn_pool_type = '' |
| | self.pool_type = pool_type |
| | self.attn_pool = AttentionalPooler( |
| | output_dim, |
| | width, |
| | n_head=attn_pooler_heads, |
| | n_queries=attn_pooler_queries, |
| | ) |
| | self.attn_pool_contrastive = None |
| | pool_dim = output_dim |
| | else: |
| | self.attn_pool = None |
| | pool_dim = width |
| | self.pool_type = pool_type |
| |
|
| | self.ln_post = norm_layer(pool_dim) |
| | self.proj = nn.Parameter(scale * torch.randn(pool_dim, output_dim)) |
| |
|
| | self.init_parameters() |
| |
|
| | def lock(self, unlocked_groups: int = 0, freeze_bn_stats: bool = False): |
| | for param in self.parameters(): |
| | param.requires_grad = False |
| |
|
| | if unlocked_groups != 0: |
| | groups = [ |
| | [ |
| | self.conv1, |
| | self.class_embedding, |
| | self.positional_embedding, |
| | self.ln_pre, |
| | ], |
| | *self.transformer.resblocks[:-1], |
| | [ |
| | self.transformer.resblocks[-1], |
| | self.ln_post, |
| | ], |
| | self.proj, |
| | ] |
| |
|
| | def _unlock(x): |
| | if isinstance(x, Sequence): |
| | for g in x: |
| | _unlock(g) |
| | else: |
| | if isinstance(x, torch.nn.Parameter): |
| | x.requires_grad = True |
| | else: |
| | for p in x.parameters(): |
| | p.requires_grad = True |
| |
|
| | _unlock(groups[-unlocked_groups:]) |
| |
|
| | def init_parameters(self): |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | pass |
| |
|
| | @torch.jit.ignore |
| | def set_grad_checkpointing(self, enable: bool = True): |
| | self.transformer.grad_checkpointing = enable |
| |
|
| | @torch.jit.ignore |
| | def no_weight_decay(self): |
| | |
| | no_wd = {'positional_embedding', 'class_embedding'} |
| | return no_wd |
| |
|
| | def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| | if self.pool_type == 'avg': |
| | pooled, tokens = x[:, 1:].mean(dim=1), x[:, 1:] |
| | elif self.pool_type == 'tok': |
| | pooled, tokens = x[:, 0], x[:, 1:] |
| | else: |
| | pooled = tokens = x |
| |
|
| | return pooled, tokens |
| |
|
| | def _embeds(self, x:torch.Tensor) -> torch.Tensor: |
| | x = self.conv1(x) |
| | x = x.reshape(x.shape[0], x.shape[1], -1) |
| | x = x.permute(0, 2, 1) |
| |
|
| | |
| | x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1) |
| | |
| | x = x + self.positional_embedding.to(x.dtype) |
| |
|
| | |
| | x = self.patch_dropout(x) |
| |
|
| | |
| | x = self.ln_pre(x) |
| | return x |
| |
|
| | def _pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| | if self.attn_pool is not None: |
| | if self.attn_pool_contrastive is not None: |
| | |
| | x = self.ln_post(x) |
| | tokens = self.attn_pool(x) |
| | if self.attn_pool_type == 'parallel': |
| | pooled = self.attn_pool_contrastive(x) |
| | else: |
| | assert self.attn_pool_type == 'cascade' |
| | pooled = self.attn_pool_contrastive(tokens) |
| | else: |
| | |
| | x = self.attn_pool(x) |
| | x = self.ln_post(x) |
| | pooled, tokens = self._global_pool(x) |
| | elif self.final_ln_after_pool: |
| | pooled, tokens = self._global_pool(x) |
| | pooled = self.ln_post(pooled) |
| | else: |
| | x = self.ln_post(x) |
| | pooled, tokens = self._global_pool(x) |
| |
|
| | return pooled, tokens |
| |
|
| | def forward_intermediates( |
| | self, |
| | x: torch.Tensor, |
| | indices: Optional[Union[int, List[int]]] = None, |
| | stop_early: bool = False, |
| | normalize_intermediates: bool = False, |
| | intermediates_only: bool = False, |
| | output_fmt: str = 'NCHW', |
| | output_extra_tokens: bool = False, |
| | ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: |
| | """ Forward features that returns intermediates. |
| | |
| | Args: |
| | x: Input image tensor |
| | indices: Take last n blocks if int, all if None, select matching indices if sequence |
| | stop_early: Stop iterating over blocks when last desired intermediate hit |
| | intermediates_only: Only return intermediate features |
| | normalize_intermediates: Apply final norm layer to all intermediates |
| | output_fmt: Shape of intermediate feature outputs |
| | output_extra_tokens: Return both extra prefix class tokens |
| | Returns: |
| | |
| | """ |
| | assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.' |
| | reshape = output_fmt == 'NCHW' |
| |
|
| | |
| | B, _, height, width = x.shape |
| | x = self._embeds(x) |
| | x, intermediates = self.transformer.forward_intermediates( |
| | x, |
| | indices=indices, |
| | stop_early=stop_early, |
| | ) |
| |
|
| | |
| | if normalize_intermediates: |
| | |
| | intermediates = [self.ln_post(xi) for xi in intermediates] |
| | num_prefix_tokens = 1 |
| | if num_prefix_tokens: |
| | |
| | prefix_tokens = [y[:, 0:num_prefix_tokens] for y in intermediates] |
| | intermediates = [y[:, num_prefix_tokens:] for y in intermediates] |
| | else: |
| | prefix_tokens = None |
| | if reshape: |
| | |
| | H, W = height // self.patch_size[0], width // self.patch_size[1] |
| | intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates] |
| |
|
| | output = {'image_intermediates': intermediates} |
| | if prefix_tokens is not None and output_extra_tokens: |
| | output['image_intermediates_prefix'] = prefix_tokens |
| |
|
| | if intermediates_only: |
| | return output |
| |
|
| | pooled, _ = self._pool(x) |
| |
|
| | if self.proj is not None: |
| | pooled = pooled @ self.proj |
| |
|
| | output['image_features'] = pooled |
| |
|
| | return output |
| |
|
| | def prune_intermediate_layers( |
| | self, |
| | indices: Union[int, List[int]] = 1, |
| | prune_norm: bool = False, |
| | prune_head: bool = True, |
| | ): |
| | """ Prune layers not required for specified intermediates. |
| | """ |
| | take_indices = self.transformer.prune_intermediate_layers(indices) |
| | if prune_norm: |
| | self.ln_post = nn.Identity() |
| | if prune_head: |
| | self.proj = None |
| | return take_indices |
| |
|
| | def forward(self, x: torch.Tensor): |
| | x = self._embeds(x) |
| | x = self.transformer(x) |
| | pooled, tokens = self._pool(x) |
| |
|
| | if self.proj is not None: |
| | pooled = pooled @ self.proj |
| |
|
| | if self.output_tokens: |
| | return pooled, tokens |
| | |
| | return pooled |
| |
|
| |
|
| | def text_global_pool( |
| | x: torch.Tensor, |
| | text: Optional[torch.Tensor] = None, |
| | pool_type: str = 'argmax', |
| | ) -> torch.Tensor: |
| | if pool_type == 'first': |
| | pooled = x[:, 0] |
| | elif pool_type == 'last': |
| | pooled = x[:, -1] |
| | elif pool_type == 'argmax': |
| | |
| | assert text is not None |
| | pooled = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] |
| | else: |
| | pooled = x |
| |
|
| | return pooled |
| |
|
| |
|
| | class TextTransformer(nn.Module): |
| | output_tokens: torch.jit.Final[bool] |
| |
|
| | def __init__( |
| | self, |
| | context_length: int = 77, |
| | vocab_size: int = 49408, |
| | width: int = 512, |
| | heads: int = 8, |
| | layers: int = 12, |
| | mlp_ratio: float = 4.0, |
| | ls_init_value: float = None, |
| | output_dim: Optional[int] = 512, |
| | embed_cls: bool = False, |
| | no_causal_mask: bool = False, |
| | pad_id: int = 0, |
| | pool_type: str = 'argmax', |
| | proj_type: str = 'linear', |
| | proj_bias: bool = False, |
| | act_layer: Callable = nn.GELU, |
| | norm_layer: Callable = LayerNorm, |
| | output_tokens: bool = False, |
| | ): |
| | super().__init__() |
| | assert pool_type in ('first', 'last', 'argmax', 'none') |
| | self.output_tokens = output_tokens |
| | self.num_pos = self.context_length = context_length |
| | self.vocab_size = vocab_size |
| | self.width = width |
| | self.output_dim = output_dim |
| | self.heads = heads |
| | self.pad_id = pad_id |
| | self.pool_type = pool_type |
| |
|
| | self.token_embedding = nn.Embedding(vocab_size, width) |
| | if embed_cls: |
| | self.cls_emb = nn.Parameter(torch.empty(width)) |
| | self.num_pos += 1 |
| | else: |
| | self.cls_emb = None |
| | self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width)) |
| | self.transformer = Transformer( |
| | width=width, |
| | layers=layers, |
| | heads=heads, |
| | mlp_ratio=mlp_ratio, |
| | ls_init_value=ls_init_value, |
| | act_layer=act_layer, |
| | norm_layer=norm_layer, |
| | ) |
| | self.ln_final = norm_layer(width) |
| |
|
| | if no_causal_mask: |
| | self.attn_mask = None |
| | else: |
| | self.register_buffer('attn_mask', self.build_causal_mask(), persistent=False) |
| |
|
| | if proj_type == 'none' or not output_dim: |
| | self.text_projection = None |
| | else: |
| | if proj_bias: |
| | self.text_projection = nn.Linear(width, output_dim) |
| | else: |
| | self.text_projection = nn.Parameter(torch.empty(width, output_dim)) |
| |
|
| | self.init_parameters() |
| |
|
| | def init_parameters(self): |
| | nn.init.normal_(self.token_embedding.weight, std=0.02) |
| | nn.init.normal_(self.positional_embedding, std=0.01) |
| | if self.cls_emb is not None: |
| | nn.init.normal_(self.cls_emb, std=0.01) |
| |
|
| | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) |
| | attn_std = self.transformer.width ** -0.5 |
| | fc_std = (2 * self.transformer.width) ** -0.5 |
| | for block in self.transformer.resblocks: |
| | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) |
| | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) |
| | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) |
| | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) |
| |
|
| | if self.text_projection is not None: |
| | if isinstance(self.text_projection, nn.Linear): |
| | nn.init.normal_(self.text_projection.weight, std=self.transformer.width ** -0.5) |
| | if self.text_projection.bias is not None: |
| | nn.init.zeros_(self.text_projection.bias) |
| | else: |
| | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) |
| |
|
| | @torch.jit.ignore |
| | def set_grad_checkpointing(self, enable=True): |
| | self.transformer.grad_checkpointing = enable |
| |
|
| | @torch.jit.ignore |
| | def no_weight_decay(self): |
| | |
| | no_wd = {'positional_embedding'} |
| | if self.cls_emb is not None: |
| | no_wd.add('cls_emb') |
| | return no_wd |
| |
|
| | def build_causal_mask(self): |
| | |
| | |
| | mask = torch.empty(self.num_pos, self.num_pos) |
| | mask.fill_(float("-inf")) |
| | mask.triu_(1) |
| | return mask |
| |
|
| | def build_cls_mask(self, text, cast_dtype: torch.dtype): |
| | cls_mask = (text != self.pad_id).unsqueeze(1) |
| | cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True) |
| | additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device) |
| | additive_mask.fill_(0) |
| | additive_mask.masked_fill_(~cls_mask, float("-inf")) |
| | additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0) |
| | return additive_mask |
| |
|
| | def _embeds(self, text) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
| | cast_dtype = self.transformer.get_cast_dtype() |
| | seq_len = text.shape[1] |
| | x = self.token_embedding(text).to(cast_dtype) |
| | attn_mask = self.attn_mask |
| | if self.cls_emb is not None: |
| | seq_len += 1 |
| | x = torch.cat([x, _expand_token(self.cls_emb, x.shape[0])], dim=1) |
| | cls_mask = self.build_cls_mask(text, cast_dtype) |
| | if attn_mask is not None: |
| | attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len] |
| | x = x + self.positional_embedding[:seq_len].to(cast_dtype) |
| | return x, attn_mask |
| |
|
| | def forward_intermediates( |
| | self, |
| | text: torch.Tensor, |
| | indices: Optional[Union[int, List[int]]] = None, |
| | stop_early: bool = False, |
| | normalize_intermediates: bool = False, |
| | intermediates_only: bool = False, |
| | output_fmt: str = 'NCHW', |
| | output_extra_tokens: bool = False, |
| | ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: |
| | """ Forward features that returns intermediates. |
| | |
| | Args: |
| | text: Input text ids |
| | indices: Take last n blocks if int, all if None, select matching indices if sequence |
| | stop_early: Stop iterating over blocks when last desired intermediate hit |
| | normalize_intermediates: Apply norm layer to all intermediates |
| | intermediates_only: Only return intermediate features |
| | output_fmt: Shape of intermediate feature outputs |
| | output_extra_tokens: Return both prefix and intermediate tokens |
| | Returns: |
| | |
| | """ |
| | assert output_fmt in ('NLC',), 'Output format must be NLC.' |
| | |
| | x, attn_mask = self._embeds(text) |
| | x, intermediates = self.transformer.forward_intermediates( |
| | x, |
| | attn_mask=attn_mask, |
| | indices=indices, |
| | stop_early=stop_early, |
| | ) |
| |
|
| | |
| | if normalize_intermediates: |
| | |
| | intermediates = [self.ln_final(xi) for xi in intermediates] |
| |
|
| | output = {} |
| |
|
| | if self.cls_emb is not None: |
| | seq_intermediates = [xi[:, :-1] for xi in intermediates] |
| | if output_extra_tokens: |
| | |
| | cls_intermediates = [xi[:, -1:] for xi in intermediates] |
| | output['text_intermediates_suffix'] = cls_intermediates |
| | intermediates = seq_intermediates |
| | output['text_intermediates'] = intermediates |
| |
|
| | if intermediates_only: |
| | return output |
| |
|
| | if self.cls_emb is not None: |
| | |
| | pooled = text_global_pool(x, pool_type='last') |
| | pooled = self.ln_final(pooled) |
| | else: |
| | x = self.ln_final(x) |
| | pooled = text_global_pool(x, text, pool_type=self.pool_type) |
| |
|
| | if self.text_projection is not None: |
| | if isinstance(self.text_projection, nn.Linear): |
| | pooled = self.text_projection(pooled) |
| | else: |
| | pooled = pooled @ self.text_projection |
| |
|
| | output['text_features'] = pooled |
| |
|
| | return output |
| |
|
| | def prune_intermediate_layers( |
| | self, |
| | indices: Union[int, List[int]] = 1, |
| | prune_norm: bool = False, |
| | prune_head: bool = True, |
| | ): |
| | """ Prune layers not required for specified intermediates. |
| | """ |
| | take_indices = self.transformer.prune_intermediate_layers(indices) |
| | if prune_norm: |
| | self.ln_final = nn.Identity() |
| | if prune_head: |
| | self.text_projection = None |
| | return take_indices |
| |
|
| | def forward(self, text): |
| | x, attn_mask = self._embeds(text) |
| |
|
| | x = self.transformer(x, attn_mask=attn_mask) |
| |
|
| | |
| | if self.cls_emb is not None: |
| | |
| | pooled = text_global_pool(x, pool_type='last') |
| | pooled = self.ln_final(pooled) |
| | tokens = x[:, :-1] |
| | else: |
| | x = self.ln_final(x) |
| | pooled = text_global_pool(x, text, pool_type=self.pool_type) |
| | tokens = x |
| |
|
| | if self.text_projection is not None: |
| | if isinstance(self.text_projection, nn.Linear): |
| | pooled = self.text_projection(pooled) |
| | else: |
| | pooled = pooled @ self.text_projection |
| |
|
| | if self.output_tokens: |
| | return pooled, tokens |
| |
|
| | return pooled |
| |
|
| |
|
| | class MultimodalTransformer(Transformer): |
| | def __init__( |
| | self, |
| | width: int, |
| | layers: int, |
| | heads: int, |
| | context_length: int = 77, |
| | mlp_ratio: float = 4.0, |
| | ls_init_value: float = None, |
| | act_layer: Callable = nn.GELU, |
| | norm_layer: Callable = LayerNorm, |
| | output_dim: int = 512, |
| | batch_first: bool = True, |
| | ): |
| | super().__init__( |
| | width=width, |
| | layers=layers, |
| | heads=heads, |
| | mlp_ratio=mlp_ratio, |
| | ls_init_value=ls_init_value, |
| | act_layer=act_layer, |
| | norm_layer=norm_layer, |
| | batch_first=batch_first, |
| | ) |
| | self.context_length = context_length |
| | self.cross_attn = nn.ModuleList([ |
| | ResidualAttentionBlock( |
| | width, |
| | heads, |
| | mlp_ratio, |
| | ls_init_value=ls_init_value, |
| | act_layer=act_layer, |
| | norm_layer=norm_layer, |
| | is_cross_attention=True, |
| | batch_first=batch_first, |
| | ) |
| | for _ in range(layers) |
| | ]) |
| |
|
| | self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) |
| |
|
| | self.ln_final = norm_layer(width) |
| | self.text_projection = nn.Parameter(torch.empty(width, output_dim)) |
| |
|
| | def init_parameters(self): |
| | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) |
| | attn_std = self.transformer.width ** -0.5 |
| | fc_std = (2 * self.transformer.width) ** -0.5 |
| | for block in self.transformer.resblocks: |
| | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) |
| | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) |
| | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) |
| | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) |
| | for block in self.transformer.cross_attn: |
| | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) |
| | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) |
| | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) |
| | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) |
| |
|
| | if self.text_projection is not None: |
| | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) |
| |
|
| | def build_attention_mask(self): |
| | |
| | |
| | mask = torch.empty(self.context_length, self.context_length) |
| | mask.fill_(float("-inf")) |
| | mask.triu_(1) |
| | return mask |
| |
|
| | def forward_intermediates( |
| | self, |
| | x: torch.Tensor, |
| | attn_mask: Optional[torch.Tensor] = None, |
| | indices: Optional[Union[int, List[int]]] = None, |
| | stop_early: bool = False, |
| | ): |
| | assert False, "Not currently implemented for MultimodalTransformer w/ xattn" |
| |
|
| | def forward(self, image_embs, text_embs): |
| | seq_len = text_embs.shape[1] |
| | if not self.batch_first: |
| | image_embs = image_embs.permute(1, 0, 2) |
| | text_embs = text_embs.permute(1, 0, 2) |
| |
|
| | for resblock, cross_attn in zip(self.resblocks, self.cross_attn): |
| | if self.grad_checkpointing and not torch.jit.is_scripting(): |
| | |
| | text_embs = checkpoint( |
| | resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len], use_reentrant=False) |
| | text_embs = checkpoint( |
| | cross_attn, text_embs, image_embs, image_embs, None, use_reentrant=False) |
| | else: |
| | text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len]) |
| | text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs) |
| |
|
| | if not self.batch_first: |
| | text_embs = text_embs.permute(1, 0, 2) |
| |
|
| | out = self.ln_final(text_embs) |
| | if self.text_projection is not None: |
| | out = out @ self.text_projection |
| |
|
| | return out |
| |
|
| | @torch.jit.ignore |
| | def set_grad_checkpointing(self, enable=True): |
| | self.grad_checkpointing = enable |
| |
|
| |
|
| |
|
| | @dataclass |
| | class CLIPVisionCfg: |
| | layers: Union[Tuple[int, int, int, int], int] = 12 |
| | width: int = 768 |
| | head_width: int = 64 |
| | mlp_ratio: float = 4.0 |
| | patch_size: int = 16 |
| | image_size: Union[Tuple[int, int], int] = 224 |
| |
|
| | ls_init_value: Optional[float] = None |
| | patch_dropout: float = 0. |
| | attentional_pool: bool = False |
| | attn_pooler_queries: int = 256 |
| | attn_pooler_heads: int = 8 |
| | no_ln_pre: bool = False |
| | pos_embed_type: str = 'learnable' |
| | final_ln_after_pool: bool = False |
| | pool_type: str = 'tok' |
| | output_tokens: bool = False |
| | act_kwargs: Optional[dict] = None |
| | norm_kwargs: Optional[dict] = None |
| |
|
| | timm_model_name: Optional[str] = None |
| | timm_model_pretrained: bool = False |
| | timm_pool: str = 'avg' |
| | timm_proj: str = 'linear' |
| | timm_proj_bias: bool = False |
| | timm_drop: float = 0. |
| | timm_drop_path: Optional[float] = None |
| |
|
| |
|
| | @dataclass |
| | class CLIPTextCfg: |
| | context_length: int = 77 |
| | vocab_size: int = 49408 |
| | hf_tokenizer_name: Optional[str] = None |
| | tokenizer_kwargs: Optional[dict] = None |
| |
|
| | width: int = 512 |
| | heads: int = 8 |
| | layers: int = 12 |
| | mlp_ratio: float = 4.0 |
| | ls_init_value: Optional[float] = None |
| | embed_cls: bool = False |
| | pad_id: int = 0 |
| | no_causal_mask: bool = False |
| | final_ln_after_pool: bool = False |
| | pool_type: str = 'argmax' |
| | proj_bias: bool = False |
| | proj_type: str = 'linear' |
| | output_tokens: bool = False |
| | act_kwargs: dict = None |
| | norm_kwargs: dict = None |
| |
|
| | |
| | hf_model_name: Optional[str] = None |
| | hf_model_pretrained: bool = True |
| | hf_proj_type: str = 'mlp' |
| | hf_pooler_type: str = 'mean_pooler' |
| |
|
| |
|
| | def get_cast_dtype(precision: str): |
| | cast_dtype = None |
| | if precision == 'bf16': |
| | cast_dtype = torch.bfloat16 |
| | elif precision == 'fp16': |
| | cast_dtype = torch.float16 |
| | return cast_dtype |
| |
|
| |
|
| | def get_input_dtype(precision: str): |
| | input_dtype = None |
| | if precision in ('bf16', 'pure_bf16'): |
| | input_dtype = torch.bfloat16 |
| | elif precision in ('fp16', 'pure_fp16'): |
| | input_dtype = torch.float16 |
| | return input_dtype |
| |
|
| |
|
| | def _build_vision_tower( |
| | embed_dim: int, |
| | vision_cfg: CLIPVisionCfg, |
| | quick_gelu: bool = False, |
| | cast_dtype: Optional[torch.dtype] = None |
| | ): |
| | if isinstance(vision_cfg, dict): |
| | vision_cfg = CLIPVisionCfg(**vision_cfg) |
| |
|
| | |
| | |
| | |
| | act_layer = QuickGELU if quick_gelu else nn.GELU |
| |
|
| | if vision_cfg.timm_model_name: |
| | visual = TimmModel( |
| | vision_cfg.timm_model_name, |
| | pretrained=vision_cfg.timm_model_pretrained, |
| | pool=vision_cfg.timm_pool, |
| | proj=vision_cfg.timm_proj, |
| | proj_bias=vision_cfg.timm_proj_bias, |
| | drop=vision_cfg.timm_drop, |
| | drop_path=vision_cfg.timm_drop_path, |
| | patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None, |
| | embed_dim=embed_dim, |
| | image_size=vision_cfg.image_size, |
| | ) |
| | elif isinstance(vision_cfg.layers, (tuple, list)): |
| | vision_heads = vision_cfg.width * 32 // vision_cfg.head_width |
| | visual = ModifiedResNet( |
| | layers=vision_cfg.layers, |
| | output_dim=embed_dim, |
| | heads=vision_heads, |
| | image_size=vision_cfg.image_size, |
| | width=vision_cfg.width, |
| | ) |
| | else: |
| | vision_heads = vision_cfg.width // vision_cfg.head_width |
| | norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm |
| | if vision_cfg.norm_kwargs: |
| | norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs) |
| | if vision_cfg.act_kwargs is not None: |
| | act_layer = partial(act_layer, **vision_cfg.act_kwargs) |
| |
|
| | visual = VisionTransformer( |
| | image_size=vision_cfg.image_size, |
| | patch_size=vision_cfg.patch_size, |
| | width=vision_cfg.width, |
| | layers=vision_cfg.layers, |
| | heads=vision_heads, |
| | mlp_ratio=vision_cfg.mlp_ratio, |
| | ls_init_value=vision_cfg.ls_init_value, |
| | patch_dropout=vision_cfg.patch_dropout, |
| | attentional_pool=vision_cfg.attentional_pool, |
| | attn_pooler_queries=vision_cfg.attn_pooler_queries, |
| | attn_pooler_heads=vision_cfg.attn_pooler_heads, |
| | pos_embed_type=vision_cfg.pos_embed_type, |
| | no_ln_pre=vision_cfg.no_ln_pre, |
| | final_ln_after_pool=vision_cfg.final_ln_after_pool, |
| | pool_type=vision_cfg.pool_type, |
| | output_tokens=vision_cfg.output_tokens, |
| | output_dim=embed_dim, |
| | act_layer=act_layer, |
| | norm_layer=norm_layer, |
| | ) |
| |
|
| | return visual |
| |
|
| |
|
| | def _build_text_tower( |
| | embed_dim: int, |
| | text_cfg: CLIPTextCfg, |
| | quick_gelu: bool = False, |
| | cast_dtype: Optional[torch.dtype] = None, |
| | ): |
| | if isinstance(text_cfg, dict): |
| | text_cfg = CLIPTextCfg(**text_cfg) |
| |
|
| | if text_cfg.hf_model_name: |
| | text = HFTextEncoder( |
| | text_cfg.hf_model_name, |
| | output_dim=embed_dim, |
| | proj_type=text_cfg.hf_proj_type, |
| | pooler_type=text_cfg.hf_pooler_type, |
| | pretrained=text_cfg.hf_model_pretrained, |
| | output_tokens=text_cfg.output_tokens, |
| | ) |
| | else: |
| | act_layer = QuickGELU if quick_gelu else nn.GELU |
| | norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm |
| | if text_cfg.norm_kwargs: |
| | norm_layer = partial(norm_layer, **text_cfg.norm_kwargs) |
| | if text_cfg.act_kwargs is not None: |
| | act_layer = partial(act_layer, **text_cfg.act_kwargs) |
| |
|
| | text = TextTransformer( |
| | context_length=text_cfg.context_length, |
| | vocab_size=text_cfg.vocab_size, |
| | width=text_cfg.width, |
| | heads=text_cfg.heads, |
| | layers=text_cfg.layers, |
| | mlp_ratio=text_cfg.mlp_ratio, |
| | ls_init_value=text_cfg.ls_init_value, |
| | output_dim=embed_dim, |
| | embed_cls=text_cfg.embed_cls, |
| | no_causal_mask=text_cfg.no_causal_mask, |
| | pad_id=text_cfg.pad_id, |
| | pool_type=text_cfg.pool_type, |
| | proj_type=text_cfg.proj_type, |
| | proj_bias=text_cfg.proj_bias, |
| | output_tokens=text_cfg.output_tokens, |
| | act_layer=act_layer, |
| | norm_layer=norm_layer, |
| | ) |
| | return text |
| |
|
| |
|
| | class CLIP(nn.Module): |
| | output_dict: torch.jit.Final[bool] |
| |
|
| | def __init__( |
| | self, |
| | embed_dim: int, |
| | vision_cfg: CLIPVisionCfg, |
| | text_cfg: CLIPTextCfg, |
| | quick_gelu: bool = False, |
| | init_logit_scale: float = np.log(1 / 0.07), |
| | init_logit_bias: Optional[float] = None, |
| | nonscalar_logit_scale: bool = False, |
| | cast_dtype: Optional[torch.dtype] = None, |
| | output_dict: bool = False, |
| | ): |
| | super().__init__() |
| | self.output_dict = output_dict |
| |
|
| | self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) |
| |
|
| | text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) |
| | self.transformer = text.transformer |
| | self.context_length = text.context_length |
| | self.vocab_size = text.vocab_size |
| | self.token_embedding = text.token_embedding |
| | self.positional_embedding = text.positional_embedding |
| | self.ln_final = text.ln_final |
| | self.text_projection = text.text_projection |
| | self.text_pool_type = text.pool_type |
| | self.register_buffer('attn_mask', text.attn_mask, persistent=False) |
| |
|
| | lshape = [1] if nonscalar_logit_scale else [] |
| | self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale) |
| | if init_logit_bias is not None: |
| | self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias) |
| | else: |
| | self.logit_bias = None |
| |
|
| | def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): |
| | |
| | self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) |
| |
|
| | @torch.jit.ignore |
| | def set_grad_checkpointing(self, enable=True): |
| | self.visual.set_grad_checkpointing(enable) |
| | self.transformer.grad_checkpointing = enable |
| |
|
| | @torch.jit.ignore |
| | def no_weight_decay(self): |
| | |
| | no_wd = {'positional_embedding'} |
| | if hasattr(self.visual, 'no_weight_decay'): |
| | for n in self.visual.no_weight_decay(): |
| | no_wd.add('visual.' + n) |
| | return no_wd |
| |
|
| | def encode_image(self, image, normalize: bool = False): |
| | features = self.visual(image) |
| | return F.normalize(features, dim=-1) if normalize else features |
| |
|
| | def encode_text(self, text, normalize: bool = False): |
| | cast_dtype = self.transformer.get_cast_dtype() |
| |
|
| | x = self.token_embedding(text).to(cast_dtype) |
| |
|
| | x = x + self.positional_embedding.to(cast_dtype) |
| | x = self.transformer(x, attn_mask=self.attn_mask) |
| | x = self.ln_final(x) |
| | x = text_global_pool(x, text, self.text_pool_type) |
| | if self.text_projection is not None: |
| | if isinstance(self.text_projection, nn.Linear): |
| | x = self.text_projection(x) |
| | else: |
| | x = x @ self.text_projection |
| |
|
| | return F.normalize(x, dim=-1) if normalize else x |
| |
|
| | def get_logits(self, image, text): |
| | image_features = self.encode_image(image, normalize=True) |
| | text_features = self.encode_text(text, normalize=True) |
| | image_logits = self.logit_scale.exp() * image_features @ text_features.T |
| | if self.logit_bias is not None: |
| | image_logits += self.logit_bias |
| | text_logits = image_logits.T |
| | return image_logits, text_logits |
| |
|
| | def forward_intermediates( |
| | self, |
| | image: Optional[torch.Tensor] = None, |
| | text: Optional[torch.Tensor] = None, |
| | image_indices: Optional[Union[int, List[int]]] = None, |
| | text_indices: Optional[Union[int, List[int]]] = None, |
| | stop_early: bool = False, |
| | normalize: bool = True, |
| | normalize_intermediates: bool = False, |
| | intermediates_only: bool = False, |
| | image_output_fmt: str = 'NCHW', |
| | image_output_extra_tokens: bool = False, |
| | text_output_fmt: str = 'NLC', |
| | text_output_extra_tokens: bool = False, |
| | output_logits: bool = False, |
| | output_logit_scale_bias: bool = False, |
| | ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: |
| | """ Forward features that returns intermediates. |
| | |
| | Args: |
| | image: Input image tensor |
| | text: Input text tensor |
| | image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence |
| | text_indices: Take last n blocks if int, all if None, select matching indices if sequence |
| | stop_early: Stop iterating over blocks when last desired intermediate hit |
| | normalize_intermediates: Apply final norm layer to all intermediates |
| | normalize: L2 Normalize final features |
| | intermediates_only: Only return intermediate features, do not return final features |
| | image_output_fmt: Shape of intermediate image feature outputs |
| | image_output_extra_tokens: Return both prefix and spatial intermediate tokens |
| | text_output_fmt: Shape of intermediate text feature outputs (ignored for this model) |
| | text_output_extra_tokens: Return both prefix and spatial intermediate tokens (ignored for this model) |
| | output_logits: Include logits in output |
| | output_logit_scale_bias: Include the logit scale bias in the output |
| | Returns: |
| | |
| | """ |
| | output = {} |
| | if intermediates_only: |
| | |
| | normalize = False |
| | output_logits = False |
| | if output_logits: |
| | assert image is not None and text is not None, 'Both image and text inputs are required to compute logits' |
| |
|
| | if image is not None: |
| | image_output = self.visual.forward_intermediates( |
| | image, |
| | indices=image_indices, |
| | stop_early=stop_early, |
| | normalize_intermediates=normalize_intermediates, |
| | intermediates_only=intermediates_only, |
| | output_fmt=image_output_fmt, |
| | output_extra_tokens=image_output_extra_tokens, |
| | ) |
| | if normalize and "image_features" in image_output: |
| | image_output["image_features"] = F.normalize(image_output["image_features"], dim=-1) |
| | output.update(image_output) |
| |
|
| | if text is not None: |
| | cast_dtype = self.transformer.get_cast_dtype() |
| | x = self.token_embedding(text).to(cast_dtype) |
| | x = x + self.positional_embedding.to(cast_dtype) |
| | x, intermediates = self.transformer.forward_intermediates( |
| | x, |
| | attn_mask=self.attn_mask, |
| | indices=text_indices |
| | ) |
| | if normalize_intermediates: |
| | intermediates = [self.ln_final(xi) for xi in intermediates] |
| |
|
| | |
| | output["text_intermediates"] = intermediates |
| |
|
| | if not intermediates_only: |
| | x = self.ln_final(x) |
| | x = text_global_pool(x, text, self.text_pool_type) |
| | if self.text_projection is not None: |
| | if isinstance(self.text_projection, nn.Linear): |
| | x = self.text_projection(x) |
| | else: |
| | x = x @ self.text_projection |
| | if normalize: |
| | x = F.normalize(x, dim=-1) |
| | output["text_features"] = x |
| |
|
| | logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None |
| |
|
| | if output_logits: |
| | image_logits = logit_scale_exp * output["image_features"] @ output["text_features"].T |
| | if self.logit_bias is not None: |
| | image_logits += self.logit_bias |
| | text_logits = image_logits.T |
| | output["image_logits"] = image_logits |
| | output["text_logits"] = text_logits |
| |
|
| | if output_logit_scale_bias: |
| | output["logit_scale"] = logit_scale_exp |
| | if self.logit_bias is not None: |
| | output['logit_bias'] = self.logit_bias |
| |
|
| | return output |
| |
|
| | def forward( |
| | self, |
| | image: Optional[torch.Tensor] = None, |
| | text: Optional[torch.Tensor] = None, |
| | ): |
| | image_features = self.encode_image(image, normalize=True) if image is not None else None |
| | text_features = self.encode_text(text, normalize=True) if text is not None else None |
| |
|
| | if self.output_dict: |
| | out_dict = { |
| | "image_features": image_features, |
| | "text_features": text_features, |
| | "logit_scale": self.logit_scale.exp() |
| | } |
| | if self.logit_bias is not None: |
| | out_dict['logit_bias'] = self.logit_bias |
| | return out_dict |
| |
|
| | if self.logit_bias is not None: |
| | return image_features, text_features, self.logit_scale.exp(), self.logit_bias |
| | return image_features, text_features, self.logit_scale.exp() |
| |
|
| |
|
| | class CustomTextCLIP(nn.Module): |
| | output_dict: torch.jit.Final[bool] |
| |
|
| | def __init__( |
| | self, |
| | embed_dim: int, |
| | vision_cfg: CLIPVisionCfg, |
| | text_cfg: CLIPTextCfg, |
| | quick_gelu: bool = False, |
| | init_logit_scale: float = np.log(1 / 0.07), |
| | init_logit_bias: Optional[float] = None, |
| | nonscalar_logit_scale: bool = False, |
| | cast_dtype: Optional[torch.dtype] = None, |
| | output_dict: bool = False, |
| | ): |
| | super().__init__() |
| | self.output_dict = output_dict |
| | self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) |
| | self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) |
| | self.context_length = self.text.context_length |
| | self.vocab_size = self.text.vocab_size |
| |
|
| | lshape = [1] if nonscalar_logit_scale else [] |
| | self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale) |
| | if init_logit_bias is not None: |
| | self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias) |
| | else: |
| | self.logit_bias = None |
| |
|
| | def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): |
| | |
| | self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) |
| |
|
| | def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): |
| | self.text.lock(unlocked_layers, freeze_layer_norm) |
| |
|
| | @torch.jit.ignore |
| | def set_grad_checkpointing(self, enable=True): |
| | self.visual.set_grad_checkpointing(enable) |
| | self.text.set_grad_checkpointing(enable) |
| |
|
| | @torch.jit.ignore |
| | def no_weight_decay(self): |
| | |
| | no_wd = set() |
| | if hasattr(self.visual, 'no_weight_decay'): |
| | for n in self.visual.no_weight_decay(): |
| | no_wd.add('visual.' + n) |
| | if hasattr(self.text, 'no_weight_decay'): |
| | for n in self.visual.no_weight_decay(): |
| | no_wd.add('text.' + n) |
| | return no_wd |
| |
|
| | def encode_image(self, image, normalize: bool = False): |
| | features = self.visual(image) |
| | return F.normalize(features, dim=-1) if normalize else features |
| |
|
| | def encode_text(self, text, normalize: bool = False): |
| | features = self.text(text) |
| | return F.normalize(features, dim=-1) if normalize else features |
| |
|
| | def get_logits(self, image, text): |
| | image_features = self.encode_image(image, normalize=True) |
| | text_features = self.encode_text(text, normalize=True) |
| | image_logits = self.logit_scale.exp() * image_features @ text_features.T |
| | if self.logit_bias is not None: |
| | image_logits += self.logit_bias |
| | text_logits = image_logits.T |
| | return image_logits, text_logits |
| |
|
| | def forward_intermediates( |
| | self, |
| | image: Optional[torch.Tensor] = None, |
| | text: Optional[torch.Tensor] = None, |
| | image_indices: Optional[Union[int, List[int]]] = None, |
| | text_indices: Optional[Union[int, List[int]]] = None, |
| | stop_early: bool = False, |
| | normalize: bool = True, |
| | normalize_intermediates: bool = False, |
| | intermediates_only: bool = False, |
| | image_output_fmt: str = 'NCHW', |
| | image_output_extra_tokens: bool = False, |
| | text_output_fmt: str = 'NLC', |
| | text_output_extra_tokens: bool = False, |
| | output_logits: bool = False, |
| | output_logit_scale_bias: bool = False, |
| | ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: |
| | """ Forward features that returns intermediates. |
| | |
| | Args: |
| | image: Input image tensor |
| | text: Input text tensor |
| | image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence |
| | text_indices: Take last n blocks if int, all if None, select matching indices if sequence |
| | stop_early: Stop iterating over blocks when last desired intermediate hit |
| | normalize: L2 Normalize final image and text features (if present) |
| | normalize_intermediates: Apply final encoder norm layer to all intermediates (if possible) |
| | intermediates_only: Only return intermediate features, do not return final features |
| | image_output_fmt: Shape of intermediate image feature outputs |
| | image_output_extra_tokens: Return both prefix and spatial intermediate tokens |
| | text_output_fmt: Shape of intermediate text feature outputs |
| | text_output_extra_tokens: Return both prefix and spatial intermediate tokens |
| | output_logits: Include logits in output |
| | output_logit_scale_bias: Include the logit scale bias in the output |
| | Returns: |
| | |
| | """ |
| | output = {} |
| | if intermediates_only: |
| | |
| | normalize = False |
| | output_logits = False |
| | if output_logits: |
| | assert image is not None and text is not None, 'Both image and text inputs are required to compute logits' |
| |
|
| | if image is not None: |
| | image_output = self.visual.forward_intermediates( |
| | image, |
| | indices=image_indices, |
| | stop_early=stop_early, |
| | normalize_intermediates=normalize_intermediates, |
| | intermediates_only=intermediates_only, |
| | output_fmt=image_output_fmt, |
| | output_extra_tokens=image_output_extra_tokens, |
| | ) |
| | if normalize and "image_features" in image_output: |
| | image_output["image_features"] = F.normalize(image_output["image_features"], dim=-1) |
| | output.update(image_output) |
| |
|
| | if text is not None: |
| | text_output = self.text.forward_intermediates( |
| | text, |
| | indices=text_indices, |
| | stop_early=stop_early, |
| | normalize_intermediates=normalize_intermediates, |
| | intermediates_only=intermediates_only, |
| | output_fmt=text_output_fmt, |
| | output_extra_tokens=text_output_extra_tokens, |
| | ) |
| | if normalize and "text_features" in text_output: |
| | text_output["text_features"] = F.normalize(text_output["text_features"], dim=-1) |
| | output.update(text_output) |
| |
|
| | logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None |
| |
|
| | if output_logits: |
| | image_logits = logit_scale_exp * output["image_features"] @ output["text_features"].T |
| | if self.logit_bias is not None: |
| | image_logits += self.logit_bias |
| | text_logits = image_logits.T |
| | output["image_logits"] = image_logits |
| | output["text_logits"] = text_logits |
| |
|
| | if output_logit_scale_bias: |
| | output["logit_scale"] = logit_scale_exp |
| | if self.logit_bias is not None: |
| | output['logit_bias'] = self.logit_bias |
| |
|
| | return output |
| |
|
| | def forward( |
| | self, |
| | image: Optional[torch.Tensor] = None, |
| | text: Optional[torch.Tensor] = None, |
| | ): |
| | image_features = self.encode_image(image, normalize=True) if image is not None else None |
| | text_features = self.encode_text(text, normalize=True) if text is not None else None |
| |
|
| | if self.output_dict: |
| | out_dict = { |
| | "image_features": image_features, |
| | "text_features": text_features, |
| | "logit_scale": self.logit_scale.exp() |
| | } |
| | if self.logit_bias is not None: |
| | out_dict['logit_bias'] = self.logit_bias |
| | return out_dict |
| |
|
| | if self.logit_bias is not None: |
| | return image_features, text_features, self.logit_scale.exp(), self.logit_bias |
| | return image_features, text_features, self.logit_scale.exp() |
| |
|
| |
|
| | def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): |
| | """Convert applicable model parameters to low-precision (bf16 or fp16)""" |
| |
|
| | def _convert_weights(l): |
| | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): |
| | l.weight.data = l.weight.data.to(dtype) |
| | if l.bias is not None: |
| | l.bias.data = l.bias.data.to(dtype) |
| |
|
| | if isinstance(l, (nn.MultiheadAttention, Attention)): |
| | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: |
| | tensor = getattr(l, attr) |
| | if tensor is not None: |
| | tensor.data = tensor.data.to(dtype) |
| |
|
| | if isinstance(l, (CLIP, TextTransformer)): |
| | |
| | attr = getattr(l, "text_projection", None) |
| | if attr is not None: |
| | attr.data = attr.data.to(dtype) |
| |
|
| | if isinstance(l, VisionTransformer): |
| | |
| | attr = getattr(l, "proj", None) |
| | if attr is not None: |
| | attr.data = attr.data.to(dtype) |
| |
|
| | model.apply(_convert_weights) |
| |
|
| |
|
| | convert_weights_to_fp16 = convert_weights_to_lp |
| |
|
| |
|
| | |
| | def convert_to_custom_text_state_dict(state_dict: dict): |
| | if 'text_projection' in state_dict: |
| | |
| | new_state_dict = {} |
| | for k, v in state_dict.items(): |
| | if any(k.startswith(p) for p in ( |
| | 'text_projection', |
| | 'positional_embedding', |
| | 'token_embedding', |
| | 'transformer', |
| | 'ln_final', |
| | )): |
| | k = 'text.' + k |
| | new_state_dict[k] = v |
| | return new_state_dict |
| | return state_dict |
| |
|
| |
|
| | def build_model_from_openai_state_dict( |
| | state_dict: dict, |
| | quick_gelu=True, |
| | cast_dtype=torch.float16, |
| | ): |
| | vit = "visual.proj" in state_dict |
| |
|
| | if vit: |
| | vision_width = state_dict["visual.conv1.weight"].shape[0] |
| | vision_layers = len( |
| | [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) |
| | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] |
| | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) |
| | image_size = vision_patch_size * grid_size |
| | else: |
| | counts: list = [ |
| | len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] |
| | vision_layers = tuple(counts) |
| | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] |
| | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) |
| | vision_patch_size = None |
| | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] |
| | image_size = output_width * 32 |
| |
|
| | embed_dim = state_dict["text_projection"].shape[1] |
| | context_length = state_dict["positional_embedding"].shape[0] |
| | vocab_size = state_dict["token_embedding.weight"].shape[0] |
| | transformer_width = state_dict["ln_final.weight"].shape[0] |
| | transformer_heads = transformer_width // 64 |
| | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) |
| |
|
| | vision_cfg = CLIPVisionCfg( |
| | layers=vision_layers, |
| | width=vision_width, |
| | patch_size=vision_patch_size, |
| | image_size=image_size, |
| | ) |
| | text_cfg = CLIPTextCfg( |
| | context_length=context_length, |
| | vocab_size=vocab_size, |
| | width=transformer_width, |
| | heads=transformer_heads, |
| | layers=transformer_layers, |
| | ) |
| | model = CLIP( |
| | embed_dim, |
| | vision_cfg=vision_cfg, |
| | text_cfg=text_cfg, |
| | quick_gelu=quick_gelu, |
| | cast_dtype=cast_dtype, |
| | ) |
| |
|
| | for key in ["input_resolution", "context_length", "vocab_size"]: |
| | state_dict.pop(key, None) |
| | convert_weights_to_fp16(model) |
| | model.load_state_dict(state_dict) |
| | return model.eval() |
| |
|
| |
|
| | def trace_model(model, batch_size=256, device=torch.device('cpu')): |
| | model.eval() |
| | image_size = model.visual.image_size |
| | example_images = torch.ones((batch_size, 3, image_size, image_size), device=device) |
| | example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device) |
| | model = torch.jit.trace_module( |
| | model, |
| | inputs=dict( |
| | forward=(example_images, example_text), |
| | encode_text=(example_text,), |
| | encode_image=(example_images,) |
| | )) |
| | model.visual.image_size = image_size |
| | return model |
| |
|
| |
|
| | def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True): |
| | |
| | old_pos_embed = state_dict.get('visual.positional_embedding', None) |
| | if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): |
| | return |
| | grid_size = to_2tuple(model.visual.grid_size) |
| | extra_tokens = 1 |
| | new_seq_len = grid_size[0] * grid_size[1] + extra_tokens |
| | if new_seq_len == old_pos_embed.shape[0]: |
| | return |
| |
|
| | if extra_tokens: |
| | pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] |
| | else: |
| | pos_emb_tok, pos_emb_img = None, old_pos_embed |
| | old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) |
| |
|
| | logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) |
| | pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) |
| | pos_emb_img = F.interpolate( |
| | pos_emb_img, |
| | size=grid_size, |
| | mode=interpolation, |
| | antialias=antialias, |
| | align_corners=False, |
| | ) |
| | pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] |
| | if pos_emb_tok is not None: |
| | new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) |
| | else: |
| | new_pos_embed = pos_emb_img |
| | state_dict['visual.positional_embedding'] = new_pos_embed |
| |
|
| |
|
| | def resize_text_pos_embed(state_dict, model, interpolation: str = 'linear', antialias: bool = False): |
| | old_pos_embed = state_dict.get('positional_embedding', None) |
| | if old_pos_embed is None: |
| | return |
| | |
| | model_pos_embed = getattr(model, 'positional_embedding', None) |
| | if model_pos_embed is None: |
| | model_pos_embed = getattr(model.text, 'positional_embedding', None) |
| |
|
| | old_num_pos = old_pos_embed.shape[0] |
| | old_width = old_pos_embed.shape[1] |
| | num_pos = model_pos_embed.shape[0] |
| | width = model_pos_embed.shape[1] |
| | assert old_width == width, 'text pos_embed width changed!' |
| | if old_num_pos == num_pos: |
| | return |
| |
|
| | logging.info('Resizing text position embedding num_pos from %s to %s', old_num_pos, num_pos) |
| | old_pos_embed = old_pos_embed.reshape(1, old_num_pos, old_width).permute(0, 2, 1) |
| | old_pos_embed = F.interpolate( |
| | old_pos_embed, |
| | size=num_pos, |
| | mode=interpolation, |
| | antialias=antialias, |
| | align_corners=False, |
| | ) |
| | old_pos_embed = old_pos_embed.permute(0, 2, 1)[0] |
| | new_pos_embed = old_pos_embed |
| |
|
| | state_dict['positional_embedding'] = new_pos_embed |
| |
|
| |
|
| | def get_model_preprocess_cfg(model): |
| | module = getattr(model, 'visual', model) |
| | preprocess_cfg = getattr(module, 'preprocess_cfg', {}) |
| | if not preprocess_cfg: |
| | |
| | size = getattr(module, 'image_size') |
| | if size is not None: |
| | preprocess_cfg['size'] = size |
| | mean = getattr(module, 'image_mean', None) |
| | if mean is not None: |
| | preprocess_cfg['mean'] = mean |
| | std = getattr(module, 'image_std', None) |
| | if std is not None: |
| | preprocess_cfg['std'] = std |
| | return preprocess_cfg |
| |
|
| |
|
| | def set_model_preprocess_cfg(model, preprocess_cfg: Dict[str, Any]): |
| | module = getattr(model, 'visual', model) |
| | module.image_mean = preprocess_cfg['mean'] |
| | module.image_std = preprocess_cfg['std'] |
| | module.preprocess_cfg = copy.deepcopy(preprocess_cfg) |
| |
|
| |
|
| | def get_model_tokenize_cfg(model): |
| | module = getattr(model, 'text', model) |
| | cfg = {} |
| | context_length = getattr(module, 'context_length', None) |
| | if context_length is not None: |
| | cfg['context_length'] = context_length |
| | vocab_size = getattr(module, 'vocab_size', None) |
| | if vocab_size is not None: |
| | cfg['vocab_size'] = vocab_size |
| | return cfg |
| |
|
| |
|
| |
|
| | try: |
| | from huggingface_hub import hf_hub_download |
| | hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__) |
| | _has_hf_hub = True |
| | except ImportError: |
| | hf_hub_download = None |
| | _has_hf_hub = False |
| |
|
| |
|
| | def _pcfg(url='', hf_hub='', **kwargs): |
| | |
| | return { |
| | 'url': url, |
| | 'hf_hub': hf_hub, |
| | 'mean': OPENAI_DATASET_MEAN, |
| | 'std': OPENAI_DATASET_STD, |
| | 'interpolation': 'bicubic', |
| | 'resize_mode': 'shortest', |
| | **kwargs, |
| | } |
| |
|
| |
|
| | def _slpcfg(url='', hf_hub='', **kwargs): |
| | |
| | return { |
| | 'url': url, |
| | 'hf_hub': hf_hub, |
| | 'mean': INCEPTION_MEAN, |
| | 'std': INCEPTION_STD, |
| | 'interpolation': 'bicubic', |
| | 'resize_mode': 'squash', |
| | **kwargs, |
| | } |
| |
|
| |
|
| | def _apcfg(url='', hf_hub='', **kwargs): |
| | |
| | return { |
| | 'url': url, |
| | 'hf_hub': hf_hub, |
| | 'mean': IMAGENET_MEAN, |
| | 'std': IMAGENET_STD, |
| | 'interpolation': 'bilinear', |
| | 'resize_mode': 'squash', |
| | **kwargs, |
| | } |
| |
|
| |
|
| | def _mccfg(url='', hf_hub='', **kwargs): |
| | |
| | return { |
| | 'url': url, |
| | 'hf_hub': hf_hub, |
| | 'mean': (0., 0., 0.), |
| | 'std': (1., 1., 1.), |
| | 'interpolation': 'bilinear', |
| | 'resize_mode': 'shortest', |
| | **kwargs, |
| | } |
| |
|
| |
|
| |
|
| | _RN50 = dict( |
| | openai=_pcfg( |
| | url="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", |
| | hf_hub="timm/resnet50_clip.openai/", |
| | quick_gelu=True, |
| | ), |
| | yfcc15m=_pcfg( |
| | url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", |
| | hf_hub="timm/resnet50_clip.yfcc15m/", |
| | quick_gelu=True, |
| | ), |
| | cc12m=_pcfg( |
| | url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt", |
| | hf_hub="timm/resnet50_clip.cc12m/", |
| | quick_gelu=True, |
| | ), |
| | ) |
| |
|
| | _RN101 = dict( |
| | openai=_pcfg( |
| | url="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", |
| | hf_hub="timm/resnet101_clip.openai/", |
| | quick_gelu=True, |
| | ), |
| | yfcc15m=_pcfg( |
| | url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt", |
| | hf_hub="timm/resnet101_clip.yfcc15m/", |
| | quick_gelu=True, |
| | ), |
| | ) |
| |
|
| | _RN50x4 = dict( |
| | openai=_pcfg( |
| | url="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", |
| | hf_hub="timm/resnet50x4_clip.openai/", |
| | quick_gelu=True, |
| | ), |
| | ) |
| |
|
| | _RN50x16 = dict( |
| | openai=_pcfg( |
| | url="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", |
| | hf_hub="timm/resnet50x16_clip.openai/", |
| | quick_gelu=True, |
| | ), |
| | ) |
| |
|
| | _RN50x64 = dict( |
| | openai=_pcfg( |
| | url="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", |
| | hf_hub="timm/resnet50x64_clip.openai/", |
| | quick_gelu=True, |
| | ), |
| | ) |
| |
|
| | _VITB32 = dict( |
| | openai=_pcfg( |
| | url="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", |
| | hf_hub="timm/vit_base_patch32_clip_224.openai/", |
| | quick_gelu=True, |
| | ), |
| | |
| | laion400m_e31=_pcfg( |
| | url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", |
| | hf_hub="timm/vit_base_patch32_clip_224.laion400m_e31/", |
| | quick_gelu=True, |
| | ), |
| | laion400m_e32=_pcfg( |
| | url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", |
| | hf_hub="timm/vit_base_patch32_clip_224.laion400m_e32/", |
| | quick_gelu=True, |
| | ), |
| | |
| | laion2b_e16=_pcfg( |
| | url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth", |
| | hf_hub="timm/vit_base_patch32_clip_224.laion2b_e16/", |
| | ), |
| | laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/'), |
| | |
| | datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K/'), |
| | |
| | datacomp_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.M-s128M-b4K/'), |
| | commonpool_m_clip_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.clip-s128M-b4K/'), |
| | commonpool_m_laion_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.laion-s128M-b4K/'), |
| | commonpool_m_image_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.image-s128M-b4K/'), |
| | commonpool_m_text_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.text-s128M-b4K/'), |
| | commonpool_m_basic_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.basic-s128M-b4K/'), |
| | commonpool_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M-s128M-b4K/'), |
| | |
| | datacomp_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.S-s13M-b4K/'), |
| | commonpool_s_clip_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.clip-s13M-b4K/'), |
| | commonpool_s_laion_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.laion-s13M-b4K/'), |
| | commonpool_s_image_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.image-s13M-b4K/'), |
| | commonpool_s_text_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.text-s13M-b4K/'), |
| | commonpool_s_basic_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.basic-s13M-b4K/'), |
| | commonpool_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S-s13M-b4K/'), |
| | |
| | metaclip_400m=_pcfg( |
| | url="https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_400m.pt", |
| | hf_hub="timm/vit_base_patch32_clip_224.metaclip_400m/", |
| | quick_gelu=True, |
| | ), |
| | metaclip_fullcc=_pcfg( |
| | url="https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_fullcc2.5b.pt", |
| | hf_hub="timm/vit_base_patch32_clip_224.metaclip_2pt5b/", |
| | quick_gelu=True, |
| | ), |
| | ) |
| |
|
| | _VITB32_256 = dict( |
| | datacomp_s34b_b86k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-256x256-DataComp-s34B-b86K/'), |
| | ) |
| |
|
| | _VITB16 = dict( |
| | openai=_pcfg( |
| | url="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", |
| | hf_hub="timm/vit_base_patch16_clip_224.openai/", |
| | quick_gelu=True, |
| | ), |
| | |
| | laion400m_e31=_pcfg( |
| | url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt", |
| | hf_hub="timm/vit_base_patch16_clip_224.laion400m_e31/", |
| | ), |
| | laion400m_e32=_pcfg( |
| | url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt", |
| | hf_hub="timm/vit_base_patch16_clip_224.laion400m_e32/", |
| | ), |
| | |
| | laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'), |
| | |
| | datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.XL-s13B-b90K/'), |
| | |
| | datacomp_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.L-s1B-b8K/'), |
| | commonpool_l_clip_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.clip-s1B-b8K/'), |
| | commonpool_l_laion_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.laion-s1B-b8K/'), |
| | commonpool_l_image_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.image-s1B-b8K/'), |
| | commonpool_l_text_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.text-s1B-b8K/'), |
| | commonpool_l_basic_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.basic-s1B-b8K/'), |
| | commonpool_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L-s1B-b8K/'), |
| | |
| | dfn2b=_pcfg( |
| | hf_hub='apple/DFN2B-CLIP-ViT-B-16/', |
| | quick_gelu=True, |
| | ), |
| | |
| | metaclip_400m=_pcfg( |
| | url="https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_400m.pt", |
| | hf_hub="timm/vit_base_patch16_clip_224.metaclip_400m/", |
| | quick_gelu=True, |
| | ), |
| | metaclip_fullcc=_pcfg( |
| | url="https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_fullcc2.5b.pt", |
| | hf_hub="timm/vit_base_patch16_clip_224.metaclip_2pt5b/", |
| | quick_gelu=True, |
| | ), |
| | ) |
| |
|
| | _VITB16_PLUS_240 = dict( |
| | laion400m_e31=_pcfg( |
| | url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt", |
| | hf_hub="timm/vit_base_patch16_plus_clip_240.laion400m_e31/", |
| | ), |
| | laion400m_e32=_pcfg( |
| | url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt", |
| | hf_hub="timm/vit_base_patch16_plus_clip_240.laion400m_e31/", |
| | ), |
| | ) |
| |
|
| | _VITL14 = dict( |
| | openai=_pcfg( |
| | url="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", |
| | hf_hub="timm/vit_large_patch14_clip_224.openai/", |
| | quick_gelu=True, |
| | ), |
| | |
| | laion400m_e31=_pcfg( |
| | url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt", |
| | hf_hub="timm/vit_large_patch14_clip_224.laion400m_e31/", |
| | ), |
| | laion400m_e32=_pcfg( |
| | url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt", |
| | hf_hub="timm/vit_large_patch14_clip_224.laion400m_e32/", |
| | ), |
| | |
| | laion2b_s32b_b82k=_pcfg( |
| | hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/', |
| | mean=INCEPTION_MEAN, std=INCEPTION_STD), |
| | |
| | datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K/'), |
| | commonpool_xl_clip_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.clip-s13B-b90K/'), |
| | commonpool_xl_laion_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.laion-s13B-b90K/'), |
| | commonpool_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL-s13B-b90K/'), |
| | |
| | metaclip_400m=_pcfg( |
| | url="https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_400m.pt", |
| | hf_hub="timm/vit_large_patch14_clip_224.metaclip_400m/", |
| | quick_gelu=True, |
| | ), |
| | metaclip_fullcc=_pcfg( |
| | url="https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_fullcc2.5b.pt", |
| | hf_hub="timm/vit_large_patch14_clip_224.metaclip_2pt5b/", |
| | quick_gelu=True, |
| | ), |
| | |
| | dfn2b=_pcfg( |
| | hf_hub='apple/DFN2B-CLIP-ViT-L-14/', |
| | quick_gelu=True, |
| | ), |
| | |
| | dfn2b_s39b=_pcfg( |
| | hf_hub='apple/DFN2B-CLIP-ViT-L-14-39B/', |
| | ), |
| | ) |
| |
|
| | _VITL14_336 = dict( |
| | openai=_pcfg( |
| | url="https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", |
| | hf_hub="timm/vit_large_patch14_clip_336.openai/", |
| | quick_gelu=True, |
| | ), |
| | ) |
| |
|
| | _VITH14 = dict( |
| | |
| | laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'), |
| | |
| | metaclip_fullcc=_pcfg( |
| | url="https://dl.fbaipublicfiles.com/MMPT/metaclip/h14_fullcc2.5b.pt", |
| | hf_hub="timm/vit_huge_patch14_clip_224.metaclip_2pt5b/", |
| | quick_gelu=True, |
| | ), |
| | metaclip_altogether=_pcfg( |
| | url="https://dl.fbaipublicfiles.com/MMPT/metaclip/h14_v1.2_altogether.pt", |
| | hf_hub="timm/vit_huge_patch14_clip_224.metaclip_altogether/", |
| | |
| | ), |
| | |
| | dfn5b=_pcfg( |
| | hf_hub='apple/DFN5B-CLIP-ViT-H-14/', |
| | quick_gelu=True, |
| | interpolation="bicubic", |
| | resize_mode="squash" |
| | ), |
| | ) |
| |
|
| | _VITH14_378 = dict( |
| | |
| | dfn5b=_pcfg( |
| | hf_hub='apple/DFN5B-CLIP-ViT-H-14-378/', |
| | quick_gelu=True, |
| | interpolation="bicubic", |
| | resize_mode="squash" |
| | ), |
| | ) |
| |
|
| | _VITg14 = dict( |
| | laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'), |
| | laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'), |
| | ) |
| |
|
| | _VITbigG14 = dict( |
| | |
| | laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'), |
| | |
| | metaclip_fullcc=_pcfg( |
| | url='https://dl.fbaipublicfiles.com/MMPT/metaclip/G14_fullcc2.5b.pt', |
| | hf_hub="timm/vit_gigantic_patch14_clip_224.metaclip_2pt5b/", |
| | quick_gelu=True, |
| | ), |
| | ) |
| |
|
| | _robertaViTB32 = dict( |
| | laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'), |
| | ) |
| |
|
| | _xlmRobertaBaseViTB32 = dict( |
| | laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'), |
| | ) |
| |
|
| | _xlmRobertaLargeFrozenViTH14 = dict( |
| | frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'), |
| | ) |
| |
|
| | _convnext_base = dict( |
| | laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'), |
| | ) |
| |
|
| | _convnext_base_w = dict( |
| | laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'), |
| | laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'), |
| | laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'), |
| | ) |
| |
|
| | _convnext_base_w_320 = dict( |
| | laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'), |
| | laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'), |
| | ) |
| |
|
| | _convnext_large_d = dict( |
| | laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'), |
| | ) |
| |
|
| | _convnext_large_d_320 = dict( |
| | laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'), |
| | laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'), |
| | ) |
| |
|
| | _convnext_xxlarge = dict( |
| | laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'), |
| | laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'), |
| | laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'), |
| | ) |
| |
|
| | _coca_VITB32 = dict( |
| | laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'), |
| | mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/') |
| | ) |
| |
|
| | _coca_VITL14 = dict( |
| | laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'), |
| | mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/') |
| | ) |
| |
|
| |
|
| | _PRETRAINED = { |
| | "RN50": _RN50, |
| | "RN101": _RN101, |
| | "RN50x4": _RN50x4, |
| | "RN50x16": _RN50x16, |
| | "RN50x64": _RN50x64, |
| |
|
| | "ViT-B-32": _VITB32, |
| | "ViT-B-32-256": _VITB32_256, |
| | "ViT-B-16": _VITB16, |
| | "ViT-B-16-plus-240": _VITB16_PLUS_240, |
| | "ViT-L-14": _VITL14, |
| | "ViT-L-14-336": _VITL14_336, |
| | "ViT-H-14": _VITH14, |
| | "ViT-H-14-378": _VITH14_378, |
| | "ViT-g-14": _VITg14, |
| | "ViT-bigG-14": _VITbigG14, |
| |
|
| | "roberta-ViT-B-32": _robertaViTB32, |
| | "xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32, |
| | "xlm-roberta-large-ViT-H-14": _xlmRobertaLargeFrozenViTH14, |
| |
|
| | "convnext_base": _convnext_base, |
| | "convnext_base_w": _convnext_base_w, |
| | "convnext_base_w_320": _convnext_base_w_320, |
| | "convnext_large_d": _convnext_large_d, |
| | "convnext_large_d_320": _convnext_large_d_320, |
| | "convnext_xxlarge": _convnext_xxlarge, |
| |
|
| | "coca_ViT-B-32": _coca_VITB32, |
| | "coca_ViT-L-14": _coca_VITL14, |
| |
|
| | "EVA01-g-14": dict( |
| | |
| | laion400m_s11b_b41k=_pcfg(hf_hub='timm/eva_giant_patch14_clip_224.laion400m_s11b_b41k/'), |
| | ), |
| | "EVA01-g-14-plus": dict( |
| | |
| | merged2b_s11b_b114k=_pcfg(hf_hub='timm/eva_giant_patch14_plus_clip_224.merged2b_s11b_b114k/'), |
| | ), |
| | "EVA02-B-16": dict( |
| | |
| | merged2b_s8b_b131k=_pcfg(hf_hub='timm/eva02_base_patch16_clip_224.merged2b_s8b_b131k/'), |
| | ), |
| | "EVA02-L-14": dict( |
| | |
| | merged2b_s4b_b131k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_224.merged2b_s4b_b131k/'), |
| | ), |
| | "EVA02-L-14-336": dict( |
| | |
| | merged2b_s6b_b61k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_336.merged2b_s6b_b61k/'), |
| | ), |
| | "EVA02-E-14": dict( |
| | |
| | laion2b_s4b_b115k=_pcfg(hf_hub='timm/eva02_enormous_patch14_clip_224.laion2b_s4b_b115k/'), |
| | ), |
| | "EVA02-E-14-plus": dict( |
| | |
| | laion2b_s9b_b144k=_pcfg(hf_hub='timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k/'), |
| | ), |
| |
|
| | "ViT-B-16-SigLIP": dict( |
| | webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP/'), |
| | ), |
| | "ViT-B-16-SigLIP-256": dict( |
| | webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-256/'), |
| | ), |
| | "ViT-B-16-SigLIP-i18n-256": dict( |
| | webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-i18n-256/'), |
| | ), |
| | "ViT-B-16-SigLIP-384": dict( |
| | webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-384/'), |
| | ), |
| | "ViT-B-16-SigLIP-512": dict( |
| | webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-512/'), |
| | ), |
| | "ViT-L-16-SigLIP-256": dict( |
| | webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP-256/'), |
| | ), |
| | "ViT-L-16-SigLIP-384": dict( |
| | webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP-384/'), |
| | ), |
| | "ViT-SO400M-14-SigLIP": dict( |
| | webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP/'), |
| | ), |
| | "ViT-SO400M-16-SigLIP-i18n-256": dict( |
| | webli=_slpcfg(hf_hub='timm/ViT-SO400M-16-SigLIP-i18n-256/'), |
| | ), |
| | "ViT-SO400M-14-SigLIP-378": dict( |
| | webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP-384/'), |
| | ), |
| | "ViT-SO400M-14-SigLIP-384": dict( |
| | webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP-384/'), |
| | ), |
| |
|
| | "ViT-B-32-SigLIP2-256": dict( |
| | webli=_slpcfg(hf_hub='timm/ViT-B-32-SigLIP2-256/'), |
| | ), |
| | "ViT-B-16-SigLIP2": dict( |
| | webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP2/'), |
| | ), |
| | "ViT-B-16-SigLIP2-256": dict( |
| | webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP2-256/'), |
| | ), |
| | "ViT-B-16-SigLIP2-384": dict( |
| | webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP2-384/'), |
| | ), |
| | "ViT-B-16-SigLIP2-512": dict( |
| | webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP2-512/'), |
| | ), |
| | "ViT-L-16-SigLIP2-256": dict( |
| | webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP2-256/'), |
| | ), |
| | "ViT-L-16-SigLIP2-384": dict( |
| | webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP2-384/'), |
| | ), |
| | "ViT-L-16-SigLIP2-512": dict( |
| | webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP2-512/'), |
| | ), |
| | "ViT-SO400M-14-SigLIP2": dict( |
| | webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP2/'), |
| | ), |
| | "ViT-SO400M-14-SigLIP2-378": dict( |
| | webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP2-378/'), |
| | ), |
| | "ViT-SO400M-16-SigLIP2-256": dict( |
| | webli=_slpcfg(hf_hub='timm/ViT-SO400M-16-SigLIP2-256/'), |
| | ), |
| | "ViT-SO400M-16-SigLIP2-384": dict( |
| | webli=_slpcfg(hf_hub='timm/ViT-SO400M-16-SigLIP2-384/'), |
| | ), |
| | "ViT-SO400M-16-SigLIP2-512": dict( |
| | webli=_slpcfg(hf_hub='timm/ViT-SO400M-16-SigLIP2-512/'), |
| | ), |
| | "ViT-gopt-16-SigLIP2-256": dict( |
| | webli=_slpcfg(hf_hub='timm/ViT-gopt-16-SigLIP2-256/'), |
| | ), |
| | "ViT-gopt-16-SigLIP2-384": dict( |
| | webli=_slpcfg(hf_hub='timm/ViT-gopt-16-SigLIP2-384/'), |
| | ), |
| |
|
| | "ViT-L-14-CLIPA": dict( |
| | datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-L-14-CLIPA-datacomp1B/'), |
| | ), |
| | "ViT-L-14-CLIPA-336": dict( |
| | datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-L-14-CLIPA-336-datacomp1B/'), |
| | ), |
| | "ViT-H-14-CLIPA": dict( |
| | datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-datacomp1B/'), |
| | ), |
| | "ViT-H-14-CLIPA-336": dict( |
| | laion2b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-336-laion2B/'), |
| | datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-336-datacomp1B/'), |
| | ), |
| | "ViT-bigG-14-CLIPA": dict( |
| | datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-bigG-14-CLIPA-datacomp1B/'), |
| | ), |
| | "ViT-bigG-14-CLIPA-336": dict( |
| | datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-bigG-14-CLIPA-336-datacomp1B/'), |
| | ), |
| |
|
| | "nllb-clip-base": dict( |
| | v1=_pcfg(hf_hub='visheratin/nllb-clip-base-oc/'), |
| | ), |
| | "nllb-clip-large": dict( |
| | v1=_pcfg(hf_hub='visheratin/nllb-clip-large-oc/'), |
| | ), |
| |
|
| | "nllb-clip-base-siglip": dict( |
| | v1=_slpcfg(hf_hub='visheratin/nllb-clip-base-siglip/'), |
| | mrl=_slpcfg(hf_hub='visheratin/nllb-siglip-mrl-base/'), |
| | ), |
| | "nllb-clip-large-siglip": dict( |
| | v1=_slpcfg(hf_hub='visheratin/nllb-clip-large-siglip/'), |
| | mrl=_slpcfg(hf_hub='visheratin/nllb-siglip-mrl-large/'), |
| | ), |
| |
|
| | "MobileCLIP-S1": dict( |
| | datacompdr=_mccfg(hf_hub='apple/MobileCLIP-S1-OpenCLIP/')), |
| | "MobileCLIP-S2": dict( |
| | datacompdr=_mccfg(hf_hub='apple/MobileCLIP-S2-OpenCLIP/')), |
| | "MobileCLIP-B": dict( |
| | datacompdr=_mccfg(hf_hub='apple/MobileCLIP-B-OpenCLIP/'), |
| | datacompdr_lt=_mccfg(hf_hub='apple/MobileCLIP-B-LT-OpenCLIP/'), |
| | ), |
| |
|
| | "ViTamin-S": dict( |
| | datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-S/pytorch_model.bin'), |
| | ), |
| | "ViTamin-S-LTT": dict( |
| | datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-S-LTT/pytorch_model.bin'), |
| | ), |
| | "ViTamin-B": dict( |
| | datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-B/pytorch_model.bin'), |
| | ), |
| | "ViTamin-B-LTT": dict( |
| | datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-B-LTT/pytorch_model.bin'), |
| | ), |
| | "ViTamin-L": dict( |
| | datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-224px/pytorch_model.bin'), |
| | ), |
| | "ViTamin-L-256": dict( |
| | datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-256px/pytorch_model.bin'), |
| | ), |
| | "ViTamin-L-336": dict( |
| | datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-336px/pytorch_model.bin'), |
| | ), |
| | "ViTamin-L-384": dict( |
| | datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-384px/pytorch_model.bin'), |
| | ), |
| | "ViTamin-L2": dict( |
| | datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-224px/pytorch_model.bin'), |
| | ), |
| | "ViTamin-L2-256": dict( |
| | datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-256px/pytorch_model.bin'), |
| | ), |
| | "ViTamin-L2-336": dict( |
| | datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-336px/pytorch_model.bin'), |
| | ), |
| | "ViTamin-L2-384": dict( |
| | datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-384px/pytorch_model.bin'), |
| | ), |
| | "ViTamin-XL-256": dict( |
| | datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-256px/pytorch_model.bin'), |
| | ), |
| | "ViTamin-XL-336": dict( |
| | datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-336px/pytorch_model.bin'), |
| | ), |
| | "ViTamin-XL-384": dict( |
| | datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-384px/pytorch_model.bin'), |
| | ), |
| | } |
| |
|
| | _PRETRAINED_quickgelu = {} |
| | for k, v in _PRETRAINED.items(): |
| | quick_gelu_tags = {} |
| | for tk, tv in v.items(): |
| | if tv.get('quick_gelu', False): |
| | quick_gelu_tags[tk] = copy.deepcopy(tv) |
| | if quick_gelu_tags: |
| | _PRETRAINED_quickgelu[k + '-quickgelu'] = quick_gelu_tags |
| | _PRETRAINED.update(_PRETRAINED_quickgelu) |
| |
|
| | def _clean_tag(tag: str): |
| | |
| | return tag.lower().replace('-', '_') |
| |
|
| |
|
| | def list_pretrained(as_str: bool = False): |
| | """ returns list of pretrained models |
| | Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True |
| | """ |
| | return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] |
| |
|
| |
|
| | def list_pretrained_models_by_tag(tag: str): |
| | """ return all models having the specified pretrain tag """ |
| | models = [] |
| | tag = _clean_tag(tag) |
| | for k in _PRETRAINED.keys(): |
| | if tag in _PRETRAINED[k]: |
| | models.append(k) |
| | return models |
| |
|
| |
|
| | def list_pretrained_tags_by_model(model: str): |
| | """ return all pretrain tags for the specified model architecture """ |
| | tags = [] |
| | if model in _PRETRAINED: |
| | tags.extend(_PRETRAINED[model].keys()) |
| | return tags |
| |
|
| |
|
| | def is_pretrained_cfg(model: str, tag: str): |
| | if model not in _PRETRAINED: |
| | return False |
| | return _clean_tag(tag) in _PRETRAINED[model] |
| |
|
| |
|
| | def get_pretrained_cfg(model: str, tag: str): |
| | if model not in _PRETRAINED: |
| | return {} |
| | model_pretrained = _PRETRAINED[model] |
| | return model_pretrained.get(_clean_tag(tag), {}) |
| |
|
| |
|
| | def get_pretrained_url(model: str, tag: str): |
| | cfg = get_pretrained_cfg(model, _clean_tag(tag)) |
| | return cfg.get('url', '') |
| |
|
| |
|
| | def download_pretrained_from_url( |
| | url: str, |
| | cache_dir: Optional[str] = None, |
| | ): |
| | if not cache_dir: |
| | cache_dir = os.path.expanduser("~/.cache/clip") |
| | os.makedirs(cache_dir, exist_ok=True) |
| | filename = os.path.basename(url) |
| |
|
| | if 'openaipublic' in url: |
| | expected_sha256 = url.split("/")[-2] |
| | elif 'mlfoundations' in url: |
| | expected_sha256 = os.path.splitext(filename)[0].split("-")[-1] |
| | else: |
| | expected_sha256 = '' |
| |
|
| | download_target = os.path.join(cache_dir, filename) |
| |
|
| | if os.path.exists(download_target) and not os.path.isfile(download_target): |
| | raise RuntimeError(f"{download_target} exists and is not a regular file") |
| |
|
| | if os.path.isfile(download_target): |
| | if expected_sha256: |
| | if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): |
| | return download_target |
| | else: |
| | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") |
| | else: |
| | return download_target |
| |
|
| | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: |
| | with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: |
| | while True: |
| | buffer = source.read(8192) |
| | if not buffer: |
| | break |
| |
|
| | output.write(buffer) |
| | loop.update(len(buffer)) |
| |
|
| | if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): |
| | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") |
| |
|
| | return download_target |
| |
|
| |
|
| | def has_hf_hub(necessary=False): |
| | if not _has_hf_hub and necessary: |
| | |
| | raise RuntimeError( |
| | 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') |
| | return _has_hf_hub |
| |
|
| |
|
| | def _get_safe_alternatives(filename: str) -> Iterable[str]: |
| | """Returns potential safetensors alternatives for a given filename. |
| | |
| | Use case: |
| | When downloading a model from the Huggingface Hub, we first look if a .safetensors file exists and if yes, we use it. |
| | """ |
| | if filename == HF_WEIGHTS_NAME: |
| | yield HF_SAFE_WEIGHTS_NAME |
| |
|
| | if filename not in (HF_WEIGHTS_NAME,) and (filename.endswith(".bin") or filename.endswith(".pth")): |
| | yield filename[:-4] + ".safetensors" |
| |
|
| |
|
| | def download_pretrained_from_hf( |
| | model_id: str, |
| | filename: Optional[str] = None, |
| | revision: Optional[str] = None, |
| | cache_dir: Optional[str] = None, |
| | ): |
| | has_hf_hub(True) |
| |
|
| | filename = filename or HF_WEIGHTS_NAME |
| |
|
| | |
| | if _has_safetensors: |
| | for safe_filename in _get_safe_alternatives(filename): |
| | try: |
| | cached_file = hf_hub_download( |
| | repo_id=model_id, |
| | filename=safe_filename, |
| | revision=revision, |
| | cache_dir=cache_dir, |
| | ) |
| | return cached_file |
| | except Exception: |
| | pass |
| |
|
| | try: |
| | |
| | cached_file = hf_hub_download( |
| | repo_id=model_id, |
| | filename=filename, |
| | revision=revision, |
| | cache_dir=cache_dir, |
| | ) |
| | return cached_file |
| | except Exception as e: |
| | raise FileNotFoundError(f"Failed to download file ({filename}) for {model_id}. Last error: {e}") |
| |
|
| |
|
| | def download_pretrained( |
| | cfg: Dict, |
| | prefer_hf_hub: bool = True, |
| | cache_dir: Optional[str] = None, |
| | ): |
| | target = '' |
| | if not cfg: |
| | return target |
| |
|
| | if 'file' in cfg: |
| | return cfg['file'] |
| |
|
| | has_hub = has_hf_hub() |
| | download_url = cfg.get('url', '') |
| | download_hf_hub = cfg.get('hf_hub', '') |
| | if has_hub and prefer_hf_hub and download_hf_hub: |
| | |
| | download_url = '' |
| |
|
| | if download_url: |
| | target = download_pretrained_from_url(download_url, cache_dir=cache_dir) |
| | elif download_hf_hub: |
| | has_hf_hub(True) |
| | |
| | |
| | |
| | model_id, filename = os.path.split(download_hf_hub) |
| | if filename: |
| | target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir) |
| | else: |
| | target = download_pretrained_from_hf(model_id, cache_dir=cache_dir) |
| |
|
| | return target |
| |
|
| | |
| | def merge_preprocess_dict( |
| | base: Union[PreprocessCfg, Dict], |
| | overlay: Dict, |
| | ): |
| | """ Merge overlay key-value pairs on top of base preprocess cfg or dict. |
| | Input dicts are filtered based on PreprocessCfg fields. |
| | """ |
| | if isinstance(base, PreprocessCfg): |
| | base_clean = asdict(base) |
| | else: |
| | base_clean = {k: v for k, v in base.items() if k in _PREPROCESS_KEYS} |
| | if overlay: |
| | overlay_clean = {k: v for k, v in overlay.items() if k in _PREPROCESS_KEYS and v is not None} |
| | base_clean.update(overlay_clean) |
| | return base_clean |
| |
|
| |
|
| | def merge_preprocess_kwargs(base: PreprocessCfg, **kwargs): |
| | return merge_preprocess_dict(base, kwargs) |
| |
|
| |
|
| | @dataclass |
| | class PreprocessCfg: |
| | size: Union[int, Tuple[int, int]] = 224 |
| | mode: str = 'RGB' |
| | mean: Tuple[float, ...] = OPENAI_DATASET_MEAN |
| | std: Tuple[float, ...] = OPENAI_DATASET_STD |
| | interpolation: str = 'bicubic' |
| | resize_mode: str = 'shortest' |
| | fill_color: int = 0 |
| |
|
| | def __post_init__(self): |
| | assert self.mode in ('RGB',) |
| |
|
| | @property |
| | def num_channels(self): |
| | return 3 |
| |
|
| | @property |
| | def input_size(self): |
| | return (self.num_channels,) + to_2tuple(self.size) |
| |
|
| |
|
| |
|
| |
|
| | @dataclass |
| | class PreprocessCfg: |
| | size: Union[int, Tuple[int, int]] = 224 |
| | mode: str = 'RGB' |
| | mean: Tuple[float, ...] = OPENAI_DATASET_MEAN |
| | std: Tuple[float, ...] = OPENAI_DATASET_STD |
| | interpolation: str = 'bicubic' |
| | resize_mode: str = 'shortest' |
| | fill_color: int = 0 |
| |
|
| | def __post_init__(self): |
| | assert self.mode in ('RGB',) |
| |
|
| | @property |
| | def num_channels(self): |
| | return 3 |
| |
|
| | @property |
| | def input_size(self): |
| | return (self.num_channels,) + to_2tuple(self.size) |
| |
|
| | _PREPROCESS_KEYS = set(asdict(PreprocessCfg()).keys()) |
| |
|
| |
|
| | def merge_preprocess_dict( |
| | base: Union[PreprocessCfg, Dict], |
| | overlay: Dict, |
| | ): |
| | """ Merge overlay key-value pairs on top of base preprocess cfg or dict. |
| | Input dicts are filtered based on PreprocessCfg fields. |
| | """ |
| | if isinstance(base, PreprocessCfg): |
| | base_clean = asdict(base) |
| | else: |
| | base_clean = {k: v for k, v in base.items() if k in _PREPROCESS_KEYS} |
| | if overlay: |
| | overlay_clean = {k: v for k, v in overlay.items() if k in _PREPROCESS_KEYS and v is not None} |
| | base_clean.update(overlay_clean) |
| | return base_clean |
| |
|
| |
|
| | def merge_preprocess_kwargs(base: PreprocessCfg, **kwargs): |
| | return merge_preprocess_dict(base, kwargs) |
| |
|
| |
|
| | @dataclass |
| | class AugmentationCfg: |
| | scale: Tuple[float, float] = (0.9, 1.0) |
| | ratio: Optional[Tuple[float, float]] = None |
| | color_jitter: Optional[Union[float, Tuple[float, float, float], Tuple[float, float, float, float]]] = None |
| | re_prob: Optional[float] = None |
| | re_count: Optional[int] = None |
| | use_timm: bool = False |
| |
|
| | |
| | color_jitter_prob: float = None |
| | gray_scale_prob: float = None |
| |
|
| |
|
| | def _setup_size(size, error_msg): |
| | if isinstance(size, numbers.Number): |
| | return int(size), int(size) |
| |
|
| | if isinstance(size, Sequence) and len(size) == 1: |
| | return size[0], size[0] |
| |
|
| | if len(size) != 2: |
| | raise ValueError(error_msg) |
| |
|
| | return size |
| |
|
| |
|
| | class ResizeKeepRatio: |
| | """ Resize and Keep Ratio |
| | |
| | Copy & paste from `timm` |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | size, |
| | longest=0., |
| | interpolation=InterpolationMode.BICUBIC, |
| | random_scale_prob=0., |
| | random_scale_range=(0.85, 1.05), |
| | random_aspect_prob=0., |
| | random_aspect_range=(0.9, 1.11) |
| | ): |
| | if isinstance(size, (list, tuple)): |
| | self.size = tuple(size) |
| | else: |
| | self.size = (size, size) |
| | self.interpolation = interpolation |
| | self.longest = float(longest) |
| | self.random_scale_prob = random_scale_prob |
| | self.random_scale_range = random_scale_range |
| | self.random_aspect_prob = random_aspect_prob |
| | self.random_aspect_range = random_aspect_range |
| |
|
| | @staticmethod |
| | def get_params( |
| | img, |
| | target_size, |
| | longest, |
| | random_scale_prob=0., |
| | random_scale_range=(0.85, 1.05), |
| | random_aspect_prob=0., |
| | random_aspect_range=(0.9, 1.11) |
| | ): |
| | """Get parameters |
| | """ |
| | source_size = img.size[::-1] |
| | h, w = source_size |
| | target_h, target_w = target_size |
| | ratio_h = h / target_h |
| | ratio_w = w / target_w |
| | ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest) |
| | if random_scale_prob > 0 and random.random() < random_scale_prob: |
| | ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1]) |
| | ratio_factor = (ratio_factor, ratio_factor) |
| | else: |
| | ratio_factor = (1., 1.) |
| | if random_aspect_prob > 0 and random.random() < random_aspect_prob: |
| | aspect_factor = random.uniform(random_aspect_range[0], random_aspect_range[1]) |
| | ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor) |
| | size = [round(x * f / ratio) for x, f in zip(source_size, ratio_factor)] |
| | return size |
| |
|
| | def __call__(self, img): |
| | """ |
| | Args: |
| | img (PIL Image): Image to be cropped and resized. |
| | |
| | Returns: |
| | PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size |
| | """ |
| | size = self.get_params( |
| | img, self.size, self.longest, |
| | self.random_scale_prob, self.random_scale_range, |
| | self.random_aspect_prob, self.random_aspect_range |
| | ) |
| | img = F.resize(img, size, self.interpolation) |
| | return img |
| |
|
| | def __repr__(self): |
| | format_string = self.__class__.__name__ + '(size={0}'.format(self.size) |
| | format_string += f', interpolation={self.interpolation})' |
| | format_string += f', longest={self.longest:.3f})' |
| | return format_string |
| |
|
| |
|
| | def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> torch.Tensor: |
| | """Center crops and/or pads the given image. |
| | If the image is torch Tensor, it is expected |
| | to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. |
| | If image size is smaller than output size along any edge, image is padded with 0 and then center cropped. |
| | |
| | Args: |
| | img (PIL Image or Tensor): Image to be cropped. |
| | output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int, |
| | it is used for both directions. |
| | fill (int, Tuple[int]): Padding color |
| | |
| | Returns: |
| | PIL Image or Tensor: Cropped image. |
| | """ |
| | if isinstance(output_size, numbers.Number): |
| | output_size = (int(output_size), int(output_size)) |
| | elif isinstance(output_size, (tuple, list)) and len(output_size) == 1: |
| | output_size = (output_size[0], output_size[0]) |
| |
|
| | _, image_height, image_width = F.get_dimensions(img) |
| | crop_height, crop_width = output_size |
| |
|
| | if crop_width > image_width or crop_height > image_height: |
| | padding_ltrb = [ |
| | (crop_width - image_width) // 2 if crop_width > image_width else 0, |
| | (crop_height - image_height) // 2 if crop_height > image_height else 0, |
| | (crop_width - image_width + 1) // 2 if crop_width > image_width else 0, |
| | (crop_height - image_height + 1) // 2 if crop_height > image_height else 0, |
| | ] |
| | img = F.pad(img, padding_ltrb, fill=fill) |
| | _, image_height, image_width = F.get_dimensions(img) |
| | if crop_width == image_width and crop_height == image_height: |
| | return img |
| |
|
| | crop_top = int(round((image_height - crop_height) / 2.0)) |
| | crop_left = int(round((image_width - crop_width) / 2.0)) |
| | return F.crop(img, crop_top, crop_left, crop_height, crop_width) |
| |
|
| |
|
| | class CenterCropOrPad(torch.nn.Module): |
| | """Crops the given image at the center. |
| | If the image is torch Tensor, it is expected |
| | to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. |
| | If image size is smaller than output size along any edge, image is padded with 0 and then center cropped. |
| | |
| | Args: |
| | size (sequence or int): Desired output size of the crop. If size is an |
| | int instead of sequence like (h, w), a square crop (size, size) is |
| | made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). |
| | """ |
| |
|
| | def __init__(self, size, fill=0): |
| | super().__init__() |
| | self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") |
| | self.fill = fill |
| |
|
| | def forward(self, img): |
| | """ |
| | Args: |
| | img (PIL Image or Tensor): Image to be cropped. |
| | |
| | Returns: |
| | PIL Image or Tensor: Cropped image. |
| | """ |
| | return center_crop_or_pad(img, self.size, fill=self.fill) |
| |
|
| | def __repr__(self) -> str: |
| | return f"{self.__class__.__name__}(size={self.size})" |
| |
|
| |
|
| | def _convert_to_rgb(image): |
| | return image.convert('RGB') |
| |
|
| |
|
| | class color_jitter(object): |
| | """ |
| | Apply Color Jitter to the PIL image with a specified probability. |
| | """ |
| | def __init__(self, brightness=0., contrast=0., saturation=0., hue=0., p=0.8): |
| | assert 0. <= p <= 1. |
| | self.p = p |
| | self.transf = ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue) |
| |
|
| | def __call__(self, img): |
| | if random.random() < self.p: |
| | return self.transf(img) |
| | else: |
| | return img |
| |
|
| |
|
| | class gray_scale(object): |
| | """ |
| | Apply Gray Scale to the PIL image with a specified probability. |
| | """ |
| | def __init__(self, p=0.2): |
| | assert 0. <= p <= 1. |
| | self.p = p |
| | self.transf = Grayscale(num_output_channels=3) |
| |
|
| | def __call__(self, img): |
| | if random.random() < self.p: |
| | return self.transf(img) |
| | else: |
| | return img |
| |
|
| |
|
| | def image_transform( |
| | image_size: Union[int, Tuple[int, int]], |
| | is_train: bool, |
| | mean: Optional[Tuple[float, ...]] = None, |
| | std: Optional[Tuple[float, ...]] = None, |
| | resize_mode: Optional[str] = None, |
| | interpolation: Optional[str] = None, |
| | fill_color: int = 0, |
| | aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, |
| | ): |
| | mean = mean or OPENAI_DATASET_MEAN |
| | if not isinstance(mean, (list, tuple)): |
| | mean = (mean,) * 3 |
| |
|
| | std = std or OPENAI_DATASET_STD |
| | if not isinstance(std, (list, tuple)): |
| | std = (std,) * 3 |
| |
|
| | interpolation = interpolation or 'bicubic' |
| | assert interpolation in ['bicubic', 'bilinear', 'random'] |
| | |
| | interpolation_mode = InterpolationMode.BILINEAR if interpolation == 'bilinear' else InterpolationMode.BICUBIC |
| |
|
| | resize_mode = resize_mode or 'shortest' |
| | assert resize_mode in ('shortest', 'longest', 'squash') |
| |
|
| | if isinstance(aug_cfg, dict): |
| | aug_cfg = AugmentationCfg(**aug_cfg) |
| | else: |
| | aug_cfg = aug_cfg or AugmentationCfg() |
| |
|
| | normalize = Normalize(mean=mean, std=std) |
| |
|
| | if is_train: |
| | aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None} |
| | use_timm = aug_cfg_dict.pop('use_timm', False) |
| | if use_timm: |
| | from timm.data import create_transform |
| | if isinstance(image_size, (tuple, list)): |
| | assert len(image_size) >= 2 |
| | input_size = (3,) + image_size[-2:] |
| | else: |
| | input_size = (3, image_size, image_size) |
| |
|
| | aug_cfg_dict.setdefault('color_jitter', None) |
| | |
| | aug_cfg_dict.pop('color_jitter_prob', None) |
| | aug_cfg_dict.pop('gray_scale_prob', None) |
| |
|
| | train_transform = create_transform( |
| | input_size=input_size, |
| | is_training=True, |
| | hflip=0., |
| | mean=mean, |
| | std=std, |
| | re_mode='pixel', |
| | interpolation=interpolation, |
| | **aug_cfg_dict, |
| | ) |
| | else: |
| | train_transform = [ |
| | RandomResizedCrop( |
| | image_size, |
| | scale=aug_cfg_dict.pop('scale'), |
| | interpolation=InterpolationMode.BICUBIC, |
| | ), |
| | _convert_to_rgb, |
| | ] |
| | if aug_cfg.color_jitter_prob: |
| | assert aug_cfg.color_jitter is not None and len(aug_cfg.color_jitter) == 4 |
| | train_transform.extend([ |
| | color_jitter(*aug_cfg.color_jitter, p=aug_cfg.color_jitter_prob) |
| | ]) |
| | if aug_cfg.gray_scale_prob: |
| | train_transform.extend([ |
| | gray_scale(aug_cfg.gray_scale_prob) |
| | ]) |
| | train_transform.extend([ |
| | ToTensor(), |
| | normalize, |
| | ]) |
| | train_transform = Compose(train_transform) |
| | if aug_cfg_dict: |
| | warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).') |
| | return train_transform |
| | else: |
| | if resize_mode == 'longest': |
| | transforms = [ |
| | ResizeKeepRatio(image_size, interpolation=interpolation_mode, longest=1), |
| | CenterCropOrPad(image_size, fill=fill_color) |
| | ] |
| | elif resize_mode == 'squash': |
| | if isinstance(image_size, int): |
| | image_size = (image_size, image_size) |
| | transforms = [ |
| | Resize(image_size, interpolation=interpolation_mode), |
| | ] |
| | else: |
| | assert resize_mode == 'shortest' |
| | if not isinstance(image_size, (tuple, list)): |
| | image_size = (image_size, image_size) |
| | if image_size[0] == image_size[1]: |
| | |
| | transforms = [ |
| | Resize(image_size[0], interpolation=interpolation_mode) |
| | ] |
| | else: |
| | |
| | transforms = [ResizeKeepRatio(image_size)] |
| | transforms += [CenterCrop(image_size)] |
| |
|
| | transforms.extend([ |
| | _convert_to_rgb, |
| | ToTensor(), |
| | normalize, |
| | ]) |
| | return Compose(transforms) |
| | |
| | |
| | def image_transform_v2( |
| | cfg: PreprocessCfg, |
| | is_train: bool, |
| | aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, |
| | ): |
| | return image_transform( |
| | image_size=cfg.size, |
| | is_train=is_train, |
| | mean=cfg.mean, |
| | std=cfg.std, |
| | interpolation=cfg.interpolation, |
| | resize_mode=cfg.resize_mode, |
| | fill_color=cfg.fill_color, |
| | aug_cfg=aug_cfg, |
| | ) |
| |
|
| | @dataclass |
| | class AugmentationCfg: |
| | scale: Tuple[float, float] = (0.9, 1.0) |
| | ratio: Optional[Tuple[float, float]] = None |
| | color_jitter: Optional[Union[float, Tuple[float, float, float], Tuple[float, float, float, float]]] = None |
| | re_prob: Optional[float] = None |
| | re_count: Optional[int] = None |
| | use_timm: bool = False |
| |
|
| | |
| | color_jitter_prob: float = None |
| | gray_scale_prob: float = None |
| |
|
| | def set_model_preprocess_cfg(model, preprocess_cfg: Dict[str, Any]): |
| | module = getattr(model, 'visual', model) |
| | module.image_mean = preprocess_cfg['mean'] |
| | module.image_std = preprocess_cfg['std'] |
| | module.preprocess_cfg = copy.deepcopy(preprocess_cfg) |
| |
|
| |
|
| | @torch.no_grad() |
| | def convert_mobile_clip_state_dict(model: CustomTextCLIP, state_dict, fastvit = True): |
| |
|
| | def _convert_timm_img(state_dict): |
| | if fastvit: |
| | from timm.models.fastvit import checkpoint_filter_fn |
| | else: |
| | from timm.models.vision_transformer_hybrid import checkpoint_filter_fn |
| | timm_state_dict = checkpoint_filter_fn(state_dict, model.visual.trunk) |
| | timm_state_dict = {'visual.trunk.' + k: v for k, v in timm_state_dict.items()} |
| | return timm_state_dict |
| |
|
| | def _convert_openclip_txt(state_dict, prefix='text_encoder.'): |
| | text_dict = {} |
| | for k, v in state_dict.items(): |
| | if not k.startswith(prefix): |
| | continue |
| | k = k.replace(prefix, '') |
| | k = k.replace('projection_layer', 'text_projection') |
| | k = k.replace('embedding_layer', 'token_embedding') |
| | if k.startswith('positional_embedding.pos_embed.pos_embed'): |
| | k = k.replace('positional_embedding.pos_embed.pos_embed', 'positional_embedding') |
| | v = v.squeeze() |
| | k = k.replace('final_layer_norm', 'ln_final') |
| | k = k.replace('pre_norm_mha.0', 'ln_1') |
| | k = k.replace('pre_norm_mha.1', 'attn') |
| | k = k.replace('pre_norm_ffn.0', 'ln_2') |
| | k = k.replace('pre_norm_ffn.1', 'mlp.c_fc') |
| | k = k.replace('pre_norm_ffn.4', 'mlp.c_proj') |
| | k = k.replace('qkv_proj.weight', 'in_proj_weight') |
| | k = k.replace('qkv_proj.bias', 'in_proj_bias') |
| | k = k.replace('transformer.', 'transformer.resblocks.') |
| | text_dict['text.' + k] = v |
| | return text_dict |
| |
|
| | image_dict = _convert_timm_img(state_dict) |
| | text_dict = _convert_openclip_txt(state_dict) |
| | out_dict = {**image_dict, **text_dict} |
| | out_dict['logit_scale'] = state_dict['logit_scale'] |
| | return out_dict |
| |
|
| |
|
| | def convert_state_dict(model: Union[CustomTextCLIP, CLIP], state_dict): |
| | if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict: |
| | |
| | state_dict = convert_mobile_clip_state_dict(model, state_dict) |
| | if 'image_encoder.model.patch_emb.0.block.conv.weight' in state_dict: |
| | |
| | state_dict = convert_mobile_clip_state_dict(model, state_dict, fastvit=False) |
| | return state_dict |
| |
|
| | def load_state_dict( |
| | checkpoint_path: str, |
| | device='cpu', |
| | weights_only=True, |
| | ): |
| | |
| | if str(checkpoint_path).endswith(".safetensors"): |
| | from safetensors.torch import load_file |
| | checkpoint = load_file(checkpoint_path, device=device) |
| | else: |
| | try: |
| | checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=weights_only) |
| | except TypeError: |
| | checkpoint = torch.load(checkpoint_path, map_location=device) |
| |
|
| | if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: |
| | state_dict = checkpoint['state_dict'] |
| | elif isinstance(checkpoint, torch.jit.ScriptModule): |
| | state_dict = checkpoint.state_dict() |
| | for key in ["input_resolution", "context_length", "vocab_size"]: |
| | state_dict.pop(key, None) |
| | else: |
| | state_dict = checkpoint |
| | if next(iter(state_dict.items()))[0].startswith('module'): |
| | state_dict = {k[7:]: v for k, v in state_dict.items()} |
| | return state_dict |
| |
|
| | def load_checkpoint( |
| | model: Union[CLIP, CustomTextCLIP], |
| | checkpoint_path: str, |
| | strict: bool = True, |
| | weights_only: bool = True, |
| | device='cpu', |
| | ): |
| | if Path(checkpoint_path).suffix in ('.npz', '.npy'): |
| | |
| | from open_clip.convert import load_big_vision_weights |
| | load_big_vision_weights(model, checkpoint_path) |
| | return {} |
| |
|
| | state_dict = load_state_dict(checkpoint_path, device=device, weights_only=weights_only) |
| |
|
| | |
| | state_dict = convert_state_dict(model, state_dict) |
| |
|
| | |
| | if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): |
| | state_dict = convert_to_custom_text_state_dict(state_dict) |
| |
|
| | |
| | if 'logit_scale' in state_dict and model.logit_scale.ndim != state_dict['logit_scale'].ndim: |
| | state_dict['logit_scale'] = state_dict['logit_scale'].reshape(model.logit_scale.shape) |
| |
|
| | |
| | if 'logit_bias' in state_dict and model.logit_bias.ndim != state_dict['logit_bias'].ndim: |
| | state_dict['logit_bias'] = state_dict['logit_bias'].reshape(model.logit_bias.shape) |
| |
|
| | |
| | if 'logit_bias' not in state_dict and model.logit_bias is not None: |
| | state_dict["logit_bias"] = torch.zeros_like(state_dict["logit_scale"]) |
| |
|
| | |
| | position_id_key = 'text.transformer.embeddings.position_ids' |
| | if position_id_key in state_dict and not hasattr(model, position_id_key): |
| | del state_dict[position_id_key] |
| |
|
| | resize_pos_embed(state_dict, model) |
| | resize_text_pos_embed(state_dict, model) |
| |
|
| | |
| | incompatible_keys = model.load_state_dict(state_dict, strict=strict) |
| | return incompatible_keys |
| |
|
| | |
| | HF_HUB_PREFIX = 'hf-hub:' |
| | _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] |
| | _MODEL_CONFIGS = {} |
| |
|
| | import json |
| |
|
| | def _get_hf_config( |
| | model_id: str, |
| | cache_dir: Optional[str] = None, |
| | ): |
| | """ Fetch model config from HuggingFace Hub. |
| | """ |
| | config_path = download_pretrained_from_hf( |
| | model_id, |
| | filename='open_clip_config.json', |
| | cache_dir=cache_dir, |
| | ) |
| | with open(config_path, 'r', encoding='utf-8') as f: |
| | config = json.load(f) |
| | return config |
| |
|
| | def get_model_config(model_name): |
| | """ Fetch model config from builtin (local library) configs. |
| | """ |
| | if model_name in _MODEL_CONFIGS: |
| | return copy.deepcopy(_MODEL_CONFIGS[model_name]) |
| | else: |
| | return None |
| |
|
| | def _natural_key(string_): |
| | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] |
| |
|
| |
|
| | def _rescan_model_configs(): |
| | global _MODEL_CONFIGS |
| |
|
| | config_ext = ('.json',) |
| | config_files = [] |
| | for config_path in _MODEL_CONFIG_PATHS: |
| | if config_path.is_file() and config_path.suffix in config_ext: |
| | config_files.append(config_path) |
| | elif config_path.is_dir(): |
| | for ext in config_ext: |
| | config_files.extend(config_path.glob(f'*{ext}')) |
| |
|
| | for cf in config_files: |
| | with open(cf, 'r') as f: |
| | model_cfg = json.load(f) |
| | if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')): |
| | _MODEL_CONFIGS[cf.stem] = model_cfg |
| |
|
| | _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))} |
| |
|
| |
|
| | _rescan_model_configs() |
| |
|
| | def list_models(): |
| | """ enumerate available model architectures based on config files """ |
| | return list(_MODEL_CONFIGS.keys()) |
| |
|
| |
|
| | def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs): |
| | if past: |
| | input_ids = input_ids[:, -1].unsqueeze(-1) |
| |
|
| | attention_mask = kwargs.get("attention_mask", None) |
| | position_ids = kwargs.get("position_ids", None) |
| |
|
| | if attention_mask is not None and position_ids is None: |
| | |
| | position_ids = attention_mask.long().cumsum(-1) - 1 |
| | position_ids.masked_fill_(attention_mask == 0, 1) |
| | else: |
| | position_ids = None |
| | return { |
| | "text": input_ids, |
| | "images": image_inputs, |
| | "past_key_values": past, |
| | "position_ids": position_ids, |
| | "attention_mask": attention_mask, |
| | } |
| |
|
| | @dataclass |
| | class MultimodalCfg(CLIPTextCfg): |
| | mlp_ratio: int = 4 |
| | dim_head: int = 64 |
| | heads: int = 8 |
| | n_queries: int = 256 |
| | attn_pooler_heads: int = 8 |
| |
|
| | try: |
| | from transformers import ( |
| | BeamSearchScorer, |
| | LogitsProcessorList, |
| | TopPLogitsWarper, |
| | TopKLogitsWarper, |
| | RepetitionPenaltyLogitsProcessor, |
| | MinLengthLogitsProcessor, |
| | MaxLengthCriteria, |
| | StopStringCriteria, |
| | EosTokenCriteria, |
| | StoppingCriteriaList |
| | ) |
| |
|
| | GENERATION_TYPES = { |
| | "top_k": TopKLogitsWarper, |
| | "top_p": TopPLogitsWarper, |
| | "beam_search": "beam_search" |
| | } |
| | _has_transformers = True |
| | except ImportError as e: |
| | GENERATION_TYPES = { |
| | "top_k": None, |
| | "top_p": None, |
| | "beam_search": "beam_search" |
| | } |
| | _has_transformers = False |
| |
|
| | def _token_to_tensor(token_id, device: str = "cpu") -> torch.Tensor: |
| | if not isinstance(token_id, torch.Tensor): |
| | if isinstance(token_id, int): |
| | token_id = [token_id] |
| | token_id = torch.tensor(token_id, device=device) |
| | return token_id |
| |
|
| |
|
| | def _build_text_decoder_tower( |
| | embed_dim, |
| | multimodal_cfg, |
| | quick_gelu: bool = False, |
| | cast_dtype: Optional[torch.dtype] = None, |
| | ): |
| | multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg |
| | act_layer = QuickGELU if quick_gelu else nn.GELU |
| | norm_layer = ( |
| | LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm |
| | ) |
| |
|
| | decoder = MultimodalTransformer( |
| | context_length=multimodal_cfg.context_length, |
| | width=multimodal_cfg.width, |
| | heads=multimodal_cfg.heads, |
| | layers=multimodal_cfg.layers, |
| | ls_init_value=multimodal_cfg.ls_init_value, |
| | output_dim=embed_dim, |
| | act_layer=act_layer, |
| | norm_layer=norm_layer, |
| | ) |
| |
|
| | return decoder |
| |
|
| | class CoCa(nn.Module): |
| | def __init__( |
| | self, |
| | embed_dim, |
| | multimodal_cfg: MultimodalCfg, |
| | text_cfg: CLIPTextCfg, |
| | vision_cfg: CLIPVisionCfg, |
| | quick_gelu: bool = False, |
| | init_logit_scale: float = np.log(1 / 0.07), |
| | init_logit_bias: Optional[float] = None, |
| | nonscalar_logit_scale: bool = False, |
| | cast_dtype: Optional[torch.dtype] = None, |
| | pad_id: int = 0, |
| | ): |
| | super().__init__() |
| | multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg |
| | text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg |
| | vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg |
| |
|
| | self.text = _build_text_tower( |
| | embed_dim=embed_dim, |
| | text_cfg=text_cfg, |
| | quick_gelu=quick_gelu, |
| | cast_dtype=cast_dtype, |
| | ) |
| |
|
| | vocab_size = ( |
| | text_cfg.vocab_size |
| | if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None |
| | else text_cfg.vocab_size |
| | ) |
| |
|
| | self.visual = _build_vision_tower( |
| | embed_dim=embed_dim, |
| | vision_cfg=vision_cfg, |
| | quick_gelu=quick_gelu, |
| | cast_dtype=cast_dtype, |
| | ) |
| |
|
| | self.text_decoder = _build_text_decoder_tower( |
| | vocab_size, |
| | multimodal_cfg=multimodal_cfg, |
| | quick_gelu=quick_gelu, |
| | cast_dtype=cast_dtype, |
| | ) |
| |
|
| | lshape = [1] if nonscalar_logit_scale else [] |
| | self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale) |
| | if init_logit_bias is not None: |
| | self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias) |
| | else: |
| | self.logit_bias = None |
| | self.pad_id = pad_id |
| |
|
| | self.context_length = multimodal_cfg.context_length |
| |
|
| | @torch.jit.ignore |
| | def set_grad_checkpointing(self, enable: bool = True): |
| | self.visual.set_grad_checkpointing(enable) |
| | self.text.set_grad_checkpointing(enable) |
| | self.text_decoder.set_grad_checkpointing(enable) |
| |
|
| | def _encode_image(self, images, normalize: bool = True): |
| | image_latent, tokens_embs = self.visual(images) |
| | image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent |
| | return image_latent, tokens_embs |
| |
|
| | def _encode_text(self, text, normalize: bool = True): |
| | text_latent, token_emb = self.text(text) |
| | text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent |
| | return text_latent, token_emb |
| |
|
| | def encode_image(self, images, normalize: bool = True): |
| | image_latent, _ = self._encode_image(images, normalize=normalize) |
| | return image_latent |
| |
|
| | def encode_text(self, text, normalize: bool = True): |
| | text_latent, _ = self._encode_text(text, normalize=normalize) |
| | return text_latent |
| |
|
| | def forward_intermediates( |
| | self, |
| | image: Optional[torch.Tensor] = None, |
| | text: Optional[torch.Tensor] = None, |
| | image_indices: Optional[Union[int, List[int]]] = None, |
| | text_indices: Optional[Union[int, List[int]]] = None, |
| | stop_early: bool = False, |
| | normalize: bool = True, |
| | normalize_intermediates: bool = False, |
| | intermediates_only: bool = False, |
| | image_output_fmt: str = 'NCHW', |
| | image_output_extra_tokens: bool = False, |
| | text_output_fmt: str = 'NLC', |
| | text_output_extra_tokens: bool = False, |
| | output_logits: bool = False, |
| | output_logit_scale_bias: bool = False, |
| | ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: |
| | """ Forward features that returns intermediates. |
| | |
| | Args: |
| | image: Input image tensor |
| | text: Input text tensor |
| | image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence |
| | text_indices: Take last n blocks if int, all if None, select matching indices if sequence |
| | stop_early: Stop iterating over blocks when last desired intermediate hit |
| | normalize: L2 Normalize final image and text features (if present) |
| | normalize_intermediates: Apply final encoder norm layer to all intermediates (if possible) |
| | intermediates_only: Only return intermediate features, do not return final features |
| | image_output_fmt: Shape of intermediate image feature outputs |
| | image_output_extra_tokens: Return both prefix and spatial intermediate tokens |
| | text_output_fmt: Shape of intermediate text feature outputs |
| | text_output_extra_tokens: Return both prefix and spatial intermediate tokens |
| | output_logits: Include logits in output |
| | output_logit_scale_bias: Include the logit scale bias in the output |
| | Returns: |
| | |
| | """ |
| | output = {} |
| | if intermediates_only: |
| | |
| | normalize = False |
| | output_logits = False |
| | if output_logits: |
| | assert False, 'FIXME, needs implementing' |
| |
|
| | if image is not None: |
| | image_output = self.visual.forward_intermediates( |
| | image, |
| | indices=image_indices, |
| | stop_early=stop_early, |
| | normalize_intermediates=normalize_intermediates, |
| | intermediates_only=intermediates_only, |
| | output_fmt=image_output_fmt, |
| | output_extra_tokens=image_output_extra_tokens, |
| | ) |
| | if normalize and "image_features" in image_output: |
| | image_output["image_features"] = F.normalize(image_output["image_features"], dim=-1) |
| | output.update(image_output) |
| |
|
| | if text is not None: |
| | text_output = self.text.forward_intermediates( |
| | text, |
| | indices=text_indices, |
| | stop_early=stop_early, |
| | normalize_intermediates=normalize_intermediates, |
| | intermediates_only=intermediates_only, |
| | output_fmt=text_output_fmt, |
| | output_extra_tokens=text_output_extra_tokens, |
| | ) |
| | if normalize and "text_features" in text_output: |
| | text_output["text_features"] = F.normalize(text_output["text_features"], dim=-1) |
| | output.update(text_output) |
| |
|
| | |
| | logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None |
| | if output_logit_scale_bias: |
| | output["logit_scale"] = logit_scale_exp |
| | if self.logit_bias is not None: |
| | output['logit_bias'] = self.logit_bias |
| |
|
| | return output |
| |
|
| | def forward( |
| | self, |
| | image, |
| | text: Optional[torch.Tensor] = None, |
| | image_latent: Optional[torch.Tensor] = None, |
| | image_embs: Optional[torch.Tensor] = None, |
| | output_labels: bool = True, |
| | ): |
| | if image_latent is None or image_embs is None: |
| | image_latent, image_embs = self._encode_image(image) |
| |
|
| | if text is None: |
| | return {"image_features": image_latent, "image_embs": image_embs} |
| |
|
| | text_latent, token_embs = self._encode_text(text) |
| |
|
| | |
| | labels: Optional[torch.Tensor] = text[:, 1:] if output_labels else None |
| | if output_labels: |
| | |
| | token_embs = token_embs[:, :-1] |
| |
|
| | logits = self.text_decoder(image_embs, token_embs) |
| | out_dict = { |
| | "image_features": image_latent, |
| | "text_features": text_latent, |
| | "logits": logits, |
| | "logit_scale": self.logit_scale.exp() |
| | } |
| | if labels is not None: |
| | out_dict["labels"] = labels |
| | if self.logit_bias is not None: |
| | out_dict["logit_bias"] = self.logit_bias |
| | return out_dict |
| |
|
| | def generate( |
| | self, |
| | image, |
| | text=None, |
| | seq_len=30, |
| | max_seq_len=77, |
| | temperature=1., |
| | generation_type="beam_search", |
| | top_p=0.1, |
| | top_k=1, |
| | pad_token_id=None, |
| | eos_token_id=None, |
| | sot_token_id=None, |
| | num_beams=6, |
| | num_beam_groups=3, |
| | min_seq_len=5, |
| | stopping_criteria=None, |
| | repetition_penalty=1.0, |
| | fixed_output_length=False |
| | ): |
| | |
| | |
| | assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`." |
| | assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len" |
| | device = image.device |
| |
|
| | with torch.no_grad(): |
| | sot_token_id = _token_to_tensor(49406 if sot_token_id is None else sot_token_id, device=device) |
| | eos_token_id = _token_to_tensor(49407 if eos_token_id is None else eos_token_id, device=device) |
| | pad_token_id = self.pad_id if pad_token_id is None else pad_token_id |
| | logit_processor = LogitsProcessorList( |
| | [ |
| | MinLengthLogitsProcessor(min_seq_len, eos_token_id), |
| | RepetitionPenaltyLogitsProcessor(repetition_penalty), |
| | ] |
| | ) |
| |
|
| | if stopping_criteria is None: |
| | stopping_criteria = [MaxLengthCriteria(max_length=seq_len)] |
| | stopping_criteria = StoppingCriteriaList(stopping_criteria) |
| |
|
| | if generation_type == "beam_search": |
| | output = self._generate_beamsearch( |
| | image_inputs=image, |
| | pad_token_id=pad_token_id, |
| | eos_token_id=eos_token_id, |
| | sot_token_id=sot_token_id, |
| | num_beams=num_beams, |
| | num_beam_groups=num_beam_groups, |
| | min_seq_len=min_seq_len, |
| | stopping_criteria=stopping_criteria, |
| | logit_processor=logit_processor, |
| | ) |
| | if fixed_output_length and output.shape[1] < seq_len: |
| | pad_len = seq_len - output.shape[1] |
| | return torch.cat(( |
| | output, |
| | torch.ones(output.shape[0], pad_len, device=device, dtype=output.dtype) * pad_token_id |
| | ), |
| | dim=1 |
| | ) |
| | return output |
| |
|
| | elif generation_type == "top_p": |
| | logit_warper = GENERATION_TYPES[generation_type](top_p) |
| | elif generation_type == "top_k": |
| | logit_warper = GENERATION_TYPES[generation_type](top_k) |
| | else: |
| | raise ValueError( |
| | f"generation_type has to be one of " |
| | f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}." |
| | ) |
| |
|
| | image_latent, image_embs = self._encode_image(image) |
| |
|
| | if text is None: |
| | text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id |
| |
|
| | was_training = self.training |
| | num_dims = len(text.shape) |
| |
|
| | if num_dims == 1: |
| | text = text[None, :] |
| |
|
| | self.eval() |
| | out = text |
| |
|
| | while True: |
| | x = out[:, -max_seq_len:] |
| | cur_len = x.shape[1] |
| | logits = self( |
| | image, |
| | x, |
| | image_latent=image_latent, |
| | image_embs=image_embs, |
| | output_labels=False, |
| | )["logits"][:, -1] |
| | mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id) |
| | sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id |
| |
|
| | if mask.all(): |
| | if not fixed_output_length: |
| | break |
| | else: |
| | logits = logits[~mask, :] |
| | filtered_logits = logit_processor(x[~mask, :], logits) |
| | filtered_logits = logit_warper(x[~mask, :], filtered_logits) |
| | probs = F.softmax(filtered_logits / temperature, dim=-1) |
| |
|
| | if (cur_len + 1 == seq_len): |
| | sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id |
| | else: |
| | sample[~mask, :] = torch.multinomial(probs, 1) |
| |
|
| | out = torch.cat((out, sample), dim=-1) |
| |
|
| | cur_len += 1 |
| |
|
| | if all(stopping_criteria(out, None)): |
| | break |
| |
|
| | if num_dims == 1: |
| | out = out.squeeze(0) |
| |
|
| | self.train(was_training) |
| | return out |
| |
|
| | def _generate_beamsearch( |
| | self, |
| | image_inputs, |
| | pad_token_id=None, |
| | eos_token_id=None, |
| | sot_token_id=None, |
| | num_beams=6, |
| | num_beam_groups=3, |
| | min_seq_len=5, |
| | stopping_criteria=None, |
| | logit_processor=None, |
| | logit_warper=None, |
| | ): |
| | device = image_inputs.device |
| | batch_size = image_inputs.shape[0] |
| | image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0) |
| | image_latent, image_embs = self._encode_image(image_inputs) |
| |
|
| | input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long) |
| | input_ids = input_ids * sot_token_id |
| | beam_scorer = BeamSearchScorer( |
| | batch_size=batch_size, |
| | num_beams=num_beams, |
| | device=device, |
| | num_beam_groups=num_beam_groups, |
| | ) |
| | |
| | logits_processor = ( |
| | LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)]) |
| | if logit_processor is None |
| | else logit_processor |
| | ) |
| |
|
| | num_beams = beam_scorer.num_beams |
| | num_beam_groups = beam_scorer.num_beam_groups |
| | num_sub_beams = num_beams // num_beam_groups |
| | batch_size = len(beam_scorer._beam_hyps) // num_beam_groups |
| | batch_beam_size, cur_len = input_ids.shape |
| | beam_indices = None |
| |
|
| | if num_beams * batch_size != batch_beam_size: |
| | raise ValueError( |
| | f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." |
| | ) |
| |
|
| | beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) |
| | |
| | |
| | beam_scores[:, ::num_sub_beams] = 0 |
| | beam_scores = beam_scores.view((batch_size * num_beams,)) |
| |
|
| | while True: |
| |
|
| | |
| | current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) |
| |
|
| | |
| | reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) |
| |
|
| | |
| | model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs) |
| | outputs = self( |
| | model_inputs['images'], |
| | model_inputs['text'], |
| | image_latent=image_latent, |
| | image_embs=image_embs, |
| | output_labels=False, |
| | ) |
| |
|
| | for beam_group_idx in range(num_beam_groups): |
| | group_start_idx = beam_group_idx * num_sub_beams |
| | group_end_idx = min(group_start_idx + num_sub_beams, num_beams) |
| | group_size = group_end_idx - group_start_idx |
| |
|
| | |
| | batch_group_indices = [] |
| |
|
| | for batch_idx in range(batch_size): |
| | batch_group_indices.extend( |
| | [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] |
| | ) |
| | group_input_ids = input_ids[batch_group_indices] |
| |
|
| | |
| | next_token_logits = outputs['logits'][batch_group_indices, -1, :] |
| | vocab_size = next_token_logits.shape[-1] |
| |
|
| | next_token_scores_processed = logits_processor( |
| | group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx |
| | ) |
| | next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) |
| | next_token_scores = next_token_scores.expand_as(next_token_scores_processed) |
| |
|
| | |
| | next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) |
| |
|
| | next_token_scores, next_tokens = torch.topk( |
| | next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True |
| | ) |
| |
|
| | next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") |
| | next_tokens = next_tokens % vocab_size |
| |
|
| | |
| | process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None |
| | beam_outputs = beam_scorer.process( |
| | group_input_ids, |
| | next_token_scores, |
| | next_tokens, |
| | next_indices, |
| | pad_token_id=pad_token_id, |
| | eos_token_id=eos_token_id, |
| | beam_indices=process_beam_indices, |
| | group_index=beam_group_idx, |
| | ) |
| | beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] |
| | beam_next_tokens = beam_outputs["next_beam_tokens"] |
| | beam_idx = beam_outputs["next_beam_indices"] |
| |
|
| | input_ids[batch_group_indices] = group_input_ids[beam_idx] |
| | group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) |
| | current_tokens[batch_group_indices] = group_input_ids[:, -1] |
| |
|
| | |
| | |
| | reordering_indices[batch_group_indices] = ( |
| | num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size) |
| | ) |
| |
|
| | input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) |
| |
|
| | |
| | cur_len = cur_len + 1 |
| | if beam_scorer.is_done or all(stopping_criteria(input_ids, None)): |
| | break |
| |
|
| | final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None |
| | sequence_outputs = beam_scorer.finalize( |
| | input_ids, |
| | beam_scores, |
| | next_tokens, |
| | next_indices, |
| | pad_token_id=pad_token_id, |
| | eos_token_id=eos_token_id, |
| | max_length=stopping_criteria.max_length, |
| | beam_indices=final_beam_indices, |
| | ) |
| | return sequence_outputs['sequences'] |
| |
|
| |
|
| | def create_model( |
| | model_name: str, |
| | pretrained: Optional[str] = None, |
| | precision: str = 'fp32', |
| | device: Union[str, torch.device] = 'cpu', |
| | jit: bool = False, |
| | force_quick_gelu: bool = False, |
| | force_custom_text: bool = False, |
| | force_patch_dropout: Optional[float] = None, |
| | force_image_size: Optional[Union[int, Tuple[int, int]]] = None, |
| | force_preprocess_cfg: Optional[Dict[str, Any]] = None, |
| | pretrained_image: bool = False, |
| | pretrained_hf: bool = True, |
| | cache_dir: Optional[str] = None, |
| | output_dict: Optional[bool] = None, |
| | require_pretrained: bool = False, |
| | load_weights_only: bool = True, |
| | **model_kwargs, |
| | ): |
| | """Creates and configures a contrastive vision-language model. |
| | |
| | Args: |
| | model_name: Name of the model architecture to create. Can be a local model name |
| | or a Hugging Face model ID prefixed with 'hf-hub:'. |
| | pretrained: Tag/path for pretrained model weights. Can be: |
| | - A pretrained tag name (e.g., 'openai') |
| | - A path to local weights |
| | - None to initialize with random weights |
| | precision: Model precision/AMP configuration. Options: |
| | - 'fp32': 32-bit floating point |
| | - 'fp16'/'bf16': Mixed precision with FP32 for certain layers |
| | - 'pure_fp16'/'pure_bf16': Pure 16-bit precision |
| | device: Device to load the model on ('cpu', 'cuda', or torch.device object) |
| | jit: If True, JIT compile the model |
| | force_quick_gelu: Force use of QuickGELU activation |
| | force_custom_text: Force use of custom text encoder |
| | force_patch_dropout: Override default patch dropout value |
| | force_image_size: Override default image size for vision encoder |
| | force_preprocess_cfg: Override default preprocessing configuration |
| | pretrained_image: Load pretrained weights for timm vision models |
| | pretrained_hf: Load pretrained weights for HF text models when not loading CLIP weights |
| | cache_dir: Override default cache directory for downloaded model files |
| | output_dict: If True and model supports it, return dictionary of features |
| | require_pretrained: Raise error if pretrained weights cannot be loaded |
| | load_weights_only: Only deserialize model weights and unpickling torch checkpoints (for safety) |
| | **model_kwargs: Additional keyword arguments passed to model constructor |
| | |
| | Returns: |
| | Created and configured model instance |
| | |
| | Raises: |
| | RuntimeError: If model config is not found or required pretrained weights |
| | cannot be loaded |
| | |
| | Examples: |
| | # Create basic CLIP model |
| | model = create_model('ViT-B/32') |
| | |
| | # Create CLIP model with mixed precision on GPU |
| | model = create_model('ViT-B/32', precision='fp16', device='cuda') |
| | |
| | # Load pretrained OpenAI weights |
| | model = create_model('ViT-B/32', pretrained='openai') |
| | |
| | # Load Hugging Face model |
| | model = create_model('hf-hub:organization/model-name') |
| | """ |
| |
|
| | force_preprocess_cfg = force_preprocess_cfg or {} |
| | preprocess_cfg = asdict(PreprocessCfg()) |
| | has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX) |
| | if has_hf_hub_prefix: |
| | model_id = model_name[len(HF_HUB_PREFIX):] |
| | checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir) |
| | config = _get_hf_config(model_id, cache_dir=cache_dir) |
| | preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg']) |
| | model_cfg = config['model_cfg'] |
| | pretrained_hf = False |
| | else: |
| | model_name = model_name.replace('/', '-') |
| | checkpoint_path = None |
| | model_cfg = None |
| |
|
| | if isinstance(device, str): |
| | device = torch.device(device) |
| |
|
| | model_cfg = model_cfg or get_model_config(model_name) |
| | if model_cfg is not None: |
| | logging.info(f'Loaded {model_name} model config.') |
| | else: |
| | logging.error(f'Model config for {model_name} not found; available models {list_models()}.') |
| | raise RuntimeError(f'Model config for {model_name} not found.') |
| |
|
| | if force_quick_gelu: |
| | |
| | model_cfg["quick_gelu"] = True |
| |
|
| | if force_patch_dropout is not None: |
| | |
| | model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout |
| |
|
| | if force_image_size is not None: |
| | |
| | model_cfg["vision_cfg"]["image_size"] = force_image_size |
| |
|
| | is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {}) |
| | if pretrained_image: |
| | if is_timm_model: |
| | |
| | model_cfg['vision_cfg']['timm_model_pretrained'] = True |
| | else: |
| | assert False, 'pretrained image towers currently only supported for timm models' |
| |
|
| | |
| | cast_dtype = get_cast_dtype(precision) |
| | is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {}) |
| | if is_hf_model: |
| | |
| | model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained |
| | custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model |
| |
|
| | model_cfg = dict(model_cfg, **model_kwargs) |
| | if custom_text: |
| | if "multimodal_cfg" in model_cfg: |
| | model = CoCa(**model_cfg, cast_dtype=cast_dtype) |
| | else: |
| | model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) |
| | else: |
| | model = CLIP(**model_cfg, cast_dtype=cast_dtype) |
| |
|
| | if precision in ("fp16", "bf16"): |
| | dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 |
| | |
| | if is_timm_model: |
| | |
| | |
| | |
| | model.to(device=device, dtype=dtype) |
| | |
| |
|
| | def _convert_ln(m): |
| | if isinstance(m, LayerNormFp32): |
| | m.weight.data = m.weight.data.to(torch.float32) |
| | m.bias.data = m.bias.data.to(torch.float32) |
| | model.apply(_convert_ln) |
| | else: |
| | model.to(device=device) |
| | convert_weights_to_lp(model, dtype=dtype) |
| | elif precision in ("pure_fp16", "pure_bf16"): |
| | dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 |
| | model.to(device=device, dtype=dtype) |
| | else: |
| | model.to(device=device) |
| |
|
| | pretrained_loaded = False |
| | if pretrained: |
| | checkpoint_path = '' |
| | pretrained_cfg = get_pretrained_cfg(model_name, pretrained) |
| | if pretrained_cfg: |
| | checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir) |
| | preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg) |
| | pretrained_quick_gelu = pretrained_cfg.get('quick_gelu', False) |
| | model_quick_gelu = model_cfg.get('quick_gelu', False) |
| | if pretrained_quick_gelu and not model_quick_gelu: |
| | warnings.warn( |
| | f'These pretrained weights were trained with QuickGELU activation but the model config does ' |
| | f'not have that enabled. Consider using a model config with a "-quickgelu" suffix or enable with a flag.') |
| | elif not pretrained_quick_gelu and model_quick_gelu: |
| | warnings.warn( |
| | f'The pretrained weights were not trained with QuickGELU but this activation is enabled in the ' |
| | f'model config, consider using a model config without QuickGELU or disable override flags.') |
| | elif os.path.exists(pretrained): |
| | checkpoint_path = pretrained |
| |
|
| | if checkpoint_path: |
| | logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') |
| | load_checkpoint(model, checkpoint_path, weights_only=load_weights_only) |
| | else: |
| | error_str = ( |
| | f'Pretrained weights ({pretrained}) not found for model {model_name}.' |
| | f' Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.') |
| | logging.warning(error_str) |
| | raise RuntimeError(error_str) |
| | pretrained_loaded = True |
| | elif has_hf_hub_prefix: |
| | logging.info(f'Loading pretrained {model_name} weights ({checkpoint_path}).') |
| | load_checkpoint(model, checkpoint_path, weights_only=load_weights_only) |
| | pretrained_loaded = True |
| |
|
| | if require_pretrained and not pretrained_loaded: |
| | |
| | raise RuntimeError( |
| | f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.') |
| |
|
| | if output_dict and hasattr(model, "output_dict"): |
| | model.output_dict = True |
| |
|
| | if jit: |
| | model = torch.jit.script(model) |
| |
|
| | |
| | if getattr(model.visual, 'image_size', None) is not None: |
| | |
| | force_preprocess_cfg['size'] = model.visual.image_size |
| | set_model_preprocess_cfg(model, merge_preprocess_dict(preprocess_cfg, force_preprocess_cfg)) |
| |
|
| | return model |
| |
|
| | def create_model_and_transforms( |
| | model_name: str, |
| | pretrained: Optional[str] = None, |
| | precision: str = 'fp32', |
| | device: Union[str, torch.device] = 'cpu', |
| | jit: bool = False, |
| | force_quick_gelu: bool = False, |
| | force_custom_text: bool = False, |
| | force_patch_dropout: Optional[float] = None, |
| | force_image_size: Optional[Union[int, Tuple[int, int]]] = None, |
| | image_mean: Optional[Tuple[float, ...]] = None, |
| | image_std: Optional[Tuple[float, ...]] = None, |
| | image_interpolation: Optional[str] = None, |
| | image_resize_mode: Optional[str] = None, |
| | aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, |
| | pretrained_image: bool = False, |
| | pretrained_hf: bool = True, |
| | cache_dir: Optional[str] = None, |
| | output_dict: Optional[bool] = None, |
| | load_weights_only: bool = True, |
| | **model_kwargs, |
| | ): |
| | force_preprocess_cfg = merge_preprocess_kwargs( |
| | {}, |
| | mean=image_mean, |
| | std=image_std, |
| | interpolation=image_interpolation, |
| | resize_mode=image_resize_mode, |
| | ) |
| |
|
| | model = create_model( |
| | model_name, |
| | pretrained, |
| | precision=precision, |
| | device=device, |
| | jit=jit, |
| | force_quick_gelu=force_quick_gelu, |
| | force_custom_text=force_custom_text, |
| | force_patch_dropout=force_patch_dropout, |
| | force_image_size=force_image_size, |
| | force_preprocess_cfg=force_preprocess_cfg, |
| | pretrained_image=pretrained_image, |
| | pretrained_hf=pretrained_hf, |
| | cache_dir=cache_dir, |
| | output_dict=output_dict, |
| | load_weights_only=load_weights_only, |
| | **model_kwargs, |
| | ) |
| |
|
| | pp_cfg = PreprocessCfg(**model.visual.preprocess_cfg) |
| |
|
| | preprocess_train = image_transform_v2( |
| | pp_cfg, |
| | is_train=True, |
| | aug_cfg=aug_cfg, |
| | ) |
| | preprocess_val = image_transform_v2( |
| | pp_cfg, |
| | is_train=False, |
| | ) |
| |
|
| | return model, preprocess_train, preprocess_val |
| |
|
| |
|
| |
|
| | open_clip_model, open_clip_imgaug, open_clip_preprocess = create_model_and_transforms( |
| | model_name='ViT-H-14', pretrained='laion2b_s32b_b79k', device=device |
| | ) |
| | print("ashish 1") |
| | |
| | |
| | |