shunk031 commited on
Commit
d1be74e
·
verified ·
1 Parent(s): 28c79f5

Upload MVANetForImageSegmentation

Browse files
Files changed (3) hide show
  1. config.json +4 -0
  2. configuration_mvanet.py +109 -0
  3. modeling_mvanet.py +1340 -0
config.json CHANGED
@@ -2,6 +2,10 @@
2
  "architectures": [
3
  "MVANetForImageSegmentation"
4
  ],
 
 
 
 
5
  "backbone_out_channels": [
6
  128,
7
  128,
 
2
  "architectures": [
3
  "MVANetForImageSegmentation"
4
  ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_mvanet.MVANetConfig",
7
+ "AutoModel": "modeling_mvanet.MVANetForImageSegmentation"
8
+ },
9
  "backbone_out_channels": [
10
  128,
11
  128,
configuration_mvanet.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MVANet model configuration."""
2
+
3
+ from typing import List
4
+
5
+ from transformers import PretrainedConfig
6
+
7
+
8
+ class MVANetConfig(PretrainedConfig):
9
+ """
10
+ Configuration class for MVANet model.
11
+
12
+ This is the configuration class to store the configuration of a
13
+ :class:`~mvanet.transformers.MVANetForImageSegmentation`.
14
+ It is used to instantiate a MVANet model according to the specified arguments,
15
+ defining the model architecture.
16
+
17
+ Configuration objects inherit from :class:`~transformers.PretrainedConfig` and
18
+ can be used to control the model outputs. Read the documentation from
19
+ :class:`~transformers.PretrainedConfig` for more information.
20
+
21
+ Args:
22
+ embedding_dim (:obj:`int`, `optional`, defaults to 128):
23
+ The embedding dimension used throughout the model.
24
+ backbone_type (:obj:`str`, `optional`, defaults to :obj:`"swinb"`):
25
+ Type of backbone to use. Currently only "swinb" (Swin Transformer Base) is supported.
26
+ backbone_pretrained (:obj:`bool`, `optional`, defaults to :obj:`True`):
27
+ Whether to use pretrained weights for the backbone.
28
+ backbone_out_channels (:obj:`List[int]`, `optional`, defaults to :obj:`[128, 128, 256, 512, 1024]`):
29
+ Output channel dimensions for each backbone level (SwinB specific).
30
+ mclm_num_heads (:obj:`int`, `optional`, defaults to 1):
31
+ Number of attention heads in Multi-field Cross Localization Module (MCLM).
32
+ mclm_pool_ratios (:obj:`List[int]`, `optional`, defaults to :obj:`[1, 4, 8]`):
33
+ Pool ratios for MCLM multi-scale attention.
34
+ mcrm_num_heads (:obj:`int`, `optional`, defaults to 1):
35
+ Number of attention heads in Multi-crop Refinement Module (MCRM).
36
+ mcrm_pool_ratios (:obj:`List[int]`, `optional`, defaults to :obj:`[2, 4, 8]`):
37
+ Pool ratios for MCRM multi-scale attention.
38
+ insmask_hidden_dim (:obj:`int`, `optional`, defaults to 384):
39
+ Hidden dimension in the instance mask head.
40
+ global_view_scale (:obj:`float`, `optional`, defaults to 0.5):
41
+ Scale factor for creating the global view (downsampled version of input).
42
+ num_patches (:obj:`int`, `optional`, defaults to 4):
43
+ Number of local patches (currently only 4 for 2x2 grid is supported).
44
+ image_size (:obj:`int`, `optional`, defaults to 1024):
45
+ Input image size the model was trained on.
46
+ num_channels (:obj:`int`, `optional`, defaults to 3):
47
+ Number of input channels (3 for RGB images).
48
+ num_labels (:obj:`int`, `optional`, defaults to 1):
49
+ Number of output labels (1 for binary segmentation).
50
+
51
+ Example::
52
+
53
+ >>> from mvanet.transformers import MVANetConfig, MVANetForImageSegmentation
54
+
55
+ >>> # Initializing a MVANet configuration
56
+ >>> configuration = MVANetConfig()
57
+
58
+ >>> # Initializing a model from the configuration
59
+ >>> model = MVANetForImageSegmentation(configuration)
60
+
61
+ >>> # Accessing the model configuration
62
+ >>> configuration = model.config
63
+ """
64
+
65
+ model_type = "mvanet"
66
+
67
+ def __init__(
68
+ self,
69
+ embedding_dim: int = 128,
70
+ backbone_type: str = "swinb",
71
+ backbone_pretrained: bool = True,
72
+ backbone_out_channels: List[int] | None = None,
73
+ mclm_num_heads: int = 1,
74
+ mclm_pool_ratios: List[int] | None = None,
75
+ mcrm_num_heads: int = 1,
76
+ mcrm_pool_ratios: List[int] | None = None,
77
+ insmask_hidden_dim: int = 384,
78
+ global_view_scale: float = 0.5,
79
+ num_patches: int = 4,
80
+ image_size: int = 1024,
81
+ num_channels: int = 3,
82
+ num_labels: int = 1,
83
+ **kwargs,
84
+ ):
85
+ super().__init__(**kwargs)
86
+
87
+ self.embedding_dim = embedding_dim
88
+ self.backbone_type = backbone_type
89
+ self.backbone_pretrained = backbone_pretrained
90
+ # SwinB backbone output channels: [128, 128, 256, 512, 1024]
91
+ self.backbone_out_channels = (
92
+ backbone_out_channels
93
+ if backbone_out_channels is not None
94
+ else [128, 128, 256, 512, 1024]
95
+ )
96
+ self.mclm_num_heads = mclm_num_heads
97
+ self.mclm_pool_ratios = (
98
+ mclm_pool_ratios if mclm_pool_ratios is not None else [1, 4, 8]
99
+ )
100
+ self.mcrm_num_heads = mcrm_num_heads
101
+ self.mcrm_pool_ratios = (
102
+ mcrm_pool_ratios if mcrm_pool_ratios is not None else [2, 4, 8]
103
+ )
104
+ self.insmask_hidden_dim = insmask_hidden_dim
105
+ self.global_view_scale = global_view_scale
106
+ self.num_patches = num_patches
107
+ self.image_size = image_size
108
+ self.num_channels = num_channels
109
+ self.num_labels = num_labels
modeling_mvanet.py ADDED
@@ -0,0 +1,1340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch MVANet model for semantic segmentation."""
2
+
3
+ import math
4
+ from typing import Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torch.utils.checkpoint as checkpoint
11
+ from einops import rearrange
12
+ from huggingface_hub import hf_hub_download
13
+ from timm.layers import DropPath, to_2tuple, trunc_normal_
14
+ from timm.models import load_checkpoint
15
+ from transformers import PreTrainedModel
16
+ from transformers.modeling_outputs import SemanticSegmenterOutput
17
+
18
+ from mvanet.transformers.configuration_mvanet import MVANetConfig
19
+
20
+ # ============================================================================
21
+ # Helper Functions
22
+ # ============================================================================
23
+
24
+
25
+ def get_activation_fn(activation):
26
+ """Return an activation function given a string"""
27
+ if activation == "relu":
28
+ return F.relu
29
+ if activation == "gelu":
30
+ return F.gelu
31
+ if activation == "glu":
32
+ return F.glu
33
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
34
+
35
+
36
+ def make_cbr(in_dim, out_dim):
37
+ return nn.Sequential(
38
+ nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1),
39
+ nn.BatchNorm2d(out_dim),
40
+ nn.PReLU(),
41
+ )
42
+
43
+
44
+ def make_cbg(in_dim, out_dim):
45
+ return nn.Sequential(
46
+ nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1),
47
+ nn.BatchNorm2d(out_dim),
48
+ nn.GELU(),
49
+ )
50
+
51
+
52
+ def rescale_to(x, scale_factor: float = 2, interpolation="nearest"):
53
+ return F.interpolate(x, scale_factor=scale_factor, mode=interpolation)
54
+
55
+
56
+ def resize_as(x, y, interpolation="bilinear"):
57
+ return F.interpolate(x, size=y.shape[-2:], mode=interpolation)
58
+
59
+
60
+ def image2patches(x):
61
+ """b c (hg h) (wg w) -> (hg wg b) c h w"""
62
+ b, c, h, w = x.shape
63
+ if h % 2 != 0 or w % 2 != 0:
64
+ x = F.interpolate(
65
+ x, size=(h + h % 2, w + w % 2), mode="bilinear", align_corners=False
66
+ )
67
+ x = rearrange(x, "b c (hg h) (wg w) -> (hg wg b) c h w", hg=2, wg=2)
68
+ return x
69
+
70
+
71
+ def patches2image(x):
72
+ """(hg wg b) c h w -> b c (hg h) (wg w)"""
73
+ patches_b, c, h, w = x.shape
74
+ actual_b = patches_b // 4
75
+ x = rearrange(x, "(hg wg b) c h w -> b c (hg h) (wg w)", hg=2, wg=2, b=actual_b)
76
+ return x
77
+
78
+
79
+ # ============================================================================
80
+ # Position Embedding
81
+ # ============================================================================
82
+
83
+
84
+ class PositionEmbeddingSine(nn.Module):
85
+ def __init__(
86
+ self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
87
+ ):
88
+ super().__init__()
89
+ self.num_pos_feats = num_pos_feats
90
+ self.temperature = temperature
91
+ self.normalize = normalize
92
+ if scale is not None and normalize is False:
93
+ raise ValueError("normalize should be True if scale is passed")
94
+ if scale is None:
95
+ scale = 2 * math.pi
96
+ self.scale = scale
97
+ self.dim_t = torch.arange(
98
+ 0,
99
+ self.num_pos_feats,
100
+ dtype=torch.float32,
101
+ device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
102
+ )
103
+
104
+ def __call__(self, b, h, w):
105
+ mask = torch.zeros([b, h, w], dtype=torch.bool, device=self.dim_t.device)
106
+ assert mask is not None
107
+ not_mask = ~mask
108
+ y_embed = not_mask.cumsum(dim=1, dtype=torch.float32)
109
+ x_embed = not_mask.cumsum(dim=2, dtype=torch.float32)
110
+ if self.normalize:
111
+ eps = 1e-6
112
+ y_embed = ((y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale).to(
113
+ mask.device
114
+ )
115
+ x_embed = ((x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale).to(
116
+ mask.device
117
+ )
118
+
119
+ dim_t = self.temperature ** (2 * (self.dim_t // 2) / self.num_pos_feats)
120
+
121
+ pos_x = x_embed[:, :, :, None] / dim_t
122
+ pos_y = y_embed[:, :, :, None] / dim_t
123
+ pos_x = torch.stack(
124
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
125
+ ).flatten(3)
126
+ pos_y = torch.stack(
127
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
128
+ ).flatten(3)
129
+ return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
130
+
131
+
132
+ # ============================================================================
133
+ # Swin Transformer Components
134
+ # ============================================================================
135
+
136
+
137
+ class Mlp(nn.Module):
138
+ """Multilayer perceptron."""
139
+
140
+ def __init__(
141
+ self,
142
+ in_features,
143
+ hidden_features=None,
144
+ out_features=None,
145
+ act_layer=nn.GELU,
146
+ drop=0.0,
147
+ ):
148
+ super().__init__()
149
+ out_features = out_features or in_features
150
+ hidden_features = hidden_features or in_features
151
+ self.fc1 = nn.Linear(in_features, hidden_features)
152
+ self.act = act_layer()
153
+ self.fc2 = nn.Linear(hidden_features, out_features)
154
+ self.drop = nn.Dropout(drop)
155
+
156
+ def forward(self, x):
157
+ x = self.fc1(x)
158
+ x = self.act(x)
159
+ x = self.drop(x)
160
+ x = self.fc2(x)
161
+ x = self.drop(x)
162
+ return x
163
+
164
+
165
+ def window_partition(x, window_size):
166
+ """
167
+ Args:
168
+ x: (B, H, W, C)
169
+ window_size (int): window size
170
+
171
+ Returns:
172
+ windows: (num_windows*B, window_size, window_size, C)
173
+ """
174
+ B, H, W, C = x.shape
175
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
176
+ windows = (
177
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
178
+ )
179
+ return windows
180
+
181
+
182
+ def window_reverse(windows, window_size, H, W):
183
+ """
184
+ Args:
185
+ windows: (num_windows*B, window_size, window_size, C)
186
+ window_size (int): Window size
187
+ H (int): Height of image
188
+ W (int): Width of image
189
+
190
+ Returns:
191
+ x: (B, H, W, C)
192
+ """
193
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
194
+ x = windows.view(
195
+ B, H // window_size, W // window_size, window_size, window_size, -1
196
+ )
197
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
198
+ return x
199
+
200
+
201
+ class WindowAttention(nn.Module):
202
+ """Window based multi-head self attention (W-MSA) module with relative position bias.
203
+ It supports both of shifted and non-shifted window.
204
+
205
+ Args:
206
+ dim (int): Number of input channels.
207
+ window_size (tuple[int]): The height and width of the window.
208
+ num_heads (int): Number of attention heads.
209
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
210
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
211
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
212
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
213
+ """
214
+
215
+ def __init__(
216
+ self,
217
+ dim,
218
+ window_size,
219
+ num_heads,
220
+ qkv_bias=True,
221
+ qk_scale=None,
222
+ attn_drop=0.0,
223
+ proj_drop=0.0,
224
+ ):
225
+ super().__init__()
226
+ self.dim = dim
227
+ self.window_size = window_size # Wh, Ww
228
+ self.num_heads = num_heads
229
+ head_dim = dim // num_heads
230
+ self.scale = qk_scale or head_dim**-0.5
231
+
232
+ # define a parameter table of relative position bias
233
+ self.relative_position_bias_table = nn.Parameter(
234
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
235
+ ) # 2*Wh-1 * 2*Ww-1, nH
236
+
237
+ # get pair-wise relative position index for each token inside the window
238
+ coords_h = torch.arange(self.window_size[0])
239
+ coords_w = torch.arange(self.window_size[1])
240
+ coords = torch.stack(
241
+ torch.meshgrid([coords_h, coords_w], indexing="ij")
242
+ ) # 2, Wh, Ww
243
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
244
+ relative_coords = (
245
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
246
+ ) # 2, Wh*Ww, Wh*Ww
247
+ relative_coords = relative_coords.permute(
248
+ 1, 2, 0
249
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
250
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
251
+ relative_coords[:, :, 1] += self.window_size[1] - 1
252
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
253
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
254
+ self.register_buffer("relative_position_index", relative_position_index)
255
+
256
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
257
+ self.attn_drop = nn.Dropout(attn_drop)
258
+ self.proj = nn.Linear(dim, dim)
259
+ self.proj_drop = nn.Dropout(proj_drop)
260
+
261
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
262
+ self.softmax = nn.Softmax(dim=-1)
263
+
264
+ def forward(self, x, mask=None):
265
+ """Forward function.
266
+
267
+ Args:
268
+ x: input features with shape of (num_windows*B, N, C)
269
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
270
+ """
271
+ B_, N, C = x.shape
272
+ qkv = (
273
+ self.qkv(x)
274
+ .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
275
+ .permute(2, 0, 3, 1, 4)
276
+ )
277
+ q, k, v = (
278
+ qkv[0],
279
+ qkv[1],
280
+ qkv[2],
281
+ ) # make torchscript happy (cannot use tensor as tuple)
282
+
283
+ q = q * self.scale
284
+ attn = q @ k.transpose(-2, -1)
285
+
286
+ relative_position_index = self.relative_position_index
287
+ assert isinstance(relative_position_index, torch.Tensor)
288
+ relative_position_bias = self.relative_position_bias_table[
289
+ relative_position_index.view(-1)
290
+ ].view(
291
+ self.window_size[0] * self.window_size[1],
292
+ self.window_size[0] * self.window_size[1],
293
+ -1,
294
+ ) # Wh*Ww,Wh*Ww,nH
295
+ relative_position_bias = relative_position_bias.permute(
296
+ 2, 0, 1
297
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
298
+ attn = attn + relative_position_bias.unsqueeze(0)
299
+
300
+ if mask is not None:
301
+ nW = mask.shape[0]
302
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
303
+ 1
304
+ ).unsqueeze(0)
305
+ attn = attn.view(-1, self.num_heads, N, N)
306
+ attn = self.softmax(attn)
307
+ else:
308
+ attn = self.softmax(attn)
309
+
310
+ attn = self.attn_drop(attn)
311
+
312
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
313
+ x = self.proj(x)
314
+ x = self.proj_drop(x)
315
+ return x
316
+
317
+
318
+ class SwinTransformerBlock(nn.Module):
319
+ """Swin Transformer Block.
320
+
321
+ Args:
322
+ dim (int): Number of input channels.
323
+ num_heads (int): Number of attention heads.
324
+ window_size (int): Window size.
325
+ shift_size (int): Shift size for SW-MSA.
326
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
327
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
328
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
329
+ drop (float, optional): Dropout rate. Default: 0.0
330
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
331
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
332
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
333
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
334
+ """
335
+
336
+ def __init__(
337
+ self,
338
+ dim,
339
+ num_heads,
340
+ window_size=7,
341
+ shift_size=0,
342
+ mlp_ratio=4.0,
343
+ qkv_bias=True,
344
+ qk_scale=None,
345
+ drop=0.0,
346
+ attn_drop=0.0,
347
+ drop_path=0.0,
348
+ act_layer=nn.GELU,
349
+ norm_layer=nn.LayerNorm,
350
+ ):
351
+ super().__init__()
352
+ self.dim = dim
353
+ self.num_heads = num_heads
354
+ self.window_size = window_size
355
+ self.shift_size = shift_size
356
+ self.mlp_ratio = mlp_ratio
357
+ assert 0 <= self.shift_size < self.window_size, (
358
+ "shift_size must in 0-window_size"
359
+ )
360
+
361
+ self.norm1 = norm_layer(dim)
362
+ self.attn = WindowAttention(
363
+ dim,
364
+ window_size=to_2tuple(self.window_size),
365
+ num_heads=num_heads,
366
+ qkv_bias=qkv_bias,
367
+ qk_scale=qk_scale,
368
+ attn_drop=attn_drop,
369
+ proj_drop=drop,
370
+ )
371
+
372
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
373
+ self.norm2 = norm_layer(dim)
374
+ mlp_hidden_dim = int(dim * mlp_ratio)
375
+ self.mlp = Mlp(
376
+ in_features=dim,
377
+ hidden_features=mlp_hidden_dim,
378
+ act_layer=act_layer,
379
+ drop=drop,
380
+ )
381
+
382
+ self.H: int | None = None
383
+ self.W: int | None = None
384
+
385
+ def forward(self, x, mask_matrix):
386
+ """Forward function.
387
+
388
+ Args:
389
+ x: Input feature, tensor size (B, H*W, C).
390
+ H, W: Spatial resolution of the input feature.
391
+ mask_matrix: Attention mask for cyclic shift.
392
+ """
393
+ B, L, C = x.shape
394
+ H, W = self.H, self.W
395
+ assert H is not None and W is not None, "H and W must be set before forward"
396
+ assert L == H * W, "input feature has wrong size"
397
+
398
+ shortcut = x
399
+ x = self.norm1(x)
400
+ x = x.view(B, H, W, C)
401
+
402
+ # pad feature maps to multiples of window size
403
+ pad_l = pad_t = 0
404
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
405
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
406
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
407
+ _, Hp, Wp, _ = x.shape
408
+
409
+ # cyclic shift
410
+ if self.shift_size > 0:
411
+ shifted_x = torch.roll(
412
+ x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
413
+ )
414
+ attn_mask = mask_matrix
415
+ else:
416
+ shifted_x = x
417
+ attn_mask = None
418
+
419
+ # partition windows
420
+ x_windows = window_partition(
421
+ shifted_x, self.window_size
422
+ ) # nW*B, window_size, window_size, C
423
+ x_windows = x_windows.view(
424
+ -1, self.window_size * self.window_size, C
425
+ ) # nW*B, window_size*window_size, C
426
+
427
+ # W-MSA/SW-MSA
428
+ attn_windows = self.attn(
429
+ x_windows, mask=attn_mask
430
+ ) # nW*B, window_size*window_size, C
431
+
432
+ # merge windows
433
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
434
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
435
+
436
+ # reverse cyclic shift
437
+ if self.shift_size > 0:
438
+ x = torch.roll(
439
+ shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
440
+ )
441
+ else:
442
+ x = shifted_x
443
+
444
+ if pad_r > 0 or pad_b > 0:
445
+ x = x[:, :H, :W, :].contiguous()
446
+
447
+ x = x.view(B, H * W, C)
448
+
449
+ # FFN
450
+ x = shortcut + self.drop_path(x)
451
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
452
+
453
+ return x
454
+
455
+
456
+ class PatchMerging(nn.Module):
457
+ """Patch Merging Layer
458
+
459
+ Args:
460
+ dim (int): Number of input channels.
461
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
462
+ """
463
+
464
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
465
+ super().__init__()
466
+ self.dim = dim
467
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
468
+ self.norm = norm_layer(4 * dim)
469
+
470
+ def forward(self, x, H, W):
471
+ """Forward function.
472
+
473
+ Args:
474
+ x: Input feature, tensor size (B, H*W, C).
475
+ H, W: Spatial resolution of the input feature.
476
+ """
477
+ B, L, C = x.shape
478
+ assert L == H * W, "input feature has wrong size"
479
+
480
+ x = x.view(B, H, W, C)
481
+
482
+ # padding
483
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
484
+ if pad_input:
485
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
486
+
487
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
488
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
489
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
490
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
491
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
492
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
493
+
494
+ x = self.norm(x)
495
+ x = self.reduction(x)
496
+
497
+ return x
498
+
499
+
500
+ class BasicLayer(nn.Module):
501
+ """A basic Swin Transformer layer for one stage.
502
+
503
+ Args:
504
+ dim (int): Number of feature channels
505
+ depth (int): Depths of this stage.
506
+ num_heads (int): Number of attention head.
507
+ window_size (int): Local window size. Default: 7.
508
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
509
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
510
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
511
+ drop (float, optional): Dropout rate. Default: 0.0
512
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
513
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
514
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
515
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
516
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
517
+ """
518
+
519
+ def __init__(
520
+ self,
521
+ dim,
522
+ depth,
523
+ num_heads,
524
+ window_size=7,
525
+ mlp_ratio=4.0,
526
+ qkv_bias=True,
527
+ qk_scale=None,
528
+ drop=0.0,
529
+ attn_drop=0.0,
530
+ drop_path=0.0,
531
+ norm_layer=nn.LayerNorm,
532
+ downsample=None,
533
+ use_checkpoint=False,
534
+ ):
535
+ super().__init__()
536
+ self.window_size = window_size
537
+ self.shift_size = window_size // 2
538
+ self.depth = depth
539
+ self.use_checkpoint = use_checkpoint
540
+
541
+ # build blocks
542
+ self.blocks = nn.ModuleList(
543
+ [
544
+ SwinTransformerBlock(
545
+ dim=dim,
546
+ num_heads=num_heads,
547
+ window_size=window_size,
548
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
549
+ mlp_ratio=mlp_ratio,
550
+ qkv_bias=qkv_bias,
551
+ qk_scale=qk_scale,
552
+ drop=drop,
553
+ attn_drop=attn_drop,
554
+ drop_path=drop_path[i]
555
+ if isinstance(drop_path, list)
556
+ else drop_path,
557
+ norm_layer=norm_layer,
558
+ )
559
+ for i in range(depth)
560
+ ]
561
+ )
562
+
563
+ # patch merging layer
564
+ if downsample is not None:
565
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
566
+ else:
567
+ self.downsample = None
568
+
569
+ def forward(self, x, H, W):
570
+ """Forward function.
571
+
572
+ Args:
573
+ x: Input feature, tensor size (B, H*W, C).
574
+ H, W: Spatial resolution of the input feature.
575
+ """
576
+
577
+ # calculate attention mask for SW-MSA
578
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
579
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
580
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
581
+ h_slices = (
582
+ slice(0, -self.window_size),
583
+ slice(-self.window_size, -self.shift_size),
584
+ slice(-self.shift_size, None),
585
+ )
586
+ w_slices = (
587
+ slice(0, -self.window_size),
588
+ slice(-self.window_size, -self.shift_size),
589
+ slice(-self.shift_size, None),
590
+ )
591
+ cnt = 0
592
+ for h in h_slices:
593
+ for w in w_slices:
594
+ img_mask[:, h, w, :] = cnt
595
+ cnt += 1
596
+
597
+ mask_windows = window_partition(
598
+ img_mask, self.window_size
599
+ ) # nW, window_size, window_size, 1
600
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
601
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
602
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
603
+ attn_mask == 0, float(0.0)
604
+ )
605
+
606
+ for blk in self.blocks:
607
+ blk.H, blk.W = H, W
608
+ if self.use_checkpoint:
609
+ x = checkpoint.checkpoint(blk, x, attn_mask)
610
+ else:
611
+ x = blk(x, attn_mask)
612
+ if self.downsample is not None:
613
+ x_down = self.downsample(x, H, W)
614
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
615
+ return x, H, W, x_down, Wh, Ww
616
+ else:
617
+ return x, H, W, x, H, W
618
+
619
+
620
+ class PatchEmbed(nn.Module):
621
+ """Image to Patch Embedding
622
+
623
+ Args:
624
+ patch_size (int): Patch token size. Default: 4.
625
+ in_chans (int): Number of input image channels. Default: 3.
626
+ embed_dim (int): Number of linear projection output channels. Default: 96.
627
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
628
+ """
629
+
630
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
631
+ super().__init__()
632
+ patch_size = to_2tuple(patch_size)
633
+ self.patch_size = patch_size
634
+
635
+ self.in_chans = in_chans
636
+ self.embed_dim = embed_dim
637
+
638
+ self.proj = nn.Conv2d(
639
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
640
+ )
641
+ if norm_layer is not None:
642
+ self.norm = norm_layer(embed_dim)
643
+ else:
644
+ self.norm = None
645
+
646
+ def forward(self, x):
647
+ """Forward function."""
648
+ # padding
649
+ _, _, H, W = x.size()
650
+ if W % self.patch_size[1] != 0:
651
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
652
+ if H % self.patch_size[0] != 0:
653
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
654
+
655
+ x = self.proj(x) # B C Wh Ww
656
+ if self.norm is not None:
657
+ Wh, Ww = x.size(2), x.size(3)
658
+ x = x.flatten(2).transpose(1, 2)
659
+ x = self.norm(x)
660
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
661
+
662
+ return x
663
+
664
+
665
+ class SwinTransformer(nn.Module):
666
+ """Swin Transformer backbone.
667
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
668
+ https://arxiv.org/pdf/2103.14030
669
+
670
+ Args:
671
+ pretrain_img_size (int): Input image size for training the pretrained model,
672
+ used in absolute postion embedding. Default 224.
673
+ patch_size (int | tuple(int)): Patch size. Default: 4.
674
+ in_chans (int): Number of input image channels. Default: 3.
675
+ embed_dim (int): Number of linear projection output channels. Default: 96.
676
+ depths (tuple[int]): Depths of each Swin Transformer stage.
677
+ num_heads (tuple[int]): Number of attention head of each stage.
678
+ window_size (int): Window size. Default: 7.
679
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
680
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
681
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
682
+ drop_rate (float): Dropout rate.
683
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
684
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
685
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
686
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
687
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
688
+ out_indices (Sequence[int]): Output from which stages.
689
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
690
+ -1 means not freezing any parameters.
691
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
692
+ """
693
+
694
+ def __init__(
695
+ self,
696
+ pretrain_img_size=224,
697
+ patch_size=4,
698
+ in_chans=3,
699
+ embed_dim=96,
700
+ depths=[2, 2, 6, 2],
701
+ num_heads=[3, 6, 12, 24],
702
+ window_size=7,
703
+ mlp_ratio=4.0,
704
+ qkv_bias=True,
705
+ qk_scale=None,
706
+ drop_rate=0.0,
707
+ attn_drop_rate=0.0,
708
+ drop_path_rate=0.2,
709
+ norm_layer=nn.LayerNorm,
710
+ ape=False,
711
+ patch_norm=True,
712
+ out_indices=(0, 1, 2, 3),
713
+ frozen_stages=-1,
714
+ use_checkpoint=False,
715
+ ):
716
+ super().__init__()
717
+
718
+ self.pretrain_img_size = pretrain_img_size
719
+ self.num_layers = len(depths)
720
+ self.embed_dim = embed_dim
721
+ self.ape = ape
722
+ self.patch_norm = patch_norm
723
+ self.out_indices = out_indices
724
+ self.frozen_stages = frozen_stages
725
+
726
+ # split image into non-overlapping patches
727
+ self.patch_embed = PatchEmbed(
728
+ patch_size=patch_size,
729
+ in_chans=in_chans,
730
+ embed_dim=embed_dim,
731
+ norm_layer=norm_layer if self.patch_norm else None,
732
+ )
733
+
734
+ # absolute position embedding
735
+ if self.ape:
736
+ pretrain_img_size = to_2tuple(pretrain_img_size)
737
+ patch_size = to_2tuple(patch_size)
738
+ patches_resolution = [
739
+ pretrain_img_size[0] // patch_size[0],
740
+ pretrain_img_size[1] // patch_size[1],
741
+ ]
742
+
743
+ self.absolute_pos_embed = nn.Parameter(
744
+ torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
745
+ )
746
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
747
+
748
+ self.pos_drop = nn.Dropout(p=drop_rate)
749
+
750
+ # stochastic depth
751
+ dpr = [
752
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
753
+ ] # stochastic depth decay rule
754
+
755
+ # build layers
756
+ self.layers = nn.ModuleList()
757
+ for i_layer in range(self.num_layers):
758
+ layer = BasicLayer(
759
+ dim=int(embed_dim * 2**i_layer),
760
+ depth=depths[i_layer],
761
+ num_heads=num_heads[i_layer],
762
+ window_size=window_size,
763
+ mlp_ratio=mlp_ratio,
764
+ qkv_bias=qkv_bias,
765
+ qk_scale=qk_scale,
766
+ drop=drop_rate,
767
+ attn_drop=attn_drop_rate,
768
+ drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
769
+ norm_layer=norm_layer,
770
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
771
+ use_checkpoint=use_checkpoint,
772
+ )
773
+ self.layers.append(layer)
774
+
775
+ num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
776
+ self.num_features = num_features
777
+
778
+ # add a norm layer for each output
779
+ for i_layer in out_indices:
780
+ layer = norm_layer(num_features[i_layer])
781
+ layer_name = f"norm{i_layer}"
782
+ self.add_module(layer_name, layer)
783
+
784
+ self._freeze_stages()
785
+
786
+ def _freeze_stages(self):
787
+ if self.frozen_stages >= 0:
788
+ self.patch_embed.eval()
789
+ for param in self.patch_embed.parameters():
790
+ param.requires_grad = False
791
+
792
+ if self.frozen_stages >= 1 and self.ape:
793
+ self.absolute_pos_embed.requires_grad = False
794
+
795
+ if self.frozen_stages >= 2:
796
+ self.pos_drop.eval()
797
+ for i in range(0, self.frozen_stages - 1):
798
+ m = self.layers[i]
799
+ m.eval()
800
+ for param in m.parameters():
801
+ param.requires_grad = False
802
+
803
+ def init_weights(self, pretrained=None):
804
+ """Initialize the weights in backbone.
805
+
806
+ Args:
807
+ pretrained (str, optional): Path to pre-trained weights.
808
+ Defaults to None.
809
+ """
810
+
811
+ def _init_weights(m):
812
+ if isinstance(m, nn.Linear):
813
+ trunc_normal_(m.weight, std=0.02)
814
+ if isinstance(m, nn.Linear) and m.bias is not None:
815
+ nn.init.constant_(m.bias, 0)
816
+ elif isinstance(m, nn.LayerNorm):
817
+ nn.init.constant_(m.bias, 0)
818
+ nn.init.constant_(m.weight, 1.0)
819
+
820
+ if isinstance(pretrained, str):
821
+ self.apply(_init_weights)
822
+ load_checkpoint(self, pretrained, strict=False)
823
+ elif pretrained is None:
824
+ self.apply(_init_weights)
825
+ else:
826
+ raise TypeError("pretrained must be a str or None")
827
+
828
+ def forward(self, x):
829
+ x = self.patch_embed(x)
830
+
831
+ Wh, Ww = x.size(2), x.size(3)
832
+ if self.ape:
833
+ # interpolate the position embedding to the corresponding size
834
+ absolute_pos_embed = F.interpolate(
835
+ self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
836
+ )
837
+ x = x + absolute_pos_embed # B Wh*Ww C
838
+
839
+ outs = [x.contiguous()]
840
+ x = x.flatten(2).transpose(1, 2)
841
+ x = self.pos_drop(x)
842
+ for i in range(self.num_layers):
843
+ layer = self.layers[i]
844
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
845
+
846
+ if i in self.out_indices:
847
+ norm_layer = getattr(self, f"norm{i}")
848
+ x_out = norm_layer(x_out)
849
+
850
+ out = (
851
+ x_out.view(-1, H, W, self.num_features[i])
852
+ .permute(0, 3, 1, 2)
853
+ .contiguous()
854
+ )
855
+ outs.append(out)
856
+
857
+ return tuple(outs)
858
+
859
+ def train(self, mode=True):
860
+ """Convert the model into training mode while keep layers freezed."""
861
+ super(SwinTransformer, self).train(mode)
862
+ self._freeze_stages()
863
+
864
+
865
+ def SwinB(pretrained=True):
866
+ model = SwinTransformer(
867
+ embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12
868
+ )
869
+ if pretrained is True:
870
+ state_dict_path = hf_hub_download(
871
+ repo_id="creative-graphic-design/MVANet-checkpoints",
872
+ filename="swin_base_patch4_window12_384_22kto1k.pth",
873
+ )
874
+ state_dict = torch.load(state_dict_path, map_location="cpu")
875
+ model.load_state_dict(state_dict["model"], strict=False)
876
+
877
+ return model
878
+
879
+
880
+ # ============================================================================
881
+ # Multi-field Cross Localization Module (MCLM)
882
+ # ============================================================================
883
+
884
+
885
+ class inf_MCLM(nn.Module):
886
+ def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
887
+ super(inf_MCLM, self).__init__()
888
+ self.attention = nn.ModuleList(
889
+ [
890
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
891
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
892
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
893
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
894
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
895
+ ]
896
+ )
897
+
898
+ self.linear1 = nn.Linear(d_model, d_model * 2)
899
+ self.linear2 = nn.Linear(d_model * 2, d_model)
900
+ self.linear3 = nn.Linear(d_model, d_model * 2)
901
+ self.linear4 = nn.Linear(d_model * 2, d_model)
902
+ self.norm1 = nn.LayerNorm(d_model)
903
+ self.norm2 = nn.LayerNorm(d_model)
904
+ self.dropout = nn.Dropout(0.1)
905
+ self.dropout1 = nn.Dropout(0.1)
906
+ self.dropout2 = nn.Dropout(0.1)
907
+ self.activation = get_activation_fn("relu")
908
+ self.pool_ratios = pool_ratios
909
+ self.p_poses = None
910
+ self.g_pos = None
911
+ self.positional_encoding = PositionEmbeddingSine(
912
+ num_pos_feats=d_model // 2, normalize=True
913
+ )
914
+
915
+ def forward(self, l, g):
916
+ """
917
+ l: 4,c,h,w
918
+ g: 1,c,h,w
919
+ """
920
+ b, c, h, w = l.size()
921
+ # 4,c,h,w -> 1,c,2h,2w
922
+ concated_locs = rearrange(l, "(hg wg b) c h w -> b c (hg h) (wg w)", hg=2, wg=2)
923
+ pools = []
924
+ p_poses_list = []
925
+ for pool_ratio in self.pool_ratios:
926
+ # b,c,h,w
927
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
928
+ pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
929
+ pools.append(rearrange(pool, "b c h w -> (h w) b c"))
930
+ pos_emb = self.positional_encoding(
931
+ pool.shape[0], pool.shape[2], pool.shape[3]
932
+ )
933
+ pos_emb = rearrange(pos_emb, "b c h w -> (h w) b c")
934
+ p_poses_list.append(pos_emb)
935
+ pools = torch.cat(pools, 0)
936
+ p_poses = torch.cat(p_poses_list, dim=0)
937
+ pos_emb = self.positional_encoding(g.shape[0], g.shape[2], g.shape[3])
938
+ g_pos = rearrange(pos_emb, "b c h w -> (h w) b c")
939
+
940
+ # attention between glb (q) & multisensory concated-locs (k,v)
941
+ g_hw_b_c = rearrange(g, "b c h w -> (h w) b c")
942
+ g_hw_b_c = g_hw_b_c + self.dropout1(
943
+ self.attention[0](g_hw_b_c + g_pos, pools + p_poses, pools)[0]
944
+ )
945
+ g_hw_b_c = self.norm1(g_hw_b_c)
946
+ g_hw_b_c = g_hw_b_c + self.dropout2(
947
+ self.linear2(self.dropout(self.activation(self.linear1(g_hw_b_c)).clone()))
948
+ )
949
+ g_hw_b_c = self.norm2(g_hw_b_c)
950
+
951
+ # attention between origin locs (q) & freashed glb (k,v)
952
+ l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
953
+ _g_hw_b_c = rearrange(g_hw_b_c, "(h w) b c -> h w b c", h=h, w=w)
954
+ _g_hw_b_c = rearrange(
955
+ _g_hw_b_c, "(ng h) (nw w) b c -> (h w) (ng nw b) c", ng=2, nw=2
956
+ )
957
+ outputs_re = []
958
+ for i, (_l, _g) in enumerate(
959
+ zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))
960
+ ):
961
+ outputs_re.append(self.attention[i + 1](_l, _g, _g)[0]) # (h w) 1 c
962
+ outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
963
+
964
+ l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
965
+ l_hw_b_c = self.norm1(l_hw_b_c)
966
+ l_hw_b_c = l_hw_b_c + self.dropout2(
967
+ self.linear4(self.dropout(self.activation(self.linear3(l_hw_b_c)).clone()))
968
+ )
969
+ l_hw_b_c = self.norm2(l_hw_b_c)
970
+
971
+ l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
972
+ return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
973
+
974
+
975
+ # ============================================================================
976
+ # Multi-crop Refinement Module (MCRM)
977
+ # ============================================================================
978
+
979
+
980
+ class inf_MCRM(nn.Module):
981
+ def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
982
+ super(inf_MCRM, self).__init__()
983
+ self.attention = nn.ModuleList(
984
+ [
985
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
986
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
987
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
988
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
989
+ ]
990
+ )
991
+
992
+ self.linear3 = nn.Linear(d_model, d_model * 2)
993
+ self.linear4 = nn.Linear(d_model * 2, d_model)
994
+ self.norm1 = nn.LayerNorm(d_model)
995
+ self.norm2 = nn.LayerNorm(d_model)
996
+ self.dropout = nn.Dropout(0.1)
997
+ self.dropout1 = nn.Dropout(0.1)
998
+ self.dropout2 = nn.Dropout(0.1)
999
+ self.sigmoid = nn.Sigmoid()
1000
+ self.activation = get_activation_fn("relu")
1001
+ self.sal_conv = nn.Conv2d(d_model, 1, 1)
1002
+ self.pool_ratios = pool_ratios
1003
+ self.positional_encoding = PositionEmbeddingSine(
1004
+ num_pos_feats=d_model // 2, normalize=True
1005
+ )
1006
+
1007
+ def forward(self, x):
1008
+ total_b, c, h, w = x.size()
1009
+ # Total batch is 5*batch_size (4 local + 1 global)
1010
+ batch_size = total_b // 5
1011
+
1012
+ # Split into local (4*batch_size) and global (batch_size)
1013
+ loc, glb = x.split([4 * batch_size, batch_size], dim=0)
1014
+ # loc: (4*batch_size, c, h, w), glb: (batch_size, c, h, w)
1015
+ patched_glb = rearrange(glb, "b c (hg h) (wg w) -> (hg wg b) c h w", hg=2, wg=2)
1016
+
1017
+ # generate token attention map
1018
+ token_attention_map = self.sigmoid(self.sal_conv(glb))
1019
+ token_attention_map = F.interpolate(
1020
+ token_attention_map, size=patches2image(loc).shape[-2:], mode="nearest"
1021
+ )
1022
+ loc = loc * rearrange(
1023
+ token_attention_map, "b c (hg h) (wg w) -> (hg wg b) c h w", hg=2, wg=2
1024
+ )
1025
+ pools = []
1026
+ for pool_ratio in self.pool_ratios:
1027
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
1028
+ pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
1029
+ pools.append(rearrange(pool, "nl c h w -> nl c (h w)"))
1030
+ # pools: (4*batch_size, c, nphw) -> (4*batch_size, nphw, 1, c)
1031
+ pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
1032
+ # Reshape to separate batch and patch dimensions: (4, batch_size, nphw, 1, c)
1033
+ # Note: image2patches outputs in order (hg wg b) where b changes fastest
1034
+ # So the order is: [p0_b0, p0_b1, ..., p1_b0, p1_b1, ..., p3_b0, p3_b1]
1035
+ pools = rearrange(pools, "(p b) nphw 1 c -> p b nphw 1 c", p=4, b=batch_size)
1036
+
1037
+ # loc_: (4*batch_size, hw, 1, c) -> (4, batch_size, hw, 1, c)
1038
+ loc_ = rearrange(loc, "nl c h w -> nl (h w) 1 c")
1039
+ loc_ = rearrange(loc_, "(p b) hw 1 c -> p b hw 1 c", p=4, b=batch_size)
1040
+
1041
+ # Apply attention for each of 4 patches (only 4 iterations, not batch_size!)
1042
+ # Each iteration processes all batch items simultaneously
1043
+ outputs = []
1044
+ for i in range(4): # Only 4 iterations regardless of batch_size!
1045
+ # Extract patch i across all batch items: (batch_size, hw, 1, c)
1046
+ q = loc_[i, :, :, :, :] # (b, hw, 1, c)
1047
+ v = pools[i, :, :, :, :] # (b, nphw, 1, c)
1048
+ k = v
1049
+
1050
+ # Reshape for MultiheadAttention: (seq, batch, dim)
1051
+ q = rearrange(q, "b hw 1 c -> hw b c")
1052
+ k = rearrange(k, "b nphw 1 c -> nphw b c")
1053
+ v = rearrange(v, "b nphw 1 c -> nphw b c")
1054
+
1055
+ # Apply attention (processes all batch_size items in parallel)
1056
+ attn_out = self.attention[i](q, k, v)[0] # (hw, b, c)
1057
+ outputs.append(attn_out)
1058
+
1059
+ # Concatenate outputs: list of 4 x (hw, b, c) -> (hw, p*b, c)
1060
+ # Interleave to match (p b) order: [p0_b0, p0_b1, ..., p1_b0, p1_b1, ...]
1061
+ outputs = torch.stack(outputs, dim=2) # (hw, b, 4, c)
1062
+ outputs = rearrange(outputs, "hw b p c -> hw (p b) c") # (hw, 4*b, c)
1063
+
1064
+ # Continue with existing operations using batch_size
1065
+ src = loc.view(4 * batch_size, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
1066
+ src = self.norm1(src)
1067
+ src = src + self.dropout2(
1068
+ self.linear4(self.dropout(self.activation(self.linear3(src)).clone()))
1069
+ )
1070
+ src = self.norm2(src)
1071
+
1072
+ src = src.permute(1, 2, 0).reshape(4 * batch_size, c, h, w) # freshed loc
1073
+ glb = glb + F.interpolate(
1074
+ patches2image(src), size=glb.shape[-2:], mode="nearest"
1075
+ ) # freshed glb
1076
+ return torch.cat((src, glb), 0)
1077
+
1078
+
1079
+ # ============================================================================
1080
+ # MVANet Model for Image Segmentation
1081
+ # ============================================================================
1082
+
1083
+
1084
+ class MVANetForImageSegmentation(PreTrainedModel):
1085
+ """
1086
+ MVANet Model for image segmentation.
1087
+
1088
+ This model is a direct reimplementation of inf_MVANet with transformers-compatible
1089
+ interface for semantic segmentation tasks.
1090
+
1091
+ Args:
1092
+ config (:class:`~mvanet.transformers.MVANetConfig`): Model configuration class with all the parameters of the model.
1093
+ Initializing with a config file does not load the weights associated with the model, only the configuration.
1094
+
1095
+ Example::\
1096
+
1097
+ >>> from transformers import AutoModel, AutoImageProcessor
1098
+ >>> from PIL import Image
1099
+
1100
+ >>> # Load model and processor
1101
+ >>> model = AutoModel.from_pretrained("creative-graphic-design/mvanet")
1102
+ >>> processor = AutoImageProcessor.from_pretrained("creative-graphic-design/mvanet")
1103
+
1104
+ >>> # Load image
1105
+ >>> image = Image.open("image.png")
1106
+
1107
+ >>> # Preprocess
1108
+ >>> inputs = processor(image, return_tensors="pt")
1109
+
1110
+ >>> # Forward pass
1111
+ >>> outputs = model(**inputs)
1112
+
1113
+ >>> # Post-process
1114
+ >>> masks = processor.post_process_semantic_segmentation(
1115
+ ... outputs, target_sizes=[image.size[::-1]]
1116
+ ... )
1117
+ """
1118
+
1119
+ config_class = MVANetConfig
1120
+ base_model_prefix = "mvanet"
1121
+ main_input_name = "pixel_values"
1122
+ supports_gradient_checkpointing = False
1123
+ _no_split_modules = []
1124
+
1125
+ def __init__(self, config: MVANetConfig):
1126
+ super().__init__(config)
1127
+ self.config = config
1128
+
1129
+ emb_dim = config.embedding_dim
1130
+
1131
+ # Backbone: Swin Transformer
1132
+ self.backbone = SwinB(pretrained=config.backbone_pretrained)
1133
+
1134
+ # Feature projection layers - use config values
1135
+ out_channels = config.backbone_out_channels
1136
+ self.output5 = make_cbr(out_channels[4], emb_dim) # 1024 -> 128
1137
+ self.output4 = make_cbr(out_channels[3], emb_dim) # 512 -> 128
1138
+ self.output3 = make_cbr(out_channels[2], emb_dim) # 256 -> 128
1139
+ self.output2 = make_cbr(out_channels[1], emb_dim) # 128 -> 128
1140
+ self.output1 = make_cbr(out_channels[0], emb_dim) # 128 -> 128
1141
+
1142
+ # Multi-field Cross Localization Module
1143
+ self.multifieldcrossatt = inf_MCLM(
1144
+ emb_dim, config.mclm_num_heads, config.mclm_pool_ratios
1145
+ )
1146
+
1147
+ # Convolution blocks for decoder
1148
+ self.conv1 = make_cbr(emb_dim, emb_dim)
1149
+ self.conv2 = make_cbr(emb_dim, emb_dim)
1150
+ self.conv3 = make_cbr(emb_dim, emb_dim)
1151
+ self.conv4 = make_cbr(emb_dim, emb_dim)
1152
+
1153
+ # Multi-crop Refinement Module decoder blocks
1154
+ self.dec_blk1 = inf_MCRM(
1155
+ emb_dim, config.mcrm_num_heads, config.mcrm_pool_ratios
1156
+ )
1157
+ self.dec_blk2 = inf_MCRM(
1158
+ emb_dim, config.mcrm_num_heads, config.mcrm_pool_ratios
1159
+ )
1160
+ self.dec_blk3 = inf_MCRM(
1161
+ emb_dim, config.mcrm_num_heads, config.mcrm_pool_ratios
1162
+ )
1163
+ self.dec_blk4 = inf_MCRM(
1164
+ emb_dim, config.mcrm_num_heads, config.mcrm_pool_ratios
1165
+ )
1166
+
1167
+ # Instance mask head - use config value
1168
+ hidden_dim = config.insmask_hidden_dim
1169
+ self.insmask_head = nn.Sequential(
1170
+ nn.Conv2d(emb_dim, hidden_dim, kernel_size=3, padding=1),
1171
+ nn.BatchNorm2d(hidden_dim),
1172
+ nn.PReLU(),
1173
+ nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
1174
+ nn.BatchNorm2d(hidden_dim),
1175
+ nn.PReLU(),
1176
+ nn.Conv2d(hidden_dim, emb_dim, kernel_size=3, padding=1),
1177
+ )
1178
+
1179
+ # Shallow feature extraction - use config value
1180
+ self.shallow = nn.Sequential(
1181
+ nn.Conv2d(config.num_channels, emb_dim, kernel_size=3, padding=1)
1182
+ )
1183
+
1184
+ # Upsampling layers
1185
+ self.upsample1 = make_cbg(emb_dim, emb_dim)
1186
+ self.upsample2 = make_cbg(emb_dim, emb_dim)
1187
+
1188
+ # Final output layer - use config value
1189
+ self.output = nn.Sequential(
1190
+ nn.Conv2d(emb_dim, config.num_labels, kernel_size=3, padding=1)
1191
+ )
1192
+
1193
+ # Set inplace operations for ReLU and Dropout
1194
+ for m in self.modules():
1195
+ if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout):
1196
+ m.inplace = True
1197
+
1198
+ # Initialize weights and apply final processing
1199
+ self.post_init()
1200
+
1201
+ def forward(
1202
+ self,
1203
+ pixel_values: torch.FloatTensor,
1204
+ labels: Optional[torch.LongTensor] = None,
1205
+ output_hidden_states: Optional[bool] = None,
1206
+ return_dict: Optional[bool] = None,
1207
+ **kwargs,
1208
+ ) -> Union[Tuple, SemanticSegmenterOutput]:
1209
+ """
1210
+ Forward pass of the model.
1211
+
1212
+ Args:
1213
+ pixel_values (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_channels, height, width)`):
1214
+ Pixel values. Pixel values can be obtained using :class:`~mvanet.transformers.MVANetImageProcessor`.
1215
+ See :meth:`~mvanet.transformers.MVANetImageProcessor.preprocess` for details.
1216
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, height, width)`, `optional`):
1217
+ Ground truth semantic segmentation maps for computing the loss.
1218
+ output_hidden_states (:obj:`bool`, `optional`):
1219
+ Whether or not to return the hidden states of all layers. Currently not supported.
1220
+ return_dict (:obj:`bool`, `optional`):
1221
+ Whether or not to return a :class:`~transformers.modeling_outputs.SemanticSegmenterOutput` instead of
1222
+ a plain tuple.
1223
+
1224
+ Returns:
1225
+ :class:`~transformers.modeling_outputs.SemanticSegmenterOutput` or :obj:`tuple`:
1226
+ A :class:`~transformers.modeling_outputs.SemanticSegmenterOutput` (if ``return_dict=True`` is passed or
1227
+ when ``config.use_return_dict=True``) or a tuple of :obj:`torch.FloatTensor`.
1228
+
1229
+ Example::\
1230
+
1231
+ >>> from mvanet.transformers import MVANetForImageSegmentation, MVANetImageProcessor
1232
+ >>> import torch
1233
+ >>> from PIL import Image
1234
+
1235
+ >>> processor = MVANetImageProcessor()
1236
+ >>> model = MVANetForImageSegmentation.from_pretrained("creative-graphic-design/mvanet")
1237
+
1238
+ >>> image = Image.open("image.png")
1239
+ >>> inputs = processor(image, return_tensors="pt")
1240
+ >>> outputs = model(**inputs)
1241
+ >>> logits = outputs.logits # (batch_size, num_labels, height, width)
1242
+ """
1243
+ return_dict = (
1244
+ return_dict if return_dict is not None else self.config.use_return_dict
1245
+ )
1246
+
1247
+ batch_size = pixel_values.shape[0]
1248
+
1249
+ # Extract shallow features
1250
+ shallow = self.shallow(pixel_values)
1251
+
1252
+ # Create multi-view input: 4 local patches + 1 global view
1253
+ # Use config value for global view scale
1254
+ glb = rescale_to(
1255
+ pixel_values,
1256
+ scale_factor=self.config.global_view_scale,
1257
+ interpolation="bilinear",
1258
+ )
1259
+ loc = image2patches(pixel_values)
1260
+ input_views = torch.cat((loc, glb), dim=0)
1261
+
1262
+ # Extract features through backbone
1263
+ feature = self.backbone(input_views)
1264
+
1265
+ # Project features to embedding dimension
1266
+ e5 = self.output5(feature[4]) # (batch*5, 128, 16, 16)
1267
+ e4 = self.output4(feature[3]) # (batch*5, 128, 32, 32)
1268
+ e3 = self.output3(feature[2]) # (batch*5, 128, 64, 64)
1269
+ e2 = self.output2(feature[1]) # (batch*5, 128, 128, 128)
1270
+ e1 = self.output1(feature[0]) # (batch*5, 128, 128, 128)
1271
+
1272
+ # Split local and global features at deepest level
1273
+ # Use config value for number of patches
1274
+ loc_e5, glb_e5 = e5.split(
1275
+ [batch_size * self.config.num_patches, batch_size], dim=0
1276
+ )
1277
+
1278
+ # Apply multi-field cross attention
1279
+ e5_cat = self.multifieldcrossatt(loc_e5, glb_e5) # (batch*5, 128, 16, 16)
1280
+
1281
+ # Decode through MCRM blocks with skip connections
1282
+ e4 = self.conv4(self.dec_blk4(e4 + resize_as(e5_cat, e4)))
1283
+ e3 = self.conv3(self.dec_blk3(e3 + resize_as(e4, e3)))
1284
+ e2 = self.conv2(self.dec_blk2(e2 + resize_as(e3, e2)))
1285
+ e1 = self.conv1(self.dec_blk1(e1 + resize_as(e2, e1)))
1286
+
1287
+ # Split local and global features
1288
+ # Use config value for number of patches
1289
+ loc_e1, glb_e1 = e1.split(
1290
+ [batch_size * self.config.num_patches, batch_size], dim=0
1291
+ )
1292
+
1293
+ # Merge local patches back to image
1294
+ output1_cat = patches2image(loc_e1)
1295
+
1296
+ # Add global features
1297
+ output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
1298
+
1299
+ # Apply instance mask head
1300
+ final_output = self.insmask_head(output1_cat)
1301
+
1302
+ # Merge shallow features
1303
+ final_output = final_output + resize_as(shallow, final_output)
1304
+ final_output = self.upsample1(rescale_to(final_output))
1305
+ final_output = rescale_to(final_output + resize_as(shallow, final_output))
1306
+ final_output = self.upsample2(final_output)
1307
+
1308
+ # Final output (logits before sigmoid)
1309
+ logits = self.output(final_output)
1310
+
1311
+ loss = None
1312
+ if labels is not None:
1313
+ # Compute binary cross-entropy loss with logits
1314
+ # labels should be float with values in [0, 1]
1315
+ loss_fct = nn.BCEWithLogitsLoss()
1316
+ # Ensure labels have the same shape as logits
1317
+ if labels.dim() == 3:
1318
+ # (B, H, W) -> (B, 1, H, W)
1319
+ labels = labels.unsqueeze(1)
1320
+ loss = loss_fct(logits, labels.float())
1321
+
1322
+ if not return_dict:
1323
+ output = (logits,)
1324
+ return ((loss,) + output) if loss is not None else output
1325
+
1326
+ return SemanticSegmenterOutput(
1327
+ loss=loss,
1328
+ logits=logits,
1329
+ hidden_states=None,
1330
+ attentions=None,
1331
+ )
1332
+
1333
+ def _init_weights(self, module):
1334
+ """
1335
+ Initialize weights.
1336
+
1337
+ The backbone (SwinB) and other modules handle their own weight initialization,
1338
+ so we don't need to do anything here.
1339
+ """
1340
+ pass