MVANet / modeling_mvanet.py
shunk031's picture
Upload MVANetForImageSegmentation
d1be74e verified
"""PyTorch MVANet model for semantic segmentation."""
import math
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from einops import rearrange
from huggingface_hub import hf_hub_download
from timm.layers import DropPath, to_2tuple, trunc_normal_
from timm.models import load_checkpoint
from transformers import PreTrainedModel
from transformers.modeling_outputs import SemanticSegmenterOutput
from mvanet.transformers.configuration_mvanet import MVANetConfig
# ============================================================================
# Helper Functions
# ============================================================================
def get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
def make_cbr(in_dim, out_dim):
return nn.Sequential(
nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1),
nn.BatchNorm2d(out_dim),
nn.PReLU(),
)
def make_cbg(in_dim, out_dim):
return nn.Sequential(
nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1),
nn.BatchNorm2d(out_dim),
nn.GELU(),
)
def rescale_to(x, scale_factor: float = 2, interpolation="nearest"):
return F.interpolate(x, scale_factor=scale_factor, mode=interpolation)
def resize_as(x, y, interpolation="bilinear"):
return F.interpolate(x, size=y.shape[-2:], mode=interpolation)
def image2patches(x):
"""b c (hg h) (wg w) -> (hg wg b) c h w"""
b, c, h, w = x.shape
if h % 2 != 0 or w % 2 != 0:
x = F.interpolate(
x, size=(h + h % 2, w + w % 2), mode="bilinear", align_corners=False
)
x = rearrange(x, "b c (hg h) (wg w) -> (hg wg b) c h w", hg=2, wg=2)
return x
def patches2image(x):
"""(hg wg b) c h w -> b c (hg h) (wg w)"""
patches_b, c, h, w = x.shape
actual_b = patches_b // 4
x = rearrange(x, "(hg wg b) c h w -> b c (hg h) (wg w)", hg=2, wg=2, b=actual_b)
return x
# ============================================================================
# Position Embedding
# ============================================================================
class PositionEmbeddingSine(nn.Module):
def __init__(
self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
self.dim_t = torch.arange(
0,
self.num_pos_feats,
dtype=torch.float32,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
def __call__(self, b, h, w):
mask = torch.zeros([b, h, w], dtype=torch.bool, device=self.dim_t.device)
assert mask is not None
not_mask = ~mask
y_embed = not_mask.cumsum(dim=1, dtype=torch.float32)
x_embed = not_mask.cumsum(dim=2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = ((y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale).to(
mask.device
)
x_embed = ((x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale).to(
mask.device
)
dim_t = self.temperature ** (2 * (self.dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack(
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos_y = torch.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
).flatten(3)
return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
# ============================================================================
# Swin Transformer Components
# ============================================================================
class Mlp(nn.Module):
"""Multilayer perceptron."""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = (
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(
B, H // window_size, W // window_size, window_size, window_size, -1
)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
"""Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(
self,
dim,
window_size,
num_heads,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(
torch.meshgrid([coords_h, coords_w], indexing="ij")
) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = (
coords_flatten[:, :, None] - coords_flatten[:, None, :]
) # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(
1, 2, 0
).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=0.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""Forward function.
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B_, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = (
qkv[0],
qkv[1],
qkv[2],
) # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = q @ k.transpose(-2, -1)
relative_position_index = self.relative_position_index
assert isinstance(relative_position_index, torch.Tensor)
relative_position_bias = self.relative_position_bias_table[
relative_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1,
) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1
).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
1
).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class SwinTransformerBlock(nn.Module):
"""Swin Transformer Block.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(
self,
dim,
num_heads,
window_size=7,
shift_size=0,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
assert 0 <= self.shift_size < self.window_size, (
"shift_size must in 0-window_size"
)
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim,
window_size=to_2tuple(self.window_size),
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
self.H: int | None = None
self.W: int | None = None
def forward(self, x, mask_matrix):
"""Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
mask_matrix: Attention mask for cyclic shift.
"""
B, L, C = x.shape
H, W = self.H, self.W
assert H is not None and W is not None, "H and W must be set before forward"
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# pad feature maps to multiples of window size
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(
x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
)
attn_mask = mask_matrix
else:
shifted_x = x
attn_mask = None
# partition windows
x_windows = window_partition(
shifted_x, self.window_size
) # nW*B, window_size, window_size, C
x_windows = x_windows.view(
-1, self.window_size * self.window_size, C
) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(
x_windows, mask=attn_mask
) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(
shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
)
else:
x = shifted_x
if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchMerging(nn.Module):
"""Patch Merging Layer
Args:
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x, H, W):
"""Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
"""
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C)
# padding
pad_input = (H % 2 == 1) or (W % 2 == 1)
if pad_input:
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
class BasicLayer(nn.Module):
"""A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of feature channels
depth (int): Depths of this stage.
num_heads (int): Number of attention head.
window_size (int): Local window size. Default: 7.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(
self,
dim,
depth,
num_heads,
window_size=7,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=nn.LayerNorm,
downsample=None,
use_checkpoint=False,
):
super().__init__()
self.window_size = window_size
self.shift_size = window_size // 2
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList(
[
SwinTransformerBlock(
dim=dim,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i]
if isinstance(drop_path, list)
else drop_path,
norm_layer=norm_layer,
)
for i in range(depth)
]
)
# patch merging layer
if downsample is not None:
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x, H, W):
"""Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
"""
# calculate attention mask for SW-MSA
Hp = int(np.ceil(H / self.window_size)) * self.window_size
Wp = int(np.ceil(W / self.window_size)) * self.window_size
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
h_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
w_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(
img_mask, self.window_size
) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
attn_mask == 0, float(0.0)
)
for blk in self.blocks:
blk.H, blk.W = H, W
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, attn_mask)
else:
x = blk(x, attn_mask)
if self.downsample is not None:
x_down = self.downsample(x, H, W)
Wh, Ww = (H + 1) // 2, (W + 1) // 2
return x, H, W, x_down, Wh, Ww
else:
return x, H, W, x, H, W
class PatchEmbed(nn.Module):
"""Image to Patch Embedding
Args:
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
patch_size = to_2tuple(patch_size)
self.patch_size = patch_size
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
"""Forward function."""
# padding
_, _, H, W = x.size()
if W % self.patch_size[1] != 0:
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
if H % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
x = self.proj(x) # B C Wh Ww
if self.norm is not None:
Wh, Ww = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
return x
class SwinTransformer(nn.Module):
"""Swin Transformer backbone.
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
Args:
pretrain_img_size (int): Input image size for training the pretrained model,
used in absolute postion embedding. Default 224.
patch_size (int | tuple(int)): Patch size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
depths (tuple[int]): Depths of each Swin Transformer stage.
num_heads (tuple[int]): Number of attention head of each stage.
window_size (int): Window size. Default: 7.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
drop_rate (float): Dropout rate.
attn_drop_rate (float): Attention dropout rate. Default: 0.
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
out_indices (Sequence[int]): Output from which stages.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters.
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(
self,
pretrain_img_size=224,
patch_size=4,
in_chans=3,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.2,
norm_layer=nn.LayerNorm,
ape=False,
patch_norm=True,
out_indices=(0, 1, 2, 3),
frozen_stages=-1,
use_checkpoint=False,
):
super().__init__()
self.pretrain_img_size = pretrain_img_size
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.out_indices = out_indices
self.frozen_stages = frozen_stages
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None,
)
# absolute position embedding
if self.ape:
pretrain_img_size = to_2tuple(pretrain_img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [
pretrain_img_size[0] // patch_size[0],
pretrain_img_size[1] // patch_size[1],
]
self.absolute_pos_embed = nn.Parameter(
torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
)
trunc_normal_(self.absolute_pos_embed, std=0.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(
dim=int(embed_dim * 2**i_layer),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint,
)
self.layers.append(layer)
num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
self.num_features = num_features
# add a norm layer for each output
for i_layer in out_indices:
layer = norm_layer(num_features[i_layer])
layer_name = f"norm{i_layer}"
self.add_module(layer_name, layer)
self._freeze_stages()
def _freeze_stages(self):
if self.frozen_stages >= 0:
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
if self.frozen_stages >= 1 and self.ape:
self.absolute_pos_embed.requires_grad = False
if self.frozen_stages >= 2:
self.pos_drop.eval()
for i in range(0, self.frozen_stages - 1):
m = self.layers[i]
m.eval()
for param in m.parameters():
param.requires_grad = False
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
def _init_weights(m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
if isinstance(pretrained, str):
self.apply(_init_weights)
load_checkpoint(self, pretrained, strict=False)
elif pretrained is None:
self.apply(_init_weights)
else:
raise TypeError("pretrained must be a str or None")
def forward(self, x):
x = self.patch_embed(x)
Wh, Ww = x.size(2), x.size(3)
if self.ape:
# interpolate the position embedding to the corresponding size
absolute_pos_embed = F.interpolate(
self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
)
x = x + absolute_pos_embed # B Wh*Ww C
outs = [x.contiguous()]
x = x.flatten(2).transpose(1, 2)
x = self.pos_drop(x)
for i in range(self.num_layers):
layer = self.layers[i]
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
if i in self.out_indices:
norm_layer = getattr(self, f"norm{i}")
x_out = norm_layer(x_out)
out = (
x_out.view(-1, H, W, self.num_features[i])
.permute(0, 3, 1, 2)
.contiguous()
)
outs.append(out)
return tuple(outs)
def train(self, mode=True):
"""Convert the model into training mode while keep layers freezed."""
super(SwinTransformer, self).train(mode)
self._freeze_stages()
def SwinB(pretrained=True):
model = SwinTransformer(
embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12
)
if pretrained is True:
state_dict_path = hf_hub_download(
repo_id="creative-graphic-design/MVANet-checkpoints",
filename="swin_base_patch4_window12_384_22kto1k.pth",
)
state_dict = torch.load(state_dict_path, map_location="cpu")
model.load_state_dict(state_dict["model"], strict=False)
return model
# ============================================================================
# Multi-field Cross Localization Module (MCLM)
# ============================================================================
class inf_MCLM(nn.Module):
def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
super(inf_MCLM, self).__init__()
self.attention = nn.ModuleList(
[
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
]
)
self.linear1 = nn.Linear(d_model, d_model * 2)
self.linear2 = nn.Linear(d_model * 2, d_model)
self.linear3 = nn.Linear(d_model, d_model * 2)
self.linear4 = nn.Linear(d_model * 2, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(0.1)
self.dropout1 = nn.Dropout(0.1)
self.dropout2 = nn.Dropout(0.1)
self.activation = get_activation_fn("relu")
self.pool_ratios = pool_ratios
self.p_poses = None
self.g_pos = None
self.positional_encoding = PositionEmbeddingSine(
num_pos_feats=d_model // 2, normalize=True
)
def forward(self, l, g):
"""
l: 4,c,h,w
g: 1,c,h,w
"""
b, c, h, w = l.size()
# 4,c,h,w -> 1,c,2h,2w
concated_locs = rearrange(l, "(hg wg b) c h w -> b c (hg h) (wg w)", hg=2, wg=2)
pools = []
p_poses_list = []
for pool_ratio in self.pool_ratios:
# b,c,h,w
tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
pools.append(rearrange(pool, "b c h w -> (h w) b c"))
pos_emb = self.positional_encoding(
pool.shape[0], pool.shape[2], pool.shape[3]
)
pos_emb = rearrange(pos_emb, "b c h w -> (h w) b c")
p_poses_list.append(pos_emb)
pools = torch.cat(pools, 0)
p_poses = torch.cat(p_poses_list, dim=0)
pos_emb = self.positional_encoding(g.shape[0], g.shape[2], g.shape[3])
g_pos = rearrange(pos_emb, "b c h w -> (h w) b c")
# attention between glb (q) & multisensory concated-locs (k,v)
g_hw_b_c = rearrange(g, "b c h w -> (h w) b c")
g_hw_b_c = g_hw_b_c + self.dropout1(
self.attention[0](g_hw_b_c + g_pos, pools + p_poses, pools)[0]
)
g_hw_b_c = self.norm1(g_hw_b_c)
g_hw_b_c = g_hw_b_c + self.dropout2(
self.linear2(self.dropout(self.activation(self.linear1(g_hw_b_c)).clone()))
)
g_hw_b_c = self.norm2(g_hw_b_c)
# attention between origin locs (q) & freashed glb (k,v)
l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
_g_hw_b_c = rearrange(g_hw_b_c, "(h w) b c -> h w b c", h=h, w=w)
_g_hw_b_c = rearrange(
_g_hw_b_c, "(ng h) (nw w) b c -> (h w) (ng nw b) c", ng=2, nw=2
)
outputs_re = []
for i, (_l, _g) in enumerate(
zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))
):
outputs_re.append(self.attention[i + 1](_l, _g, _g)[0]) # (h w) 1 c
outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
l_hw_b_c = self.norm1(l_hw_b_c)
l_hw_b_c = l_hw_b_c + self.dropout2(
self.linear4(self.dropout(self.activation(self.linear3(l_hw_b_c)).clone()))
)
l_hw_b_c = self.norm2(l_hw_b_c)
l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
# ============================================================================
# Multi-crop Refinement Module (MCRM)
# ============================================================================
class inf_MCRM(nn.Module):
def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
super(inf_MCRM, self).__init__()
self.attention = nn.ModuleList(
[
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
]
)
self.linear3 = nn.Linear(d_model, d_model * 2)
self.linear4 = nn.Linear(d_model * 2, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(0.1)
self.dropout1 = nn.Dropout(0.1)
self.dropout2 = nn.Dropout(0.1)
self.sigmoid = nn.Sigmoid()
self.activation = get_activation_fn("relu")
self.sal_conv = nn.Conv2d(d_model, 1, 1)
self.pool_ratios = pool_ratios
self.positional_encoding = PositionEmbeddingSine(
num_pos_feats=d_model // 2, normalize=True
)
def forward(self, x):
total_b, c, h, w = x.size()
# Total batch is 5*batch_size (4 local + 1 global)
batch_size = total_b // 5
# Split into local (4*batch_size) and global (batch_size)
loc, glb = x.split([4 * batch_size, batch_size], dim=0)
# loc: (4*batch_size, c, h, w), glb: (batch_size, c, h, w)
patched_glb = rearrange(glb, "b c (hg h) (wg w) -> (hg wg b) c h w", hg=2, wg=2)
# generate token attention map
token_attention_map = self.sigmoid(self.sal_conv(glb))
token_attention_map = F.interpolate(
token_attention_map, size=patches2image(loc).shape[-2:], mode="nearest"
)
loc = loc * rearrange(
token_attention_map, "b c (hg h) (wg w) -> (hg wg b) c h w", hg=2, wg=2
)
pools = []
for pool_ratio in self.pool_ratios:
tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
pools.append(rearrange(pool, "nl c h w -> nl c (h w)"))
# pools: (4*batch_size, c, nphw) -> (4*batch_size, nphw, 1, c)
pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
# Reshape to separate batch and patch dimensions: (4, batch_size, nphw, 1, c)
# Note: image2patches outputs in order (hg wg b) where b changes fastest
# So the order is: [p0_b0, p0_b1, ..., p1_b0, p1_b1, ..., p3_b0, p3_b1]
pools = rearrange(pools, "(p b) nphw 1 c -> p b nphw 1 c", p=4, b=batch_size)
# loc_: (4*batch_size, hw, 1, c) -> (4, batch_size, hw, 1, c)
loc_ = rearrange(loc, "nl c h w -> nl (h w) 1 c")
loc_ = rearrange(loc_, "(p b) hw 1 c -> p b hw 1 c", p=4, b=batch_size)
# Apply attention for each of 4 patches (only 4 iterations, not batch_size!)
# Each iteration processes all batch items simultaneously
outputs = []
for i in range(4): # Only 4 iterations regardless of batch_size!
# Extract patch i across all batch items: (batch_size, hw, 1, c)
q = loc_[i, :, :, :, :] # (b, hw, 1, c)
v = pools[i, :, :, :, :] # (b, nphw, 1, c)
k = v
# Reshape for MultiheadAttention: (seq, batch, dim)
q = rearrange(q, "b hw 1 c -> hw b c")
k = rearrange(k, "b nphw 1 c -> nphw b c")
v = rearrange(v, "b nphw 1 c -> nphw b c")
# Apply attention (processes all batch_size items in parallel)
attn_out = self.attention[i](q, k, v)[0] # (hw, b, c)
outputs.append(attn_out)
# Concatenate outputs: list of 4 x (hw, b, c) -> (hw, p*b, c)
# Interleave to match (p b) order: [p0_b0, p0_b1, ..., p1_b0, p1_b1, ...]
outputs = torch.stack(outputs, dim=2) # (hw, b, 4, c)
outputs = rearrange(outputs, "hw b p c -> hw (p b) c") # (hw, 4*b, c)
# Continue with existing operations using batch_size
src = loc.view(4 * batch_size, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
src = self.norm1(src)
src = src + self.dropout2(
self.linear4(self.dropout(self.activation(self.linear3(src)).clone()))
)
src = self.norm2(src)
src = src.permute(1, 2, 0).reshape(4 * batch_size, c, h, w) # freshed loc
glb = glb + F.interpolate(
patches2image(src), size=glb.shape[-2:], mode="nearest"
) # freshed glb
return torch.cat((src, glb), 0)
# ============================================================================
# MVANet Model for Image Segmentation
# ============================================================================
class MVANetForImageSegmentation(PreTrainedModel):
"""
MVANet Model for image segmentation.
This model is a direct reimplementation of inf_MVANet with transformers-compatible
interface for semantic segmentation tasks.
Args:
config (:class:`~mvanet.transformers.MVANetConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration.
Example::\
>>> from transformers import AutoModel, AutoImageProcessor
>>> from PIL import Image
>>> # Load model and processor
>>> model = AutoModel.from_pretrained("creative-graphic-design/mvanet")
>>> processor = AutoImageProcessor.from_pretrained("creative-graphic-design/mvanet")
>>> # Load image
>>> image = Image.open("image.png")
>>> # Preprocess
>>> inputs = processor(image, return_tensors="pt")
>>> # Forward pass
>>> outputs = model(**inputs)
>>> # Post-process
>>> masks = processor.post_process_semantic_segmentation(
... outputs, target_sizes=[image.size[::-1]]
... )
"""
config_class = MVANetConfig
base_model_prefix = "mvanet"
main_input_name = "pixel_values"
supports_gradient_checkpointing = False
_no_split_modules = []
def __init__(self, config: MVANetConfig):
super().__init__(config)
self.config = config
emb_dim = config.embedding_dim
# Backbone: Swin Transformer
self.backbone = SwinB(pretrained=config.backbone_pretrained)
# Feature projection layers - use config values
out_channels = config.backbone_out_channels
self.output5 = make_cbr(out_channels[4], emb_dim) # 1024 -> 128
self.output4 = make_cbr(out_channels[3], emb_dim) # 512 -> 128
self.output3 = make_cbr(out_channels[2], emb_dim) # 256 -> 128
self.output2 = make_cbr(out_channels[1], emb_dim) # 128 -> 128
self.output1 = make_cbr(out_channels[0], emb_dim) # 128 -> 128
# Multi-field Cross Localization Module
self.multifieldcrossatt = inf_MCLM(
emb_dim, config.mclm_num_heads, config.mclm_pool_ratios
)
# Convolution blocks for decoder
self.conv1 = make_cbr(emb_dim, emb_dim)
self.conv2 = make_cbr(emb_dim, emb_dim)
self.conv3 = make_cbr(emb_dim, emb_dim)
self.conv4 = make_cbr(emb_dim, emb_dim)
# Multi-crop Refinement Module decoder blocks
self.dec_blk1 = inf_MCRM(
emb_dim, config.mcrm_num_heads, config.mcrm_pool_ratios
)
self.dec_blk2 = inf_MCRM(
emb_dim, config.mcrm_num_heads, config.mcrm_pool_ratios
)
self.dec_blk3 = inf_MCRM(
emb_dim, config.mcrm_num_heads, config.mcrm_pool_ratios
)
self.dec_blk4 = inf_MCRM(
emb_dim, config.mcrm_num_heads, config.mcrm_pool_ratios
)
# Instance mask head - use config value
hidden_dim = config.insmask_hidden_dim
self.insmask_head = nn.Sequential(
nn.Conv2d(emb_dim, hidden_dim, kernel_size=3, padding=1),
nn.BatchNorm2d(hidden_dim),
nn.PReLU(),
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
nn.BatchNorm2d(hidden_dim),
nn.PReLU(),
nn.Conv2d(hidden_dim, emb_dim, kernel_size=3, padding=1),
)
# Shallow feature extraction - use config value
self.shallow = nn.Sequential(
nn.Conv2d(config.num_channels, emb_dim, kernel_size=3, padding=1)
)
# Upsampling layers
self.upsample1 = make_cbg(emb_dim, emb_dim)
self.upsample2 = make_cbg(emb_dim, emb_dim)
# Final output layer - use config value
self.output = nn.Sequential(
nn.Conv2d(emb_dim, config.num_labels, kernel_size=3, padding=1)
)
# Set inplace operations for ReLU and Dropout
for m in self.modules():
if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout):
m.inplace = True
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
pixel_values: torch.FloatTensor,
labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, SemanticSegmenterOutput]:
"""
Forward pass of the model.
Args:
pixel_values (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using :class:`~mvanet.transformers.MVANetImageProcessor`.
See :meth:`~mvanet.transformers.MVANetImageProcessor.preprocess` for details.
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, height, width)`, `optional`):
Ground truth semantic segmentation maps for computing the loss.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. Currently not supported.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.modeling_outputs.SemanticSegmenterOutput` instead of
a plain tuple.
Returns:
:class:`~transformers.modeling_outputs.SemanticSegmenterOutput` or :obj:`tuple`:
A :class:`~transformers.modeling_outputs.SemanticSegmenterOutput` (if ``return_dict=True`` is passed or
when ``config.use_return_dict=True``) or a tuple of :obj:`torch.FloatTensor`.
Example::\
>>> from mvanet.transformers import MVANetForImageSegmentation, MVANetImageProcessor
>>> import torch
>>> from PIL import Image
>>> processor = MVANetImageProcessor()
>>> model = MVANetForImageSegmentation.from_pretrained("creative-graphic-design/mvanet")
>>> image = Image.open("image.png")
>>> inputs = processor(image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits # (batch_size, num_labels, height, width)
"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
batch_size = pixel_values.shape[0]
# Extract shallow features
shallow = self.shallow(pixel_values)
# Create multi-view input: 4 local patches + 1 global view
# Use config value for global view scale
glb = rescale_to(
pixel_values,
scale_factor=self.config.global_view_scale,
interpolation="bilinear",
)
loc = image2patches(pixel_values)
input_views = torch.cat((loc, glb), dim=0)
# Extract features through backbone
feature = self.backbone(input_views)
# Project features to embedding dimension
e5 = self.output5(feature[4]) # (batch*5, 128, 16, 16)
e4 = self.output4(feature[3]) # (batch*5, 128, 32, 32)
e3 = self.output3(feature[2]) # (batch*5, 128, 64, 64)
e2 = self.output2(feature[1]) # (batch*5, 128, 128, 128)
e1 = self.output1(feature[0]) # (batch*5, 128, 128, 128)
# Split local and global features at deepest level
# Use config value for number of patches
loc_e5, glb_e5 = e5.split(
[batch_size * self.config.num_patches, batch_size], dim=0
)
# Apply multi-field cross attention
e5_cat = self.multifieldcrossatt(loc_e5, glb_e5) # (batch*5, 128, 16, 16)
# Decode through MCRM blocks with skip connections
e4 = self.conv4(self.dec_blk4(e4 + resize_as(e5_cat, e4)))
e3 = self.conv3(self.dec_blk3(e3 + resize_as(e4, e3)))
e2 = self.conv2(self.dec_blk2(e2 + resize_as(e3, e2)))
e1 = self.conv1(self.dec_blk1(e1 + resize_as(e2, e1)))
# Split local and global features
# Use config value for number of patches
loc_e1, glb_e1 = e1.split(
[batch_size * self.config.num_patches, batch_size], dim=0
)
# Merge local patches back to image
output1_cat = patches2image(loc_e1)
# Add global features
output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
# Apply instance mask head
final_output = self.insmask_head(output1_cat)
# Merge shallow features
final_output = final_output + resize_as(shallow, final_output)
final_output = self.upsample1(rescale_to(final_output))
final_output = rescale_to(final_output + resize_as(shallow, final_output))
final_output = self.upsample2(final_output)
# Final output (logits before sigmoid)
logits = self.output(final_output)
loss = None
if labels is not None:
# Compute binary cross-entropy loss with logits
# labels should be float with values in [0, 1]
loss_fct = nn.BCEWithLogitsLoss()
# Ensure labels have the same shape as logits
if labels.dim() == 3:
# (B, H, W) -> (B, 1, H, W)
labels = labels.unsqueeze(1)
loss = loss_fct(logits, labels.float())
if not return_dict:
output = (logits,)
return ((loss,) + output) if loss is not None else output
return SemanticSegmenterOutput(
loss=loss,
logits=logits,
hidden_states=None,
attentions=None,
)
def _init_weights(self, module):
"""
Initialize weights.
The backbone (SwinB) and other modules handle their own weight initialization,
so we don't need to do anything here.
"""
pass