BiliSakura commited on
Commit
bc2f004
·
verified ·
1 Parent(s): fd2c760

Add files using upload-large-folder tool

Browse files
__pycache__/modeling_rsp.cpython-312.pyc ADDED
Binary file (9.55 kB). View file
 
__pycache__/modular_swin.cpython-312.pyc ADDED
Binary file (33.4 kB). View file
 
config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "ape": false,
3
+ "architectures": [
4
+ "RSPSwinForImageClassification"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_rsp.RSPSwinConfig",
8
+ "AutoModelForImageClassification": "modeling_rsp.RSPSwinForImageClassification"
9
+ },
10
+ "depths": [
11
+ 2,
12
+ 2,
13
+ 6,
14
+ 2
15
+ ],
16
+ "embed_dim": 96,
17
+ "image_size": 224,
18
+ "mlp_ratio": 4.0,
19
+ "model_type": "rsp_swin",
20
+ "num_channels": 3,
21
+ "num_heads": [
22
+ 3,
23
+ 6,
24
+ 12,
25
+ 24
26
+ ],
27
+ "num_labels": 51,
28
+ "patch_norm": true,
29
+ "patch_size": 4,
30
+ "qkv_bias": true,
31
+ "source_checkpoint": "/mnt/data/projects/model_hubs/raw/rsp-swin-t-ckpt.pth",
32
+ "window_size": 7
33
+ }
configuration_rsp.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration classes for RSP models compatible with transformers"""
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+
6
+ class RSPResNetConfig(PretrainedConfig):
7
+ """Configuration for RSP ResNet models"""
8
+
9
+ model_type = "rsp_resnet"
10
+
11
+ def __init__(
12
+ self,
13
+ block="Bottleneck",
14
+ layers=[3, 4, 6, 3],
15
+ image_size=224,
16
+ num_channels=3,
17
+ num_labels=51,
18
+ **kwargs
19
+ ):
20
+ super().__init__(**kwargs)
21
+ self.block = block
22
+ self.layers = layers
23
+ self.image_size = image_size
24
+ self.num_channels = num_channels
25
+ self.num_labels = num_labels
26
+
27
+
28
+ class RSPSwinConfig(PretrainedConfig):
29
+ """Configuration for RSP Swin Transformer models"""
30
+
31
+ model_type = "rsp_swin"
32
+
33
+ def __init__(
34
+ self,
35
+ image_size=224,
36
+ patch_size=4,
37
+ num_channels=3,
38
+ embed_dim=96,
39
+ depths=[2, 2, 6, 2],
40
+ num_heads=[3, 6, 12, 24],
41
+ window_size=7,
42
+ mlp_ratio=4.0,
43
+ qkv_bias=True,
44
+ ape=False,
45
+ patch_norm=True,
46
+ num_labels=51,
47
+ **kwargs
48
+ ):
49
+ super().__init__(**kwargs)
50
+ self.image_size = image_size
51
+ self.patch_size = patch_size
52
+ self.num_channels = num_channels
53
+ self.embed_dim = embed_dim
54
+ self.depths = depths
55
+ self.num_heads = num_heads
56
+ self.window_size = window_size
57
+ self.mlp_ratio = mlp_ratio
58
+ self.qkv_bias = qkv_bias
59
+ self.ape = ape
60
+ self.patch_norm = patch_norm
61
+ self.num_labels = num_labels
62
+
63
+
64
+ class RSPViTAEConfig(PretrainedConfig):
65
+ """Configuration for RSP ViTAE models"""
66
+
67
+ model_type = "rsp_vitae"
68
+
69
+ def __init__(
70
+ self,
71
+ image_size=224,
72
+ num_channels=3,
73
+ stages=4,
74
+ embed_dims=[64, 64, 128, 256],
75
+ token_dims=[64, 128, 256, 512],
76
+ downsample_ratios=[4, 2, 2, 2],
77
+ NC_depth=[2, 2, 8, 2],
78
+ NC_heads=[1, 2, 4, 8],
79
+ RC_heads=[1, 1, 2, 4],
80
+ NC_group=[1, 32, 64, 128],
81
+ RC_group=[1, 16, 32, 64],
82
+ mlp_ratio=4.0,
83
+ num_labels=51,
84
+ **kwargs
85
+ ):
86
+ super().__init__(**kwargs)
87
+ self.image_size = image_size
88
+ self.num_channels = num_channels
89
+ self.stages = stages
90
+ self.embed_dims = embed_dims
91
+ self.token_dims = token_dims
92
+ self.downsample_ratios = downsample_ratios
93
+ self.NC_depth = NC_depth
94
+ self.NC_heads = NC_heads
95
+ self.RC_heads = RC_heads
96
+ self.NC_group = NC_group
97
+ self.RC_group = RC_group
98
+ self.mlp_ratio = mlp_ratio
99
+ self.num_labels = num_labels
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:175ce1246be89814fbd96872eb438367fd2099e01816fab1305d8563fdedf17c
3
+ size 111367532
modeling_rsp.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model classes for RSP models compatible with transformers"""
2
+
3
+ import sys
4
+ import os
5
+ from pathlib import Path
6
+ import torch
7
+ import torch.nn as nn
8
+ from transformers import PreTrainedModel
9
+ from safetensors.torch import load_file
10
+
11
+ # Import local modular model
12
+ from modular_swin import SwinTransformer
13
+
14
+ # Import other models from sibling directories if needed
15
+ _parent_dir = Path(__file__).parent.parent
16
+ import importlib.util
17
+
18
+ # Import ResNet from RSP-ResNet-50
19
+ _resnet_path = _parent_dir / "RSP-ResNet-50" / "modular_resnet.py"
20
+ if _resnet_path.exists():
21
+ spec = importlib.util.spec_from_file_location("modular_resnet_resnet", _resnet_path)
22
+ resnet_module = importlib.util.module_from_spec(spec)
23
+ spec.loader.exec_module(resnet_module)
24
+ ResNet = resnet_module.ResNet
25
+ Bottleneck = resnet_module.Bottleneck
26
+ else:
27
+ ResNet = None
28
+ Bottleneck = None
29
+
30
+ # Import ViTAE from RSP-ViTAEv2-S
31
+ _vitae_path = _parent_dir / "RSP-ViTAEv2-S" / "modular_vitae_window_noshift.py"
32
+ if _vitae_path.exists():
33
+ spec = importlib.util.spec_from_file_location("modular_vitae_window_noshift_vitae", _vitae_path)
34
+ vitae_module = importlib.util.module_from_spec(spec)
35
+ spec.loader.exec_module(vitae_module)
36
+ ViTAE_Window_NoShift_12_basic_stages4_14 = vitae_module.ViTAE_Window_NoShift_12_basic_stages4_14
37
+ else:
38
+ ViTAE_Window_NoShift_12_basic_stages4_14 = None
39
+
40
+ # Import configuration - handle both relative and absolute imports
41
+ try:
42
+ from configuration_rsp import RSPResNetConfig, RSPSwinConfig, RSPViTAEConfig
43
+ except ImportError:
44
+ # Fallback: import from same directory
45
+ import importlib.util
46
+ config_path = Path(__file__).parent / "configuration_rsp.py"
47
+ spec = importlib.util.spec_from_file_location("configuration_rsp", config_path)
48
+ config_module = importlib.util.module_from_spec(spec)
49
+ spec.loader.exec_module(config_module)
50
+ RSPResNetConfig = config_module.RSPResNetConfig
51
+ RSPSwinConfig = config_module.RSPSwinConfig
52
+ RSPViTAEConfig = config_module.RSPViTAEConfig
53
+
54
+
55
+ class RSPResNetForImageClassification(PreTrainedModel):
56
+ """RSP ResNet model for image classification"""
57
+
58
+ config_class = RSPResNetConfig
59
+
60
+ def __init__(self, config):
61
+ super().__init__(config)
62
+
63
+ # Build ResNet model from config
64
+ block = Bottleneck if config.block == "Bottleneck" else None
65
+ if block is None:
66
+ raise ValueError(f"Unsupported block type: {config.block}")
67
+
68
+ self.model = ResNet(
69
+ block=block,
70
+ layers=config.layers,
71
+ num_classes=config.num_labels
72
+ )
73
+
74
+ def forward(self, pixel_values=None, labels=None, **kwargs):
75
+ """
76
+ Args:
77
+ pixel_values: Input images (B, C, H, W)
78
+ labels: Optional labels for loss computation
79
+ """
80
+ if pixel_values is None:
81
+ raise ValueError("pixel_values must be provided")
82
+
83
+ logits = self.model(pixel_values)
84
+
85
+ loss = None
86
+ if labels is not None:
87
+ loss_fct = nn.CrossEntropyLoss()
88
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
89
+
90
+ return {
91
+ "logits": logits,
92
+ "loss": loss
93
+ } if loss is not None else {"logits": logits}
94
+
95
+ @classmethod
96
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
97
+ """Load model from pretrained checkpoint"""
98
+ config = kwargs.pop("config", None)
99
+ if config is None:
100
+ config = RSPResNetConfig.from_pretrained(pretrained_model_name_or_path)
101
+
102
+ model = cls(config)
103
+
104
+ # Load weights from safetensors
105
+ model_path = Path(pretrained_model_name_or_path)
106
+ safetensors_path = model_path / "model.safetensors"
107
+
108
+ if safetensors_path.exists():
109
+ state_dict = load_file(str(safetensors_path))
110
+ # Remove 'model.' prefix if present
111
+ state_dict_clean = {}
112
+ for k, v in state_dict.items():
113
+ if k.startswith("model."):
114
+ state_dict_clean[k[6:]] = v
115
+ else:
116
+ state_dict_clean[k] = v
117
+ model.model.load_state_dict(state_dict_clean, strict=False)
118
+ else:
119
+ raise FileNotFoundError(f"Model weights not found at {safetensors_path}")
120
+
121
+ return model
122
+
123
+
124
+ class RSPSwinForImageClassification(PreTrainedModel):
125
+ """RSP Swin Transformer model for image classification"""
126
+
127
+ config_class = RSPSwinConfig
128
+
129
+ def __init__(self, config):
130
+ super().__init__(config)
131
+
132
+ # Build SwinTransformer model from config
133
+ self.model = SwinTransformer(
134
+ img_size=config.image_size,
135
+ patch_size=config.patch_size,
136
+ in_chans=config.num_channels,
137
+ num_classes=config.num_labels,
138
+ embed_dim=config.embed_dim,
139
+ depths=config.depths,
140
+ num_heads=config.num_heads,
141
+ window_size=config.window_size,
142
+ mlp_ratio=config.mlp_ratio,
143
+ qkv_bias=config.qkv_bias,
144
+ ape=config.ape,
145
+ patch_norm=config.patch_norm,
146
+ )
147
+
148
+ def forward(self, pixel_values=None, labels=None, **kwargs):
149
+ """
150
+ Args:
151
+ pixel_values: Input images (B, C, H, W)
152
+ labels: Optional labels for loss computation
153
+ """
154
+ if pixel_values is None:
155
+ raise ValueError("pixel_values must be provided")
156
+
157
+ logits = self.model(pixel_values)
158
+
159
+ loss = None
160
+ if labels is not None:
161
+ loss_fct = nn.CrossEntropyLoss()
162
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
163
+
164
+ return {
165
+ "logits": logits,
166
+ "loss": loss
167
+ } if loss is not None else {"logits": logits}
168
+
169
+ @classmethod
170
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
171
+ """Load model from pretrained checkpoint"""
172
+ config = kwargs.pop("config", None)
173
+ if config is None:
174
+ config = RSPSwinConfig.from_pretrained(pretrained_model_name_or_path)
175
+
176
+ model = cls(config)
177
+
178
+ # Load weights from safetensors
179
+ model_path = Path(pretrained_model_name_or_path)
180
+ safetensors_path = model_path / "model.safetensors"
181
+
182
+ if safetensors_path.exists():
183
+ state_dict = load_file(str(safetensors_path))
184
+ # Remove 'model.' prefix if present
185
+ state_dict_clean = {}
186
+ for k, v in state_dict.items():
187
+ if k.startswith("model."):
188
+ state_dict_clean[k[6:]] = v
189
+ else:
190
+ state_dict_clean[k] = v
191
+ model.model.load_state_dict(state_dict_clean, strict=False)
192
+ else:
193
+ raise FileNotFoundError(f"Model weights not found at {safetensors_path}")
194
+
195
+ return model
196
+
197
+
198
+ class RSPViTAEForImageClassification(PreTrainedModel):
199
+ """RSP ViTAE model for image classification"""
200
+
201
+ config_class = RSPViTAEConfig
202
+
203
+ def __init__(self, config):
204
+ super().__init__(config)
205
+
206
+ # Build ViTAE model from config
207
+ # Note: ViTAE_Window_NoShift_12_basic_stages4_14 already sets most parameters as defaults:
208
+ # - stages=4, embed_dims=[64, 64, 128, 256], token_dims=[64, 128, 256, 512]
209
+ # - downsample_ratios=[4, 2, 2, 2], NC_depth=[2, 2, 8, 2], etc.
210
+ # We only pass parameters that need to be overridden (img_size, num_classes)
211
+ # The function accepts **kwargs, so we can pass window_size if needed
212
+ self.model = ViTAE_Window_NoShift_12_basic_stages4_14(
213
+ pretrained=False,
214
+ img_size=config.image_size,
215
+ num_classes=config.num_labels,
216
+ window_size=7,
217
+ )
218
+
219
+ def forward(self, pixel_values=None, labels=None, **kwargs):
220
+ """
221
+ Args:
222
+ pixel_values: Input images (B, C, H, W)
223
+ labels: Optional labels for loss computation
224
+ """
225
+ if pixel_values is None:
226
+ raise ValueError("pixel_values must be provided")
227
+
228
+ logits = self.model(pixel_values)
229
+
230
+ loss = None
231
+ if labels is not None:
232
+ loss_fct = nn.CrossEntropyLoss()
233
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
234
+
235
+ return {
236
+ "logits": logits,
237
+ "loss": loss
238
+ } if loss is not None else {"logits": logits}
239
+
240
+ @classmethod
241
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
242
+ """Load model from pretrained checkpoint"""
243
+ config = kwargs.pop("config", None)
244
+ if config is None:
245
+ config = RSPViTAEConfig.from_pretrained(pretrained_model_name_or_path)
246
+
247
+ model = cls(config)
248
+
249
+ # Load weights from safetensors
250
+ model_path = Path(pretrained_model_name_or_path)
251
+ safetensors_path = model_path / "model.safetensors"
252
+
253
+ if safetensors_path.exists():
254
+ state_dict = load_file(str(safetensors_path))
255
+ # Remove 'model.' prefix if present
256
+ state_dict_clean = {}
257
+ for k, v in state_dict.items():
258
+ if k.startswith("model."):
259
+ state_dict_clean[k[6:]] = v
260
+ else:
261
+ state_dict_clean[k] = v
262
+ model.model.load_state_dict(state_dict_clean, strict=False)
263
+ else:
264
+ raise FileNotFoundError(f"Model weights not found at {safetensors_path}")
265
+
266
+ return model
modular_swin.py ADDED
@@ -0,0 +1,650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Swin Transformer
3
+ # Copyright (c) 2021 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ze Liu
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.utils.checkpoint as checkpoint
11
+
12
+ # Use transformers equivalents instead of timm
13
+ from transformers.models.swin.modeling_swin import SwinDropPath as DropPath
14
+ from transformers import initialization as init
15
+
16
+ # Simple to_2tuple replacement (no external dependency needed)
17
+ def to_2tuple(x):
18
+ """Convert input to 2-tuple if not already a tuple."""
19
+ return x if isinstance(x, tuple) else (x, x)
20
+
21
+ # Use transformers trunc_normal_ for initialization
22
+ trunc_normal_ = init.trunc_normal_
23
+
24
+
25
+ class Mlp(nn.Module):
26
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
27
+ super().__init__()
28
+ out_features = out_features or in_features
29
+ hidden_features = hidden_features or in_features
30
+ self.fc1 = nn.Linear(in_features, hidden_features)
31
+ self.act = act_layer()
32
+ self.fc2 = nn.Linear(hidden_features, out_features)
33
+ self.drop = nn.Dropout(drop)
34
+
35
+ def forward(self, x):
36
+ x = self.fc1(x)
37
+ x = self.act(x)
38
+ x = self.drop(x)
39
+ x = self.fc2(x)
40
+ x = self.drop(x)
41
+ return x
42
+
43
+
44
+ def window_partition(x, window_size):
45
+ """
46
+ Args:
47
+ x: (B, H, W, C)
48
+ window_size (int): window size
49
+
50
+ Returns:
51
+ windows: (num_windows*B, window_size, window_size, C)
52
+ """
53
+ # 按windows分块
54
+ B, H, W, C = x.shape
55
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
56
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
57
+ return windows
58
+
59
+
60
+ def window_reverse(windows, window_size, H, W):
61
+ """
62
+ Args:
63
+ windows: (num_windows*B, window_size, window_size, C)
64
+ window_size (int): Window size
65
+ H (int): Height of image
66
+ W (int): Width of image
67
+
68
+ Returns:
69
+ x: (B, H, W, C)
70
+ """
71
+ # 将windows形式变成2D特征形式
72
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
73
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
74
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
75
+ return x
76
+
77
+
78
+ class WindowAttention(nn.Module):
79
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
80
+ It supports both of shifted and non-shifted window.
81
+
82
+ Args:
83
+ dim (int): Number of input channels.
84
+ window_size (tuple[int]): The height and width of the window.
85
+ num_heads (int): Number of attention heads.
86
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
87
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
88
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
89
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
90
+ """
91
+
92
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
93
+
94
+ super().__init__()
95
+ self.dim = dim
96
+ self.window_size = window_size # Wh, Ww
97
+ self.num_heads = num_heads
98
+ head_dim = dim // num_heads
99
+ self.scale = qk_scale or head_dim ** -0.5
100
+
101
+ # define a parameter table of relative position bias
102
+ # 相对位置表, 对于每个head, 大小为(2wh-1)*(2ww-1)
103
+ # 大于索引的最大值(2wh-2)*(2ww-1), 可以保证寻址
104
+ self.relative_position_bias_table = nn.Parameter(
105
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
106
+
107
+ # get pair-wise relative position index for each token inside the window
108
+ coords_h = torch.arange(self.window_size[0])
109
+ coords_w = torch.arange(self.window_size[1])
110
+ # coods: (2, windwows_size, windows_size)
111
+ # 在一个windows size的窗口内,生成每个位置的行列号
112
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
113
+ # 窗口内每个位置的行列号展平
114
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
115
+ # 窗口内每个位置的的行列号与窗口内所有位置的行列号的差值
116
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
117
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
118
+ # 差值范围从[1-ws, ws-1]->[0, 2ws-2]
119
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
120
+ relative_coords[:, :, 1] += self.window_size[1] - 1
121
+ # 行号差值范围[0, 2wh-2] -> [0, (2wh-2)(ww-1)]
122
+ # 列号差值范围仍然为[0, 2ww-2]
123
+ # 乘法操作是为了区分沿主对���线对称的像素,此类像素在将行列号转换成一维偏移时值相同
124
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
125
+ # 获得相对位置索引,尺寸为(ws*ws, ws*ws)
126
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
127
+ self.register_buffer("relative_position_index", relative_position_index)
128
+
129
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
130
+ self.attn_drop = nn.Dropout(attn_drop)
131
+ self.proj = nn.Linear(dim, dim)
132
+ self.proj_drop = nn.Dropout(proj_drop)
133
+
134
+ init.trunc_normal_(self.relative_position_bias_table, std=0.02)
135
+ self.softmax = nn.Softmax(dim=-1)
136
+
137
+ def forward(self, x, mask=None):
138
+ """
139
+ Args:
140
+ x: input features with shape of (num_windows*B, N, C)
141
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
142
+ """
143
+ B_, N, C = x.shape
144
+ # N是一个窗口内的token数
145
+ # qkv: B, N, 3, H, C/H->3, B, H, N, C/H
146
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
147
+ # q/k/v: B, H, N, C', 其中C' = C/H
148
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
149
+
150
+ q = q * self.scale
151
+ attn = (q @ k.transpose(-2, -1)) # B,H,N,N
152
+
153
+ # 索引表: (2wh-1)*(2ww-1), H
154
+ # 相对索引:wh*ww, wh*ww, 索引范围[0,(2wh-2)*(2ww-1)]
155
+ # 取出相对位置向量 (wh*ww, wh*ww, H)
156
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
157
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
158
+ # 相对位置向量尺寸变换 (wh*ww, wh*ww, H) -> (H, wh*ww, wh*ww)
159
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
160
+ attn = attn + relative_position_bias.unsqueeze(0)
161
+
162
+ if mask is not None:
163
+ nW = mask.shape[0]
164
+ # nw是窗口的数量
165
+ # 加了-100之后。softmax生成的权值很小,相当于这部分被忽略掉
166
+ # mask: nW, ws*ws, ws*ws -> 1, nW, 1, ws*ws, ws*ws
167
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
168
+ attn = attn.view(-1, self.num_heads, N, N)
169
+ attn = self.softmax(attn)
170
+ else:
171
+ attn = self.softmax(attn)
172
+
173
+ attn = self.attn_drop(attn)
174
+
175
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
176
+ x = self.proj(x)
177
+ x = self.proj_drop(x)
178
+ return x
179
+
180
+ def extra_repr(self) -> str:
181
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
182
+
183
+ def flops(self, N):
184
+ # calculate flops for 1 window with token length of N
185
+ flops = 0
186
+ # qkv = self.qkv(x)
187
+ flops += N * self.dim * 3 * self.dim
188
+ # attn = (q @ k.transpose(-2, -1))
189
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
190
+ # x = (attn @ v)
191
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
192
+ # x = self.proj(x)
193
+ flops += N * self.dim * self.dim
194
+ return flops
195
+
196
+
197
+ class SwinTransformerBlock(nn.Module):
198
+ r""" Swin Transformer Block.
199
+
200
+ Args:
201
+ dim (int): Number of input channels.
202
+ input_resolution (tuple[int]): Input resulotion.
203
+ num_heads (int): Number of attention heads.
204
+ window_size (int): Window size.
205
+ shift_size (int): Shift size for SW-MSA.
206
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
207
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
208
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
209
+ drop (float, optional): Dropout rate. Default: 0.0
210
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
211
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
212
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
213
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
214
+ """
215
+
216
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
217
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
218
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
219
+ super().__init__()
220
+ self.dim = dim
221
+ self.input_resolution = input_resolution
222
+ self.num_heads = num_heads
223
+ self.window_size = window_size
224
+ self.shift_size = shift_size # 图片平移的距离
225
+ self.mlp_ratio = mlp_ratio
226
+ # 特征大小小于窗口,就不分窗口了,也不平移图片了
227
+ if min(self.input_resolution) <= self.window_size:
228
+ # if window size is larger than input resolution, we don't partition windows
229
+ self.shift_size = 0
230
+ self.window_size = min(self.input_resolution)
231
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
232
+
233
+ self.norm1 = norm_layer(dim)
234
+ self.attn = WindowAttention(
235
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
236
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
237
+
238
+ # 每个block随机drop
239
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
240
+ self.norm2 = norm_layer(dim)
241
+ mlp_hidden_dim = int(dim * mlp_ratio)
242
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
243
+
244
+ if self.shift_size > 0:
245
+ # calculate attention mask for SW-MSA
246
+ H, W = self.input_resolution
247
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
248
+ # 将图像分成9份,并分别打上标号, 相当于模拟出要SW-MSA算attention值时候的情形
249
+ # 因为图片循环平移以后,有的窗口包括了不连续的图片,所以得上mask
250
+ h_slices = (slice(0, -self.window_size),
251
+ slice(-self.window_size, -self.shift_size),
252
+ slice(-self.shift_size, None))
253
+ w_slices = (slice(0, -self.window_size),
254
+ slice(-self.window_size, -self.shift_size),
255
+ slice(-self.shift_size, None))
256
+ cnt = 0
257
+ for h in h_slices:
258
+ for w in w_slices:
259
+ img_mask[:, h, w, :] = cnt
260
+ cnt += 1
261
+ # mask逐windows设置,N,ws, ws,N是windows的个数
262
+ # 1, H, W, 1 -> 1, nh,ws, nw, ws, 1 -> 1*nh*nw,ws, ws, 1
263
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
264
+ # 1*nh*nw,ws, ws, 1 -> 1*nh*nw,ws*ws, 每个窗口一个掩模
265
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
266
+ # N, 1, ws*ws - N, ws*ws, 1
267
+ # N,ws*ws, ws*ws 能得到窗口中每个位置与其它位置标号的差,只要不是0都不能算到attention中
268
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
269
+ # 通过设置合理的mask,让Shifted Window Attention在与Window Attention相同的窗口个数下,达到等价的计算结果
270
+ # 在计算Attention的时候,让具有相同indexQK进行计算,而忽略不同indexQK计算结果。
271
+ # 因为图片滚动出现了无关区域,无关区域就是标号不为0的区域
272
+ # 那么无关区域标记为-100,attention的时候就不用关注了
273
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
274
+ else:
275
+ attn_mask = None
276
+
277
+ self.register_buffer("attn_mask", attn_mask)
278
+
279
+ def forward(self, x):
280
+ H, W = self.input_resolution
281
+ B, L, C = x.shape
282
+ assert L == H * W, "input feature has wrong size"
283
+
284
+ shortcut = x
285
+ x = self.norm1(x)
286
+ x = x.view(B, H, W, C)
287
+
288
+ # cyclic shift
289
+ if self.shift_size > 0:
290
+ # SWMSA并非平移窗口,而是滚动平移图片,这样就实现了不同区域的交互
291
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
292
+ else:
293
+ shifted_x = x
294
+
295
+ # partition windows
296
+ # B, H, W, C -> B, nh,ws, nw, ws, C -> B*nh*nw,ws, ws, C
297
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
298
+ # B*nh*nw,ws, ws, C -> B*nh*nw,l, C
299
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
300
+
301
+ # W-MSA/SW-MSA
302
+ # 在windows内部求self attention
303
+ # B*nh*nw,l, C -> B*nh*nw,l, C
304
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
305
+
306
+ # merge windows
307
+ # B*nh*nw,l, C -> B*nh*nw,ws, ws, C
308
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
309
+ # B*nh*nw,ws, ws, C -> B,nh*ws, nw*ws, C
310
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
311
+
312
+ # reverse cyclic shift
313
+ #
314
+ if self.shift_size > 0:
315
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
316
+ else:
317
+ x = shifted_x
318
+
319
+ # 2D -> 1D特征
320
+ x = x.view(B, H * W, C)
321
+
322
+ # FFN
323
+ x = shortcut + self.drop_path(x)
324
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
325
+
326
+ return x
327
+
328
+ def extra_repr(self) -> str:
329
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
330
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
331
+
332
+ def flops(self):
333
+ flops = 0
334
+ H, W = self.input_resolution
335
+ # norm1
336
+ flops += self.dim * H * W
337
+ # W-MSA/SW-MSA
338
+ nW = H * W / self.window_size / self.window_size
339
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
340
+ # mlp
341
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
342
+ # norm2
343
+ flops += self.dim * H * W
344
+ return flops
345
+
346
+
347
+ class PatchMerging(nn.Module):
348
+ r""" Patch Merging Layer.
349
+
350
+ Args:
351
+ input_resolution (tuple[int]): Resolution of input feature.
352
+ dim (int): Number of input channels.
353
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
354
+ """
355
+
356
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
357
+ super().__init__()
358
+ self.input_resolution = input_resolution
359
+ self.dim = dim
360
+ # 按照原文,由4C变2C
361
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
362
+ self.norm = norm_layer(4 * dim)
363
+
364
+ def forward(self, x):
365
+ """
366
+ x: B, H*W, C
367
+ """
368
+ H, W = self.input_resolution
369
+ B, L, C = x.shape
370
+ assert L == H * W, "input feature has wrong size"
371
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
372
+
373
+ x = x.view(B, H, W, C)
374
+
375
+ # 把4个patch的token合成一个,pixel-shuffle的思想
376
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
377
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
378
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
379
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
380
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
381
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
382
+
383
+ x = self.norm(x)
384
+ x = self.reduction(x)
385
+
386
+ return x
387
+
388
+ def extra_repr(self) -> str:
389
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
390
+
391
+ def flops(self):
392
+ H, W = self.input_resolution
393
+ flops = H * W * self.dim
394
+ flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
395
+ return flops
396
+
397
+
398
+ class BasicLayer(nn.Module):
399
+ """ A basic Swin Transformer layer for one stage.
400
+
401
+ Args:
402
+ dim (int): Number of input channels.
403
+ input_resolution (tuple[int]): Input resolution.
404
+ depth (int): Number of blocks.
405
+ num_heads (int): Number of attention heads.
406
+ window_size (int): Local window size.
407
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
408
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
409
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
410
+ drop (float, optional): Dropout rate. Default: 0.0
411
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
412
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
413
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
414
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
415
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
416
+ """
417
+
418
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
419
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
420
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
421
+
422
+ super().__init__()
423
+ self.dim = dim
424
+ self.input_resolution = input_resolution
425
+ self.depth = depth
426
+ self.use_checkpoint = use_checkpoint
427
+
428
+ # build blocks
429
+ self.blocks = nn.ModuleList([
430
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
431
+ num_heads=num_heads, window_size=window_size,
432
+ # idx为偶数的block采用WMSA, idx为奇数的block采用SWMSA
433
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
434
+ mlp_ratio=mlp_ratio,
435
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
436
+ drop=drop, attn_drop=attn_drop,
437
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
438
+ norm_layer=norm_layer)
439
+ for i in range(depth)])
440
+
441
+ # patch merging layer
442
+ if downsample is not None:
443
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
444
+ else:
445
+ self.downsample = None
446
+
447
+ def forward(self, x):
448
+ for blk in self.blocks:
449
+ if self.use_checkpoint:
450
+ x = checkpoint.checkpoint(blk, x)
451
+ else:
452
+ x = blk(x)
453
+ if self.downsample is not None:
454
+ x = self.downsample(x)
455
+ return x
456
+
457
+ def extra_repr(self) -> str:
458
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
459
+
460
+ def flops(self):
461
+ flops = 0
462
+ for blk in self.blocks:
463
+ flops += blk.flops()
464
+ if self.downsample is not None:
465
+ flops += self.downsample.flops()
466
+ return flops
467
+
468
+
469
+ class PatchEmbed(nn.Module):
470
+ r""" Image to Patch Embedding
471
+
472
+ Args:
473
+ img_size (int): Image size. Default: 224.
474
+ patch_size (int): Patch token size. Default: 4.
475
+ in_chans (int): Number of input image channels. Default: 3.
476
+ embed_dim (int): Number of linear projection output channels. Default: 96.
477
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
478
+ """
479
+
480
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
481
+ super().__init__()
482
+ img_size = to_2tuple(img_size)
483
+ patch_size = to_2tuple(patch_size)
484
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
485
+ self.img_size = img_size
486
+ self.patch_size = patch_size
487
+ self.patches_resolution = patches_resolution # embedding后的特征的分辨率
488
+ self.num_patches = patches_resolution[0] * patches_resolution[1] # embedding后的token数,其实就是分辨率乘积
489
+
490
+ self.in_chans = in_chans
491
+ self.embed_dim = embed_dim
492
+
493
+ # 直接通过大小与步长均为patch_size的卷积核实现patch的embedding
494
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
495
+ if norm_layer is not None:
496
+ self.norm = norm_layer(embed_dim)
497
+ else:
498
+ self.norm = None
499
+
500
+ def forward(self, x):
501
+ B, C, H, W = x.shape
502
+ # FIXME look at relaxing size constraints
503
+ assert H == self.img_size[0] and W == self.img_size[1], \
504
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
505
+ x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
506
+ if self.norm is not None:
507
+ x = self.norm(x)
508
+ return x
509
+
510
+ def flops(self):
511
+ Ho, Wo = self.patches_resolution
512
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
513
+ if self.norm is not None:
514
+ flops += Ho * Wo * self.embed_dim
515
+ return flops
516
+
517
+
518
+ class SwinTransformer(nn.Module):
519
+ r""" Swin Transformer
520
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
521
+ https://arxiv.org/pdf/2103.14030
522
+
523
+ Args:
524
+ img_size (int | tuple(int)): Input image size. Default 224
525
+ patch_size (int | tuple(int)): Patch size. Default: 4
526
+ in_chans (int): Number of input image channels. Default: 3
527
+ num_classes (int): Number of classes for classification head. Default: 1000
528
+ embed_dim (int): Patch embedding dimension. Default: 96
529
+ depths (tuple(int)): Depth of each Swin Transformer layer.
530
+ num_heads (tuple(int)): Number of attention heads in different layers.
531
+ window_size (int): Window size. Default: 7
532
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
533
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
534
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
535
+ drop_rate (float): Dropout rate. Default: 0
536
+ attn_drop_rate (float): Attention dropout rate. Default: 0
537
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
538
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
539
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
540
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
541
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
542
+ """
543
+
544
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
545
+ embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
546
+ window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
547
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
548
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
549
+ use_checkpoint=False, **kwargs):
550
+ super().__init__()
551
+
552
+ self.num_classes = num_classes
553
+ self.num_layers = len(depths)
554
+ self.embed_dim = embed_dim
555
+ self.ape = ape
556
+ self.patch_norm = patch_norm
557
+ self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
558
+ self.mlp_ratio = mlp_ratio
559
+
560
+ # split image into non-overlapping patches
561
+ self.patch_embed = PatchEmbed(
562
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
563
+ norm_layer=norm_layer if self.patch_norm else None)
564
+ num_patches = self.patch_embed.num_patches
565
+ patches_resolution = self.patch_embed.patches_resolution
566
+ self.patches_resolution = patches_resolution
567
+
568
+ # absolute position embedding
569
+ if self.ape:
570
+ # 1, Ph*Pw, C
571
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
572
+ init.trunc_normal_(self.absolute_pos_embed, std=0.02)
573
+
574
+ self.pos_drop = nn.Dropout(p=drop_rate)
575
+
576
+ # stochastic depth
577
+ # 生成长度为block总数,从0到 drop_path_rate的等差数列
578
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
579
+
580
+ # build layers
581
+ self.layers = nn.ModuleList()
582
+ for i_layer in range(self.num_layers):
583
+ layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
584
+ input_resolution=(patches_resolution[0] // (2 ** i_layer),
585
+ patches_resolution[1] // (2 ** i_layer)),
586
+ depth=depths[i_layer],
587
+ num_heads=num_heads[i_layer],
588
+ window_size=window_size,
589
+ mlp_ratio=self.mlp_ratio,
590
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
591
+ drop=drop_rate, attn_drop=attn_drop_rate,
592
+ # 每个stage各个block的drop都不一样,越往后drop越多
593
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
594
+ norm_layer=norm_layer,
595
+ # 通过合并patch的embedding来实现降采样
596
+ # 前三个stage后边有,最后一个stage后边没有
597
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
598
+ use_checkpoint=use_checkpoint)
599
+ self.layers.append(layer)
600
+
601
+ self.norm = norm_layer(self.num_features)
602
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
603
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
604
+
605
+ self.apply(self._init_weights)
606
+
607
+ def _init_weights(self, m):
608
+ if isinstance(m, nn.Linear):
609
+ init.trunc_normal_(m.weight, std=0.02)
610
+ if isinstance(m, nn.Linear) and m.bias is not None:
611
+ init.constant_(m.bias, 0)
612
+ elif isinstance(m, nn.LayerNorm):
613
+ init.constant_(m.bias, 0)
614
+ init.constant_(m.weight, 1.0)
615
+
616
+ @torch.jit.ignore
617
+ def no_weight_decay(self):
618
+ return {'absolute_pos_embed'}
619
+
620
+ @torch.jit.ignore
621
+ def no_weight_decay_keywords(self):
622
+ return {'relative_position_bias_table'}
623
+
624
+ def forward_features(self, x):
625
+ x = self.patch_embed(x) # B Ph*Pw C
626
+ if self.ape:
627
+ x = x + self.absolute_pos_embed
628
+ x = self.pos_drop(x)
629
+
630
+ for layer in self.layers:
631
+ x = layer(x)
632
+
633
+ x = self.norm(x) # B L C
634
+ x = self.avgpool(x.transpose(1, 2)) # B C 1
635
+ x = torch.flatten(x, 1)
636
+ return x
637
+
638
+ def forward(self, x):
639
+ x = self.forward_features(x)
640
+ x = self.head(x)
641
+ return x
642
+
643
+ def flops(self):
644
+ flops = 0
645
+ flops += self.patch_embed.flops()
646
+ for i, layer in enumerate(self.layers):
647
+ flops += layer.flops()
648
+ flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
649
+ flops += self.num_features * self.num_classes
650
+ return flops