|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Image encoder.""" |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
from third_parts.tokenize_anything import layers |
|
|
|
|
|
|
|
|
def space_to_depth(input, block_size): |
|
|
"""Rearrange blocks of spatial data into depth.""" |
|
|
if input.dim() == 3: |
|
|
hXw, c = input.size()[1:] |
|
|
h = w = int(hXw**0.5) |
|
|
else: |
|
|
h, w, c = input.size()[1:] |
|
|
h1, w1 = h // block_size, w // block_size |
|
|
c1 = (block_size**2) * c |
|
|
input = input.reshape((-1, h1, block_size, w1, block_size, c)) |
|
|
return input.permute(0, 1, 3, 2, 4, 5).reshape((-1, h1, w1, c1)) |
|
|
|
|
|
|
|
|
def depth_to_space(input, block_size): |
|
|
"""Rearrange blocks of depth data into spatial.""" |
|
|
h1, w1, c1 = input.size()[1:] |
|
|
h, w = h1 * block_size, w1 * block_size |
|
|
c = c1 // (block_size**2) |
|
|
input = input.reshape((-1, h1, w1, block_size, block_size, c)) |
|
|
return input.permute(0, 1, 3, 2, 4, 5).reshape((-1, h, w, c)) |
|
|
|
|
|
|
|
|
class MLP(nn.Module): |
|
|
"""Two layers MLP.""" |
|
|
|
|
|
def __init__(self, dim, mlp_ratio=4): |
|
|
super(MLP, self).__init__() |
|
|
self.fc1 = nn.Linear(dim, int(dim * mlp_ratio)) |
|
|
self.fc2 = nn.Linear(int(dim * mlp_ratio), dim) |
|
|
self.activation = nn.GELU() |
|
|
|
|
|
def forward(self, x): |
|
|
return self.fc2(self.activation(self.fc1(x))) |
|
|
|
|
|
|
|
|
class Attention(nn.Module): |
|
|
"""Multihead attention.""" |
|
|
|
|
|
def __init__(self, dim, num_heads, qkv_bias=True): |
|
|
super(Attention, self).__init__() |
|
|
self.num_heads = num_heads |
|
|
self.head_dim = dim // num_heads |
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
|
|
self.proj = nn.Linear(dim, dim) |
|
|
self.rel_pos_embed: RelPosEmbed = None |
|
|
|
|
|
def forward(self, x): |
|
|
qkv_shape = (-1, x.size(1), 3, self.num_heads, self.head_dim) |
|
|
q, k, v = self.qkv(x).view(qkv_shape).permute(2, 0, 3, 1, 4).unbind(dim=0) |
|
|
o = nn.functional.scaled_dot_product_attention(q, k, v, self.rel_pos_embed.get_bias()) |
|
|
return self.proj(o.transpose(1, 2).flatten(2)) |
|
|
|
|
|
|
|
|
class Block(nn.Module): |
|
|
"""Transformer block.""" |
|
|
|
|
|
def __init__(self, dim, num_heads, mlp_ratio=4, qkv_bias=True): |
|
|
super(Block, self).__init__() |
|
|
self.norm1 = nn.LayerNorm(dim) |
|
|
self.attn = Attention(dim, num_heads, qkv_bias=qkv_bias) |
|
|
self.norm2 = nn.LayerNorm(dim) |
|
|
self.mlp = MLP(dim, mlp_ratio=mlp_ratio) |
|
|
self.drop_path = layers.DropPath(0.1, inplace=True) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.drop_path(self.attn(self.norm1(x))).add_(x) |
|
|
return self.drop_path(self.mlp(self.norm2(x))).add_(x) |
|
|
|
|
|
|
|
|
class Bottleneck(nn.Module): |
|
|
"""The bottleneck block.""" |
|
|
|
|
|
def __init__(self, dim, expansion=2, width=None): |
|
|
super(Bottleneck, self).__init__() |
|
|
width = width or dim // expansion |
|
|
self.conv1 = nn.Conv2d(dim, width, 1, bias=False) |
|
|
self.norm1 = nn.SyncBatchNorm(width) |
|
|
self.conv2 = nn.Conv2d(width, width, 3, padding=1, bias=False) |
|
|
self.norm2 = nn.SyncBatchNorm(width) |
|
|
self.conv3 = nn.Conv2d(width, dim, 1, bias=False) |
|
|
self.norm3 = nn.SyncBatchNorm(dim) |
|
|
self.activation = nn.GELU() |
|
|
|
|
|
def forward(self, x): |
|
|
shortcut = x |
|
|
x = self.activation(self.norm1(self.conv1(x))) |
|
|
x = self.activation(self.norm2(self.conv2(x))) |
|
|
return self.norm3(self.conv3(x)).add_(shortcut) |
|
|
|
|
|
|
|
|
class PatchEmbed(nn.Module): |
|
|
"""Patch embedding layer.""" |
|
|
|
|
|
def __init__(self, dim=768, patch_size=16, bias=True): |
|
|
super(PatchEmbed, self).__init__() |
|
|
self.proj = nn.Conv2d(3, dim, patch_size, patch_size, bias=bias) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.proj(x).flatten(2).transpose(1, 2) |
|
|
|
|
|
|
|
|
class PosEmbed(nn.Module): |
|
|
"""Position embedding layer.""" |
|
|
|
|
|
def __init__(self, dim, num_patches): |
|
|
super(PosEmbed, self).__init__() |
|
|
self.dim = dim |
|
|
self.num_patches = num_patches |
|
|
self.weight = nn.Parameter(torch.zeros(num_patches, dim)) |
|
|
nn.init.normal_(self.weight, std=0.02) |
|
|
|
|
|
def forward(self, x): |
|
|
return x.add_(self.weight) |
|
|
|
|
|
|
|
|
class RelPosEmbed(nn.Module): |
|
|
"""Relative position embedding layer.""" |
|
|
|
|
|
def __init__(self, num_heads, size): |
|
|
super(RelPosEmbed, self).__init__() |
|
|
self.register_buffer("index", self.get_index(size)) |
|
|
self.weight = nn.Parameter(torch.zeros(num_heads, (2 * size - 1) ** 2)) |
|
|
|
|
|
@staticmethod |
|
|
def get_index(size): |
|
|
"""Return the relative index.""" |
|
|
grid = torch.arange(size) |
|
|
grid = torch.stack(torch.meshgrid(grid, grid, indexing="ij")).reshape((2, -1)) |
|
|
coords = grid[:, :, None] - grid[:, None, :] + (size - 1) |
|
|
coords[0] *= 2 * size - 1 |
|
|
return coords.sum(0) |
|
|
|
|
|
def get_bias(self): |
|
|
return self.weight[:, self.index] |
|
|
|
|
|
def forward(self, x): |
|
|
return x.add_(self.get_bias()) |
|
|
|
|
|
|
|
|
class SimpleFeaturePyramid(nn.Module): |
|
|
"""Module to create pyramid features.""" |
|
|
|
|
|
def __init__(self, embed_dim, out_dim, patch_size=16, min_lvl=4, max_lvl=4): |
|
|
super(SimpleFeaturePyramid, self).__init__() |
|
|
self.min_lvl, self.max_lvl = min_lvl, max_lvl |
|
|
self.input_conv = nn.ModuleList() |
|
|
self.lateral_conv = nn.ModuleList() |
|
|
self.output_conv = nn.ModuleList() |
|
|
patch_lvl = dict((2**i, i) for i in range(6))[patch_size] |
|
|
for lvl in [min(i + 2, self.max_lvl) for i in range(4)]: |
|
|
if lvl == patch_lvl or lvl < self.min_lvl: |
|
|
self.input_conv += [nn.Identity()] |
|
|
elif lvl < patch_lvl: |
|
|
stride, layers = 2 ** (patch_lvl - lvl), [] |
|
|
while stride > 1: |
|
|
layers += [nn.ConvTranspose2d(embed_dim, embed_dim, 2, 2)] |
|
|
layers += [nn.SyncBatchNorm(embed_dim), nn.GELU()] if stride > 2 else [] |
|
|
stride /= 2 |
|
|
self.input_conv.append(nn.Sequential(*layers)) |
|
|
elif lvl > patch_lvl: |
|
|
stride = 2 ** (lvl - patch_lvl) |
|
|
self.input_conv += [nn.MaxPool2d(stride, stride)] |
|
|
for _ in range(min_lvl, max_lvl + 1): |
|
|
self.lateral_conv.append( |
|
|
nn.Sequential( |
|
|
nn.Conv2d(embed_dim, out_dim, kernel_size=1, bias=False), |
|
|
nn.SyncBatchNorm(out_dim), |
|
|
) |
|
|
) |
|
|
self.output_conv.append( |
|
|
nn.Sequential( |
|
|
nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1, bias=False), |
|
|
nn.SyncBatchNorm(out_dim), |
|
|
) |
|
|
) |
|
|
|
|
|
def forward(self, inputs): |
|
|
inputs = inputs + [inputs[-1]] * (4 - len(inputs)) |
|
|
inputs = [conv(x) for conv, x in zip(self.input_conv, inputs)] |
|
|
features = inputs[self.min_lvl - 1 : self.max_lvl] |
|
|
laterals = [conv(x) for conv, x in zip(self.lateral_conv, features)] |
|
|
return [conv(x) for conv, x in zip(self.output_conv, laterals)] |
|
|
|
|
|
|
|
|
class ImageEncoderViT(nn.Module): |
|
|
"""ViT image encoder.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
depth, |
|
|
embed_dim, |
|
|
num_heads, |
|
|
mlp_ratio=4, |
|
|
patch_size=16, |
|
|
window_size=16, |
|
|
image_size=1024, |
|
|
out_dim=256, |
|
|
): |
|
|
super(ImageEncoderViT, self).__init__() |
|
|
self.embed_dim = embed_dim |
|
|
self.image_size = image_size |
|
|
self.window_size = window_size or image_size // patch_size |
|
|
self.patch_embed = PatchEmbed(embed_dim, patch_size) |
|
|
self.pos_embed = PosEmbed(embed_dim, (image_size // patch_size) ** 2) |
|
|
self.blocks = nn.ModuleList(Block(embed_dim, num_heads, mlp_ratio) for _ in range(depth)) |
|
|
for blk in self.blocks: |
|
|
blk.attn.rel_pos_embed = RelPosEmbed(num_heads, self.window_size) |
|
|
self.norm = nn.LayerNorm(embed_dim) |
|
|
self.cross_conv = nn.ModuleList(Bottleneck(embed_dim) for _ in range(4)) |
|
|
self.neck = SimpleFeaturePyramid(embed_dim, out_dim, patch_size) |
|
|
self.cross_indices = list(range(depth // 4 - 1, depth, depth // 4)) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.patch_embed(x) |
|
|
x = self.pos_embed(x) |
|
|
x = space_to_depth(x, self.window_size) |
|
|
wmsa_shape = (-1,) + x.shape[1:] |
|
|
msa_shape = (-1, self.window_size**2, self.embed_dim) |
|
|
x = x.reshape(msa_shape) |
|
|
for i, blk in enumerate(self.blocks): |
|
|
x = blk(x) |
|
|
if i in self.cross_indices or i == len(self.blocks) - 1: |
|
|
x = self.norm(x) if i == len(self.blocks) - 1 else x |
|
|
x = depth_to_space(x.reshape(wmsa_shape), self.window_size) |
|
|
x = x.permute(0, 3, 1, 2).contiguous() |
|
|
if i in self.cross_indices: |
|
|
x = self.cross_conv[self.cross_indices.index(i)](x) |
|
|
if i in self.cross_indices and i < len(self.blocks) - 1: |
|
|
x = x.permute(0, 2, 3, 1) |
|
|
x = space_to_depth(x, self.window_size).reshape(msa_shape) |
|
|
return self.neck([x]) |
|
|
|