abdulbasitdev commited on
Commit
21598fa
·
verified ·
1 Parent(s): 1756571

Upload MALUNet CVC-ClinicDB weights

Browse files
Files changed (5) hide show
  1. README.md +71 -0
  2. best.pth +3 -0
  3. infer.py +159 -0
  4. models/__init__.py +0 -0
  5. models/malunet.py +317 -0
README.md ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - image-segmentation
5
+ - medical-imaging
6
+ - polyp-segmentation
7
+ - pytorch
8
+ - malunet
9
+ datasets:
10
+ - cvc-clinicdb
11
+ library_name: pytorch
12
+ pipeline_tag: image-segmentation
13
+ ---
14
+
15
+ # MALUNet · CVC-ClinicDB (Polyp Segmentation)
16
+
17
+ Lightweight U-shape segmentation network adapted from
18
+ [jcruan519/MALUNet](https://github.com/jcruan519/MALUNet) and trained on
19
+ [CVC-ClinicDB](https://www.kaggle.com/datasets/balraj98/cvcclinicdb) for
20
+ binary polyp segmentation in colonoscopy frames.
21
+
22
+ ## Model
23
+
24
+ - Architecture: MALUNet (DGA + IEA + CAB + SAB)
25
+ - Channels: `[8, 16, 24, 32, 48, 64]`, `split_att="fc"`, `bridge=True`
26
+ - Input: RGB, 256×256
27
+ - Output: single-channel sigmoid mask (1 = polyp)
28
+ - Parameters: ~0.18 M
29
+
30
+ ## Training
31
+
32
+ - Dataset: CVC-ClinicDB (612 paired image/mask frames)
33
+ - Split: 80% train / 20% val (seeded by filename, `seed=42`)
34
+ - Loss: BCE + Dice
35
+ - Optimizer: AdamW, `lr=1e-3`, `weight_decay=1e-2`
36
+ - Schedule: CosineAnnealingLR, `T_max=50`, `eta_min=1e-5`
37
+ - Augmentations: random h/v flip, random rotation
38
+ - Epochs: 150
39
+
40
+ ## Usage
41
+
42
+ ```python
43
+ import torch
44
+ from huggingface_hub import hf_hub_download
45
+ from infer import load_model, predict_mask # infer.py from this repo
46
+ from PIL import Image
47
+
48
+ model = load_model("YOUR_USERNAME/malunet-cvc")
49
+ mask = predict_mask(model, Image.open("polyp.png"))
50
+ Image.fromarray(mask).save("mask.png")
51
+ ```
52
+
53
+ `infer.py` and `models/malunet.py` are bundled in this repo so you can
54
+ also clone it and run inference without the original training code.
55
+
56
+ ## Limitations
57
+
58
+ - Trained on CVC-ClinicDB only (612 frames, single source). Generalization
59
+ to other colonoscopy systems / patient populations is unverified.
60
+ - Not a medical device. Research / demo use only.
61
+
62
+ ## Citation
63
+
64
+ ```bibtex
65
+ @inproceedings{ruan2023malunet,
66
+ title={MALUNet: A multi-attention and light-weight UNet for skin lesion segmentation},
67
+ author={Ruan, Jiacheng and Xie, Mingye and Xiang, Suncheng and Liu, Ting and Fu, Yongtao},
68
+ booktitle={2022 IEEE International Conference on Bioinformatics and Biomedicine (BIBM)},
69
+ year={2022}
70
+ }
71
+ ```
best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5966e588253cb8c8d4119c10a40fb4ebc60c3cf87fe4d04f4409d03fd271848a
3
+ size 790195
infer.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Standalone inference helpers for MALUNet on CVC-ClinicDB.
2
+
3
+ `load_model` accepts either a local checkpoint path or an "<owner>/<repo>"
4
+ reference to a Hugging Face model repository (it downloads `best.pth`).
5
+
6
+ CLI:
7
+ python infer.py --weights ./best.pth --image polyp.png --out mask.png
8
+ python infer.py --weights jane-l/malunet-cvc --image polyp.png --out mask.png
9
+ """
10
+ import argparse
11
+ import io
12
+ import os
13
+ from pathlib import Path
14
+ from typing import Tuple, Union
15
+
16
+ import numpy as np
17
+ import torch
18
+ from PIL import Image
19
+
20
+ from models.malunet import MALUNet
21
+
22
+ DEFAULT_MODEL_CONFIG = {
23
+ "num_classes": 1,
24
+ "input_channels": 3,
25
+ "c_list": [8, 16, 24, 32, 48, 64],
26
+ "split_att": "fc",
27
+ "bridge": True,
28
+ }
29
+ INPUT_SIZE = 256
30
+ NORM_MEAN = 109.0
31
+ NORM_STD = 75.0
32
+
33
+
34
+ def _build():
35
+ return MALUNet(
36
+ num_classes=DEFAULT_MODEL_CONFIG["num_classes"],
37
+ input_channels=DEFAULT_MODEL_CONFIG["input_channels"],
38
+ c_list=DEFAULT_MODEL_CONFIG["c_list"],
39
+ split_att=DEFAULT_MODEL_CONFIG["split_att"],
40
+ bridge=DEFAULT_MODEL_CONFIG["bridge"],
41
+ )
42
+
43
+
44
+ def _is_hf_repo_id(s: str) -> bool:
45
+ if os.path.exists(s):
46
+ return False
47
+ return "/" in s and not s.endswith(".pth") and not s.endswith(".pt")
48
+
49
+
50
+ def _strip_module_prefix(state_dict):
51
+ return {k[7:] if k.startswith("module.") else k: v for k, v in state_dict.items()}
52
+
53
+
54
+ def load_model(weights: str, device: Union[str, torch.device, None] = None) -> torch.nn.Module:
55
+ if device is None:
56
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
+ elif isinstance(device, str):
58
+ device = torch.device(device)
59
+
60
+ if _is_hf_repo_id(weights):
61
+ from huggingface_hub import hf_hub_download
62
+
63
+ weights = hf_hub_download(repo_id=weights, filename="best.pth")
64
+
65
+ state = torch.load(weights, map_location="cpu")
66
+ if isinstance(state, dict) and "model_state_dict" in state:
67
+ state = state["model_state_dict"]
68
+ state = _strip_module_prefix(state)
69
+
70
+ model = _build()
71
+ model.load_state_dict(state, strict=True)
72
+ model.to(device).eval()
73
+ return model
74
+
75
+
76
+ def _preprocess(img: Image.Image) -> Tuple[torch.Tensor, Tuple[int, int]]:
77
+ """RGB PIL image -> normalized (1,3,H,W) tensor. Returns the original (H,W)."""
78
+ img = img.convert("RGB")
79
+ orig_size = img.size[::-1] # (H, W)
80
+ arr = np.asarray(img, dtype=np.float32)
81
+ arr = (arr - NORM_MEAN) / NORM_STD
82
+ lo, hi = arr.min(), arr.max()
83
+ if hi > lo:
84
+ arr = (arr - lo) / (hi - lo) * 255.0
85
+ else:
86
+ arr = np.zeros_like(arr)
87
+ img_resized = Image.fromarray(arr.astype(np.uint8)).resize(
88
+ (INPUT_SIZE, INPUT_SIZE), Image.BILINEAR
89
+ )
90
+ t = torch.from_numpy(np.asarray(img_resized, dtype=np.float32)).permute(2, 0, 1).unsqueeze(0)
91
+ return t, orig_size
92
+
93
+
94
+ @torch.no_grad()
95
+ def predict_mask(
96
+ model: torch.nn.Module,
97
+ image: Union[str, Path, Image.Image, bytes],
98
+ threshold: float = 0.5,
99
+ return_prob: bool = False,
100
+ ) -> np.ndarray:
101
+ """Returns a uint8 mask resized back to the original image resolution."""
102
+ if isinstance(image, (str, Path)):
103
+ img = Image.open(image)
104
+ elif isinstance(image, bytes):
105
+ img = Image.open(io.BytesIO(image))
106
+ elif isinstance(image, Image.Image):
107
+ img = image
108
+ else:
109
+ raise TypeError(f"unsupported image type: {type(image)}")
110
+
111
+ device = next(model.parameters()).device
112
+ t, (h, w) = _preprocess(img)
113
+ t = t.to(device).float()
114
+ out = model(t) # (1,1,256,256), already sigmoid
115
+ prob = out[0, 0].cpu().numpy()
116
+ prob_full = np.array(
117
+ Image.fromarray((prob * 255).astype(np.uint8)).resize((w, h), Image.BILINEAR),
118
+ dtype=np.float32,
119
+ ) / 255.0
120
+ if return_prob:
121
+ return prob_full
122
+ return (prob_full >= threshold).astype(np.uint8) * 255
123
+
124
+
125
+ def overlay(image: Image.Image, mask: np.ndarray, alpha: float = 0.45) -> Image.Image:
126
+ base = image.convert("RGB")
127
+ bw, bh = base.size
128
+ if mask.shape != (bh, bw):
129
+ mask = np.array(Image.fromarray(mask).resize((bw, bh), Image.NEAREST))
130
+ color = np.zeros((bh, bw, 3), dtype=np.uint8)
131
+ color[..., 0] = mask # red
132
+ base_arr = np.asarray(base, dtype=np.float32)
133
+ mask_bool = mask > 0
134
+ blended = base_arr.copy()
135
+ blended[mask_bool] = (1 - alpha) * base_arr[mask_bool] + alpha * color[mask_bool]
136
+ return Image.fromarray(blended.astype(np.uint8))
137
+
138
+
139
+ def main():
140
+ ap = argparse.ArgumentParser()
141
+ ap.add_argument("--weights", required=True, help="Local .pth path OR <owner>/<repo> on HF")
142
+ ap.add_argument("--image", required=True)
143
+ ap.add_argument("--out", default="mask.png")
144
+ ap.add_argument("--overlay-out", default=None, help="optional overlay PNG path")
145
+ ap.add_argument("--threshold", type=float, default=0.5)
146
+ args = ap.parse_args()
147
+
148
+ model = load_model(args.weights)
149
+ img = Image.open(args.image)
150
+ mask = predict_mask(model, img, threshold=args.threshold)
151
+ Image.fromarray(mask).save(args.out)
152
+ print(f"wrote {args.out}")
153
+ if args.overlay_out:
154
+ overlay(img, mask).save(args.overlay_out)
155
+ print(f"wrote {args.overlay_out}")
156
+
157
+
158
+ if __name__ == "__main__":
159
+ main()
models/__init__.py ADDED
File without changes
models/malunet.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+ from timm.models.layers import trunc_normal_
6
+ import math
7
+
8
+
9
+ class DepthWiseConv2d(nn.Module):
10
+ def __init__(self, dim_in, dim_out, kernel_size=3, padding=1, stride=1, dilation=1):
11
+ super().__init__()
12
+
13
+ self.conv1 = nn.Conv2d(dim_in, dim_in, kernel_size=kernel_size, padding=padding,
14
+ stride=stride, dilation=dilation, groups=dim_in)
15
+ self.norm_layer = nn.GroupNorm(4, dim_in)
16
+ self.conv2 = nn.Conv2d(dim_in, dim_out, kernel_size=1)
17
+
18
+ def forward(self, x):
19
+ return self.conv2(self.norm_layer(self.conv1(x)))
20
+
21
+ class GatedAttentionUnit(nn.Module):
22
+ def __init__(self, in_c, out_c, kernel_size):
23
+ super().__init__()
24
+ self.w1 = nn.Sequential(
25
+ DepthWiseConv2d(in_c, in_c, kernel_size, padding=kernel_size//2),
26
+ nn.Sigmoid()
27
+ )
28
+
29
+ self.w2 = nn.Sequential(
30
+ DepthWiseConv2d(in_c, in_c, kernel_size + 2, padding=(kernel_size + 2)//2),
31
+ nn.GELU()
32
+ )
33
+ self.wo = nn.Sequential(
34
+ DepthWiseConv2d(in_c, out_c, kernel_size),
35
+ nn.GELU()
36
+ )
37
+
38
+ self.cw = nn.Conv2d(in_c, out_c, 1)
39
+
40
+ def forward(self, x):
41
+ x1, x2 = self.w1(x), self.w2(x)
42
+ out = self.wo(x1 * x2) + self.cw(x)
43
+ return out
44
+
45
+
46
+ class DilatedGatedAttention(nn.Module):
47
+ def __init__(self, in_c, out_c, k_size=3, dilated_ratio=[7, 5, 2, 1]):
48
+ super().__init__()
49
+
50
+ self.mda0 = nn.Conv2d(in_c//4, in_c//4, kernel_size=k_size, stride=1,
51
+ padding=(k_size+(k_size-1)*(dilated_ratio[0]-1))//2,
52
+ dilation=dilated_ratio[0], groups=in_c//4)
53
+ self.mda1 = nn.Conv2d(in_c//4, in_c//4, kernel_size=k_size, stride=1,
54
+ padding=(k_size+(k_size-1)*(dilated_ratio[1]-1))//2,
55
+ dilation=dilated_ratio[1], groups=in_c//4)
56
+ self.mda2 = nn.Conv2d(in_c//4, in_c//4, kernel_size=k_size, stride=1,
57
+ padding=(k_size+(k_size-1)*(dilated_ratio[2]-1))//2,
58
+ dilation=dilated_ratio[2], groups=in_c//4)
59
+ self.mda3 = nn.Conv2d(in_c//4, in_c//4, kernel_size=k_size, stride=1,
60
+ padding=(k_size+(k_size-1)*(dilated_ratio[3]-1))//2,
61
+ dilation=dilated_ratio[3], groups=in_c//4)
62
+ self.norm_layer = nn.GroupNorm(4, in_c)
63
+ self.conv = nn.Conv2d(in_c, in_c, 1)
64
+
65
+ self.gau = GatedAttentionUnit(in_c, out_c, 3)
66
+
67
+ def forward(self, x):
68
+ x = torch.chunk(x, 4, dim=1)
69
+ x0 = self.mda0(x[0])
70
+ x1 = self.mda1(x[1])
71
+ x2 = self.mda2(x[2])
72
+ x3 = self.mda3(x[3])
73
+ x = F.gelu(self.conv(self.norm_layer(torch.cat((x0, x1, x2, x3), dim=1))))
74
+ x = self.gau(x)
75
+ return x
76
+
77
+
78
+ class EAblock(nn.Module):
79
+ def __init__(self, in_c):
80
+ super().__init__()
81
+
82
+ self.conv1 = nn.Conv2d(in_c, in_c, 1)
83
+
84
+ self.k = in_c * 4
85
+ self.linear_0 = nn.Conv1d(in_c, self.k, 1, bias=False)
86
+
87
+ self.linear_1 = nn.Conv1d(self.k, in_c, 1, bias=False)
88
+ self.linear_1.weight.data = self.linear_0.weight.data.permute(1, 0, 2)
89
+
90
+ self.conv2 = nn.Conv2d(in_c, in_c, 1, bias=False)
91
+ self.norm_layer = nn.GroupNorm(4, in_c)
92
+
93
+ def forward(self, x):
94
+ idn = x
95
+ x = self.conv1(x)
96
+
97
+ b, c, h, w = x.size()
98
+ x = x.view(b, c, h*w) # b * c * n
99
+
100
+ attn = self.linear_0(x) # b, k, n
101
+ attn = F.softmax(attn, dim=-1) # b, k, n
102
+
103
+ attn = attn / (1e-9 + attn.sum(dim=1, keepdim=True)) # # b, k, n
104
+ x = self.linear_1(attn) # b, c, n
105
+
106
+ x = x.view(b, c, h, w)
107
+ x = self.norm_layer(self.conv2(x))
108
+ x = x + idn
109
+ x = F.gelu(x)
110
+ return x
111
+
112
+
113
+ class Channel_Att_Bridge(nn.Module):
114
+ def __init__(self, c_list, split_att='fc'):
115
+ super().__init__()
116
+ c_list_sum = sum(c_list) - c_list[-1]
117
+ self.split_att = split_att
118
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
119
+ self.get_all_att = nn.Conv1d(1, 1, kernel_size=3, padding=1, bias=False)
120
+ self.att1 = nn.Linear(c_list_sum, c_list[0]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[0], 1)
121
+ self.att2 = nn.Linear(c_list_sum, c_list[1]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[1], 1)
122
+ self.att3 = nn.Linear(c_list_sum, c_list[2]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[2], 1)
123
+ self.att4 = nn.Linear(c_list_sum, c_list[3]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[3], 1)
124
+ self.att5 = nn.Linear(c_list_sum, c_list[4]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[4], 1)
125
+ self.sigmoid = nn.Sigmoid()
126
+
127
+ def forward(self, t1, t2, t3, t4, t5):
128
+ att = torch.cat((self.avgpool(t1),
129
+ self.avgpool(t2),
130
+ self.avgpool(t3),
131
+ self.avgpool(t4),
132
+ self.avgpool(t5)), dim=1)
133
+ att = self.get_all_att(att.squeeze(-1).transpose(-1, -2))
134
+ if self.split_att != 'fc':
135
+ att = att.transpose(-1, -2)
136
+ att1 = self.sigmoid(self.att1(att))
137
+ att2 = self.sigmoid(self.att2(att))
138
+ att3 = self.sigmoid(self.att3(att))
139
+ att4 = self.sigmoid(self.att4(att))
140
+ att5 = self.sigmoid(self.att5(att))
141
+ if self.split_att == 'fc':
142
+ att1 = att1.transpose(-1, -2).unsqueeze(-1).expand_as(t1)
143
+ att2 = att2.transpose(-1, -2).unsqueeze(-1).expand_as(t2)
144
+ att3 = att3.transpose(-1, -2).unsqueeze(-1).expand_as(t3)
145
+ att4 = att4.transpose(-1, -2).unsqueeze(-1).expand_as(t4)
146
+ att5 = att5.transpose(-1, -2).unsqueeze(-1).expand_as(t5)
147
+ else:
148
+ att1 = att1.unsqueeze(-1).expand_as(t1)
149
+ att2 = att2.unsqueeze(-1).expand_as(t2)
150
+ att3 = att3.unsqueeze(-1).expand_as(t3)
151
+ att4 = att4.unsqueeze(-1).expand_as(t4)
152
+ att5 = att5.unsqueeze(-1).expand_as(t5)
153
+
154
+ return att1, att2, att3, att4, att5
155
+
156
+
157
+ class Spatial_Att_Bridge(nn.Module):
158
+ def __init__(self):
159
+ super().__init__()
160
+ self.shared_conv2d = nn.Sequential(nn.Conv2d(2, 1, 7, stride=1, padding=9, dilation=3),
161
+ nn.Sigmoid())
162
+
163
+ def forward(self, t1, t2, t3, t4, t5):
164
+ t_list = [t1, t2, t3, t4, t5]
165
+ att_list = []
166
+ for t in t_list:
167
+ avg_out = torch.mean(t, dim=1, keepdim=True)
168
+ max_out, _ = torch.max(t, dim=1, keepdim=True)
169
+ att = torch.cat([avg_out, max_out], dim=1)
170
+ att = self.shared_conv2d(att)
171
+ att_list.append(att)
172
+ return att_list[0], att_list[1], att_list[2], att_list[3], att_list[4]
173
+
174
+
175
+ class SC_Att_Bridge(nn.Module):
176
+ def __init__(self, c_list, split_att='fc'):
177
+ super().__init__()
178
+
179
+ self.catt = Channel_Att_Bridge(c_list, split_att=split_att)
180
+ self.satt = Spatial_Att_Bridge()
181
+
182
+ def forward(self, t1, t2, t3, t4, t5):
183
+ r1, r2, r3, r4, r5 = t1, t2, t3, t4, t5
184
+
185
+ satt1, satt2, satt3, satt4, satt5 = self.satt(t1, t2, t3, t4, t5)
186
+ t1, t2, t3, t4, t5 = satt1 * t1, satt2 * t2, satt3 * t3, satt4 * t4, satt5 * t5
187
+
188
+ r1_, r2_, r3_, r4_, r5_ = t1, t2, t3, t4, t5
189
+ t1, t2, t3, t4, t5 = t1 + r1, t2 + r2, t3 + r3, t4 + r4, t5 + r5
190
+
191
+ catt1, catt2, catt3, catt4, catt5 = self.catt(t1, t2, t3, t4, t5)
192
+ t1, t2, t3, t4, t5 = catt1 * t1, catt2 * t2, catt3 * t3, catt4 * t4, catt5 * t5
193
+
194
+ return t1 + r1_, t2 + r2_, t3 + r3_, t4 + r4_, t5 + r5_
195
+
196
+
197
+ class MALUNet(nn.Module):
198
+
199
+ def __init__(self, num_classes=1, input_channels=3, c_list=[8,16,24,32,48,64],
200
+ split_att='fc', bridge=True):
201
+ super().__init__()
202
+
203
+ self.bridge = bridge
204
+
205
+ self.encoder1 = nn.Sequential(
206
+ nn.Conv2d(input_channels, c_list[0], 3, stride=1, padding=1),
207
+ )
208
+ self.encoder2 =nn.Sequential(
209
+ nn.Conv2d(c_list[0], c_list[1], 3, stride=1, padding=1),
210
+ )
211
+ self.encoder3 = nn.Sequential(
212
+ nn.Conv2d(c_list[1], c_list[2], 3, stride=1, padding=1),
213
+ )
214
+ self.encoder4 = nn.Sequential(
215
+ EAblock(c_list[2]),
216
+ DilatedGatedAttention(c_list[2], c_list[3]),
217
+ )
218
+ self.encoder5 = nn.Sequential(
219
+ EAblock(c_list[3]),
220
+ DilatedGatedAttention(c_list[3], c_list[4]),
221
+ )
222
+ self.encoder6 = nn.Sequential(
223
+ EAblock(c_list[4]),
224
+ DilatedGatedAttention(c_list[4], c_list[5]),
225
+ )
226
+
227
+ if bridge:
228
+ self.scab = SC_Att_Bridge(c_list, split_att)
229
+ print('SC_Att_Bridge was used')
230
+
231
+ self.decoder1 = nn.Sequential(
232
+ DilatedGatedAttention(c_list[5], c_list[4]),
233
+ EAblock(c_list[4]),
234
+ )
235
+ self.decoder2 = nn.Sequential(
236
+ DilatedGatedAttention(c_list[4], c_list[3]),
237
+ EAblock(c_list[3]),
238
+ )
239
+ self.decoder3 = nn.Sequential(
240
+ DilatedGatedAttention(c_list[3], c_list[2]),
241
+ EAblock(c_list[2]),
242
+ )
243
+ self.decoder4 = nn.Sequential(
244
+ nn.Conv2d(c_list[2], c_list[1], 3, stride=1, padding=1),
245
+ )
246
+ self.decoder5 = nn.Sequential(
247
+ nn.Conv2d(c_list[1], c_list[0], 3, stride=1, padding=1),
248
+ )
249
+ self.ebn1 = nn.GroupNorm(4, c_list[0])
250
+ self.ebn2 = nn.GroupNorm(4, c_list[1])
251
+ self.ebn3 = nn.GroupNorm(4, c_list[2])
252
+ self.ebn4 = nn.GroupNorm(4, c_list[3])
253
+ self.ebn5 = nn.GroupNorm(4, c_list[4])
254
+ self.dbn1 = nn.GroupNorm(4, c_list[4])
255
+ self.dbn2 = nn.GroupNorm(4, c_list[3])
256
+ self.dbn3 = nn.GroupNorm(4, c_list[2])
257
+ self.dbn4 = nn.GroupNorm(4, c_list[1])
258
+ self.dbn5 = nn.GroupNorm(4, c_list[0])
259
+
260
+ self.final = nn.Conv2d(c_list[0], num_classes, kernel_size=1)
261
+
262
+ self.apply(self._init_weights)
263
+
264
+ def _init_weights(self, m):
265
+ if isinstance(m, nn.Linear):
266
+ trunc_normal_(m.weight, std=.02)
267
+ if isinstance(m, nn.Linear) and m.bias is not None:
268
+ nn.init.constant_(m.bias, 0)
269
+ elif isinstance(m, nn.Conv1d):
270
+ n = m.kernel_size[0] * m.out_channels
271
+ m.weight.data.normal_(0, math.sqrt(2. / n))
272
+ elif isinstance(m, nn.Conv2d):
273
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
274
+ fan_out //= m.groups
275
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
276
+ if m.bias is not None:
277
+ m.bias.data.zero_()
278
+
279
+ def forward(self, x):
280
+
281
+ out = F.gelu(F.max_pool2d(self.ebn1(self.encoder1(x)),2,2))
282
+ t1 = out # b, c0, H/2, W/2
283
+
284
+ out = F.gelu(F.max_pool2d(self.ebn2(self.encoder2(out)),2,2))
285
+ t2 = out # b, c1, H/4, W/4
286
+
287
+ out = F.gelu(F.max_pool2d(self.ebn3(self.encoder3(out)),2,2))
288
+ t3 = out # b, c2, H/8, W/8
289
+
290
+ out = F.gelu(F.max_pool2d(self.ebn4(self.encoder4(out)),2,2))
291
+ t4 = out # b, c3, H/16, W/16
292
+
293
+ out = F.gelu(F.max_pool2d(self.ebn5(self.encoder5(out)),2,2))
294
+ t5 = out # b, c4, H/32, W/32
295
+
296
+ if self.bridge: t1, t2, t3, t4, t5 = self.scab(t1, t2, t3, t4, t5)
297
+
298
+ out = F.gelu(self.encoder6(out)) # b, c5, H/32, W/32
299
+
300
+ out5 = F.gelu(self.dbn1(self.decoder1(out))) # b, c4, H/32, W/32
301
+ out5 = torch.add(out5, t5) # b, c4, H/32, W/32
302
+
303
+ out4 = F.gelu(F.interpolate(self.dbn2(self.decoder2(out5)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c3, H/16, W/16
304
+ out4 = torch.add(out4, t4) # b, c3, H/16, W/16
305
+
306
+ out3 = F.gelu(F.interpolate(self.dbn3(self.decoder3(out4)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c2, H/8, W/8
307
+ out3 = torch.add(out3, t3) # b, c2, H/8, W/8
308
+
309
+ out2 = F.gelu(F.interpolate(self.dbn4(self.decoder4(out3)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c1, H/4, W/4
310
+ out2 = torch.add(out2, t2) # b, c1, H/4, W/4
311
+
312
+ out1 = F.gelu(F.interpolate(self.dbn5(self.decoder5(out2)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c0, H/2, W/2
313
+ out1 = torch.add(out1, t1) # b, c0, H/2, W/2
314
+
315
+ out0 = F.interpolate(self.final(out1),scale_factor=(2,2),mode ='bilinear',align_corners=True) # b, num_class, H, W
316
+
317
+ return torch.sigmoid(out0)