|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Image decoder.""" |
|
|
|
|
|
try: |
|
|
from flash_attn import flash_attn_func |
|
|
except ImportError: |
|
|
flash_attn_func = None |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
|
|
|
class TransposedLayerNorm(nn.LayerNorm): |
|
|
"""LayerNorm with pre-transposed spatial axes.""" |
|
|
|
|
|
def forward(self, input): |
|
|
return super().forward(input.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) |
|
|
|
|
|
|
|
|
class MLP(nn.Module): |
|
|
"""Two layers MLP.""" |
|
|
|
|
|
def __init__(self, dim, mlp_dim, activation_type="ReLU"): |
|
|
super(MLP, self).__init__() |
|
|
self.fc1 = nn.Linear(dim, mlp_dim) |
|
|
self.fc2 = nn.Linear(mlp_dim, dim) |
|
|
self.activation = getattr(nn, activation_type)() |
|
|
self.activation.inplace = True |
|
|
|
|
|
def forward(self, x): |
|
|
return self.fc2(self.activation(self.fc1(x))) |
|
|
|
|
|
|
|
|
class Attention(nn.Module): |
|
|
"""Multi-head attention.""" |
|
|
|
|
|
def __init__(self, dim=256, num_heads=8, attn_ratio=1): |
|
|
super(Attention, self).__init__() |
|
|
self.num_heads = num_heads or dim // 64 |
|
|
self.head_dim = int(dim * attn_ratio) // self.num_heads |
|
|
self.q_proj = nn.Linear(dim, self.num_heads * self.head_dim) |
|
|
self.k_proj = nn.Linear(dim, self.num_heads * self.head_dim) |
|
|
self.v_proj = nn.Linear(dim, self.num_heads * self.head_dim) |
|
|
self.proj = nn.Linear(self.num_heads * self.head_dim, dim) |
|
|
|
|
|
def forward(self, q, k, v): |
|
|
q = self.q_proj(q).view(-1, q.size(1), self.num_heads, self.head_dim) |
|
|
k = self.k_proj(k).view(-1, k.size(1), self.num_heads, self.head_dim) |
|
|
v = self.v_proj(v).view(-1, v.size(1), self.num_heads, self.head_dim) |
|
|
if flash_attn_func is None or q.device.type != "cuda": |
|
|
q, k, v = (_.transpose(1, 2) for _ in (q, k, v)) |
|
|
o = nn.functional.scaled_dot_product_attention(q, k, v).transpose(1, 2) |
|
|
else: |
|
|
o = flash_attn_func(q, k, v) |
|
|
return self.proj(o.flatten(2)) |
|
|
|
|
|
|
|
|
class Block(nn.Module): |
|
|
"""Transformer block.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim=256, |
|
|
num_heads=8, |
|
|
attn_ratio=0.5, |
|
|
mlp_dim=2048, |
|
|
activation_type="ReLU", |
|
|
skip_first_query_pos=False, |
|
|
): |
|
|
super(Block, self).__init__() |
|
|
self.self_attn = Attention(dim, num_heads) |
|
|
self.norm1 = nn.LayerNorm(dim) |
|
|
self.cross_attn_token_to_image = Attention(dim, num_heads, attn_ratio) |
|
|
self.norm2 = nn.LayerNorm(dim) |
|
|
self.mlp = MLP(dim, mlp_dim, activation_type) |
|
|
self.norm3 = nn.LayerNorm(dim) |
|
|
self.cross_attn_image_to_token = Attention(dim, num_heads, attn_ratio) |
|
|
self.norm4 = nn.LayerNorm(dim) |
|
|
self.dropout = nn.Dropout(0.1, inplace=True) |
|
|
self.skip_first_query_pos = skip_first_query_pos |
|
|
|
|
|
def forward(self, query, key, query_pos, key_pos): |
|
|
if self.skip_first_query_pos: |
|
|
query = self.norm1(self.self_attn(query, query, query)) |
|
|
else: |
|
|
q = query + query_pos |
|
|
query = self.norm1(self.dropout(self.self_attn(q, q, query)).add_(query)) |
|
|
q, k = query + query_pos, key + key_pos |
|
|
query = self.norm2(self.dropout(self.cross_attn_token_to_image(q, k, key)).add_(query)) |
|
|
query = self.norm3(self.dropout(self.mlp(query)).add_(query)) |
|
|
key = self.norm4(self.cross_attn_image_to_token(k, query + query_pos, query).add_(key)) |
|
|
return query, key |
|
|
|
|
|
|
|
|
class Transformer(nn.Module): |
|
|
"""Two-way transformer decoder.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
embed_dim=256, |
|
|
num_heads=8, |
|
|
attn_ratio=0.5, |
|
|
mlp_dim=2048, |
|
|
activation_type="ReLU", |
|
|
depth=2, |
|
|
): |
|
|
super(Transformer, self).__init__() |
|
|
self.blocks = nn.ModuleList( |
|
|
Block( |
|
|
embed_dim, |
|
|
num_heads, |
|
|
attn_ratio=attn_ratio, |
|
|
mlp_dim=mlp_dim, |
|
|
activation_type=activation_type, |
|
|
skip_first_query_pos=i == 0, |
|
|
) |
|
|
for i in range(depth) |
|
|
) |
|
|
self.final_attn_token_to_image = Attention(embed_dim, num_heads, attn_ratio) |
|
|
self.norm = nn.LayerNorm(embed_dim) |
|
|
self.dropout = nn.Dropout(0.1, inplace=True) |
|
|
|
|
|
def forward(self, query, key, query_pos, key_pos): |
|
|
for blk in self.blocks: |
|
|
query, key = blk(query, key, query_pos, key_pos) |
|
|
q, k = query + query_pos, key + key_pos |
|
|
query = self.norm(self.dropout(self.final_attn_token_to_image(q, k, key)).add_(query)) |
|
|
return query, key |
|
|
|
|
|
|
|
|
class Predictor(nn.Module): |
|
|
"""MLP predictor.""" |
|
|
|
|
|
def __init__(self, in_dim, out_dim, mlp_dim=None, depth=3): |
|
|
super(Predictor, self).__init__() |
|
|
mlp_dims = [mlp_dim or in_dim] * (depth - 1) |
|
|
in_dims, out_dims = [in_dim] + mlp_dims, mlp_dims + [out_dim] |
|
|
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip(in_dims, out_dims)) |
|
|
|
|
|
def forward(self, x): |
|
|
for fc in self.layers[:-1]: |
|
|
x = nn.functional.relu(fc(x), inplace=True) |
|
|
return self.layers[-1](x) |
|
|
|
|
|
|
|
|
class ImageDecoder(nn.Module): |
|
|
"""Module to decode region tokens and masks.""" |
|
|
|
|
|
def __init__(self, depth, embed_dim, num_heads, num_mask_tokens=4, sem_embed_dim=1024): |
|
|
super(ImageDecoder, self).__init__() |
|
|
self.embed_dim = embed_dim |
|
|
self.num_mask_tokens = num_mask_tokens |
|
|
self.transformer = Transformer(embed_dim, num_heads, depth=depth) |
|
|
self.iou_token = nn.Embedding(1, embed_dim) |
|
|
self.sem_tokens = nn.Embedding(num_mask_tokens, embed_dim) |
|
|
self.mask_tokens = nn.Embedding(num_mask_tokens, embed_dim) |
|
|
self.output_conv = nn.Sequential( |
|
|
nn.ConvTranspose2d(embed_dim, embed_dim // 4, 2, 2), |
|
|
TransposedLayerNorm(embed_dim // 4), |
|
|
nn.GELU(), |
|
|
nn.ConvTranspose2d(embed_dim // 4, embed_dim // 8, 2, 2), |
|
|
nn.GELU(), |
|
|
) |
|
|
self.mask_pred = nn.ModuleList( |
|
|
Predictor(embed_dim, embed_dim // 8) for _ in range(num_mask_tokens) |
|
|
) |
|
|
self.iou_pred = Predictor(embed_dim, num_mask_tokens) |
|
|
self.sem_pred = Predictor(embed_dim, sem_embed_dim, sem_embed_dim) |
|
|
|
|
|
def get_outputs(self, inputs): |
|
|
img_embeds = inputs["img_embeds"] |
|
|
sparse_embeds = inputs["sparse_embeds"] |
|
|
ims_per_batch = img_embeds.size(0) |
|
|
prompts_per_batch = sparse_embeds.size(0) |
|
|
img_embed_size = img_embeds.shape[2:-1] |
|
|
|
|
|
tokens = [self.sem_tokens.weight, self.iou_token.weight, self.mask_tokens.weight] |
|
|
query = torch.cat(tokens).unsqueeze_(0).expand(prompts_per_batch, -1, -1) |
|
|
query = torch.cat((query, sparse_embeds), dim=1) |
|
|
num_tokens = query.shape[1] - sparse_embeds.shape[1] |
|
|
|
|
|
key = img_embeds.expand(-1, prompts_per_batch // ims_per_batch, -1, -1, -1) |
|
|
key = key.flatten(0, 1).flatten(1, 2) |
|
|
|
|
|
query, key = self.transformer(query, key, query, inputs["img_pos"]) |
|
|
|
|
|
key = key.transpose(1, 2).view((-1, self.embed_dim) + img_embed_size) |
|
|
mask_embeds = self.output_conv(key).flatten(2) |
|
|
|
|
|
sem_tokens = query[:, : self.num_mask_tokens] |
|
|
sam_tokens = query[:, self.num_mask_tokens : num_tokens].unbind(1) |
|
|
iou_tokens, mask_tokens = sam_tokens[0], sam_tokens[1:] |
|
|
|
|
|
mask_pred = [f(x) for f, x in zip(self.mask_pred, mask_tokens)] |
|
|
mask_pred = torch.stack(mask_pred, dim=1) @ mask_embeds |
|
|
mask_pred_size = list(4 * embed_size for embed_size in img_embed_size) |
|
|
mask_pred = mask_pred.view([-1, self.num_mask_tokens] + mask_pred_size) |
|
|
outputs = {"iou_pred": self.iou_pred(iou_tokens), "mask_pred": mask_pred} |
|
|
outputs["sem_tokens"] = sem_tokens.unsqueeze_(2) |
|
|
outputs["sem_embeds"] = self.sem_pred(outputs["sem_tokens"].flatten(2)) |
|
|
return outputs |
|
|
|
|
|
def forward(self, inputs): |
|
|
outputs = self.get_outputs(inputs) |
|
|
outputs["iou_pred"] = outputs["iou_pred"].float() |
|
|
return outputs |
|
|
|