Upload MALUNet CVC-ClinicDB weights
Browse files- README.md +71 -0
- best.pth +3 -0
- infer.py +159 -0
- models/__init__.py +0 -0
- models/malunet.py +317 -0
README.md
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
tags:
|
| 4 |
+
- image-segmentation
|
| 5 |
+
- medical-imaging
|
| 6 |
+
- polyp-segmentation
|
| 7 |
+
- pytorch
|
| 8 |
+
- malunet
|
| 9 |
+
datasets:
|
| 10 |
+
- cvc-clinicdb
|
| 11 |
+
library_name: pytorch
|
| 12 |
+
pipeline_tag: image-segmentation
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
# MALUNet · CVC-ClinicDB (Polyp Segmentation)
|
| 16 |
+
|
| 17 |
+
Lightweight U-shape segmentation network adapted from
|
| 18 |
+
[jcruan519/MALUNet](https://github.com/jcruan519/MALUNet) and trained on
|
| 19 |
+
[CVC-ClinicDB](https://www.kaggle.com/datasets/balraj98/cvcclinicdb) for
|
| 20 |
+
binary polyp segmentation in colonoscopy frames.
|
| 21 |
+
|
| 22 |
+
## Model
|
| 23 |
+
|
| 24 |
+
- Architecture: MALUNet (DGA + IEA + CAB + SAB)
|
| 25 |
+
- Channels: `[8, 16, 24, 32, 48, 64]`, `split_att="fc"`, `bridge=True`
|
| 26 |
+
- Input: RGB, 256×256
|
| 27 |
+
- Output: single-channel sigmoid mask (1 = polyp)
|
| 28 |
+
- Parameters: ~0.18 M
|
| 29 |
+
|
| 30 |
+
## Training
|
| 31 |
+
|
| 32 |
+
- Dataset: CVC-ClinicDB (612 paired image/mask frames)
|
| 33 |
+
- Split: 80% train / 20% val (seeded by filename, `seed=42`)
|
| 34 |
+
- Loss: BCE + Dice
|
| 35 |
+
- Optimizer: AdamW, `lr=1e-3`, `weight_decay=1e-2`
|
| 36 |
+
- Schedule: CosineAnnealingLR, `T_max=50`, `eta_min=1e-5`
|
| 37 |
+
- Augmentations: random h/v flip, random rotation
|
| 38 |
+
- Epochs: 150
|
| 39 |
+
|
| 40 |
+
## Usage
|
| 41 |
+
|
| 42 |
+
```python
|
| 43 |
+
import torch
|
| 44 |
+
from huggingface_hub import hf_hub_download
|
| 45 |
+
from infer import load_model, predict_mask # infer.py from this repo
|
| 46 |
+
from PIL import Image
|
| 47 |
+
|
| 48 |
+
model = load_model("YOUR_USERNAME/malunet-cvc")
|
| 49 |
+
mask = predict_mask(model, Image.open("polyp.png"))
|
| 50 |
+
Image.fromarray(mask).save("mask.png")
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
`infer.py` and `models/malunet.py` are bundled in this repo so you can
|
| 54 |
+
also clone it and run inference without the original training code.
|
| 55 |
+
|
| 56 |
+
## Limitations
|
| 57 |
+
|
| 58 |
+
- Trained on CVC-ClinicDB only (612 frames, single source). Generalization
|
| 59 |
+
to other colonoscopy systems / patient populations is unverified.
|
| 60 |
+
- Not a medical device. Research / demo use only.
|
| 61 |
+
|
| 62 |
+
## Citation
|
| 63 |
+
|
| 64 |
+
```bibtex
|
| 65 |
+
@inproceedings{ruan2023malunet,
|
| 66 |
+
title={MALUNet: A multi-attention and light-weight UNet for skin lesion segmentation},
|
| 67 |
+
author={Ruan, Jiacheng and Xie, Mingye and Xiang, Suncheng and Liu, Ting and Fu, Yongtao},
|
| 68 |
+
booktitle={2022 IEEE International Conference on Bioinformatics and Biomedicine (BIBM)},
|
| 69 |
+
year={2022}
|
| 70 |
+
}
|
| 71 |
+
```
|
best.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5966e588253cb8c8d4119c10a40fb4ebc60c3cf87fe4d04f4409d03fd271848a
|
| 3 |
+
size 790195
|
infer.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Standalone inference helpers for MALUNet on CVC-ClinicDB.
|
| 2 |
+
|
| 3 |
+
`load_model` accepts either a local checkpoint path or an "<owner>/<repo>"
|
| 4 |
+
reference to a Hugging Face model repository (it downloads `best.pth`).
|
| 5 |
+
|
| 6 |
+
CLI:
|
| 7 |
+
python infer.py --weights ./best.pth --image polyp.png --out mask.png
|
| 8 |
+
python infer.py --weights jane-l/malunet-cvc --image polyp.png --out mask.png
|
| 9 |
+
"""
|
| 10 |
+
import argparse
|
| 11 |
+
import io
|
| 12 |
+
import os
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Tuple, Union
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
from PIL import Image
|
| 19 |
+
|
| 20 |
+
from models.malunet import MALUNet
|
| 21 |
+
|
| 22 |
+
DEFAULT_MODEL_CONFIG = {
|
| 23 |
+
"num_classes": 1,
|
| 24 |
+
"input_channels": 3,
|
| 25 |
+
"c_list": [8, 16, 24, 32, 48, 64],
|
| 26 |
+
"split_att": "fc",
|
| 27 |
+
"bridge": True,
|
| 28 |
+
}
|
| 29 |
+
INPUT_SIZE = 256
|
| 30 |
+
NORM_MEAN = 109.0
|
| 31 |
+
NORM_STD = 75.0
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _build():
|
| 35 |
+
return MALUNet(
|
| 36 |
+
num_classes=DEFAULT_MODEL_CONFIG["num_classes"],
|
| 37 |
+
input_channels=DEFAULT_MODEL_CONFIG["input_channels"],
|
| 38 |
+
c_list=DEFAULT_MODEL_CONFIG["c_list"],
|
| 39 |
+
split_att=DEFAULT_MODEL_CONFIG["split_att"],
|
| 40 |
+
bridge=DEFAULT_MODEL_CONFIG["bridge"],
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _is_hf_repo_id(s: str) -> bool:
|
| 45 |
+
if os.path.exists(s):
|
| 46 |
+
return False
|
| 47 |
+
return "/" in s and not s.endswith(".pth") and not s.endswith(".pt")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _strip_module_prefix(state_dict):
|
| 51 |
+
return {k[7:] if k.startswith("module.") else k: v for k, v in state_dict.items()}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def load_model(weights: str, device: Union[str, torch.device, None] = None) -> torch.nn.Module:
|
| 55 |
+
if device is None:
|
| 56 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 57 |
+
elif isinstance(device, str):
|
| 58 |
+
device = torch.device(device)
|
| 59 |
+
|
| 60 |
+
if _is_hf_repo_id(weights):
|
| 61 |
+
from huggingface_hub import hf_hub_download
|
| 62 |
+
|
| 63 |
+
weights = hf_hub_download(repo_id=weights, filename="best.pth")
|
| 64 |
+
|
| 65 |
+
state = torch.load(weights, map_location="cpu")
|
| 66 |
+
if isinstance(state, dict) and "model_state_dict" in state:
|
| 67 |
+
state = state["model_state_dict"]
|
| 68 |
+
state = _strip_module_prefix(state)
|
| 69 |
+
|
| 70 |
+
model = _build()
|
| 71 |
+
model.load_state_dict(state, strict=True)
|
| 72 |
+
model.to(device).eval()
|
| 73 |
+
return model
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _preprocess(img: Image.Image) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
| 77 |
+
"""RGB PIL image -> normalized (1,3,H,W) tensor. Returns the original (H,W)."""
|
| 78 |
+
img = img.convert("RGB")
|
| 79 |
+
orig_size = img.size[::-1] # (H, W)
|
| 80 |
+
arr = np.asarray(img, dtype=np.float32)
|
| 81 |
+
arr = (arr - NORM_MEAN) / NORM_STD
|
| 82 |
+
lo, hi = arr.min(), arr.max()
|
| 83 |
+
if hi > lo:
|
| 84 |
+
arr = (arr - lo) / (hi - lo) * 255.0
|
| 85 |
+
else:
|
| 86 |
+
arr = np.zeros_like(arr)
|
| 87 |
+
img_resized = Image.fromarray(arr.astype(np.uint8)).resize(
|
| 88 |
+
(INPUT_SIZE, INPUT_SIZE), Image.BILINEAR
|
| 89 |
+
)
|
| 90 |
+
t = torch.from_numpy(np.asarray(img_resized, dtype=np.float32)).permute(2, 0, 1).unsqueeze(0)
|
| 91 |
+
return t, orig_size
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@torch.no_grad()
|
| 95 |
+
def predict_mask(
|
| 96 |
+
model: torch.nn.Module,
|
| 97 |
+
image: Union[str, Path, Image.Image, bytes],
|
| 98 |
+
threshold: float = 0.5,
|
| 99 |
+
return_prob: bool = False,
|
| 100 |
+
) -> np.ndarray:
|
| 101 |
+
"""Returns a uint8 mask resized back to the original image resolution."""
|
| 102 |
+
if isinstance(image, (str, Path)):
|
| 103 |
+
img = Image.open(image)
|
| 104 |
+
elif isinstance(image, bytes):
|
| 105 |
+
img = Image.open(io.BytesIO(image))
|
| 106 |
+
elif isinstance(image, Image.Image):
|
| 107 |
+
img = image
|
| 108 |
+
else:
|
| 109 |
+
raise TypeError(f"unsupported image type: {type(image)}")
|
| 110 |
+
|
| 111 |
+
device = next(model.parameters()).device
|
| 112 |
+
t, (h, w) = _preprocess(img)
|
| 113 |
+
t = t.to(device).float()
|
| 114 |
+
out = model(t) # (1,1,256,256), already sigmoid
|
| 115 |
+
prob = out[0, 0].cpu().numpy()
|
| 116 |
+
prob_full = np.array(
|
| 117 |
+
Image.fromarray((prob * 255).astype(np.uint8)).resize((w, h), Image.BILINEAR),
|
| 118 |
+
dtype=np.float32,
|
| 119 |
+
) / 255.0
|
| 120 |
+
if return_prob:
|
| 121 |
+
return prob_full
|
| 122 |
+
return (prob_full >= threshold).astype(np.uint8) * 255
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def overlay(image: Image.Image, mask: np.ndarray, alpha: float = 0.45) -> Image.Image:
|
| 126 |
+
base = image.convert("RGB")
|
| 127 |
+
bw, bh = base.size
|
| 128 |
+
if mask.shape != (bh, bw):
|
| 129 |
+
mask = np.array(Image.fromarray(mask).resize((bw, bh), Image.NEAREST))
|
| 130 |
+
color = np.zeros((bh, bw, 3), dtype=np.uint8)
|
| 131 |
+
color[..., 0] = mask # red
|
| 132 |
+
base_arr = np.asarray(base, dtype=np.float32)
|
| 133 |
+
mask_bool = mask > 0
|
| 134 |
+
blended = base_arr.copy()
|
| 135 |
+
blended[mask_bool] = (1 - alpha) * base_arr[mask_bool] + alpha * color[mask_bool]
|
| 136 |
+
return Image.fromarray(blended.astype(np.uint8))
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def main():
|
| 140 |
+
ap = argparse.ArgumentParser()
|
| 141 |
+
ap.add_argument("--weights", required=True, help="Local .pth path OR <owner>/<repo> on HF")
|
| 142 |
+
ap.add_argument("--image", required=True)
|
| 143 |
+
ap.add_argument("--out", default="mask.png")
|
| 144 |
+
ap.add_argument("--overlay-out", default=None, help="optional overlay PNG path")
|
| 145 |
+
ap.add_argument("--threshold", type=float, default=0.5)
|
| 146 |
+
args = ap.parse_args()
|
| 147 |
+
|
| 148 |
+
model = load_model(args.weights)
|
| 149 |
+
img = Image.open(args.image)
|
| 150 |
+
mask = predict_mask(model, img, threshold=args.threshold)
|
| 151 |
+
Image.fromarray(mask).save(args.out)
|
| 152 |
+
print(f"wrote {args.out}")
|
| 153 |
+
if args.overlay_out:
|
| 154 |
+
overlay(img, mask).save(args.overlay_out)
|
| 155 |
+
print(f"wrote {args.overlay_out}")
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
if __name__ == "__main__":
|
| 159 |
+
main()
|
models/__init__.py
ADDED
|
File without changes
|
models/malunet.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
from timm.models.layers import trunc_normal_
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class DepthWiseConv2d(nn.Module):
|
| 10 |
+
def __init__(self, dim_in, dim_out, kernel_size=3, padding=1, stride=1, dilation=1):
|
| 11 |
+
super().__init__()
|
| 12 |
+
|
| 13 |
+
self.conv1 = nn.Conv2d(dim_in, dim_in, kernel_size=kernel_size, padding=padding,
|
| 14 |
+
stride=stride, dilation=dilation, groups=dim_in)
|
| 15 |
+
self.norm_layer = nn.GroupNorm(4, dim_in)
|
| 16 |
+
self.conv2 = nn.Conv2d(dim_in, dim_out, kernel_size=1)
|
| 17 |
+
|
| 18 |
+
def forward(self, x):
|
| 19 |
+
return self.conv2(self.norm_layer(self.conv1(x)))
|
| 20 |
+
|
| 21 |
+
class GatedAttentionUnit(nn.Module):
|
| 22 |
+
def __init__(self, in_c, out_c, kernel_size):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.w1 = nn.Sequential(
|
| 25 |
+
DepthWiseConv2d(in_c, in_c, kernel_size, padding=kernel_size//2),
|
| 26 |
+
nn.Sigmoid()
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
self.w2 = nn.Sequential(
|
| 30 |
+
DepthWiseConv2d(in_c, in_c, kernel_size + 2, padding=(kernel_size + 2)//2),
|
| 31 |
+
nn.GELU()
|
| 32 |
+
)
|
| 33 |
+
self.wo = nn.Sequential(
|
| 34 |
+
DepthWiseConv2d(in_c, out_c, kernel_size),
|
| 35 |
+
nn.GELU()
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
self.cw = nn.Conv2d(in_c, out_c, 1)
|
| 39 |
+
|
| 40 |
+
def forward(self, x):
|
| 41 |
+
x1, x2 = self.w1(x), self.w2(x)
|
| 42 |
+
out = self.wo(x1 * x2) + self.cw(x)
|
| 43 |
+
return out
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class DilatedGatedAttention(nn.Module):
|
| 47 |
+
def __init__(self, in_c, out_c, k_size=3, dilated_ratio=[7, 5, 2, 1]):
|
| 48 |
+
super().__init__()
|
| 49 |
+
|
| 50 |
+
self.mda0 = nn.Conv2d(in_c//4, in_c//4, kernel_size=k_size, stride=1,
|
| 51 |
+
padding=(k_size+(k_size-1)*(dilated_ratio[0]-1))//2,
|
| 52 |
+
dilation=dilated_ratio[0], groups=in_c//4)
|
| 53 |
+
self.mda1 = nn.Conv2d(in_c//4, in_c//4, kernel_size=k_size, stride=1,
|
| 54 |
+
padding=(k_size+(k_size-1)*(dilated_ratio[1]-1))//2,
|
| 55 |
+
dilation=dilated_ratio[1], groups=in_c//4)
|
| 56 |
+
self.mda2 = nn.Conv2d(in_c//4, in_c//4, kernel_size=k_size, stride=1,
|
| 57 |
+
padding=(k_size+(k_size-1)*(dilated_ratio[2]-1))//2,
|
| 58 |
+
dilation=dilated_ratio[2], groups=in_c//4)
|
| 59 |
+
self.mda3 = nn.Conv2d(in_c//4, in_c//4, kernel_size=k_size, stride=1,
|
| 60 |
+
padding=(k_size+(k_size-1)*(dilated_ratio[3]-1))//2,
|
| 61 |
+
dilation=dilated_ratio[3], groups=in_c//4)
|
| 62 |
+
self.norm_layer = nn.GroupNorm(4, in_c)
|
| 63 |
+
self.conv = nn.Conv2d(in_c, in_c, 1)
|
| 64 |
+
|
| 65 |
+
self.gau = GatedAttentionUnit(in_c, out_c, 3)
|
| 66 |
+
|
| 67 |
+
def forward(self, x):
|
| 68 |
+
x = torch.chunk(x, 4, dim=1)
|
| 69 |
+
x0 = self.mda0(x[0])
|
| 70 |
+
x1 = self.mda1(x[1])
|
| 71 |
+
x2 = self.mda2(x[2])
|
| 72 |
+
x3 = self.mda3(x[3])
|
| 73 |
+
x = F.gelu(self.conv(self.norm_layer(torch.cat((x0, x1, x2, x3), dim=1))))
|
| 74 |
+
x = self.gau(x)
|
| 75 |
+
return x
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class EAblock(nn.Module):
|
| 79 |
+
def __init__(self, in_c):
|
| 80 |
+
super().__init__()
|
| 81 |
+
|
| 82 |
+
self.conv1 = nn.Conv2d(in_c, in_c, 1)
|
| 83 |
+
|
| 84 |
+
self.k = in_c * 4
|
| 85 |
+
self.linear_0 = nn.Conv1d(in_c, self.k, 1, bias=False)
|
| 86 |
+
|
| 87 |
+
self.linear_1 = nn.Conv1d(self.k, in_c, 1, bias=False)
|
| 88 |
+
self.linear_1.weight.data = self.linear_0.weight.data.permute(1, 0, 2)
|
| 89 |
+
|
| 90 |
+
self.conv2 = nn.Conv2d(in_c, in_c, 1, bias=False)
|
| 91 |
+
self.norm_layer = nn.GroupNorm(4, in_c)
|
| 92 |
+
|
| 93 |
+
def forward(self, x):
|
| 94 |
+
idn = x
|
| 95 |
+
x = self.conv1(x)
|
| 96 |
+
|
| 97 |
+
b, c, h, w = x.size()
|
| 98 |
+
x = x.view(b, c, h*w) # b * c * n
|
| 99 |
+
|
| 100 |
+
attn = self.linear_0(x) # b, k, n
|
| 101 |
+
attn = F.softmax(attn, dim=-1) # b, k, n
|
| 102 |
+
|
| 103 |
+
attn = attn / (1e-9 + attn.sum(dim=1, keepdim=True)) # # b, k, n
|
| 104 |
+
x = self.linear_1(attn) # b, c, n
|
| 105 |
+
|
| 106 |
+
x = x.view(b, c, h, w)
|
| 107 |
+
x = self.norm_layer(self.conv2(x))
|
| 108 |
+
x = x + idn
|
| 109 |
+
x = F.gelu(x)
|
| 110 |
+
return x
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class Channel_Att_Bridge(nn.Module):
|
| 114 |
+
def __init__(self, c_list, split_att='fc'):
|
| 115 |
+
super().__init__()
|
| 116 |
+
c_list_sum = sum(c_list) - c_list[-1]
|
| 117 |
+
self.split_att = split_att
|
| 118 |
+
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
| 119 |
+
self.get_all_att = nn.Conv1d(1, 1, kernel_size=3, padding=1, bias=False)
|
| 120 |
+
self.att1 = nn.Linear(c_list_sum, c_list[0]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[0], 1)
|
| 121 |
+
self.att2 = nn.Linear(c_list_sum, c_list[1]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[1], 1)
|
| 122 |
+
self.att3 = nn.Linear(c_list_sum, c_list[2]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[2], 1)
|
| 123 |
+
self.att4 = nn.Linear(c_list_sum, c_list[3]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[3], 1)
|
| 124 |
+
self.att5 = nn.Linear(c_list_sum, c_list[4]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[4], 1)
|
| 125 |
+
self.sigmoid = nn.Sigmoid()
|
| 126 |
+
|
| 127 |
+
def forward(self, t1, t2, t3, t4, t5):
|
| 128 |
+
att = torch.cat((self.avgpool(t1),
|
| 129 |
+
self.avgpool(t2),
|
| 130 |
+
self.avgpool(t3),
|
| 131 |
+
self.avgpool(t4),
|
| 132 |
+
self.avgpool(t5)), dim=1)
|
| 133 |
+
att = self.get_all_att(att.squeeze(-1).transpose(-1, -2))
|
| 134 |
+
if self.split_att != 'fc':
|
| 135 |
+
att = att.transpose(-1, -2)
|
| 136 |
+
att1 = self.sigmoid(self.att1(att))
|
| 137 |
+
att2 = self.sigmoid(self.att2(att))
|
| 138 |
+
att3 = self.sigmoid(self.att3(att))
|
| 139 |
+
att4 = self.sigmoid(self.att4(att))
|
| 140 |
+
att5 = self.sigmoid(self.att5(att))
|
| 141 |
+
if self.split_att == 'fc':
|
| 142 |
+
att1 = att1.transpose(-1, -2).unsqueeze(-1).expand_as(t1)
|
| 143 |
+
att2 = att2.transpose(-1, -2).unsqueeze(-1).expand_as(t2)
|
| 144 |
+
att3 = att3.transpose(-1, -2).unsqueeze(-1).expand_as(t3)
|
| 145 |
+
att4 = att4.transpose(-1, -2).unsqueeze(-1).expand_as(t4)
|
| 146 |
+
att5 = att5.transpose(-1, -2).unsqueeze(-1).expand_as(t5)
|
| 147 |
+
else:
|
| 148 |
+
att1 = att1.unsqueeze(-1).expand_as(t1)
|
| 149 |
+
att2 = att2.unsqueeze(-1).expand_as(t2)
|
| 150 |
+
att3 = att3.unsqueeze(-1).expand_as(t3)
|
| 151 |
+
att4 = att4.unsqueeze(-1).expand_as(t4)
|
| 152 |
+
att5 = att5.unsqueeze(-1).expand_as(t5)
|
| 153 |
+
|
| 154 |
+
return att1, att2, att3, att4, att5
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class Spatial_Att_Bridge(nn.Module):
|
| 158 |
+
def __init__(self):
|
| 159 |
+
super().__init__()
|
| 160 |
+
self.shared_conv2d = nn.Sequential(nn.Conv2d(2, 1, 7, stride=1, padding=9, dilation=3),
|
| 161 |
+
nn.Sigmoid())
|
| 162 |
+
|
| 163 |
+
def forward(self, t1, t2, t3, t4, t5):
|
| 164 |
+
t_list = [t1, t2, t3, t4, t5]
|
| 165 |
+
att_list = []
|
| 166 |
+
for t in t_list:
|
| 167 |
+
avg_out = torch.mean(t, dim=1, keepdim=True)
|
| 168 |
+
max_out, _ = torch.max(t, dim=1, keepdim=True)
|
| 169 |
+
att = torch.cat([avg_out, max_out], dim=1)
|
| 170 |
+
att = self.shared_conv2d(att)
|
| 171 |
+
att_list.append(att)
|
| 172 |
+
return att_list[0], att_list[1], att_list[2], att_list[3], att_list[4]
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class SC_Att_Bridge(nn.Module):
|
| 176 |
+
def __init__(self, c_list, split_att='fc'):
|
| 177 |
+
super().__init__()
|
| 178 |
+
|
| 179 |
+
self.catt = Channel_Att_Bridge(c_list, split_att=split_att)
|
| 180 |
+
self.satt = Spatial_Att_Bridge()
|
| 181 |
+
|
| 182 |
+
def forward(self, t1, t2, t3, t4, t5):
|
| 183 |
+
r1, r2, r3, r4, r5 = t1, t2, t3, t4, t5
|
| 184 |
+
|
| 185 |
+
satt1, satt2, satt3, satt4, satt5 = self.satt(t1, t2, t3, t4, t5)
|
| 186 |
+
t1, t2, t3, t4, t5 = satt1 * t1, satt2 * t2, satt3 * t3, satt4 * t4, satt5 * t5
|
| 187 |
+
|
| 188 |
+
r1_, r2_, r3_, r4_, r5_ = t1, t2, t3, t4, t5
|
| 189 |
+
t1, t2, t3, t4, t5 = t1 + r1, t2 + r2, t3 + r3, t4 + r4, t5 + r5
|
| 190 |
+
|
| 191 |
+
catt1, catt2, catt3, catt4, catt5 = self.catt(t1, t2, t3, t4, t5)
|
| 192 |
+
t1, t2, t3, t4, t5 = catt1 * t1, catt2 * t2, catt3 * t3, catt4 * t4, catt5 * t5
|
| 193 |
+
|
| 194 |
+
return t1 + r1_, t2 + r2_, t3 + r3_, t4 + r4_, t5 + r5_
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class MALUNet(nn.Module):
|
| 198 |
+
|
| 199 |
+
def __init__(self, num_classes=1, input_channels=3, c_list=[8,16,24,32,48,64],
|
| 200 |
+
split_att='fc', bridge=True):
|
| 201 |
+
super().__init__()
|
| 202 |
+
|
| 203 |
+
self.bridge = bridge
|
| 204 |
+
|
| 205 |
+
self.encoder1 = nn.Sequential(
|
| 206 |
+
nn.Conv2d(input_channels, c_list[0], 3, stride=1, padding=1),
|
| 207 |
+
)
|
| 208 |
+
self.encoder2 =nn.Sequential(
|
| 209 |
+
nn.Conv2d(c_list[0], c_list[1], 3, stride=1, padding=1),
|
| 210 |
+
)
|
| 211 |
+
self.encoder3 = nn.Sequential(
|
| 212 |
+
nn.Conv2d(c_list[1], c_list[2], 3, stride=1, padding=1),
|
| 213 |
+
)
|
| 214 |
+
self.encoder4 = nn.Sequential(
|
| 215 |
+
EAblock(c_list[2]),
|
| 216 |
+
DilatedGatedAttention(c_list[2], c_list[3]),
|
| 217 |
+
)
|
| 218 |
+
self.encoder5 = nn.Sequential(
|
| 219 |
+
EAblock(c_list[3]),
|
| 220 |
+
DilatedGatedAttention(c_list[3], c_list[4]),
|
| 221 |
+
)
|
| 222 |
+
self.encoder6 = nn.Sequential(
|
| 223 |
+
EAblock(c_list[4]),
|
| 224 |
+
DilatedGatedAttention(c_list[4], c_list[5]),
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
if bridge:
|
| 228 |
+
self.scab = SC_Att_Bridge(c_list, split_att)
|
| 229 |
+
print('SC_Att_Bridge was used')
|
| 230 |
+
|
| 231 |
+
self.decoder1 = nn.Sequential(
|
| 232 |
+
DilatedGatedAttention(c_list[5], c_list[4]),
|
| 233 |
+
EAblock(c_list[4]),
|
| 234 |
+
)
|
| 235 |
+
self.decoder2 = nn.Sequential(
|
| 236 |
+
DilatedGatedAttention(c_list[4], c_list[3]),
|
| 237 |
+
EAblock(c_list[3]),
|
| 238 |
+
)
|
| 239 |
+
self.decoder3 = nn.Sequential(
|
| 240 |
+
DilatedGatedAttention(c_list[3], c_list[2]),
|
| 241 |
+
EAblock(c_list[2]),
|
| 242 |
+
)
|
| 243 |
+
self.decoder4 = nn.Sequential(
|
| 244 |
+
nn.Conv2d(c_list[2], c_list[1], 3, stride=1, padding=1),
|
| 245 |
+
)
|
| 246 |
+
self.decoder5 = nn.Sequential(
|
| 247 |
+
nn.Conv2d(c_list[1], c_list[0], 3, stride=1, padding=1),
|
| 248 |
+
)
|
| 249 |
+
self.ebn1 = nn.GroupNorm(4, c_list[0])
|
| 250 |
+
self.ebn2 = nn.GroupNorm(4, c_list[1])
|
| 251 |
+
self.ebn3 = nn.GroupNorm(4, c_list[2])
|
| 252 |
+
self.ebn4 = nn.GroupNorm(4, c_list[3])
|
| 253 |
+
self.ebn5 = nn.GroupNorm(4, c_list[4])
|
| 254 |
+
self.dbn1 = nn.GroupNorm(4, c_list[4])
|
| 255 |
+
self.dbn2 = nn.GroupNorm(4, c_list[3])
|
| 256 |
+
self.dbn3 = nn.GroupNorm(4, c_list[2])
|
| 257 |
+
self.dbn4 = nn.GroupNorm(4, c_list[1])
|
| 258 |
+
self.dbn5 = nn.GroupNorm(4, c_list[0])
|
| 259 |
+
|
| 260 |
+
self.final = nn.Conv2d(c_list[0], num_classes, kernel_size=1)
|
| 261 |
+
|
| 262 |
+
self.apply(self._init_weights)
|
| 263 |
+
|
| 264 |
+
def _init_weights(self, m):
|
| 265 |
+
if isinstance(m, nn.Linear):
|
| 266 |
+
trunc_normal_(m.weight, std=.02)
|
| 267 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 268 |
+
nn.init.constant_(m.bias, 0)
|
| 269 |
+
elif isinstance(m, nn.Conv1d):
|
| 270 |
+
n = m.kernel_size[0] * m.out_channels
|
| 271 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
| 272 |
+
elif isinstance(m, nn.Conv2d):
|
| 273 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 274 |
+
fan_out //= m.groups
|
| 275 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 276 |
+
if m.bias is not None:
|
| 277 |
+
m.bias.data.zero_()
|
| 278 |
+
|
| 279 |
+
def forward(self, x):
|
| 280 |
+
|
| 281 |
+
out = F.gelu(F.max_pool2d(self.ebn1(self.encoder1(x)),2,2))
|
| 282 |
+
t1 = out # b, c0, H/2, W/2
|
| 283 |
+
|
| 284 |
+
out = F.gelu(F.max_pool2d(self.ebn2(self.encoder2(out)),2,2))
|
| 285 |
+
t2 = out # b, c1, H/4, W/4
|
| 286 |
+
|
| 287 |
+
out = F.gelu(F.max_pool2d(self.ebn3(self.encoder3(out)),2,2))
|
| 288 |
+
t3 = out # b, c2, H/8, W/8
|
| 289 |
+
|
| 290 |
+
out = F.gelu(F.max_pool2d(self.ebn4(self.encoder4(out)),2,2))
|
| 291 |
+
t4 = out # b, c3, H/16, W/16
|
| 292 |
+
|
| 293 |
+
out = F.gelu(F.max_pool2d(self.ebn5(self.encoder5(out)),2,2))
|
| 294 |
+
t5 = out # b, c4, H/32, W/32
|
| 295 |
+
|
| 296 |
+
if self.bridge: t1, t2, t3, t4, t5 = self.scab(t1, t2, t3, t4, t5)
|
| 297 |
+
|
| 298 |
+
out = F.gelu(self.encoder6(out)) # b, c5, H/32, W/32
|
| 299 |
+
|
| 300 |
+
out5 = F.gelu(self.dbn1(self.decoder1(out))) # b, c4, H/32, W/32
|
| 301 |
+
out5 = torch.add(out5, t5) # b, c4, H/32, W/32
|
| 302 |
+
|
| 303 |
+
out4 = F.gelu(F.interpolate(self.dbn2(self.decoder2(out5)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c3, H/16, W/16
|
| 304 |
+
out4 = torch.add(out4, t4) # b, c3, H/16, W/16
|
| 305 |
+
|
| 306 |
+
out3 = F.gelu(F.interpolate(self.dbn3(self.decoder3(out4)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c2, H/8, W/8
|
| 307 |
+
out3 = torch.add(out3, t3) # b, c2, H/8, W/8
|
| 308 |
+
|
| 309 |
+
out2 = F.gelu(F.interpolate(self.dbn4(self.decoder4(out3)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c1, H/4, W/4
|
| 310 |
+
out2 = torch.add(out2, t2) # b, c1, H/4, W/4
|
| 311 |
+
|
| 312 |
+
out1 = F.gelu(F.interpolate(self.dbn5(self.decoder5(out2)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c0, H/2, W/2
|
| 313 |
+
out1 = torch.add(out1, t1) # b, c0, H/2, W/2
|
| 314 |
+
|
| 315 |
+
out0 = F.interpolate(self.final(out1),scale_factor=(2,2),mode ='bilinear',align_corners=True) # b, num_class, H, W
|
| 316 |
+
|
| 317 |
+
return torch.sigmoid(out0)
|