File size: 7,425 Bytes
226675b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
import collections
import itertools
import einops
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
# adapted from timm (timm/models/layers/helpers.py)
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
assert len(x) == n
return x
return tuple(itertools.repeat(x, n))
return parse
# adapted from timm (timm/models/layers/helpers.py)
def to_ntuple(x, n):
return _ntuple(n=n)(x)
# from kappamodules.functional.pos_embed import interpolate_sincos
def interpolate_sincos(embed, seqlens, mode="bicubic"):
assert embed.ndim - 2 == len(seqlens)
embed = F.interpolate(
einops.rearrange(embed, "1 ... dim -> 1 dim ..."),
size=seqlens,
mode=mode,
)
embed = einops.rearrange(embed, "1 dim ... -> 1 ... dim")
return embed
# from kappamodules.vit import VitPatchEmbed
class VitPatchEmbed(nn.Module):
def __init__(self, dim, num_channels, resolution, patch_size, stride=None, init_weights="xavier_uniform"):
super().__init__()
self.resolution = resolution
self.init_weights = init_weights
self.ndim = len(resolution)
self.patch_size = to_ntuple(patch_size, n=self.ndim)
if stride is None:
self.stride = self.patch_size
else:
self.stride = to_ntuple(stride, n=self.ndim)
for i in range(self.ndim):
assert resolution[i] % self.patch_size[i] == 0, \
f"resolution[{i}] % patch_size[{i}] != 0 (resolution={resolution} patch_size={patch_size})"
self.seqlens = [resolution[i] // self.patch_size[i] for i in range(self.ndim)]
if self.patch_size == self.stride:
# use primitive type as np.prod gives np.int which is not compatible with all serialization/logging
self.num_patches = int(np.prod(self.seqlens))
else:
if self.ndim == 1:
conv_func = F.conv1d
elif self.ndim == 2:
conv_func = F.conv2d
elif self.ndim == 3:
conv_func = F.conv3d
else:
raise NotImplementedError
self.num_patches = conv_func(
input=torch.zeros(1, 1, *resolution),
weight=torch.zeros(1, 1, *self.patch_size),
stride=self.stride,
).numel()
if self.ndim == 1:
conv_ctor = nn.Conv1d
elif self.ndim == 2:
conv_ctor = nn.Conv2d
elif self.ndim == 3:
conv_ctor = nn.Conv3d
else:
raise NotImplementedError
self.proj = conv_ctor(num_channels, dim, kernel_size=self.patch_size, stride=self.stride)
self.reset_parameters()
def reset_parameters(self):
if self.init_weights == "torch":
pass
elif self.init_weights == "xavier_uniform":
# initialize as nn.Linear
w = self.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.zeros_(self.proj.bias)
else:
raise NotImplementedError
def forward(self, x):
assert all(x.size(i + 2) % self.patch_size[i] == 0 for i in range(self.ndim)), \
f"x.shape={x.shape} incompatible with patch_size={self.patch_size}"
x = self.proj(x)
_, _, H, W, L = x.shape
x = einops.rearrange(x, "b c ... -> b ... c")
return x, H, W, L
# from kappamodules.vit import VitPosEmbed2d
class VitPosEmbed2d(nn.Module):
def __init__(self, seqlens, dim: int, allow_interpolation: bool = True):
super().__init__()
self.seqlens = seqlens
self.dim = dim
self.allow_interpolation = allow_interpolation
self.embed = nn.Parameter(torch.zeros(1, *seqlens, dim))
self.reset_parameters()
@property
def _expected_x_ndim(self):
return len(self.seqlens) + 2
def reset_parameters(self):
nn.init.trunc_normal_(self.embed, std=.02)
def forward(self, x):
assert x.ndim == self._expected_x_ndim
if x.shape[1:] != self.embed.shape[1:]:
assert self.allow_interpolation
embed = interpolate_sincos(embed=self.embed, seqlens=x.shape[1:-1])
else:
embed = self.embed
return x + embed
# from kappamodules.layers import DropPath
class DropPath(nn.Sequential):
"""
Efficiently drop paths (Stochastic Depth) per sample such that dropped samples are not processed.
This is a subclass of nn.Sequential and can be used either as standalone Module or like nn.Sequential.
Examples::
>>> # use as nn.Sequential module
>>> sequential_droppath = DropPath(nn.Linear(4, 4), drop_prob=0.2)
>>> y = sequential_droppath(torch.randn(10, 4))
>>> # use as standalone module
>>> standalone_layer = nn.Linear(4, 4)
>>> standalone_droppath = DropPath(drop_prob=0.2)
>>> y = standalone_droppath(torch.randn(10, 4), standalone_layer)
"""
def __init__(self, *args, drop_prob: float = 0., scale_by_keep: bool = True, stochastic_drop_prob: bool = False):
super().__init__(*args)
assert 0. <= drop_prob < 1.
self._drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
self.stochastic_drop_prob = stochastic_drop_prob
@property
def drop_prob(self):
return self._drop_prob
@drop_prob.setter
def drop_prob(self, value):
assert 0. <= value < 1.
self._drop_prob = value
@property
def keep_prob(self):
return 1. - self.drop_prob
def forward(self, x, residual_path=None, residual_path_kwargs=None):
assert (len(self) == 0) ^ (residual_path is None)
residual_path_kwargs = residual_path_kwargs or {}
if self.drop_prob == 0. or not self.training:
if residual_path is None:
return x + super().forward(x, **residual_path_kwargs)
else:
return x + residual_path(x, **residual_path_kwargs)
# generate indices to keep (propagated through transform path)
bs = len(x)
if self.stochastic_drop_prob:
perm = torch.empty(bs, device=x.device).bernoulli_(self.keep_prob).nonzero().squeeze(1)
scale = 1 / self.keep_prob
else:
keep_count = max(int(bs * self.keep_prob), 1)
scale = bs / keep_count
perm = torch.randperm(bs, device=x.device)[:keep_count]
# propagate
if self.scale_by_keep:
alpha = scale
else:
alpha = 1.
# reduce kwargs (e.g. used for DiT block where scale/shift/gate is passed and also has to be reduced)
residual_path_kwargs = {
key: value[perm] if torch.is_tensor(value) else value
for key, value in residual_path_kwargs.items()
}
if residual_path is None:
residual = super().forward(x[perm], **residual_path_kwargs)
else:
residual = residual_path(x[perm], **residual_path_kwargs)
return torch.index_add(
x.flatten(start_dim=1),
dim=0,
index=perm,
source=residual.to(x.dtype).flatten(start_dim=1),
alpha=alpha,
).view_as(x)
def extra_repr(self):
return f'drop_prob={round(self.drop_prob, 3):0.3f}'
|