from __future__ import annotations import math from enum import Enum from typing import List, Tuple, Optional, Union import torch import torch.nn as nn import torch.nn.functional as F import numpy as np def deactivate_requires_grad_and_to_eval(model: nn.Module): """Deactivates the requires_grad flag for all parameters of a model. This has the same effect as permanently executing the model within a `torch.no_grad()` context. Use this method to disable gradient computation and therefore training for a model. Examples: >>> backbone = resnet18() >>> deactivate_requires_grad(backbone) """ for param in model.parameters(): param.requires_grad = False model.eval() def activate_requires_grad_and_to_train(model: nn.Module): """Activates the requires_grad flag for all parameters of a model. Use this method to activate gradients for a model (e.g. after deactivating them using `deactivate_requires_grad(...)`). Examples: >>> backbone = resnet18() >>> activate_requires_grad(backbone) """ for param in model.parameters(): param.requires_grad = True model.train() @torch.no_grad() def update_momentum(model: nn.Module, model_ema: nn.Module, m: float): """Updates parameters of `model_ema` with Exponential Moving Average of `model` Momentum encoders are a crucial component fo models such as MoCo or BYOL. Examples: >>> backbone = resnet18() >>> projection_head = MoCoProjectionHead() >>> backbone_momentum = copy.deepcopy(moco) >>> projection_head_momentum = copy.deepcopy(projection_head) >>> >>> # update momentum >>> update_momentum(moco, moco_momentum, m=0.999) >>> update_momentum(projection_head, projection_head_momentum, m=0.999) """ for model_ema, model in zip(model_ema.parameters(), model.parameters()): model_ema.data = model_ema.data * m + model.data * (1.0 - m) def update_drop_path_rate( model: "VisionTransformer", drop_path_rate: float, mode: str = "linear", ) -> None: """Updates the drop path rate in a VisionTransformer model. Args: model: VisionTransformer model. drop_path_rate: Maximum drop path rate. mode: Drop path rate update mode. Can be "linear" or "uniform". Linear increases the drop path rate from 0 to drop_path_rate over the depth of the model. Uniform sets the drop path rate to drop_path_rate for all blocks. Raises: ValueError: If an unknown mode is provided. """ from timm.layers import DropPath total_depth = len(model.blocks) # Determine drop path rates based on the specified mode if mode == "linear": drop_probabilities = np.linspace(0, drop_path_rate, total_depth) elif mode == "uniform": drop_probabilities = [drop_path_rate for _ in range(total_depth)] else: raise ValueError( f"Unknown mode: '{mode}', supported modes are 'linear' and 'uniform'." ) # Update the drop path rate for each block in the model for block, drop_prob in zip(model.blocks, drop_probabilities): if drop_prob > 0.0: block.drop_path1 = DropPath(drop_prob=drop_prob) block.drop_path2 = DropPath(drop_prob=drop_prob) else: block.drop_path1 = nn.Identity() block.drop_path2 = nn.Identity() def repeat_token(token: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor: """Repeats a token size times. Args: token: Token tensor with shape (1, 1, dim). size: (batch_size, sequence_length) tuple. Returns: Tensor with shape (batch_size, sequence_length, dim) containing copies of the input token. """ batch_size, sequence_length = size return token.repeat(batch_size, sequence_length, 1) def expand_index_like(index: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor: """Expands the index along the last dimension of the input tokens. Args: index: Index tensor with shape (batch_size, idx_length) where each entry is an index in [0, sequence_length). tokens: Tokens tensor with shape (batch_size, sequence_length, dim). Returns: Index tensor with shape (batch_size, idx_length, dim) where the original indices are repeated dim times along the last dimension. """ dim = tokens.shape[-1] index = index.unsqueeze(-1).expand(-1, -1, dim) return index def get_at_index(tokens: torch.Tensor, index: torch.Tensor) -> torch.Tensor: """Selects tokens at index. Args: tokens: Token tensor with shape (batch_size, sequence_length, dim). index: Index tensor with shape (batch_size, index_length) where each entry is an index in [0, sequence_length). Returns: Token tensor with shape (batch_size, index_length, dim) containing the selected tokens. """ index = expand_index_like(index, tokens) return torch.gather(tokens, 1, index) def set_at_index( tokens: torch.Tensor, index: torch.Tensor, value: torch.Tensor ) -> torch.Tensor: """Copies all values into the input tensor at the given indices. Args: tokens: Tokens tensor with shape (batch_size, sequence_length, dim). index: Index tensor with shape (batch_size, index_length). value: Value tensor with shape (batch_size, index_length, dim). Returns: Tokens tensor with shape (batch_size, sequence_length, dim) containing the new values. """ index = expand_index_like(index, tokens) return torch.scatter(tokens, 1, index, value) def mask_at_index( tokens: torch.Tensor, index: torch.Tensor, mask_token: torch.Tensor ) -> torch.Tensor: """Copies mask token into the input tensor at the given indices. Args: tokens: Tokens tensor with shape (batch_size, sequence_length, dim). index: Index tensor with shape (batch_size, index_length). mask_token: Value tensor with shape (1, 1, dim). Returns: Tokens tensor with shape (batch_size, sequence_length, dim) containing the new values. """ mask = tokens.new_zeros(tokens.shape) mask = set_at_index(mask, index, 1) return (1 - mask) * tokens + mask * mask_token def mask_bool(tokens: torch.Tensor, mask: torch.Tensor, mask_token: torch.Tensor) -> torch. Tensor: """Returns a tensor with tokens replaced by the mask tokens in all positions where the mask is True. Args: tokens: Tokens tensor with shape (batch_size, sequence_length, dim). mask: Boolean mask tensor with shape (batch_size, sequence_length). mask_token: Mask token with shape (1, 1, dim). Returns: Tokens tensor with shape (batch_size, sequence_length, dim) where tokens[i, j] is replaced by the mask token if mask[i, j] is True. """ # Convert to int for multiplication. mask = mask.unsqueeze(-1).to(torch.bool).to(torch.int) return (1 - mask) * tokens + mask * mask_token def patchify(images: torch.Tensor, patch_size: Tuple[int, int, int]) -> torch.Tensor: """Converts a batch of input images into patches. Args: images: Images tensor with shape (batch_size, channels, height, width, depth) patch_size: Patch size in pixels. Image width and height must be multiples of the patch size. Returns: Patches tensor with shape (batch_size, num_patches, channels * math.prod(patch_size)) where num_patches = image_width / patch_size * image_height / patch_size. """ N, C, H, W, D = images.shape assert ( H % patch_size[0] == 0 and W % patch_size[1] == 0 and D % patch_size[2] == 0 ), "Image height, width, and depth must be multiples of the patch size." patch_h = H // patch_size[0] patch_w = W // patch_size[1] patch_d = D // patch_size[2] num_patches = patch_h * patch_w * patch_d patches = images.reshape(shape=( N, C, patch_h, patch_size[0], patch_w, patch_size[1], patch_d, patch_size[2], )) patches = torch.einsum("nchpwqdr->nhwdpqrc", patches) patches = patches.reshape(shape=(N, num_patches, math.prod(patch_size) * C)) return patches def random_token_mask( size: Tuple[int, int], mask_ratio: float = 0.6, mask_class_token: bool = False, device: Optional[Union[torch.device, str]] = None, ) -> torch.Tensor: """Creates random token masks. Args: size: Size of the token batch for which to generate masks. Should be (batch_size, sequence_length). mask_ratio: Percentage of tokens to mask. mask_class_token: If False the class token is never masked. If True the class token might be masked. device: Device on which to create the index masks. Returns: A (index_keep, index_mask) tuple where each index is a tensor. index_keep contains the indices of the unmasked tokens and has shape (batch_size, num_keep). index_mask contains the indices of the masked tokens and has shape (batch_size, sequence_length - num_keep). num_keep is equal to sequence_length * (1- mask_ratio). """ batch_size, sequence_length = size num_keep = int(sequence_length * (1 - mask_ratio)) noise = torch.rand(batch_size, sequence_length, device=device) if not mask_class_token and sequence_length > 0: # make sure that class token is not masked noise[:, 0] = -1 num_keep = max(1, num_keep) # get indices of tokens to keep indices = torch.argsort(noise, dim=1) idx_keep = indices[:, :num_keep] idx_mask = indices[:, num_keep:] return idx_keep, idx_mask def resample_abs_pos_embed( posemb: torch.Tensor, new_size: List[int], old_size: List[int], num_prefix_tokens: int = 1, interpolation: str = 'trilinear', ): # sort out sizes, assume square if old size not provided num_pos_tokens = posemb.shape[1] num_new_tokens = new_size[0] * new_size[1] * new_size[2] + num_prefix_tokens if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]: return posemb if num_prefix_tokens: posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:] else: posemb_prefix, posemb = None, posemb # do the interpolation embed_dim = posemb.shape[-1] orig_dtype = posemb.dtype posemb = posemb.float() # interpolate needs float32 posemb = posemb.reshape(1, old_size[0], old_size[1], old_size[2], -1).permute(0, 4, 1, 2, 3) posemb = F.interpolate(posemb, size=new_size, mode=interpolation) posemb = posemb.permute(0, 2, 3, 4, 1).reshape(1, -1, embed_dim) posemb = posemb.to(orig_dtype) # add back extra (class, etc) prefix tokens if posemb_prefix is not None: posemb = torch.cat([posemb_prefix, posemb], dim=1) return posemb def resample_abs_pos_embed_nhwdc( posemb: torch.Tensor, new_size: List[int], interpolation: str = 'trilinear', ): if new_size[0] == posemb.shape[-4] and new_size[1] == posemb.shape[-3] and new_size[2] == posemb.shape[-2]: return posemb orig_dtype = posemb.dtype posemb = posemb.float() posemb = posemb.reshape(1, posemb.shape[-4], posemb.shape[-3], posemb.shape[-2], posemb.shape[-1]).permute(0, 4, 1, 2, 3) posemb = F.interpolate(posemb, size=new_size, mode=interpolation) posemb = posemb.permute(0, 2, 3, 4, 1).to(orig_dtype) return posemb def resample_patch_embed( patch_embed, new_size: List[int], interpolation: str = 'trilinear', ): """Resample the weights of the patch embedding kernel to target resolution. We resample the patch embedding kernel by approximately inverting the effect of patch resizing. Code based on: https://github.com/google-research/big_vision/blob/b00544b81f8694488d5f36295aeb7972f3755ffe/big_vision/models/proj/flexi/vit.py With this resizing, we can for example load a B/8 filter into a B/16 model and, on 2x larger input image, the result will match. Args: patch_embed: original parameter to be resized. new_size (tuple[int, int, int]): target shape (depth, height, width). interpolation (str): interpolation for resize Returns: Resized patch embedding kernel. """ import numpy as np try: from torch import vmap except ImportError: from functorch import vmap assert len(patch_embed.shape) == 5, "Five dimensions expected" assert len(new_size) == 3, "New shape should only be (height, width, depth)" old_size = patch_embed.shape[-3:] if tuple(old_size) == tuple(new_size): return patch_embed def resize(x_np, _new_size): x_tf = torch.Tensor(x_np)[None, None, ...] x_upsampled = F.interpolate( x_tf, size=_new_size, mode=interpolation)[0, 0, ...].numpy() return x_upsampled def get_resize_mat(_old_size, _new_size): mat = [] for i in range(np.prod(_old_size)): basis_vec = np.zeros(_old_size) basis_vec[np.unravel_index(i, _old_size)] = 1. mat.append(resize(basis_vec, _new_size).reshape(-1)) return np.stack(mat).T resize_mat = get_resize_mat(old_size, new_size) resize_mat_pinv = torch.tensor(np.linalg.pinv(resize_mat.T), device=patch_embed.device) def resample_kernel(kernel): resampled_kernel = resize_mat_pinv @ kernel.reshape(-1) return resampled_kernel.reshape(new_size) v_resample_kernel = vmap(vmap(resample_kernel, 0, 0), 1, 1) orig_dtype = patch_embed.dtype patch_embed = patch_embed.float() patch_embed = v_resample_kernel(patch_embed) patch_embed = patch_embed.to(orig_dtype) return patch_embed def feature_take_indices( num_features: int, indices: Optional[Union[int, List[int]]] = None, as_set: bool = False, ) -> Tuple[List[int], int]: """ Determine the absolute feature indices to 'take' from. Note: This function can be called in forward() so must be torchscript compatible, which requires some incomplete typing and workaround hacks. Args: num_features: total number of features to select from indices: indices to select, None -> select all int -> select last n list/tuple of int -> return specified (-ve indices specify from end) as_set: return as a set Returns: List (or set) of absolute (from beginning) indices, Maximum index """ if indices is None: indices = num_features # all features if None if isinstance(indices, int): # convert int -> last n indices assert 0 < indices <= num_features, f'last-n ({indices}) is out of range (1 to {num_features})' take_indices = [num_features - indices + i for i in range(indices)] else: take_indices: List[int] = [] for i in indices: idx = num_features + i if i < 0 else i assert 0 <= idx < num_features, f'feature index {idx} is out of range (0 to {num_features - 1})' take_indices.append(idx) if not torch.jit.is_scripting() and as_set: return set(take_indices), max(take_indices) return take_indices, max(take_indices) def global_pool_nlc( x: torch.Tensor, pool_type: str = 'token', num_prefix_tokens: int = 1, reduce_include_prefix: bool = False, ): if not pool_type: return x if pool_type == 'token': x = x[:, 0] # class token else: x = x if reduce_include_prefix else x[:, num_prefix_tokens:] if pool_type == 'avg': x = x.mean(dim=1) elif pool_type == 'avgmax': x = 0.5 * (x.amax(dim=1) + x.mean(dim=1)) elif pool_type == 'max': x = x.amax(dim=1) else: assert not pool_type, f'Unknown pool type {pool_type}' return x def cat_keep_shapes( x_list: List[torch.Tensor] ) -> Tuple[torch.Tensor, List[Tuple[int, ...]], List[int]]: if not x_list: return torch.empty(0), [], [] shapes = [x.shape for x in x_list] num_tokens = [x.select(dim=-1, index=0).numel() for x in x_list] x_cat = torch.cat([x.flatten(0, -2) for x in x_list], dim=0) return x_cat, shapes, num_tokens def uncat_with_shapes( x_cat: torch.Tensor, shapes: List[Tuple[int, ...]], num_tokens: List[int] ) -> List[torch.Tensor]: if not shapes: return [] x_splitted = torch.split_with_sizes(x_cat, num_tokens, dim=0) shapes_adjusted = [shape[:-1] + torch.Size([x_cat.shape[-1]]) for shape in shapes] outputs_reshape = [x.reshape(shape) for x, shape in zip(x_splitted, shapes_adjusted)] return outputs_reshape def last_token_pool( last_hidden_states: torch.Tensor, attention_mask: torch.Tensor ) -> torch.Tensor: left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) if left_padding: return last_hidden_states[:, -1] else: sequence_lengths = attention_mask.sum(dim=1) - 1 batch_size = last_hidden_states.shape[0] return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] class Format(str, Enum): NCHWD = 'NCHWD' NHWDC = 'NHWDC' NCL = 'NCL' NLC = 'NLC' def nchwd_to(x: torch.Tensor, fmt: Format): if fmt == Format.NHWDC: x = x.permute(0, 2, 3, 4, 1) elif fmt == Format.NLC: x = x.flatten(2).transpose(1, 2) elif fmt == Format.NCL: x = x.flatten(2) return x def nhwdc_to(x: torch.Tensor, fmt: Format): if fmt == Format.NCHWD: x = x.permute(0, 4, 1, 2, 3) elif fmt == Format.NLC: x = x.flatten(1, 2) elif fmt == Format.NCL: x = x.flatten(1, 2).transpose(1, 2) return x