SPECTRE-Large / spectre /utils /modeling.py
cclaess's picture
Initial commit
8b41845 verified
from __future__ import annotations
import math
from enum import Enum
from typing import List, Tuple, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
def deactivate_requires_grad_and_to_eval(model: nn.Module):
"""Deactivates the requires_grad flag for all parameters of a model.
This has the same effect as permanently executing the model within a `torch.no_grad()`
context. Use this method to disable gradient computation and therefore
training for a model.
Examples:
>>> backbone = resnet18()
>>> deactivate_requires_grad(backbone)
"""
for param in model.parameters():
param.requires_grad = False
model.eval()
def activate_requires_grad_and_to_train(model: nn.Module):
"""Activates the requires_grad flag for all parameters of a model.
Use this method to activate gradients for a model (e.g. after deactivating
them using `deactivate_requires_grad(...)`).
Examples:
>>> backbone = resnet18()
>>> activate_requires_grad(backbone)
"""
for param in model.parameters():
param.requires_grad = True
model.train()
@torch.no_grad()
def update_momentum(model: nn.Module, model_ema: nn.Module, m: float):
"""Updates parameters of `model_ema` with Exponential Moving Average of `model`
Momentum encoders are a crucial component fo models such as MoCo or BYOL.
Examples:
>>> backbone = resnet18()
>>> projection_head = MoCoProjectionHead()
>>> backbone_momentum = copy.deepcopy(moco)
>>> projection_head_momentum = copy.deepcopy(projection_head)
>>>
>>> # update momentum
>>> update_momentum(moco, moco_momentum, m=0.999)
>>> update_momentum(projection_head, projection_head_momentum, m=0.999)
"""
for model_ema, model in zip(model_ema.parameters(), model.parameters()):
model_ema.data = model_ema.data * m + model.data * (1.0 - m)
def update_drop_path_rate(
model: "VisionTransformer",
drop_path_rate: float,
mode: str = "linear",
) -> None:
"""Updates the drop path rate in a VisionTransformer model.
Args:
model:
VisionTransformer model.
drop_path_rate:
Maximum drop path rate.
mode:
Drop path rate update mode. Can be "linear" or "uniform". Linear increases
the drop path rate from 0 to drop_path_rate over the depth of the model.
Uniform sets the drop path rate to drop_path_rate for all blocks.
Raises:
ValueError: If an unknown mode is provided.
"""
from timm.layers import DropPath
total_depth = len(model.blocks)
# Determine drop path rates based on the specified mode
if mode == "linear":
drop_probabilities = np.linspace(0, drop_path_rate, total_depth)
elif mode == "uniform":
drop_probabilities = [drop_path_rate for _ in range(total_depth)]
else:
raise ValueError(
f"Unknown mode: '{mode}', supported modes are 'linear' and 'uniform'."
)
# Update the drop path rate for each block in the model
for block, drop_prob in zip(model.blocks, drop_probabilities):
if drop_prob > 0.0:
block.drop_path1 = DropPath(drop_prob=drop_prob)
block.drop_path2 = DropPath(drop_prob=drop_prob)
else:
block.drop_path1 = nn.Identity()
block.drop_path2 = nn.Identity()
def repeat_token(token: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
"""Repeats a token size times.
Args:
token: Token tensor with shape (1, 1, dim).
size: (batch_size, sequence_length) tuple.
Returns:
Tensor with shape (batch_size, sequence_length, dim) containing copies
of the input token.
"""
batch_size, sequence_length = size
return token.repeat(batch_size, sequence_length, 1)
def expand_index_like(index: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:
"""Expands the index along the last dimension of the input tokens.
Args:
index:
Index tensor with shape (batch_size, idx_length) where each entry is
an index in [0, sequence_length).
tokens:
Tokens tensor with shape (batch_size, sequence_length, dim).
Returns:
Index tensor with shape (batch_size, idx_length, dim) where the original
indices are repeated dim times along the last dimension.
"""
dim = tokens.shape[-1]
index = index.unsqueeze(-1).expand(-1, -1, dim)
return index
def get_at_index(tokens: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
"""Selects tokens at index.
Args:
tokens:
Token tensor with shape (batch_size, sequence_length, dim).
index:
Index tensor with shape (batch_size, index_length) where each entry is
an index in [0, sequence_length).
Returns:
Token tensor with shape (batch_size, index_length, dim) containing the
selected tokens.
"""
index = expand_index_like(index, tokens)
return torch.gather(tokens, 1, index)
def set_at_index(
tokens: torch.Tensor, index: torch.Tensor, value: torch.Tensor
) -> torch.Tensor:
"""Copies all values into the input tensor at the given indices.
Args:
tokens: Tokens tensor with shape (batch_size, sequence_length, dim).
index: Index tensor with shape (batch_size, index_length).
value: Value tensor with shape (batch_size, index_length, dim).
Returns:
Tokens tensor with shape (batch_size, sequence_length, dim) containing
the new values.
"""
index = expand_index_like(index, tokens)
return torch.scatter(tokens, 1, index, value)
def mask_at_index(
tokens: torch.Tensor, index: torch.Tensor, mask_token: torch.Tensor
) -> torch.Tensor:
"""Copies mask token into the input tensor at the given indices.
Args:
tokens:
Tokens tensor with shape (batch_size, sequence_length, dim).
index:
Index tensor with shape (batch_size, index_length).
mask_token:
Value tensor with shape (1, 1, dim).
Returns:
Tokens tensor with shape (batch_size, sequence_length, dim) containing
the new values.
"""
mask = tokens.new_zeros(tokens.shape)
mask = set_at_index(mask, index, 1)
return (1 - mask) * tokens + mask * mask_token
def mask_bool(tokens: torch.Tensor, mask: torch.Tensor, mask_token: torch.Tensor) -> torch. Tensor:
"""Returns a tensor with tokens replaced by the mask tokens in all positions where
the mask is True.
Args:
tokens:
Tokens tensor with shape (batch_size, sequence_length, dim).
mask:
Boolean mask tensor with shape (batch_size, sequence_length).
mask_token:
Mask token with shape (1, 1, dim).
Returns:
Tokens tensor with shape (batch_size, sequence_length, dim) where tokens[i, j]
is replaced by the mask token if mask[i, j] is True.
"""
# Convert to int for multiplication.
mask = mask.unsqueeze(-1).to(torch.bool).to(torch.int)
return (1 - mask) * tokens + mask * mask_token
def patchify(images: torch.Tensor, patch_size: Tuple[int, int, int]) -> torch.Tensor:
"""Converts a batch of input images into patches.
Args:
images:
Images tensor with shape (batch_size, channels, height, width, depth)
patch_size:
Patch size in pixels. Image width and height must be multiples of
the patch size.
Returns:
Patches tensor with shape (batch_size, num_patches, channels * math.prod(patch_size))
where num_patches = image_width / patch_size * image_height / patch_size.
"""
N, C, H, W, D = images.shape
assert (
H % patch_size[0] == 0
and W % patch_size[1] == 0
and D % patch_size[2] == 0
), "Image height, width, and depth must be multiples of the patch size."
patch_h = H // patch_size[0]
patch_w = W // patch_size[1]
patch_d = D // patch_size[2]
num_patches = patch_h * patch_w * patch_d
patches = images.reshape(shape=(
N, C,
patch_h, patch_size[0],
patch_w, patch_size[1],
patch_d, patch_size[2],
))
patches = torch.einsum("nchpwqdr->nhwdpqrc", patches)
patches = patches.reshape(shape=(N, num_patches, math.prod(patch_size) * C))
return patches
def random_token_mask(
size: Tuple[int, int],
mask_ratio: float = 0.6,
mask_class_token: bool = False,
device: Optional[Union[torch.device, str]] = None,
) -> torch.Tensor:
"""Creates random token masks.
Args:
size:
Size of the token batch for which to generate masks.
Should be (batch_size, sequence_length).
mask_ratio:
Percentage of tokens to mask.
mask_class_token:
If False the class token is never masked. If True the class token
might be masked.
device:
Device on which to create the index masks.
Returns:
A (index_keep, index_mask) tuple where each index is a tensor.
index_keep contains the indices of the unmasked tokens and has shape
(batch_size, num_keep). index_mask contains the indices of the masked
tokens and has shape (batch_size, sequence_length - num_keep).
num_keep is equal to sequence_length * (1- mask_ratio).
"""
batch_size, sequence_length = size
num_keep = int(sequence_length * (1 - mask_ratio))
noise = torch.rand(batch_size, sequence_length, device=device)
if not mask_class_token and sequence_length > 0:
# make sure that class token is not masked
noise[:, 0] = -1
num_keep = max(1, num_keep)
# get indices of tokens to keep
indices = torch.argsort(noise, dim=1)
idx_keep = indices[:, :num_keep]
idx_mask = indices[:, num_keep:]
return idx_keep, idx_mask
def resample_abs_pos_embed(
posemb: torch.Tensor,
new_size: List[int],
old_size: List[int],
num_prefix_tokens: int = 1,
interpolation: str = 'trilinear',
):
# sort out sizes, assume square if old size not provided
num_pos_tokens = posemb.shape[1]
num_new_tokens = new_size[0] * new_size[1] * new_size[2] + num_prefix_tokens
if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]:
return posemb
if num_prefix_tokens:
posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:]
else:
posemb_prefix, posemb = None, posemb
# do the interpolation
embed_dim = posemb.shape[-1]
orig_dtype = posemb.dtype
posemb = posemb.float() # interpolate needs float32
posemb = posemb.reshape(1, old_size[0], old_size[1], old_size[2], -1).permute(0, 4, 1, 2, 3)
posemb = F.interpolate(posemb, size=new_size, mode=interpolation)
posemb = posemb.permute(0, 2, 3, 4, 1).reshape(1, -1, embed_dim)
posemb = posemb.to(orig_dtype)
# add back extra (class, etc) prefix tokens
if posemb_prefix is not None:
posemb = torch.cat([posemb_prefix, posemb], dim=1)
return posemb
def resample_abs_pos_embed_nhwdc(
posemb: torch.Tensor,
new_size: List[int],
interpolation: str = 'trilinear',
):
if new_size[0] == posemb.shape[-4] and new_size[1] == posemb.shape[-3] and new_size[2] == posemb.shape[-2]:
return posemb
orig_dtype = posemb.dtype
posemb = posemb.float()
posemb = posemb.reshape(1, posemb.shape[-4], posemb.shape[-3], posemb.shape[-2], posemb.shape[-1]).permute(0, 4, 1, 2, 3)
posemb = F.interpolate(posemb, size=new_size, mode=interpolation)
posemb = posemb.permute(0, 2, 3, 4, 1).to(orig_dtype)
return posemb
def resample_patch_embed(
patch_embed,
new_size: List[int],
interpolation: str = 'trilinear',
):
"""Resample the weights of the patch embedding kernel to target resolution.
We resample the patch embedding kernel by approximately inverting the effect
of patch resizing.
Code based on:
https://github.com/google-research/big_vision/blob/b00544b81f8694488d5f36295aeb7972f3755ffe/big_vision/models/proj/flexi/vit.py
With this resizing, we can for example load a B/8 filter into a B/16 model
and, on 2x larger input image, the result will match.
Args:
patch_embed: original parameter to be resized.
new_size (tuple[int, int, int]): target shape (depth, height, width).
interpolation (str): interpolation for resize
Returns:
Resized patch embedding kernel.
"""
import numpy as np
try:
from torch import vmap
except ImportError:
from functorch import vmap
assert len(patch_embed.shape) == 5, "Five dimensions expected"
assert len(new_size) == 3, "New shape should only be (height, width, depth)"
old_size = patch_embed.shape[-3:]
if tuple(old_size) == tuple(new_size):
return patch_embed
def resize(x_np, _new_size):
x_tf = torch.Tensor(x_np)[None, None, ...]
x_upsampled = F.interpolate(
x_tf, size=_new_size, mode=interpolation)[0, 0, ...].numpy()
return x_upsampled
def get_resize_mat(_old_size, _new_size):
mat = []
for i in range(np.prod(_old_size)):
basis_vec = np.zeros(_old_size)
basis_vec[np.unravel_index(i, _old_size)] = 1.
mat.append(resize(basis_vec, _new_size).reshape(-1))
return np.stack(mat).T
resize_mat = get_resize_mat(old_size, new_size)
resize_mat_pinv = torch.tensor(np.linalg.pinv(resize_mat.T), device=patch_embed.device)
def resample_kernel(kernel):
resampled_kernel = resize_mat_pinv @ kernel.reshape(-1)
return resampled_kernel.reshape(new_size)
v_resample_kernel = vmap(vmap(resample_kernel, 0, 0), 1, 1)
orig_dtype = patch_embed.dtype
patch_embed = patch_embed.float()
patch_embed = v_resample_kernel(patch_embed)
patch_embed = patch_embed.to(orig_dtype)
return patch_embed
def feature_take_indices(
num_features: int,
indices: Optional[Union[int, List[int]]] = None,
as_set: bool = False,
) -> Tuple[List[int], int]:
""" Determine the absolute feature indices to 'take' from.
Note: This function can be called in forward() so must be torchscript compatible,
which requires some incomplete typing and workaround hacks.
Args:
num_features: total number of features to select from
indices: indices to select,
None -> select all
int -> select last n
list/tuple of int -> return specified (-ve indices specify from end)
as_set: return as a set
Returns:
List (or set) of absolute (from beginning) indices, Maximum index
"""
if indices is None:
indices = num_features # all features if None
if isinstance(indices, int):
# convert int -> last n indices
assert 0 < indices <= num_features, f'last-n ({indices}) is out of range (1 to {num_features})'
take_indices = [num_features - indices + i for i in range(indices)]
else:
take_indices: List[int] = []
for i in indices:
idx = num_features + i if i < 0 else i
assert 0 <= idx < num_features, f'feature index {idx} is out of range (0 to {num_features - 1})'
take_indices.append(idx)
if not torch.jit.is_scripting() and as_set:
return set(take_indices), max(take_indices)
return take_indices, max(take_indices)
def global_pool_nlc(
x: torch.Tensor,
pool_type: str = 'token',
num_prefix_tokens: int = 1,
reduce_include_prefix: bool = False,
):
if not pool_type:
return x
if pool_type == 'token':
x = x[:, 0] # class token
else:
x = x if reduce_include_prefix else x[:, num_prefix_tokens:]
if pool_type == 'avg':
x = x.mean(dim=1)
elif pool_type == 'avgmax':
x = 0.5 * (x.amax(dim=1) + x.mean(dim=1))
elif pool_type == 'max':
x = x.amax(dim=1)
else:
assert not pool_type, f'Unknown pool type {pool_type}'
return x
def cat_keep_shapes(
x_list: List[torch.Tensor]
) -> Tuple[torch.Tensor, List[Tuple[int, ...]], List[int]]:
if not x_list:
return torch.empty(0), [], []
shapes = [x.shape for x in x_list]
num_tokens = [x.select(dim=-1, index=0).numel() for x in x_list]
x_cat = torch.cat([x.flatten(0, -2) for x in x_list], dim=0)
return x_cat, shapes, num_tokens
def uncat_with_shapes(
x_cat: torch.Tensor,
shapes: List[Tuple[int, ...]],
num_tokens: List[int]
) -> List[torch.Tensor]:
if not shapes:
return []
x_splitted = torch.split_with_sizes(x_cat, num_tokens, dim=0)
shapes_adjusted = [shape[:-1] + torch.Size([x_cat.shape[-1]]) for shape in shapes]
outputs_reshape = [x.reshape(shape) for x, shape in zip(x_splitted, shapes_adjusted)]
return outputs_reshape
def last_token_pool(
last_hidden_states: torch.Tensor,
attention_mask: torch.Tensor
) -> torch.Tensor:
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
return last_hidden_states[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_states.shape[0]
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device),
sequence_lengths]
class Format(str, Enum):
NCHWD = 'NCHWD'
NHWDC = 'NHWDC'
NCL = 'NCL'
NLC = 'NLC'
def nchwd_to(x: torch.Tensor, fmt: Format):
if fmt == Format.NHWDC:
x = x.permute(0, 2, 3, 4, 1)
elif fmt == Format.NLC:
x = x.flatten(2).transpose(1, 2)
elif fmt == Format.NCL:
x = x.flatten(2)
return x
def nhwdc_to(x: torch.Tensor, fmt: Format):
if fmt == Format.NCHWD:
x = x.permute(0, 4, 1, 2, 3)
elif fmt == Format.NLC:
x = x.flatten(1, 2)
elif fmt == Format.NCL:
x = x.flatten(1, 2).transpose(1, 2)
return x