InPeerReview's picture
Upload 161 files
226675b verified
import collections
import itertools
import einops
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
# adapted from timm (timm/models/layers/helpers.py)
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
assert len(x) == n
return x
return tuple(itertools.repeat(x, n))
return parse
# adapted from timm (timm/models/layers/helpers.py)
def to_ntuple(x, n):
return _ntuple(n=n)(x)
# from kappamodules.functional.pos_embed import interpolate_sincos
def interpolate_sincos(embed, seqlens, mode="bicubic"):
assert embed.ndim - 2 == len(seqlens)
embed = F.interpolate(
einops.rearrange(embed, "1 ... dim -> 1 dim ..."),
size=seqlens,
mode=mode,
)
embed = einops.rearrange(embed, "1 dim ... -> 1 ... dim")
return embed
# from kappamodules.vit import VitPatchEmbed
class VitPatchEmbed(nn.Module):
def __init__(self, dim, num_channels, resolution, patch_size, stride=None, init_weights="xavier_uniform"):
super().__init__()
self.resolution = resolution
self.init_weights = init_weights
self.ndim = len(resolution)
self.patch_size = to_ntuple(patch_size, n=self.ndim)
if stride is None:
self.stride = self.patch_size
else:
self.stride = to_ntuple(stride, n=self.ndim)
for i in range(self.ndim):
assert resolution[i] % self.patch_size[i] == 0, \
f"resolution[{i}] % patch_size[{i}] != 0 (resolution={resolution} patch_size={patch_size})"
self.seqlens = [resolution[i] // self.patch_size[i] for i in range(self.ndim)]
if self.patch_size == self.stride:
# use primitive type as np.prod gives np.int which is not compatible with all serialization/logging
self.num_patches = int(np.prod(self.seqlens))
else:
if self.ndim == 1:
conv_func = F.conv1d
elif self.ndim == 2:
conv_func = F.conv2d
elif self.ndim == 3:
conv_func = F.conv3d
else:
raise NotImplementedError
self.num_patches = conv_func(
input=torch.zeros(1, 1, *resolution),
weight=torch.zeros(1, 1, *self.patch_size),
stride=self.stride,
).numel()
if self.ndim == 1:
conv_ctor = nn.Conv1d
elif self.ndim == 2:
conv_ctor = nn.Conv2d
elif self.ndim == 3:
conv_ctor = nn.Conv3d
else:
raise NotImplementedError
self.proj = conv_ctor(num_channels, dim, kernel_size=self.patch_size, stride=self.stride)
self.reset_parameters()
def reset_parameters(self):
if self.init_weights == "torch":
pass
elif self.init_weights == "xavier_uniform":
# initialize as nn.Linear
w = self.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.zeros_(self.proj.bias)
else:
raise NotImplementedError
def forward(self, x):
assert all(x.size(i + 2) % self.patch_size[i] == 0 for i in range(self.ndim)), \
f"x.shape={x.shape} incompatible with patch_size={self.patch_size}"
x = self.proj(x)
_, _, H, W, L = x.shape
x = einops.rearrange(x, "b c ... -> b ... c")
return x, H, W, L
# from kappamodules.vit import VitPosEmbed2d
class VitPosEmbed2d(nn.Module):
def __init__(self, seqlens, dim: int, allow_interpolation: bool = True):
super().__init__()
self.seqlens = seqlens
self.dim = dim
self.allow_interpolation = allow_interpolation
self.embed = nn.Parameter(torch.zeros(1, *seqlens, dim))
self.reset_parameters()
@property
def _expected_x_ndim(self):
return len(self.seqlens) + 2
def reset_parameters(self):
nn.init.trunc_normal_(self.embed, std=.02)
def forward(self, x):
assert x.ndim == self._expected_x_ndim
if x.shape[1:] != self.embed.shape[1:]:
assert self.allow_interpolation
embed = interpolate_sincos(embed=self.embed, seqlens=x.shape[1:-1])
else:
embed = self.embed
return x + embed
# from kappamodules.layers import DropPath
class DropPath(nn.Sequential):
"""
Efficiently drop paths (Stochastic Depth) per sample such that dropped samples are not processed.
This is a subclass of nn.Sequential and can be used either as standalone Module or like nn.Sequential.
Examples::
>>> # use as nn.Sequential module
>>> sequential_droppath = DropPath(nn.Linear(4, 4), drop_prob=0.2)
>>> y = sequential_droppath(torch.randn(10, 4))
>>> # use as standalone module
>>> standalone_layer = nn.Linear(4, 4)
>>> standalone_droppath = DropPath(drop_prob=0.2)
>>> y = standalone_droppath(torch.randn(10, 4), standalone_layer)
"""
def __init__(self, *args, drop_prob: float = 0., scale_by_keep: bool = True, stochastic_drop_prob: bool = False):
super().__init__(*args)
assert 0. <= drop_prob < 1.
self._drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
self.stochastic_drop_prob = stochastic_drop_prob
@property
def drop_prob(self):
return self._drop_prob
@drop_prob.setter
def drop_prob(self, value):
assert 0. <= value < 1.
self._drop_prob = value
@property
def keep_prob(self):
return 1. - self.drop_prob
def forward(self, x, residual_path=None, residual_path_kwargs=None):
assert (len(self) == 0) ^ (residual_path is None)
residual_path_kwargs = residual_path_kwargs or {}
if self.drop_prob == 0. or not self.training:
if residual_path is None:
return x + super().forward(x, **residual_path_kwargs)
else:
return x + residual_path(x, **residual_path_kwargs)
# generate indices to keep (propagated through transform path)
bs = len(x)
if self.stochastic_drop_prob:
perm = torch.empty(bs, device=x.device).bernoulli_(self.keep_prob).nonzero().squeeze(1)
scale = 1 / self.keep_prob
else:
keep_count = max(int(bs * self.keep_prob), 1)
scale = bs / keep_count
perm = torch.randperm(bs, device=x.device)[:keep_count]
# propagate
if self.scale_by_keep:
alpha = scale
else:
alpha = 1.
# reduce kwargs (e.g. used for DiT block where scale/shift/gate is passed and also has to be reduced)
residual_path_kwargs = {
key: value[perm] if torch.is_tensor(value) else value
for key, value in residual_path_kwargs.items()
}
if residual_path is None:
residual = super().forward(x[perm], **residual_path_kwargs)
else:
residual = residual_path(x[perm], **residual_path_kwargs)
return torch.index_add(
x.flatten(start_dim=1),
dim=0,
index=perm,
source=residual.to(x.dtype).flatten(start_dim=1),
alpha=alpha,
).view_as(x)
def extra_repr(self):
return f'drop_prob={round(self.drop_prob, 3):0.3f}'