Kim Mạnh Hưng commited on
Commit
aa04f76
·
1 Parent(s): 0e003ab

Add U-Net app and weights

Browse files
Files changed (41) hide show
  1. .gitignore +3 -0
  2. app.py +118 -0
  3. configs/isic/isic2018_attunet.yaml +50 -0
  4. configs/isic/isic2018_missformer.yaml +52 -0
  5. configs/isic/isic2018_multiresunet.yaml +51 -0
  6. configs/isic/isic2018_resunet.yaml +50 -0
  7. configs/isic/isic2018_transunet.yaml +52 -0
  8. configs/isic/isic2018_uctransnet.yaml +50 -0
  9. configs/isic/isic2018_unet.yaml +51 -0
  10. configs/isic/isic2018_unetpp.yaml +51 -0
  11. configs/segpc/segpc2021_attunet.yaml +47 -0
  12. configs/segpc/segpc2021_missformer.yaml +49 -0
  13. configs/segpc/segpc2021_multiresunet.yaml +53 -0
  14. configs/segpc/segpc2021_resunet.yaml +47 -0
  15. configs/segpc/segpc2021_transunet.yaml +52 -0
  16. configs/segpc/segpc2021_uctransnet.yaml +47 -0
  17. configs/segpc/segpc2021_unet.yaml +48 -0
  18. configs/segpc/segpc2021_unetpp.yaml +48 -0
  19. models/__init__.py +0 -0
  20. models/_missformer/MISSFormer.py +398 -0
  21. models/_missformer/__init__.py +0 -0
  22. models/_missformer/segformer.py +557 -0
  23. models/_resunet/__init__.py +0 -0
  24. models/_resunet/modules.py +143 -0
  25. models/_resunet/res_unet.py +65 -0
  26. models/_transunet/vit_seg_configs.py +130 -0
  27. models/_transunet/vit_seg_modeling.py +453 -0
  28. models/_transunet/vit_seg_modeling_c4.py +453 -0
  29. models/_transunet/vit_seg_modeling_resnet_skip.py +160 -0
  30. models/_transunet/vit_seg_modeling_resnet_skip_c4.py +160 -0
  31. models/_uctransnet/CTrans.py +365 -0
  32. models/_uctransnet/Config.py +72 -0
  33. models/_uctransnet/UCTransNet.py +139 -0
  34. models/_uctransnet/UNet.py +111 -0
  35. models/attunet.py +427 -0
  36. models/multiresunet.py +190 -0
  37. models/unet.py +64 -0
  38. models/unetpp.py +141 -0
  39. requirements.txt +6 -0
  40. saved_models/isic2018_unet/best_model_state_dict.pt +3 -0
  41. saved_models/segpc2021_unet/best_model_state_dict.pt +3 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
app.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import yaml
6
+ import os
7
+ from models.unet import UNet
8
+
9
+ # Configuration
10
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
+ # Map dataset names to config/model paths
12
+ CONFIG_PATHS = {
13
+ 'isic': './configs/isic/isic2018_unet.yaml',
14
+ 'segpc': './configs/segpc/segpc2021_unet.yaml'
15
+ }
16
+ MODEL_PATHS = {
17
+ 'isic': './saved_models/isic2018_unet/best_model_state_dict.pt',
18
+ 'segpc': './saved_models/segpc2021_unet/best_model_state_dict.pt'
19
+ }
20
+
21
+ def load_config(config_path):
22
+ with open(config_path, 'r') as f:
23
+ return yaml.safe_load(f)
24
+
25
+ def load_model(dataset_name):
26
+ config = load_config(CONFIG_PATHS[dataset_name])
27
+ model = UNet(
28
+ in_channels=config['model']['in_channels'],
29
+ out_channels=config['model']['out_channels']
30
+ )
31
+ model_path = MODEL_PATHS[dataset_name]
32
+ if os.path.exists(model_path):
33
+ state_dict = torch.load(model_path, map_location=DEVICE)
34
+ model.load_state_dict(state_dict)
35
+ print(f"Loaded model for {dataset_name} from {model_path}")
36
+ else:
37
+ print(f"Warning: Model weights not found for {dataset_name} at {model_path}")
38
+
39
+ model.to(DEVICE)
40
+ model.eval()
41
+ return model
42
+
43
+ # Load models once (cache them)
44
+ models = {}
45
+ for ds in ['isic', 'segpc']:
46
+ try:
47
+ models[ds] = load_model(ds)
48
+ except Exception as e:
49
+ print(f"Error loading model {ds}: {e}")
50
+
51
+ def predict(image, dataset_choice):
52
+ if image is None:
53
+ return None
54
+
55
+ if dataset_choice not in models:
56
+ return None
57
+
58
+ model = models[dataset_choice]
59
+
60
+ # Preprocess
61
+ # Resize to 224x224 as per config
62
+ img_resized = image.resize((224, 224))
63
+ img_np = np.array(img_resized).astype(np.float32) / 255.0
64
+
65
+ # Handle channels
66
+ if dataset_choice == 'isic':
67
+ # ISIC: 3 channels (RGB)
68
+ if img_np.shape[-1] == 4:
69
+ img_np = img_np[:, :, :3]
70
+ img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).float()
71
+ else:
72
+ # SegPC: 4 channels (BMP input often loaded as RGB, need to assume/check)
73
+ if img_np.shape[-1] == 3:
74
+ # Create fake 4th channel
75
+ padding = np.zeros((224, 224, 1), dtype=np.float32)
76
+ img_np = np.concatenate([img_np, padding], axis=-1)
77
+
78
+ img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).float()
79
+
80
+ img_tensor = img_tensor.to(DEVICE)
81
+
82
+ with torch.no_grad():
83
+ output = model(img_tensor)
84
+ probs = torch.sigmoid(output)
85
+ pred_mask = (probs > 0.5).float().cpu().numpy()[0, 0]
86
+
87
+ # Post-process for visualization
88
+ # Create an overlay
89
+ base_img = np.array(img_resized)
90
+ overlay = base_img.copy()
91
+
92
+ # Green mask
93
+ mask_bool = pred_mask > 0
94
+ overlay[mask_bool] = [0, 255, 0] # Make Green
95
+
96
+ # Blend
97
+ final_img = (0.6 * base_img + 0.4 * overlay).astype(np.uint8)
98
+
99
+ return final_img
100
+
101
+ # Interface
102
+ iface = gr.Interface(
103
+ fn=predict,
104
+ inputs=[
105
+ gr.Image(type="pil", label="Input Image"),
106
+ gr.Radio(["isic", "segpc"], label="Dataset Model", value="isic")
107
+ ],
108
+ outputs=gr.Image(type="numpy", label="Prediction Overlay"),
109
+ title="Medical Image Segmentation (Awesome-U-Net)",
110
+ description="Upload an image to segment skin lesions (ISIC) or cells (SegPC).",
111
+ examples=[
112
+ # Add example paths if available
113
+ # ["dataset_examples/isic_sample.jpg", "isic"]
114
+ ]
115
+ )
116
+
117
+ if __name__ == "__main__":
118
+ iface.launch()
configs/isic/isic2018_attunet.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run:
2
+ mode: 'train'
3
+ device: 'gpu'
4
+ transforms: none
5
+ dataset:
6
+ class_name: "ISIC2018Dataset"
7
+ input_size: 224
8
+ training:
9
+ params:
10
+ data_dir: "/path/to/datasets/ISIC2018"
11
+ validation:
12
+ params:
13
+ data_dir: "/path/to/datasets/ISIC2018"
14
+ number_classes: 2
15
+ data_loader:
16
+ train:
17
+ batch_size: 16
18
+ shuffle: true
19
+ num_workers: 8
20
+ pin_memory: true
21
+ validation:
22
+ batch_size: 16
23
+ shuffle: false
24
+ num_workers: 8
25
+ pin_memory: true
26
+ test:
27
+ batch_size: 16
28
+ shuffle: false
29
+ num_workers: 4
30
+ pin_memory: false
31
+ training:
32
+ optimizer:
33
+ name: 'Adam'
34
+ params:
35
+ lr: 0.0001
36
+ criterion:
37
+ name: "DiceLoss"
38
+ params: {}
39
+ scheduler:
40
+ factor: 0.5
41
+ patience: 10
42
+ epochs: 100
43
+ model:
44
+ save_dir: '../../saved_models/isic2018_attunet'
45
+ load_weights: false
46
+ name: 'AttU_Net'
47
+ params:
48
+ img_ch: 3
49
+ output_ch: 2
50
+ # preprocess:
configs/isic/isic2018_missformer.yaml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run:
2
+ mode: 'train'
3
+ device: 'gpu'
4
+ transforms: none
5
+ dataset:
6
+ class_name: "ISIC2018Dataset"
7
+ input_size: 224
8
+ training:
9
+ params:
10
+ data_dir: "/path/to/datasets/ISIC2018"
11
+ validation:
12
+ params:
13
+ data_dir: "/path/to/datasets/ISIC2018"
14
+ number_classes: 2
15
+ data_loader:
16
+ train:
17
+ batch_size: 16
18
+ shuffle: true
19
+ num_workers: 8
20
+ pin_memory: true
21
+ validation:
22
+ batch_size: 16
23
+ shuffle: false
24
+ num_workers: 8
25
+ pin_memory: true
26
+ test:
27
+ batch_size: 16
28
+ shuffle: false
29
+ num_workers: 4
30
+ pin_memory: false
31
+ training:
32
+ optimizer:
33
+ name: 'SGD'
34
+ params:
35
+ lr: 0.0001
36
+ momentum: 0.9
37
+ weight_decay: 0.0001
38
+ criterion:
39
+ name: "DiceLoss"
40
+ params: {}
41
+ scheduler:
42
+ factor: 0.5
43
+ patience: 10
44
+ epochs: 300
45
+ model:
46
+ save_dir: '../../saved_models/isic2018_missformer'
47
+ load_weights: false
48
+ name: "MISSFormer"
49
+ params:
50
+ in_ch: 3
51
+ num_classes: 2
52
+ # preprocess:
configs/isic/isic2018_multiresunet.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run:
2
+ mode: 'train'
3
+ device: 'gpu'
4
+ transforms: none
5
+ dataset:
6
+ class_name: "ISIC2018Dataset"
7
+ input_size: 224
8
+ training:
9
+ params:
10
+ data_dir: "/path/to/datasets/ISIC2018"
11
+ validation:
12
+ params:
13
+ data_dir: "/path/to/datasets/ISIC2018"
14
+ number_classes: 2
15
+ data_loader:
16
+ train:
17
+ batch_size: 16
18
+ shuffle: true
19
+ num_workers: 2
20
+ pin_memory: true
21
+ validation:
22
+ batch_size: 16
23
+ shuffle: false
24
+ num_workers: 2
25
+ pin_memory: true
26
+ test:
27
+ batch_size: 16
28
+ shuffle: false
29
+ num_workers: 2
30
+ pin_memory: false
31
+ training:
32
+ optimizer:
33
+ name: 'Adam'
34
+ params:
35
+ lr: 0.0005
36
+ criterion:
37
+ name: "DiceLoss"
38
+ params: {}
39
+ scheduler:
40
+ factor: 0.5
41
+ patience: 10
42
+ epochs: 100
43
+ model:
44
+ save_dir: '../../saved_models/isic2018_multiresunet'
45
+ load_weights: false
46
+ name: 'MultiResUnet'
47
+ params:
48
+ channels: 3
49
+ filters: 32
50
+ nclasses: 2
51
+ # preprocess:
configs/isic/isic2018_resunet.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run:
2
+ mode: 'train'
3
+ device: 'gpu'
4
+ transforms: none
5
+ dataset:
6
+ class_name: "ISIC2018Dataset"
7
+ input_size: 224
8
+ training:
9
+ params:
10
+ data_dir: "/path/to/datasets/ISIC2018"
11
+ validation:
12
+ params:
13
+ data_dir: "/path/to/datasets/ISIC2018"
14
+ number_classes: 2
15
+ data_loader:
16
+ train:
17
+ batch_size: 16
18
+ shuffle: true
19
+ num_workers: 8
20
+ pin_memory: true
21
+ validation:
22
+ batch_size: 16
23
+ shuffle: false
24
+ num_workers: 8
25
+ pin_memory: true
26
+ test:
27
+ batch_size: 16
28
+ shuffle: false
29
+ num_workers: 4
30
+ pin_memory: false
31
+ training:
32
+ optimizer:
33
+ name: 'Adam'
34
+ params:
35
+ lr: 0.0001
36
+ criterion:
37
+ name: "DiceLoss"
38
+ params: {}
39
+ scheduler:
40
+ factor: 0.5
41
+ patience: 10
42
+ epochs: 100
43
+ model:
44
+ save_dir: '../../saved_models/isic2018_resunet'
45
+ load_weights: false
46
+ name: 'ResUnet'
47
+ params:
48
+ in_ch: 3
49
+ out_ch: 2
50
+ # preprocess:
configs/isic/isic2018_transunet.yaml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run:
2
+ mode: 'train'
3
+ device: 'gpu'
4
+ transforms: none
5
+ dataset:
6
+ class_name: "ISIC2018Dataset"
7
+ input_size: 224
8
+ training:
9
+ params:
10
+ data_dir: "/path/to/datasets/ISIC2018"
11
+ validation:
12
+ params:
13
+ data_dir: "/path/to/datasets/ISIC2018"
14
+ number_classes: 2
15
+ data_loader:
16
+ train:
17
+ batch_size: 16
18
+ shuffle: true
19
+ num_workers: 8
20
+ pin_memory: true
21
+ validation:
22
+ batch_size: 16
23
+ shuffle: false
24
+ num_workers: 8
25
+ pin_memory: true
26
+ test:
27
+ batch_size: 16
28
+ shuffle: false
29
+ num_workers: 4
30
+ pin_memory: false
31
+ training:
32
+ optimizer:
33
+ name: 'SGD'
34
+ params:
35
+ lr: 0.0001
36
+ momentum: 0.9
37
+ weight_decay: 0.0001
38
+ criterion:
39
+ name: "DiceLoss"
40
+ params: {}
41
+ scheduler:
42
+ factor: 0.5
43
+ patience: 10
44
+ epochs: 100
45
+ model:
46
+ save_dir: '../../saved_models/isic2018_transunet'
47
+ load_weights: false
48
+ name: 'VisionTransformer'
49
+ params:
50
+ img_size: 224
51
+ num_classes: 2
52
+ # preprocess:
configs/isic/isic2018_uctransnet.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run:
2
+ mode: 'train'
3
+ device: 'gpu'
4
+ transforms: none
5
+ dataset:
6
+ class_name: "ISIC2018Dataset"
7
+ input_size: 224
8
+ training:
9
+ params:
10
+ data_dir: "/path/to/datasets/ISIC2018"
11
+ validation:
12
+ params:
13
+ data_dir: "/path/to/datasets/ISIC2018"
14
+ number_classes: 2
15
+ data_loader:
16
+ train:
17
+ batch_size: 16
18
+ shuffle: true
19
+ num_workers: 8
20
+ pin_memory: true
21
+ validation:
22
+ batch_size: 16
23
+ shuffle: false
24
+ num_workers: 8
25
+ pin_memory: true
26
+ test:
27
+ batch_size: 16
28
+ shuffle: false
29
+ num_workers: 4
30
+ pin_memory: false
31
+ training:
32
+ optimizer:
33
+ name: 'Adam'
34
+ params:
35
+ lr: 0.0001
36
+ criterion:
37
+ name: "DiceLoss"
38
+ params: {}
39
+ scheduler:
40
+ factor: 0.5
41
+ patience: 10
42
+ epochs: 100
43
+ model:
44
+ save_dir: '../../saved_models/isic2018_uctransnet'
45
+ load_weights: false
46
+ name: "UCTransNet"
47
+ params:
48
+ n_channels: 3
49
+ n_classes: 2
50
+ # preprocess:
configs/isic/isic2018_unet.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run:
2
+ mode: 'train'
3
+ device: 'gpu'
4
+ transforms: none
5
+ dataset:
6
+ class_name: "ISIC2018Dataset"
7
+ input_size: 224
8
+ training:
9
+ params:
10
+ data_dir: "./datasets/ISIC2018"
11
+ validation:
12
+ params:
13
+ data_dir: "./datasets/ISIC2018"
14
+ number_classes: 2
15
+ data_loader:
16
+ train:
17
+ batch_size: 16
18
+ shuffle: true
19
+ num_workers: 0
20
+ pin_memory: true
21
+ validation:
22
+ batch_size: 16
23
+ shuffle: false
24
+ num_workers: 0
25
+ pin_memory: true
26
+ test:
27
+ batch_size: 16
28
+ shuffle: false
29
+ num_workers: 0
30
+ pin_memory: false
31
+ training:
32
+ optimizer:
33
+ name: 'Adam'
34
+ params:
35
+ lr: 0.0001
36
+ criterion:
37
+ name: "DiceLoss"
38
+ params: {}
39
+ scheduler:
40
+ factor: 0.5
41
+ patience: 10
42
+ epochs: 2
43
+ model:
44
+ save_dir: './saved_models/isic2018_unet'
45
+ load_weights: false
46
+ name: 'UNet'
47
+ params:
48
+ in_channels: 3
49
+ out_channels: 2
50
+ with_bn: false
51
+ # preprocess:
configs/isic/isic2018_unetpp.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run:
2
+ mode: 'train'
3
+ device: 'gpu'
4
+ transforms: none
5
+ dataset:
6
+ class_name: "ISIC2018Dataset"
7
+ input_size: 224
8
+ training:
9
+ params:
10
+ data_dir: "/path/to/datasets/ISIC2018"
11
+ validation:
12
+ params:
13
+ data_dir: "/path/to/datasets/ISIC2018"
14
+ number_classes: 2
15
+ data_loader:
16
+ train:
17
+ batch_size: 16
18
+ shuffle: true
19
+ num_workers: 8
20
+ pin_memory: true
21
+ validation:
22
+ batch_size: 16
23
+ shuffle: false
24
+ num_workers: 8
25
+ pin_memory: true
26
+ test:
27
+ batch_size: 16
28
+ shuffle: false
29
+ num_workers: 4
30
+ pin_memory: false
31
+ training:
32
+ optimizer:
33
+ name: 'Adam'
34
+ params:
35
+ lr: 0.0001
36
+ criterion:
37
+ name: "DiceLoss"
38
+ params: {}
39
+ scheduler:
40
+ factor: 0.5
41
+ patience: 10
42
+ epochs: 100
43
+ model:
44
+ save_dir: '../../saved_models/isic2018_unetpp'
45
+ load_weights: false
46
+ name: 'NestedUNet'
47
+ params:
48
+ num_classes: 2
49
+ input_channels: 3
50
+ deep_supervision: false
51
+ # preprocess:
configs/segpc/segpc2021_attunet.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run:
2
+ mode: 'train'
3
+ device: 'gpu'
4
+ transforms: none
5
+ dataset:
6
+ class_name: "SegPC2021Dataset"
7
+ input_size: 224
8
+ scale: 2.5
9
+ data_dir: "/path/to/datasets/segpc/np"
10
+ dataset_dir: "/path/to/datasets/segpc/TCIA_SegPC_dataset"
11
+ number_classes: 2
12
+ data_loader:
13
+ train:
14
+ batch_size: 16
15
+ shuffle: true
16
+ num_workers: 4
17
+ pin_memory: true
18
+ validation:
19
+ batch_size: 16
20
+ shuffle: false
21
+ num_workers: 4
22
+ pin_memory: true
23
+ test:
24
+ batch_size: 16
25
+ shuffle: false
26
+ num_workers: 4
27
+ pin_memory: false
28
+ training:
29
+ optimizer:
30
+ name: 'Adam'
31
+ params:
32
+ lr: 0.0001
33
+ criterion:
34
+ name: "DiceLoss"
35
+ params: {}
36
+ scheduler:
37
+ factor: 0.5
38
+ patience: 10
39
+ epochs: 100
40
+ model:
41
+ save_dir: '../../saved_models/segpc2021_attunet'
42
+ load_weights: false
43
+ name: 'AttU_Net'
44
+ params:
45
+ img_ch: 4
46
+ output_ch: 2
47
+ # preprocess:
configs/segpc/segpc2021_missformer.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run:
2
+ mode: 'train'
3
+ device: 'gpu'
4
+ transforms: none
5
+ dataset:
6
+ class_name: "SegPC2021Dataset"
7
+ input_size: 224
8
+ scale: 2.5
9
+ data_dir: "/path/to/datasets/segpc/np"
10
+ dataset_dir: "/path/to/datasets/segpc/TCIA_SegPC_dataset"
11
+ number_classes: 2
12
+ data_loader:
13
+ train:
14
+ batch_size: 16
15
+ shuffle: true
16
+ num_workers: 4
17
+ pin_memory: true
18
+ validation:
19
+ batch_size: 16
20
+ shuffle: false
21
+ num_workers: 4
22
+ pin_memory: true
23
+ test:
24
+ batch_size: 16
25
+ shuffle: false
26
+ num_workers: 4
27
+ pin_memory: false
28
+ training:
29
+ optimizer:
30
+ name: 'SGD'
31
+ params:
32
+ lr: 0.0001
33
+ momentum: 0.9
34
+ weight_decay: 0.0001
35
+ criterion:
36
+ name: "DiceLoss"
37
+ params: {}
38
+ scheduler:
39
+ factor: 0.5
40
+ patience: 10
41
+ epochs: 500
42
+ model:
43
+ save_dir: '../../saved_models/segpc2021_missformer'
44
+ load_weights: false
45
+ name: 'MISSFormer'
46
+ params:
47
+ in_ch: 4
48
+ num_classes: 2
49
+ # preprocess:
configs/segpc/segpc2021_multiresunet.yaml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run:
2
+ mode: 'train'
3
+ device: 'gpu'
4
+ transforms: none
5
+ dataset:
6
+ class_name: "SegPC2021Dataset"
7
+ input_size: 224
8
+ scale: 2.5
9
+ data_dir: "/path/to/datasets/segpc/np"
10
+ dataset_dir: "/path/to/datasets/segpc/TCIA_SegPC_dataset"
11
+ number_classes: 2
12
+ data_loader:
13
+ train:
14
+ batch_size: 16
15
+ shuffle: true
16
+ num_workers: 4
17
+ pin_memory: true
18
+ validation:
19
+ batch_size: 16
20
+ shuffle: false
21
+ num_workers: 4
22
+ pin_memory: true
23
+ test:
24
+ batch_size: 16
25
+ shuffle: false
26
+ num_workers: 4
27
+ pin_memory: false
28
+ training:
29
+ optimizer:
30
+ name: 'Adam'
31
+ params:
32
+ lr: 0.0001
33
+ # name: "SGD"
34
+ # params:
35
+ # lr: 0.0001
36
+ # momentum: 0.9
37
+ # weight_decay: 0.0001
38
+ criterion:
39
+ name: "DiceLoss"
40
+ params: {}
41
+ scheduler:
42
+ factor: 0.5
43
+ patience: 10
44
+ epochs: 100
45
+ model:
46
+ save_dir: '../../saved_models/segpc2021_multiresunet'
47
+ load_weights: false
48
+ name: 'MultiResUnet'
49
+ params:
50
+ channels: 4
51
+ filters: 32
52
+ nclasses: 2
53
+ # preprocess:
configs/segpc/segpc2021_resunet.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run:
2
+ mode: 'train'
3
+ device: 'gpu'
4
+ transforms: none
5
+ dataset:
6
+ class_name: "SegPC2021Dataset"
7
+ input_size: 224
8
+ scale: 2.5
9
+ data_dir: "/path/to/datasets/segpc/np"
10
+ dataset_dir: "/path/to/datasets/segpc/TCIA_SegPC_dataset"
11
+ number_classes: 2
12
+ data_loader:
13
+ train:
14
+ batch_size: 16
15
+ shuffle: true
16
+ num_workers: 4
17
+ pin_memory: true
18
+ validation:
19
+ batch_size: 16
20
+ shuffle: false
21
+ num_workers: 4
22
+ pin_memory: true
23
+ test:
24
+ batch_size: 16
25
+ shuffle: false
26
+ num_workers: 4
27
+ pin_memory: false
28
+ training:
29
+ optimizer:
30
+ name: 'Adam'
31
+ params:
32
+ lr: 0.0001
33
+ criterion:
34
+ name: "DiceLoss"
35
+ params: {}
36
+ scheduler:
37
+ factor: 0.5
38
+ patience: 10
39
+ epochs: 100
40
+ model:
41
+ save_dir: '../../saved_models/segpc2021_resunet'
42
+ load_weights: false
43
+ name: 'ResUnet'
44
+ params:
45
+ in_ch: 4
46
+ out_ch: 2
47
+ # preprocess:
configs/segpc/segpc2021_transunet.yaml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run:
2
+ mode: 'train'
3
+ device: 'gpu'
4
+ transforms: none
5
+ dataset:
6
+ class_name: "SegPC2021Dataset"
7
+ input_size: 224
8
+ scale: 2.5
9
+ data_dir: "/path/to/datasets/segpc/np"
10
+ dataset_dir: "/path/to/datasets/segpc/TCIA_SegPC_dataset"
11
+ number_classes: 2
12
+ data_loader:
13
+ train:
14
+ batch_size: 16
15
+ shuffle: true
16
+ num_workers: 4
17
+ pin_memory: true
18
+ validation:
19
+ batch_size: 16
20
+ shuffle: false
21
+ num_workers: 4
22
+ pin_memory: true
23
+ test:
24
+ batch_size: 16
25
+ shuffle: false
26
+ num_workers: 4
27
+ pin_memory: false
28
+ training:
29
+ optimizer:
30
+ # name: 'Adam'
31
+ # params:
32
+ # lr: 0.0001
33
+ name: "SGD"
34
+ params:
35
+ lr: 0.0001
36
+ momentum: 0.9
37
+ weight_decay: 0.0001
38
+ criterion:
39
+ name: "DiceLoss"
40
+ params: {}
41
+ scheduler:
42
+ factor: 0.5
43
+ patience: 10
44
+ epochs: 100
45
+ model:
46
+ save_dir: '../../saved_models/segpc2021_transunet'
47
+ load_weights: false
48
+ name: 'VisionTransformer'
49
+ params:
50
+ img_size: 224
51
+ num_classes: 2
52
+ # preprocess:
configs/segpc/segpc2021_uctransnet.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run:
2
+ mode: 'train'
3
+ device: 'gpu'
4
+ transforms: none
5
+ dataset:
6
+ class_name: "SegPC2021Dataset"
7
+ input_size: 224
8
+ scale: 2.5
9
+ data_dir: "/path/to/datasets/segpc/np"
10
+ dataset_dir: "/path/to/datasets/segpc/TCIA_SegPC_dataset"
11
+ number_classes: 2
12
+ data_loader:
13
+ train:
14
+ batch_size: 16
15
+ shuffle: true
16
+ num_workers: 4
17
+ pin_memory: true
18
+ validation:
19
+ batch_size: 16
20
+ shuffle: false
21
+ num_workers: 4
22
+ pin_memory: true
23
+ test:
24
+ batch_size: 16
25
+ shuffle: false
26
+ num_workers: 4
27
+ pin_memory: false
28
+ training:
29
+ optimizer:
30
+ name: 'Adam'
31
+ params:
32
+ lr: 0.0001
33
+ criterion:
34
+ name: "DiceLoss"
35
+ params: {}
36
+ scheduler:
37
+ factor: 0.5
38
+ patience: 10
39
+ epochs: 100
40
+ model:
41
+ save_dir: '../../saved_models/segpc2021_uctransnet'
42
+ load_weights: false
43
+ name: 'UCTransNet'
44
+ params:
45
+ n_channels: 4
46
+ n_classes: 2
47
+ # preprocess:
configs/segpc/segpc2021_unet.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run:
2
+ mode: 'train'
3
+ device: 'gpu'
4
+ transforms: none
5
+ dataset:
6
+ class_name: "SegPC2021Dataset"
7
+ input_size: 224
8
+ scale: 2.5
9
+ data_dir: "./datasets/SegPC2021/np"
10
+ dataset_dir: "./datasets/SegPC2021/TCIA_SegPC_dataset"
11
+ number_classes: 2
12
+ data_loader:
13
+ train:
14
+ batch_size: 8
15
+ shuffle: true
16
+ num_workers: 0
17
+ pin_memory: true
18
+ validation:
19
+ batch_size: 8
20
+ shuffle: false
21
+ num_workers: 0
22
+ pin_memory: true
23
+ test:
24
+ batch_size: 8
25
+ shuffle: false
26
+ num_workers: 0
27
+ pin_memory: false
28
+ training:
29
+ optimizer:
30
+ name: 'Adam'
31
+ params:
32
+ lr: 0.0001
33
+ criterion:
34
+ name: "DiceLoss"
35
+ params: {}
36
+ scheduler:
37
+ factor: 0.5
38
+ patience: 10
39
+ epochs: 2
40
+ model:
41
+ save_dir: './saved_models/segpc2021_unet'
42
+ load_weights: false
43
+ name: 'UNet'
44
+ params:
45
+ in_channels: 4
46
+ out_channels: 2
47
+ with_bn: false
48
+ # preprocess:
configs/segpc/segpc2021_unetpp.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run:
2
+ mode: 'train'
3
+ device: 'gpu'
4
+ transforms: none
5
+ dataset:
6
+ class_name: "SegPC2021Dataset"
7
+ input_size: 224
8
+ scale: 2.5
9
+ data_dir: "/path/to/datasets/segpc/np"
10
+ dataset_dir: "/path/to/datasets/segpc/TCIA_SegPC_dataset"
11
+ number_classes: 2
12
+ data_loader:
13
+ train:
14
+ batch_size: 16
15
+ shuffle: true
16
+ num_workers: 4
17
+ pin_memory: true
18
+ validation:
19
+ batch_size: 16
20
+ shuffle: false
21
+ num_workers: 4
22
+ pin_memory: true
23
+ test:
24
+ batch_size: 16
25
+ shuffle: false
26
+ num_workers: 4
27
+ pin_memory: false
28
+ training:
29
+ optimizer:
30
+ name: 'Adam'
31
+ params:
32
+ lr: 0.0001
33
+ criterion:
34
+ name: "DiceLoss"
35
+ params: {}
36
+ scheduler:
37
+ factor: 0.5
38
+ patience: 10
39
+ epochs: 100
40
+ model:
41
+ save_dir: '../../saved_models/segpc2021_unetpp'
42
+ load_weights: false
43
+ name: 'NestedUNet'
44
+ params:
45
+ num_classes: 2
46
+ input_channels: 4
47
+ deep_supervision: false
48
+ # preprocess:
models/__init__.py ADDED
File without changes
models/_missformer/MISSFormer.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .segformer import *
4
+ from typing import Tuple
5
+ from einops import rearrange
6
+
7
+ class PatchExpand(nn.Module):
8
+ def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
9
+ super().__init__()
10
+ self.input_resolution = input_resolution
11
+ self.dim = dim
12
+ self.expand = nn.Linear(dim, 2*dim, bias=False) if dim_scale==2 else nn.Identity()
13
+ self.norm = norm_layer(dim // dim_scale)
14
+
15
+ def forward(self, x):
16
+ """
17
+ x: B, H*W, C
18
+ """
19
+ # print("x_shape-----",x.shape)
20
+ H, W = self.input_resolution
21
+ x = self.expand(x)
22
+
23
+ B, L, C = x.shape
24
+ # print(x.shape)
25
+ assert L == H * W, "input feature has wrong size"
26
+
27
+ x = x.view(B, H, W, C)
28
+ x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4)
29
+ x = x.view(B,-1,C//4)
30
+ x= self.norm(x.clone())
31
+
32
+ return x
33
+
34
+ class FinalPatchExpand_X4(nn.Module):
35
+ def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm):
36
+ super().__init__()
37
+ self.input_resolution = input_resolution
38
+ self.dim = dim
39
+ self.dim_scale = dim_scale
40
+ self.expand = nn.Linear(dim, 16*dim, bias=False)
41
+ self.output_dim = dim
42
+ self.norm = norm_layer(self.output_dim)
43
+
44
+ def forward(self, x):
45
+ """
46
+ x: B, H*W, C
47
+ """
48
+ H, W = self.input_resolution
49
+ x = self.expand(x)
50
+ B, L, C = x.shape
51
+ assert L == H * W, "input feature has wrong size"
52
+
53
+ x = x.view(B, H, W, C)
54
+ x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//(self.dim_scale**2))
55
+ x = x.view(B,-1,self.output_dim)
56
+ x= self.norm(x.clone())
57
+
58
+ return x
59
+
60
+
61
+ class SegU_decoder(nn.Module):
62
+ def __init__(self, input_size, in_out_chan, heads, reduction_ratios, n_class=9, norm_layer=nn.LayerNorm, is_last=False):
63
+ super().__init__()
64
+ dims = in_out_chan[0]
65
+ out_dim = in_out_chan[1]
66
+ if not is_last:
67
+ self.concat_linear = nn.Linear(dims*2, out_dim)
68
+ # transformer decoder
69
+ self.layer_up = PatchExpand(input_resolution=input_size, dim=out_dim, dim_scale=2, norm_layer=norm_layer)
70
+ self.last_layer = None
71
+ else:
72
+ self.concat_linear = nn.Linear(dims*4, out_dim)
73
+ # transformer decoder
74
+ self.layer_up = FinalPatchExpand_X4(input_resolution=input_size, dim=out_dim, dim_scale=4, norm_layer=norm_layer)
75
+ # self.last_layer = nn.Linear(out_dim, n_class)
76
+ self.last_layer = nn.Conv2d(out_dim, n_class,1)
77
+ # self.last_layer = None
78
+
79
+ self.layer_former_1 = TransformerBlock(out_dim, heads, reduction_ratios)
80
+ self.layer_former_2 = TransformerBlock(out_dim, heads, reduction_ratios)
81
+
82
+
83
+ def init_weights(self):
84
+ for m in self.modules():
85
+ if isinstance(m, nn.Linear):
86
+ nn.init.xavier_uniform_(m.weight)
87
+ if m.bias is not None:
88
+ nn.init.zeros_(m.bias)
89
+ elif isinstance(m, nn.LayerNorm):
90
+ nn.init.ones_(m.weight)
91
+ nn.init.zeros_(m.bias)
92
+ elif isinstance(m, nn.Conv2d):
93
+ nn.init.xavier_uniform_(m.weight)
94
+ if m.bias is not None:
95
+ nn.init.zeros_(m.bias)
96
+
97
+ init_weights(self)
98
+
99
+
100
+
101
+ def forward(self, x1, x2=None):
102
+ if x2 is not None:
103
+ b, h, w, c = x2.shape
104
+ x2 = x2.view(b, -1, c)
105
+ # print("------",x1.shape, x2.shape)
106
+ cat_x = torch.cat([x1, x2], dim=-1)
107
+ # print("-----catx shape", cat_x.shape)
108
+ cat_linear_x = self.concat_linear(cat_x)
109
+ tran_layer_1 = self.layer_former_1(cat_linear_x, h, w)
110
+ tran_layer_2 = self.layer_former_2(tran_layer_1, h, w)
111
+
112
+ if self.last_layer:
113
+ out = self.last_layer(self.layer_up(tran_layer_2).view(b, 4*h, 4*w, -1).permute(0,3,1,2))
114
+ else:
115
+ out = self.layer_up(tran_layer_2)
116
+ else:
117
+ # if len(x1.shape)>3:
118
+ # x1 = x1.permute(0,2,3,1)
119
+ # b, h, w, c = x1.shape
120
+ # x1 = x1.view(b, -1, c)
121
+ out = self.layer_up(x1)
122
+ return out
123
+
124
+
125
+ class BridgeLayer_4(nn.Module):
126
+ def __init__(self, dims, head, reduction_ratios):
127
+ super().__init__()
128
+
129
+ self.norm1 = nn.LayerNorm(dims)
130
+ self.attn = M_EfficientSelfAtten(dims, head, reduction_ratios)
131
+ self.norm2 = nn.LayerNorm(dims)
132
+ self.mixffn1 = MixFFN_skip(dims,dims*4)
133
+ self.mixffn2 = MixFFN_skip(dims*2,dims*8)
134
+ self.mixffn3 = MixFFN_skip(dims*5,dims*20)
135
+ self.mixffn4 = MixFFN_skip(dims*8,dims*32)
136
+
137
+
138
+ def forward(self, inputs):
139
+ B = inputs[0].shape[0]
140
+ C = 64
141
+ if (type(inputs) == list):
142
+ # print("-----1-----")
143
+ c1, c2, c3, c4 = inputs
144
+ B, C, _, _= c1.shape
145
+ c1f = c1.permute(0, 2, 3, 1).reshape(B, -1, C) # 3136*64
146
+ c2f = c2.permute(0, 2, 3, 1).reshape(B, -1, C) # 1568*64
147
+ c3f = c3.permute(0, 2, 3, 1).reshape(B, -1, C) # 980*64
148
+ c4f = c4.permute(0, 2, 3, 1).reshape(B, -1, C) # 392*64
149
+
150
+ # print(c1f.shape, c2f.shape, c3f.shape, c4f.shape)
151
+ inputs = torch.cat([c1f, c2f, c3f, c4f], -2)
152
+ else:
153
+ B,_,C = inputs.shape
154
+
155
+ tx1 = inputs + self.attn(self.norm1(inputs))
156
+ tx = self.norm2(tx1)
157
+
158
+
159
+ tem1 = tx[:,:3136,:].reshape(B, -1, C)
160
+ tem2 = tx[:,3136:4704,:].reshape(B, -1, C*2)
161
+ tem3 = tx[:,4704:5684,:].reshape(B, -1, C*5)
162
+ tem4 = tx[:,5684:6076,:].reshape(B, -1, C*8)
163
+
164
+ m1f = self.mixffn1(tem1, 56, 56).reshape(B, -1, C)
165
+ m2f = self.mixffn2(tem2, 28, 28).reshape(B, -1, C)
166
+ m3f = self.mixffn3(tem3, 14, 14).reshape(B, -1, C)
167
+ m4f = self.mixffn4(tem4, 7, 7).reshape(B, -1, C)
168
+
169
+ t1 = torch.cat([m1f, m2f, m3f, m4f], -2)
170
+
171
+ tx2 = tx1 + t1
172
+
173
+
174
+ return tx2
175
+
176
+
177
+ class BridgeLayer_3(nn.Module):
178
+ def __init__(self, dims, head, reduction_ratios):
179
+ super().__init__()
180
+
181
+ self.norm1 = nn.LayerNorm(dims)
182
+ self.attn = M_EfficientSelfAtten(dims, head, reduction_ratios)
183
+ self.norm2 = nn.LayerNorm(dims)
184
+ # self.mixffn1 = MixFFN(dims,dims*4)
185
+ self.mixffn2 = MixFFN(dims*2,dims*8)
186
+ self.mixffn3 = MixFFN(dims*5,dims*20)
187
+ self.mixffn4 = MixFFN(dims*8,dims*32)
188
+
189
+
190
+ def forward(self, inputs: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) -> torch.Tensor:
191
+ B = inputs[0].shape[0]
192
+ C = 64
193
+ if (type(inputs) == list):
194
+ # print("-----1-----")
195
+ c1, c2, c3, c4 = inputs
196
+ B, C, _, _= c1.shape
197
+ c1f = c1.permute(0, 2, 3, 1).reshape(B, -1, C) # 3136*64
198
+ c2f = c2.permute(0, 2, 3, 1).reshape(B, -1, C) # 1568*64
199
+ c3f = c3.permute(0, 2, 3, 1).reshape(B, -1, C) # 980*64
200
+ c4f = c4.permute(0, 2, 3, 1).reshape(B, -1, C) # 392*64
201
+
202
+ # print(c1f.shape, c2f.shape, c3f.shape, c4f.shape)
203
+ inputs = torch.cat([c2f, c3f, c4f], -2)
204
+ else:
205
+ B,_,C = inputs.shape
206
+
207
+ tx1 = inputs + self.attn(self.norm1(inputs))
208
+ tx = self.norm2(tx1)
209
+
210
+
211
+ # tem1 = tx[:,:3136,:].reshape(B, -1, C)
212
+ tem2 = tx[:,:1568,:].reshape(B, -1, C*2)
213
+ tem3 = tx[:,1568:2548,:].reshape(B, -1, C*5)
214
+ tem4 = tx[:,2548:2940,:].reshape(B, -1, C*8)
215
+
216
+ # m1f = self.mixffn1(tem1, 56, 56).reshape(B, -1, C)
217
+ m2f = self.mixffn2(tem2, 28, 28).reshape(B, -1, C)
218
+ m3f = self.mixffn3(tem3, 14, 14).reshape(B, -1, C)
219
+ m4f = self.mixffn4(tem4, 7, 7).reshape(B, -1, C)
220
+
221
+ t1 = torch.cat([m2f, m3f, m4f], -2)
222
+
223
+ tx2 = tx1 + t1
224
+
225
+
226
+ return tx2
227
+
228
+
229
+
230
+ class BridegeBlock_4(nn.Module):
231
+ def __init__(self, dims, head, reduction_ratios):
232
+ super().__init__()
233
+ self.bridge_layer1 = BridgeLayer_4(dims, head, reduction_ratios)
234
+ self.bridge_layer2 = BridgeLayer_4(dims, head, reduction_ratios)
235
+ self.bridge_layer3 = BridgeLayer_4(dims, head, reduction_ratios)
236
+ self.bridge_layer4 = BridgeLayer_4(dims, head, reduction_ratios)
237
+
238
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
239
+ bridge1 = self.bridge_layer1(x)
240
+ bridge2 = self.bridge_layer2(bridge1)
241
+ bridge3 = self.bridge_layer3(bridge2)
242
+ bridge4 = self.bridge_layer4(bridge3)
243
+
244
+ B,_,C = bridge4.shape
245
+ outs = []
246
+
247
+ sk1 = bridge4[:,:3136,:].reshape(B, 56, 56, C).permute(0,3,1,2)
248
+ sk2 = bridge4[:,3136:4704,:].reshape(B, 28, 28, C*2).permute(0,3,1,2)
249
+ sk3 = bridge4[:,4704:5684,:].reshape(B, 14, 14, C*5).permute(0,3,1,2)
250
+ sk4 = bridge4[:,5684:6076,:].reshape(B, 7, 7, C*8).permute(0,3,1,2)
251
+
252
+ outs.append(sk1)
253
+ outs.append(sk2)
254
+ outs.append(sk3)
255
+ outs.append(sk4)
256
+
257
+ return outs
258
+
259
+
260
+ class BridegeBlock_3(nn.Module):
261
+ def __init__(self, dims, head, reduction_ratios):
262
+ super().__init__()
263
+ self.bridge_layer1 = BridgeLayer_3(dims, head, reduction_ratios)
264
+ self.bridge_layer2 = BridgeLayer_3(dims, head, reduction_ratios)
265
+ self.bridge_layer3 = BridgeLayer_3(dims, head, reduction_ratios)
266
+ self.bridge_layer4 = BridgeLayer_3(dims, head, reduction_ratios)
267
+
268
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
269
+ outs = []
270
+ if (type(x) == list):
271
+ # print("-----1-----")
272
+ outs.append(x[0])
273
+ bridge1 = self.bridge_layer1(x)
274
+ bridge2 = self.bridge_layer2(bridge1)
275
+ bridge3 = self.bridge_layer3(bridge2)
276
+ bridge4 = self.bridge_layer4(bridge3)
277
+
278
+ B,_,C = bridge4.shape
279
+
280
+
281
+ # sk1 = bridge2[:,:3136,:].reshape(B, 56, 56, C).permute(0,3,1,2)
282
+ sk2 = bridge4[:,:1568,:].reshape(B, 28, 28, C*2).permute(0,3,1,2)
283
+ sk3 = bridge4[:,1568:2548,:].reshape(B, 14, 14, C*5).permute(0,3,1,2)
284
+ sk4 = bridge4[:,2548:2940,:].reshape(B, 7, 7, C*8).permute(0,3,1,2)
285
+
286
+ # outs.append(sk1)
287
+ outs.append(sk2)
288
+ outs.append(sk3)
289
+ outs.append(sk4)
290
+
291
+ return outs
292
+
293
+
294
+ class MyDecoderLayer(nn.Module):
295
+ def __init__(self, input_size, in_out_chan, heads, reduction_ratios,token_mlp_mode, n_class=9, norm_layer=nn.LayerNorm, is_last=False):
296
+ super().__init__()
297
+ dims = in_out_chan[0]
298
+ out_dim = in_out_chan[1]
299
+ if not is_last:
300
+ self.concat_linear = nn.Linear(dims*2, out_dim)
301
+ # transformer decoder
302
+ self.layer_up = PatchExpand(input_resolution=input_size, dim=out_dim, dim_scale=2, norm_layer=norm_layer)
303
+ self.last_layer = None
304
+ else:
305
+ self.concat_linear = nn.Linear(dims*4, out_dim)
306
+ # transformer decoder
307
+ self.layer_up = FinalPatchExpand_X4(input_resolution=input_size, dim=out_dim, dim_scale=4, norm_layer=norm_layer)
308
+ # self.last_layer = nn.Linear(out_dim, n_class)
309
+ self.last_layer = nn.Conv2d(out_dim, n_class,1)
310
+ # self.last_layer = None
311
+
312
+ self.layer_former_1 = TransformerBlock(out_dim, heads, reduction_ratios, token_mlp_mode)
313
+ self.layer_former_2 = TransformerBlock(out_dim, heads, reduction_ratios, token_mlp_mode)
314
+
315
+
316
+ def init_weights(self):
317
+ for m in self.modules():
318
+ if isinstance(m, nn.Linear):
319
+ nn.init.xavier_uniform_(m.weight)
320
+ if m.bias is not None:
321
+ nn.init.zeros_(m.bias)
322
+ elif isinstance(m, nn.LayerNorm):
323
+ nn.init.ones_(m.weight)
324
+ nn.init.zeros_(m.bias)
325
+ elif isinstance(m, nn.Conv2d):
326
+ nn.init.xavier_uniform_(m.weight)
327
+ if m.bias is not None:
328
+ nn.init.zeros_(m.bias)
329
+
330
+ init_weights(self)
331
+
332
+ def forward(self, x1, x2=None):
333
+ if x2 is not None:
334
+ b, h, w, c = x2.shape
335
+ x2 = x2.view(b, -1, c)
336
+ # print("------",x1.shape, x2.shape)
337
+ cat_x = torch.cat([x1, x2], dim=-1)
338
+ # print("-----catx shape", cat_x.shape)
339
+ cat_linear_x = self.concat_linear(cat_x)
340
+ tran_layer_1 = self.layer_former_1(cat_linear_x, h, w)
341
+ tran_layer_2 = self.layer_former_2(tran_layer_1, h, w)
342
+
343
+ if self.last_layer:
344
+ out = self.last_layer(self.layer_up(tran_layer_2).view(b, 4*h, 4*w, -1).permute(0,3,1,2))
345
+ else:
346
+ out = self.layer_up(tran_layer_2)
347
+ else:
348
+ # if len(x1.shape)>3:
349
+ # x1 = x1.permute(0,2,3,1)
350
+ # b, h, w, c = x1.shape
351
+ # x1 = x1.view(b, -1, c)
352
+ out = self.layer_up(x1)
353
+ return out
354
+
355
+ class MISSFormer(nn.Module):
356
+ def __init__(self, num_classes=9, in_ch=3, token_mlp_mode="mix_skip", encoder_pretrained=True):
357
+ super().__init__()
358
+
359
+ reduction_ratios = [8, 4, 2, 1]
360
+ heads = [1, 2, 5, 8]
361
+ d_base_feat_size = 7 #16 for 512 inputsize 7for 224
362
+ in_out_chan = [[32, 64],[144, 128],[288, 320],[512, 512]]
363
+
364
+ dims, layers = [[64, 128, 320, 512], [2, 2, 2, 2]]
365
+ self.backbone = MiT(224, dims, layers,in_ch, token_mlp_mode)
366
+
367
+ self.reduction_ratios = [1, 2, 4, 8]
368
+ self.bridge = BridegeBlock_4(64, 1, self.reduction_ratios)
369
+
370
+ self.decoder_3= MyDecoderLayer((d_base_feat_size,d_base_feat_size), in_out_chan[3], heads[3], reduction_ratios[3],token_mlp_mode, n_class=num_classes)
371
+ self.decoder_2= MyDecoderLayer((d_base_feat_size*2,d_base_feat_size*2),in_out_chan[2], heads[2], reduction_ratios[2], token_mlp_mode, n_class=num_classes)
372
+ self.decoder_1= MyDecoderLayer((d_base_feat_size*4,d_base_feat_size*4), in_out_chan[1], heads[1], reduction_ratios[1], token_mlp_mode, n_class=num_classes)
373
+ self.decoder_0= MyDecoderLayer((d_base_feat_size*8,d_base_feat_size*8), in_out_chan[0], heads[0], reduction_ratios[0], token_mlp_mode, n_class=num_classes, is_last=True)
374
+
375
+
376
+ def forward(self, x):
377
+ #---------------Encoder-------------------------
378
+ if x.size()[1] == 1:
379
+ x = x.repeat(1,3,1,1)
380
+
381
+ encoder = self.backbone(x)
382
+ bridge = self.bridge(encoder) #list
383
+
384
+ b,c,_,_ = bridge[3].shape
385
+ # print(bridge[3].shape, bridge[2].shape,bridge[1].shape, bridge[0].shape)
386
+ #---------------Decoder-------------------------
387
+ # print("stage3-----")
388
+ tmp_3 = self.decoder_3(bridge[3].permute(0,2,3,1).view(b,-1,c))
389
+ # print("stage2-----")
390
+ tmp_2 = self.decoder_2(tmp_3, bridge[2].permute(0,2,3,1))
391
+ # print("stage1-----")
392
+ tmp_1 = self.decoder_1(tmp_2, bridge[1].permute(0,2,3,1))
393
+ # print("stage0-----")
394
+ tmp_0 = self.decoder_0(tmp_1, bridge[0].permute(0,2,3,1))
395
+
396
+ return tmp_0
397
+
398
+
models/_missformer/__init__.py ADDED
File without changes
models/_missformer/segformer.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ from typing import Tuple
5
+
6
+
7
+ class EfficientSelfAtten(nn.Module):
8
+ def __init__(self, dim, head, reduction_ratio):
9
+ super().__init__()
10
+ self.head = head
11
+ self.reduction_ratio = reduction_ratio
12
+ self.scale = (dim // head) ** -0.5
13
+ self.q = nn.Linear(dim, dim, bias=True)
14
+ self.kv = nn.Linear(dim, dim*2, bias=True)
15
+ self.proj = nn.Linear(dim, dim)
16
+
17
+ if reduction_ratio > 1:
18
+ self.sr = nn.Conv2d(dim, dim, reduction_ratio, reduction_ratio)
19
+ self.norm = nn.LayerNorm(dim)
20
+
21
+ def forward(self, x: torch.Tensor, H, W) -> torch.Tensor:
22
+ B, N, C = x.shape
23
+ q = self.q(x).reshape(B, N, self.head, C // self.head).permute(0, 2, 1, 3)
24
+
25
+ if self.reduction_ratio > 1:
26
+ p_x = x.clone().permute(0, 2, 1).reshape(B, C, H, W)
27
+ sp_x = self.sr(p_x).reshape(B, C, -1).permute(0, 2, 1)
28
+ x = self.norm(sp_x)
29
+
30
+ kv = self.kv(x).reshape(B, -1, 2, self.head, C // self.head).permute(2, 0, 3, 1, 4)
31
+ k, v = kv[0], kv[1]
32
+
33
+ attn = (q @ k.transpose(-2, -1)) * self.scale
34
+ attn_score = attn.softmax(dim=-1)
35
+
36
+ x_atten = (attn_score @ v).transpose(1, 2).reshape(B, N, C)
37
+ out = self.proj(x_atten)
38
+
39
+ return out
40
+
41
+
42
+ class SelfAtten(nn.Module):
43
+ def __init__(self, dim, head):
44
+ super().__init__()
45
+ self.head = head
46
+ self.scale = (dim // head) ** -0.5
47
+ self.q = nn.Linear(dim, dim, bias=True)
48
+ self.kv = nn.Linear(dim, dim*2, bias=True)
49
+ self.proj = nn.Linear(dim, dim)
50
+
51
+
52
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
53
+ B, N, C = x.shape
54
+ q = self.q(x).reshape(B, N, self.head, C // self.head).permute(0, 2, 1, 3)
55
+
56
+ kv = self.kv(x).reshape(B, -1, 2, self.head, C // self.head).permute(2, 0, 3, 1, 4)
57
+ k, v = kv[0], kv[1]
58
+
59
+ attn = (q @ k.transpose(-2, -1)) * self.scale
60
+ attn_score = attn.softmax(dim=-1)
61
+
62
+ x_atten = (attn_score @ v).transpose(1, 2).reshape(B, N, C)
63
+ out = self.proj(x_atten)
64
+
65
+ return out
66
+
67
+ class Scale_reduce(nn.Module):
68
+ def __init__(self, dim, reduction_ratio):
69
+ super().__init__()
70
+ self.dim = dim
71
+ self.reduction_ratio = reduction_ratio
72
+ if(len(self.reduction_ratio)==4):
73
+ self.sr0 = nn.Conv2d(dim, dim, reduction_ratio[3], reduction_ratio[3])
74
+ self.sr1 = nn.Conv2d(dim*2, dim*2, reduction_ratio[2], reduction_ratio[2])
75
+ self.sr2 = nn.Conv2d(dim*5, dim*5, reduction_ratio[1], reduction_ratio[1])
76
+
77
+ elif(len(self.reduction_ratio)==3):
78
+ self.sr0 = nn.Conv2d(dim*2, dim*2, reduction_ratio[2], reduction_ratio[2])
79
+ self.sr1 = nn.Conv2d(dim*5, dim*5, reduction_ratio[1], reduction_ratio[1])
80
+
81
+ self.norm = nn.LayerNorm(dim)
82
+
83
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
84
+ B, N, C = x.shape
85
+ if(len(self.reduction_ratio)==4):
86
+ tem0 = x[:,:3136,:].reshape(B, 56, 56, C).permute(0, 3, 1, 2)
87
+ tem1 = x[:,3136:4704,:].reshape(B, 28, 28, C*2).permute(0, 3, 1, 2)
88
+ tem2 = x[:,4704:5684,:].reshape(B, 14, 14, C*5).permute(0, 3, 1, 2)
89
+ tem3 = x[:,5684:6076,:]
90
+
91
+ sr_0 = self.sr0(tem0).reshape(B, C, -1).permute(0, 2, 1)
92
+ sr_1 = self.sr1(tem1).reshape(B, C, -1).permute(0, 2, 1)
93
+ sr_2 = self.sr2(tem2).reshape(B, C, -1).permute(0, 2, 1)
94
+
95
+ reduce_out = self.norm(torch.cat([sr_0, sr_1, sr_2, tem3], -2))
96
+
97
+ if(len(self.reduction_ratio)==3):
98
+ tem0 = x[:,:1568,:].reshape(B, 28, 28, C*2).permute(0, 3, 1, 2)
99
+ tem1 = x[:,1568:2548,:].reshape(B, 14, 14, C*5).permute(0, 3, 1, 2)
100
+ tem2 = x[:,2548:2940,:]
101
+
102
+ sr_0 = self.sr0(tem0).reshape(B, C, -1).permute(0, 2, 1)
103
+ sr_1 = self.sr1(tem1).reshape(B, C, -1).permute(0, 2, 1)
104
+
105
+ reduce_out = self.norm(torch.cat([sr_0, sr_1, tem2], -2))
106
+
107
+ return reduce_out
108
+
109
+
110
+
111
+
112
+
113
+ class M_EfficientSelfAtten(nn.Module):
114
+ def __init__(self, dim, head, reduction_ratio):
115
+ super().__init__()
116
+ self.head = head
117
+ self.reduction_ratio = reduction_ratio # list[1 2 4 8]
118
+ self.scale = (dim // head) ** -0.5
119
+ self.q = nn.Linear(dim, dim, bias=True)
120
+ self.kv = nn.Linear(dim, dim*2, bias=True)
121
+ self.proj = nn.Linear(dim, dim)
122
+
123
+ if reduction_ratio is not None:
124
+ self.scale_reduce = Scale_reduce(dim,reduction_ratio)
125
+
126
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
127
+ B, N, C = x.shape
128
+ q = self.q(x).reshape(B, N, self.head, C // self.head).permute(0, 2, 1, 3)
129
+
130
+ if self.reduction_ratio is not None:
131
+ x = self.scale_reduce(x)
132
+
133
+ kv = self.kv(x).reshape(B, -1, 2, self.head, C // self.head).permute(2, 0, 3, 1, 4)
134
+ k, v = kv[0], kv[1]
135
+
136
+ attn = (q @ k.transpose(-2, -1)) * self.scale
137
+ attn_score = attn.softmax(dim=-1)
138
+
139
+ x_atten = (attn_score @ v).transpose(1, 2).reshape(B, N, C)
140
+ out = self.proj(x_atten)
141
+
142
+
143
+ return out
144
+
145
+
146
+ class LocalEnhance_EfficientSelfAtten(nn.Module):
147
+ def __init__(self, dim, head, reduction_ratio):
148
+ super().__init__()
149
+ self.head = head
150
+ self.reduction_ratio = reduction_ratio
151
+ self.scale = (dim // head) ** -0.5
152
+ self.q = nn.Linear(dim, dim, bias=True)
153
+ self.kv = nn.Linear(dim, dim*2, bias=True)
154
+ self.proj = nn.Linear(dim, dim)
155
+ self.local_pos = DWConv(dim)
156
+
157
+ if reduction_ratio > 1:
158
+ self.sr = nn.Conv2d(dim, dim, reduction_ratio, reduction_ratio)
159
+ self.norm = nn.LayerNorm(dim)
160
+
161
+ def forward(self, x: torch.Tensor, H, W) -> torch.Tensor:
162
+ B, N, C = x.shape
163
+ q = self.q(x).reshape(B, N, self.head, C // self.head).permute(0, 2, 1, 3)
164
+
165
+ if self.reduction_ratio > 1:
166
+ p_x = x.clone().permute(0, 2, 1).reshape(B, C, H, W)
167
+ sp_x = self.sr(p_x).reshape(B, C, -1).permute(0, 2, 1)
168
+ x = self.norm(sp_x)
169
+
170
+ kv = self.kv(x).reshape(B, -1, 2, self.head, C // self.head).permute(2, 0, 3, 1, 4)
171
+ k, v = kv[0], kv[1]
172
+
173
+ attn = (q @ k.transpose(-2, -1)) * self.scale
174
+ attn_score = attn.softmax(dim=-1)
175
+ local_v = v.permute(0, 2, 1, 3).reshape(B, N, C)
176
+ local_pos = self.local_pos(local_v).reshape(B, -1, self.head, C//self.head).permute(0, 2, 1, 3)
177
+ x_atten = ((attn_score @ v) + local_pos).transpose(1, 2).reshape(B, N, C)
178
+ out = self.proj(x_atten)
179
+
180
+ return out
181
+
182
+
183
+ class DWConv(nn.Module):
184
+ def __init__(self, dim):
185
+ super().__init__()
186
+ self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim)
187
+
188
+ def forward(self, x: torch.Tensor, H, W) -> torch.Tensor:
189
+ B, N, C = x.shape
190
+ tx = x.transpose(1, 2).view(B, C, H, W)
191
+ conv_x = self.dwconv(tx)
192
+ return conv_x.flatten(2).transpose(1, 2)
193
+
194
+
195
+ class MixFFN(nn.Module):
196
+ def __init__(self, c1, c2):
197
+ super().__init__()
198
+ self.fc1 = nn.Linear(c1, c2)
199
+ self.dwconv = DWConv(c2)
200
+ self.act = nn.GELU()
201
+ self.fc2 = nn.Linear(c2, c1)
202
+
203
+ def forward(self, x: torch.Tensor, H, W) -> torch.Tensor:
204
+ ax = self.act(self.dwconv(self.fc1(x), H, W))
205
+ out = self.fc2(ax)
206
+ return out
207
+
208
+ class MixFFN_skip(nn.Module):
209
+ def __init__(self, c1, c2):
210
+ super().__init__()
211
+ self.fc1 = nn.Linear(c1, c2)
212
+ self.dwconv = DWConv(c2)
213
+ self.act = nn.GELU()
214
+ self.fc2 = nn.Linear(c2, c1)
215
+ self.norm1 = nn.LayerNorm(c2)
216
+ self.norm2 = nn.LayerNorm(c2)
217
+ self.norm3 = nn.LayerNorm(c2)
218
+ def forward(self, x: torch.Tensor, H, W) -> torch.Tensor:
219
+ ax = self.act(self.norm1(self.dwconv(self.fc1(x), H, W)+self.fc1(x)))
220
+ out = self.fc2(ax)
221
+ return out
222
+
223
+ class MLP_FFN(nn.Module):
224
+ def __init__(self, c1, c2):
225
+ super().__init__()
226
+ self.fc1 = nn.Linear(c1, c2)
227
+ self.act = nn.GELU()
228
+ self.fc2 = nn.Linear(c2, c1)
229
+
230
+ def forward(self, x):
231
+ x = self.fc1(x)
232
+ x = self.act(x)
233
+ x = self.fc2(x)
234
+ return x
235
+
236
+ class MixD_FFN(nn.Module):
237
+ def __init__(self, c1, c2, fuse_mode = "add"):
238
+ super().__init__()
239
+ self.fc1 = nn.Linear(c1, c2)
240
+ self.dwconv = DWConv(c2)
241
+ self.act = nn.GELU()
242
+ self.fc2 = nn.Linear(c2, c1) if fuse_mode=="add" else nn.Linear(c2*2, c1)
243
+ self.fuse_mode = fuse_mode
244
+
245
+ def forward(self, x):
246
+ ax = self.dwconv(self.fc1(x), H, W)
247
+ fuse = self.act(ax+self.fc1(x)) if self.fuse_mode=="add" else self.act(torch.cat([ax, self.fc1(x)],2))
248
+ out = self.fc2(ax)
249
+ return out
250
+
251
+
252
+ class OverlapPatchEmbeddings(nn.Module):
253
+ def __init__(self, img_size=224, patch_size=7, stride=4, padding=1, in_ch=3, dim=768):
254
+ super().__init__()
255
+ self.num_patches = (img_size // patch_size) ** 2
256
+ self.proj = nn.Conv2d(in_ch, dim, patch_size, stride, padding)
257
+ self.norm = nn.LayerNorm(dim)
258
+
259
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
260
+ px = self.proj(x)
261
+ _, _, H, W = px.shape
262
+ fx = px.flatten(2).transpose(1, 2)
263
+ nfx = self.norm(fx)
264
+ return nfx, H, W
265
+
266
+
267
+
268
+ class TransformerBlock(nn.Module):
269
+ def __init__(self, dim, head, reduction_ratio=1, token_mlp='mix'):
270
+ super().__init__()
271
+ self.norm1 = nn.LayerNorm(dim)
272
+ self.attn = EfficientSelfAtten(dim, head, reduction_ratio)
273
+ self.norm2 = nn.LayerNorm(dim)
274
+ if token_mlp=='mix':
275
+ self.mlp = MixFFN(dim, int(dim*4))
276
+ elif token_mlp=='mix_skip':
277
+ self.mlp = MixFFN_skip(dim, int(dim*4))
278
+ else:
279
+ self.mlp = MLP_FFN(dim, int(dim*4))
280
+
281
+ def forward(self, x: torch.Tensor, H, W) -> torch.Tensor:
282
+ tx = x + self.attn(self.norm1(x), H, W)
283
+ mx = tx + self.mlp(self.norm2(tx), H, W)
284
+ return mx
285
+
286
+
287
+ class FuseTransformerBlock(nn.Module):
288
+ def __init__(self, dim, head, reduction_ratio=1, fuse_mode = "add"):
289
+ super().__init__()
290
+ self.norm1 = nn.LayerNorm(dim)
291
+ self.attn = EfficientSelfAtten(dim, head, reduction_ratio)
292
+ self.norm2 = nn.LayerNorm(dim)
293
+ self.mlp = MixD_FFN(dim, int(dim*4), fuse_mode)
294
+
295
+ def forward(self, x: torch.Tensor, H, W) -> torch.Tensor:
296
+ tx = x + self.attn(self.norm1(x), H, W)
297
+ mx = tx + self.mlp(self.norm2(tx), H, W)
298
+ return mx
299
+
300
+
301
+ class MLP(nn.Module):
302
+ def __init__(self, dim, embed_dim):
303
+ super().__init__()
304
+ self.proj = nn.Linear(dim, embed_dim)
305
+
306
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
307
+ x = x.flatten(2).transpose(1, 2)
308
+ return self.proj(x)
309
+
310
+
311
+ class ConvModule(nn.Module):
312
+ def __init__(self, c1, c2, k):
313
+ super().__init__()
314
+ self.conv = nn.Conv2d(c1, c2, k, bias=False)
315
+ self.bn = nn.BatchNorm2d(c2)
316
+ self.activate = nn.ReLU(True)
317
+
318
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
319
+ return self.activate(self.bn(self.conv(x)))
320
+
321
+
322
+ class MiT(nn.Module):
323
+ def __init__(self, image_size, dims, layers, in_ch=3, token_mlp='mix_skip'):
324
+ super().__init__()
325
+ patch_sizes = [7, 3, 3, 3]
326
+ strides = [4, 2, 2, 2]
327
+ padding_sizes = [3, 1, 1, 1]
328
+ reduction_ratios = [8, 4, 2, 1]
329
+ heads = [1, 2, 5, 8]
330
+
331
+ # patch_embed
332
+ self.patch_embed1 = OverlapPatchEmbeddings(image_size, patch_sizes[0], strides[0], padding_sizes[0], in_ch, dims[0])
333
+ self.patch_embed2 = OverlapPatchEmbeddings(image_size//4, patch_sizes[1], strides[1], padding_sizes[1],dims[0], dims[1])
334
+ self.patch_embed3 = OverlapPatchEmbeddings(image_size//8, patch_sizes[2], strides[2], padding_sizes[2],dims[1], dims[2])
335
+ self.patch_embed4 = OverlapPatchEmbeddings(image_size//16, patch_sizes[3], strides[3], padding_sizes[3],dims[2], dims[3])
336
+
337
+ # transformer encoder
338
+ self.block1 = nn.ModuleList([
339
+ TransformerBlock(dims[0], heads[0], reduction_ratios[0],token_mlp)
340
+ for _ in range(layers[0])])
341
+ self.norm1 = nn.LayerNorm(dims[0])
342
+
343
+ self.block2 = nn.ModuleList([
344
+ TransformerBlock(dims[1], heads[1], reduction_ratios[1],token_mlp)
345
+ for _ in range(layers[1])])
346
+ self.norm2 = nn.LayerNorm(dims[1])
347
+
348
+ self.block3 = nn.ModuleList([
349
+ TransformerBlock(dims[2], heads[2], reduction_ratios[2], token_mlp)
350
+ for _ in range(layers[2])])
351
+ self.norm3 = nn.LayerNorm(dims[2])
352
+
353
+ self.block4 = nn.ModuleList([
354
+ TransformerBlock(dims[3], heads[3], reduction_ratios[3], token_mlp)
355
+ for _ in range(layers[3])])
356
+ self.norm4 = nn.LayerNorm(dims[3])
357
+
358
+ # self.head = nn.Linear(dims[3], num_classes)
359
+
360
+
361
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
362
+ B = x.shape[0]
363
+ outs = []
364
+
365
+ # stage 1
366
+ x, H, W = self.patch_embed1(x)
367
+ for blk in self.block1:
368
+ x = blk(x, H, W)
369
+ x = self.norm1(x)
370
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
371
+ outs.append(x)
372
+
373
+ # stage 2
374
+ x, H, W = self.patch_embed2(x)
375
+ for blk in self.block2:
376
+ x = blk(x, H, W)
377
+ x = self.norm2(x)
378
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
379
+ outs.append(x)
380
+
381
+ # stage 3
382
+ x, H, W = self.patch_embed3(x)
383
+ for blk in self.block3:
384
+ x = blk(x, H, W)
385
+ x = self.norm3(x)
386
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
387
+ outs.append(x)
388
+
389
+ # stage 4
390
+ x, H, W = self.patch_embed4(x)
391
+ for blk in self.block4:
392
+ x = blk(x, H, W)
393
+ x = self.norm4(x)
394
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
395
+ outs.append(x)
396
+
397
+ return outs
398
+
399
+
400
+ class FuseMiT(nn.Module):
401
+ def __init__(self, image_size, dims, layers, fuse_mode='add'):
402
+ super().__init__()
403
+ patch_sizes = [7, 3, 3, 3]
404
+ strides = [4, 2, 2, 2]
405
+ padding_sizes = [3, 1, 1, 1]
406
+ reduction_ratios = [8, 4, 2, 1]
407
+ heads = [1, 2, 5, 8]
408
+
409
+ # patch_embed
410
+ self.patch_embed1 = OverlapPatchEmbeddings(image_size, patch_sizes[0], strides[0], padding_sizes[0], 3, dims[0])
411
+ self.patch_embed2 = OverlapPatchEmbeddings(image_size//4, patch_sizes[1], strides[1], padding_sizes[1],dims[0], dims[1])
412
+ self.patch_embed3 = OverlapPatchEmbeddings(image_size//8, patch_sizes[2], strides[2], padding_sizes[2],dims[1], dims[2])
413
+ self.patch_embed4 = OverlapPatchEmbeddings(image_size//16, patch_sizes[3], strides[3], padding_sizes[3],dims[2], dims[3])
414
+
415
+ # transformer encoder
416
+ self.block1 = nn.ModuleList([
417
+ FuseTransformerBlock(dims[0], heads[0], reduction_ratios[0],fuse_mode)
418
+ for _ in range(layers[0])])
419
+ self.norm1 = nn.LayerNorm(dims[0])
420
+
421
+ self.block2 = nn.ModuleList([
422
+ FuseTransformerBlock(dims[1], heads[1], reduction_ratios[1],fuse_mode)
423
+ for _ in range(layers[1])])
424
+ self.norm2 = nn.LayerNorm(dims[1])
425
+
426
+ self.block3 = nn.ModuleList([
427
+ FuseTransformerBlock(dims[2], heads[2], reduction_ratios[2], fuse_mode)
428
+ for _ in range(layers[2])])
429
+ self.norm3 = nn.LayerNorm(dims[2])
430
+
431
+ self.block4 = nn.ModuleList([
432
+ FuseTransformerBlock(dims[3], heads[3], reduction_ratios[3], fuse_mode)
433
+ for _ in range(layers[3])])
434
+ self.norm4 = nn.LayerNorm(dims[3])
435
+
436
+ # self.head = nn.Linear(dims[3], num_classes)
437
+
438
+
439
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
440
+ B = x.shape[0]
441
+ outs = []
442
+
443
+ # stage 1
444
+ x, H, W = self.patch_embed1(x)
445
+ for blk in self.block1:
446
+ x = blk(x, H, W)
447
+ x = self.norm1(x)
448
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
449
+ outs.append(x)
450
+
451
+ # stage 2
452
+ x, H, W = self.patch_embed2(x)
453
+ for blk in self.block2:
454
+ x = blk(x, H, W)
455
+ x = self.norm2(x)
456
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
457
+ outs.append(x)
458
+
459
+ # stage 3
460
+ x, H, W = self.patch_embed3(x)
461
+ for blk in self.block3:
462
+ x = blk(x, H, W)
463
+ x = self.norm3(x)
464
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
465
+ outs.append(x)
466
+
467
+ # stage 4
468
+ x, H, W = self.patch_embed4(x)
469
+ for blk in self.block4:
470
+ x = blk(x, H, W)
471
+ x = self.norm4(x)
472
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
473
+ outs.append(x)
474
+
475
+ return outs
476
+
477
+
478
+
479
+
480
+ class Decoder(nn.Module):
481
+ def __init__(self, dims, embed_dim, num_classes):
482
+ super().__init__()
483
+
484
+ self.linear_c1 = MLP(dims[0], embed_dim)
485
+ self.linear_c2 = MLP(dims[1], embed_dim)
486
+ self.linear_c3 = MLP(dims[2], embed_dim)
487
+ self.linear_c4 = MLP(dims[3], embed_dim)
488
+
489
+ self.linear_fuse = ConvModule(embed_dim*4, embed_dim, 1)
490
+ self.linear_pred = nn.Conv2d(embed_dim, num_classes, 1)
491
+
492
+ self.conv_seg = nn.Conv2d(128, num_classes, 1)
493
+
494
+ self.dropout = nn.Dropout2d(0.1)
495
+
496
+ def forward(self, inputs: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) -> torch.Tensor:
497
+ c1, c2, c3, c4 = inputs
498
+ n = c1.shape[0]
499
+ c1f = self.linear_c1(c1).permute(0, 2, 1).reshape(n, -1, c1.shape[2], c1.shape[3])
500
+
501
+ c2f = self.linear_c2(c2).permute(0, 2, 1).reshape(n, -1, c2.shape[2], c2.shape[3])
502
+ c2f = F.interpolate(c2f, size=c1.shape[2:], mode='bilinear', align_corners=False)
503
+
504
+ c3f = self.linear_c3(c3).permute(0, 2, 1).reshape(n, -1, c3.shape[2], c3.shape[3])
505
+ c3f = F.interpolate(c3f, size=c1.shape[2:], mode='bilinear', align_corners=False)
506
+
507
+ c4f = self.linear_c4(c4).permute(0, 2, 1).reshape(n, -1, c4.shape[2], c4.shape[3])
508
+ c4f = F.interpolate(c4f, size=c1.shape[2:], mode='bilinear', align_corners=False)
509
+
510
+ c = self.linear_fuse(torch.cat([c4f, c3f, c2f, c1f], dim=1))
511
+ c = self.dropout(c)
512
+ return self.linear_pred(c)
513
+
514
+
515
+ segformer_settings = {
516
+ 'B0': [[32, 64, 160, 256], [2, 2, 2, 2], 256], # [channel dimensions, num encoder layers, embed dim]
517
+ 'B1': [[64, 128, 320, 512], [2, 2, 2, 2], 256],
518
+ 'B2': [[64, 128, 320, 512], [3, 4, 6, 3], 768],
519
+ 'B3': [[64, 128, 320, 512], [3, 4, 18, 3], 768],
520
+ 'B4': [[64, 128, 320, 512], [3, 8, 27, 3], 768],
521
+ 'B5': [[64, 128, 320, 512], [3, 6, 40, 3], 768]
522
+ }
523
+
524
+
525
+ class SegFormer(nn.Module):
526
+ def __init__(self, model_name: str = 'B0', num_classes: int = 19, image_size: int = 224) -> None:
527
+ super().__init__()
528
+ assert model_name in segformer_settings.keys(), f"SegFormer model name should be in {list(segformer_settings.keys())}"
529
+ dims, layers, embed_dim = segformer_settings[model_name]
530
+
531
+ self.backbone = MiT(image_size, dims, layers)
532
+ self.decode_head = Decoder(dims, embed_dim, num_classes)
533
+
534
+ def init_weights(self, pretrained: str = None) -> None:
535
+ if pretrained:
536
+ self.backbone.load_state_dict(torch.load(pretrained, map_location='cpu'), strict=False)
537
+ else:
538
+ for m in self.modules():
539
+ if isinstance(m, nn.Linear):
540
+ nn.init.xavier_uniform_(m.weight)
541
+ if m.bias is not None:
542
+ nn.init.zeros_(m.bias)
543
+ elif isinstance(m, nn.LayerNorm):
544
+ nn.init.ones_(m.weight)
545
+ nn.init.zeros_(m.bias)
546
+ elif isinstance(m, nn.Conv2d):
547
+ nn.init.xavier_uniform_(m.weight)
548
+ if m.bias is not None:
549
+ nn.init.zeros_(m.bias)
550
+
551
+
552
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
553
+ if x.size()[1] == 1:
554
+ x = x.repeat(1,3,1,1)
555
+ encoder_outs = self.backbone(x)
556
+ return self.decode_head(encoder_outs)
557
+
models/_resunet/__init__.py ADDED
File without changes
models/_resunet/modules.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/rishikksh20/ResUnet/blob/master/core/modules.py
2
+
3
+
4
+ import torch.nn as nn
5
+ import torch
6
+
7
+
8
+ class ResidualConv(nn.Module):
9
+ def __init__(self, input_dim, output_dim, stride, padding):
10
+ super(ResidualConv, self).__init__()
11
+
12
+ self.conv_block = nn.Sequential(
13
+ nn.BatchNorm2d(input_dim),
14
+ nn.ReLU(),
15
+ nn.Conv2d(
16
+ input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
17
+ ),
18
+ nn.BatchNorm2d(output_dim),
19
+ nn.ReLU(),
20
+ nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
21
+ )
22
+ self.conv_skip = nn.Sequential(
23
+ nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
24
+ nn.BatchNorm2d(output_dim),
25
+ )
26
+
27
+ def forward(self, x):
28
+
29
+ return self.conv_block(x) + self.conv_skip(x)
30
+
31
+
32
+ class Upsample(nn.Module):
33
+ def __init__(self, input_dim, output_dim, kernel, stride):
34
+ super(Upsample, self).__init__()
35
+
36
+ self.upsample = nn.ConvTranspose2d(
37
+ input_dim, output_dim, kernel_size=kernel, stride=stride
38
+ )
39
+
40
+ def forward(self, x):
41
+ return self.upsample(x)
42
+
43
+
44
+ class Squeeze_Excite_Block(nn.Module):
45
+ def __init__(self, channel, reduction=16):
46
+ super(Squeeze_Excite_Block, self).__init__()
47
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
48
+ self.fc = nn.Sequential(
49
+ nn.Linear(channel, channel // reduction, bias=False),
50
+ nn.ReLU(inplace=True),
51
+ nn.Linear(channel // reduction, channel, bias=False),
52
+ nn.Sigmoid(),
53
+ )
54
+
55
+ def forward(self, x):
56
+ b, c, _, _ = x.size()
57
+ y = self.avg_pool(x).view(b, c)
58
+ y = self.fc(y).view(b, c, 1, 1)
59
+ return x * y.expand_as(x)
60
+
61
+
62
+ class ASPP(nn.Module):
63
+ def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):
64
+ super(ASPP, self).__init__()
65
+
66
+ self.aspp_block1 = nn.Sequential(
67
+ nn.Conv2d(
68
+ in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0]
69
+ ),
70
+ nn.ReLU(inplace=True),
71
+ nn.BatchNorm2d(out_dims),
72
+ )
73
+ self.aspp_block2 = nn.Sequential(
74
+ nn.Conv2d(
75
+ in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1]
76
+ ),
77
+ nn.ReLU(inplace=True),
78
+ nn.BatchNorm2d(out_dims),
79
+ )
80
+ self.aspp_block3 = nn.Sequential(
81
+ nn.Conv2d(
82
+ in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2]
83
+ ),
84
+ nn.ReLU(inplace=True),
85
+ nn.BatchNorm2d(out_dims),
86
+ )
87
+
88
+ self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1)
89
+ self._init_weights()
90
+
91
+ def forward(self, x):
92
+ x1 = self.aspp_block1(x)
93
+ x2 = self.aspp_block2(x)
94
+ x3 = self.aspp_block3(x)
95
+ out = torch.cat([x1, x2, x3], dim=1)
96
+ return self.output(out)
97
+
98
+ def _init_weights(self):
99
+ for m in self.modules():
100
+ if isinstance(m, nn.Conv2d):
101
+ nn.init.kaiming_normal_(m.weight)
102
+ elif isinstance(m, nn.BatchNorm2d):
103
+ m.weight.data.fill_(1)
104
+ m.bias.data.zero_()
105
+
106
+
107
+ class Upsample_(nn.Module):
108
+ def __init__(self, scale=2):
109
+ super(Upsample_, self).__init__()
110
+
111
+ self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale)
112
+
113
+ def forward(self, x):
114
+ return self.upsample(x)
115
+
116
+
117
+ class AttentionBlock(nn.Module):
118
+ def __init__(self, input_encoder, input_decoder, output_dim):
119
+ super(AttentionBlock, self).__init__()
120
+
121
+ self.conv_encoder = nn.Sequential(
122
+ nn.BatchNorm2d(input_encoder),
123
+ nn.ReLU(),
124
+ nn.Conv2d(input_encoder, output_dim, 3, padding=1),
125
+ nn.MaxPool2d(2, 2),
126
+ )
127
+
128
+ self.conv_decoder = nn.Sequential(
129
+ nn.BatchNorm2d(input_decoder),
130
+ nn.ReLU(),
131
+ nn.Conv2d(input_decoder, output_dim, 3, padding=1),
132
+ )
133
+
134
+ self.conv_attn = nn.Sequential(
135
+ nn.BatchNorm2d(output_dim),
136
+ nn.ReLU(),
137
+ nn.Conv2d(output_dim, 1, 1),
138
+ )
139
+
140
+ def forward(self, x1, x2):
141
+ out = self.conv_encoder(x1) + self.conv_decoder(x2)
142
+ out = self.conv_attn(out)
143
+ return out * x2
models/_resunet/res_unet.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/rishikksh20/ResUnet
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from .modules import ResidualConv, Upsample
6
+
7
+
8
+ class ResUnet(nn.Module):
9
+ def __init__(self, in_ch, out_ch, filters=[64, 128, 256, 512]):
10
+ super(ResUnet, self).__init__()
11
+
12
+ self.input_layer = nn.Sequential(
13
+ nn.Conv2d(in_ch, filters[0], kernel_size=3, padding=1),
14
+ nn.BatchNorm2d(filters[0]),
15
+ nn.ReLU(),
16
+ nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
17
+ )
18
+ self.input_skip = nn.Sequential(
19
+ nn.Conv2d(in_ch, filters[0], kernel_size=3, padding=1)
20
+ )
21
+
22
+ self.residual_conv_1 = ResidualConv(filters[0], filters[1], 2, 1)
23
+ self.residual_conv_2 = ResidualConv(filters[1], filters[2], 2, 1)
24
+
25
+ self.bridge = ResidualConv(filters[2], filters[3], 2, 1)
26
+
27
+ self.upsample_1 = Upsample(filters[3], filters[3], 2, 2)
28
+ self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], 1, 1)
29
+
30
+ self.upsample_2 = Upsample(filters[2], filters[2], 2, 2)
31
+ self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], 1, 1)
32
+
33
+ self.upsample_3 = Upsample(filters[1], filters[1], 2, 2)
34
+ self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], 1, 1)
35
+
36
+ self.output_layer = nn.Sequential(
37
+ nn.Conv2d(filters[0], out_ch, 1, 1),
38
+ )
39
+
40
+ def forward(self, x):
41
+ # Encode
42
+ x1 = self.input_layer(x) + self.input_skip(x)
43
+ x2 = self.residual_conv_1(x1)
44
+ x3 = self.residual_conv_2(x2)
45
+ # Bridge
46
+ x4 = self.bridge(x3)
47
+ # Decode
48
+ x4 = self.upsample_1(x4)
49
+ x5 = torch.cat([x4, x3], dim=1)
50
+
51
+ x6 = self.up_residual_conv1(x5)
52
+
53
+ x6 = self.upsample_2(x6)
54
+ x7 = torch.cat([x6, x2], dim=1)
55
+
56
+ x8 = self.up_residual_conv2(x7)
57
+
58
+ x8 = self.upsample_3(x8)
59
+ x9 = torch.cat([x8, x1], dim=1)
60
+
61
+ x10 = self.up_residual_conv3(x9)
62
+
63
+ output = self.output_layer(x10)
64
+
65
+ return output
models/_transunet/vit_seg_configs.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ml_collections
2
+
3
+ def get_b16_config():
4
+ """Returns the ViT-B/16 configuration."""
5
+ config = ml_collections.ConfigDict()
6
+ config.patches = ml_collections.ConfigDict({'size': (16, 16)})
7
+ config.hidden_size = 768
8
+ config.transformer = ml_collections.ConfigDict()
9
+ config.transformer.mlp_dim = 3072
10
+ config.transformer.num_heads = 12
11
+ config.transformer.num_layers = 12
12
+ config.transformer.attention_dropout_rate = 0.0
13
+ config.transformer.dropout_rate = 0.1
14
+
15
+ config.classifier = 'seg'
16
+ config.representation_size = None
17
+ config.resnet_pretrained_path = None
18
+ config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_16.npz'
19
+ config.patch_size = 16
20
+
21
+ config.decoder_channels = (256, 128, 64, 16)
22
+ config.n_classes = 2
23
+ config.activation = 'softmax'
24
+ return config
25
+
26
+
27
+ def get_testing():
28
+ """Returns a minimal configuration for testing."""
29
+ config = ml_collections.ConfigDict()
30
+ config.patches = ml_collections.ConfigDict({'size': (16, 16)})
31
+ config.hidden_size = 1
32
+ config.transformer = ml_collections.ConfigDict()
33
+ config.transformer.mlp_dim = 1
34
+ config.transformer.num_heads = 1
35
+ config.transformer.num_layers = 1
36
+ config.transformer.attention_dropout_rate = 0.0
37
+ config.transformer.dropout_rate = 0.1
38
+ config.classifier = 'token'
39
+ config.representation_size = None
40
+ return config
41
+
42
+ def get_r50_b16_config():
43
+ """Returns the Resnet50 + ViT-B/16 configuration."""
44
+ config = get_b16_config()
45
+ config.patches.grid = (16, 16)
46
+ config.resnet = ml_collections.ConfigDict()
47
+ config.resnet.num_layers = (3, 4, 9)
48
+ config.resnet.width_factor = 1
49
+
50
+ config.classifier = 'seg'
51
+ config.pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz'
52
+ config.decoder_channels = (256, 128, 64, 16)
53
+ config.skip_channels = [512, 256, 64, 16]
54
+ config.n_classes = 2
55
+ config.n_skip = 3
56
+ config.activation = 'softmax'
57
+
58
+ return config
59
+
60
+
61
+ def get_b32_config():
62
+ """Returns the ViT-B/32 configuration."""
63
+ config = get_b16_config()
64
+ config.patches.size = (32, 32)
65
+ config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_32.npz'
66
+ return config
67
+
68
+
69
+ def get_l16_config():
70
+ """Returns the ViT-L/16 configuration."""
71
+ config = ml_collections.ConfigDict()
72
+ config.patches = ml_collections.ConfigDict({'size': (16, 16)})
73
+ config.hidden_size = 1024
74
+ config.transformer = ml_collections.ConfigDict()
75
+ config.transformer.mlp_dim = 4096
76
+ config.transformer.num_heads = 16
77
+ config.transformer.num_layers = 24
78
+ config.transformer.attention_dropout_rate = 0.0
79
+ config.transformer.dropout_rate = 0.1
80
+ config.representation_size = None
81
+
82
+ # custom
83
+ config.classifier = 'seg'
84
+ config.resnet_pretrained_path = None
85
+ config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-L_16.npz'
86
+ config.decoder_channels = (256, 128, 64, 16)
87
+ config.n_classes = 2
88
+ config.activation = 'softmax'
89
+ return config
90
+
91
+
92
+ def get_r50_l16_config():
93
+ """Returns the Resnet50 + ViT-L/16 configuration. customized """
94
+ config = get_l16_config()
95
+ config.patches.grid = (16, 16)
96
+ config.resnet = ml_collections.ConfigDict()
97
+ config.resnet.num_layers = (3, 4, 9)
98
+ config.resnet.width_factor = 1
99
+
100
+ config.classifier = 'seg'
101
+ config.resnet_pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz'
102
+ config.decoder_channels = (256, 128, 64, 16)
103
+ config.skip_channels = [512, 256, 64, 16]
104
+ config.n_classes = 2
105
+ config.activation = 'softmax'
106
+ return config
107
+
108
+
109
+ def get_l32_config():
110
+ """Returns the ViT-L/32 configuration."""
111
+ config = get_l16_config()
112
+ config.patches.size = (32, 32)
113
+ return config
114
+
115
+
116
+ def get_h14_config():
117
+ """Returns the ViT-L/16 configuration."""
118
+ config = ml_collections.ConfigDict()
119
+ config.patches = ml_collections.ConfigDict({'size': (14, 14)})
120
+ config.hidden_size = 1280
121
+ config.transformer = ml_collections.ConfigDict()
122
+ config.transformer.mlp_dim = 5120
123
+ config.transformer.num_heads = 16
124
+ config.transformer.num_layers = 32
125
+ config.transformer.attention_dropout_rate = 0.0
126
+ config.transformer.dropout_rate = 0.1
127
+ config.classifier = 'token'
128
+ config.representation_size = None
129
+
130
+ return config
models/_transunet/vit_seg_modeling.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ from __future__ import absolute_import
3
+ from __future__ import division
4
+ from __future__ import print_function
5
+
6
+ import copy
7
+ import logging
8
+ import math
9
+
10
+ from os.path import join as pjoin
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import numpy as np
15
+
16
+ from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
17
+ from torch.nn.modules.utils import _pair
18
+ from scipy import ndimage
19
+ from . import vit_seg_configs as configs
20
+ from .vit_seg_modeling_resnet_skip import ResNetV2
21
+
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
27
+ ATTENTION_K = "MultiHeadDotProductAttention_1/key"
28
+ ATTENTION_V = "MultiHeadDotProductAttention_1/value"
29
+ ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
30
+ FC_0 = "MlpBlock_3/Dense_0"
31
+ FC_1 = "MlpBlock_3/Dense_1"
32
+ ATTENTION_NORM = "LayerNorm_0"
33
+ MLP_NORM = "LayerNorm_2"
34
+
35
+
36
+ def np2th(weights, conv=False):
37
+ """Possibly convert HWIO to OIHW."""
38
+ if conv:
39
+ weights = weights.transpose([3, 2, 0, 1])
40
+ return torch.from_numpy(weights)
41
+
42
+
43
+ def swish(x):
44
+ return x * torch.sigmoid(x)
45
+
46
+
47
+ ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}
48
+
49
+
50
+ class Attention(nn.Module):
51
+ def __init__(self, config, vis):
52
+ super(Attention, self).__init__()
53
+ self.vis = vis
54
+ self.num_attention_heads = config.transformer["num_heads"]
55
+ self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
56
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
57
+
58
+ self.query = Linear(config.hidden_size, self.all_head_size)
59
+ self.key = Linear(config.hidden_size, self.all_head_size)
60
+ self.value = Linear(config.hidden_size, self.all_head_size)
61
+
62
+ self.out = Linear(config.hidden_size, config.hidden_size)
63
+ self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
64
+ self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])
65
+
66
+ self.softmax = Softmax(dim=-1)
67
+
68
+ def transpose_for_scores(self, x):
69
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
70
+ x = x.view(*new_x_shape)
71
+ return x.permute(0, 2, 1, 3)
72
+
73
+ def forward(self, hidden_states):
74
+ mixed_query_layer = self.query(hidden_states)
75
+ mixed_key_layer = self.key(hidden_states)
76
+ mixed_value_layer = self.value(hidden_states)
77
+
78
+ query_layer = self.transpose_for_scores(mixed_query_layer)
79
+ key_layer = self.transpose_for_scores(mixed_key_layer)
80
+ value_layer = self.transpose_for_scores(mixed_value_layer)
81
+
82
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
83
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
84
+ attention_probs = self.softmax(attention_scores)
85
+ weights = attention_probs if self.vis else None
86
+ attention_probs = self.attn_dropout(attention_probs)
87
+
88
+ context_layer = torch.matmul(attention_probs, value_layer)
89
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
90
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
91
+ context_layer = context_layer.view(*new_context_layer_shape)
92
+ attention_output = self.out(context_layer)
93
+ attention_output = self.proj_dropout(attention_output)
94
+ return attention_output, weights
95
+
96
+
97
+ class Mlp(nn.Module):
98
+ def __init__(self, config):
99
+ super(Mlp, self).__init__()
100
+ self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])
101
+ self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)
102
+ self.act_fn = ACT2FN["gelu"]
103
+ self.dropout = Dropout(config.transformer["dropout_rate"])
104
+
105
+ self._init_weights()
106
+
107
+ def _init_weights(self):
108
+ nn.init.xavier_uniform_(self.fc1.weight)
109
+ nn.init.xavier_uniform_(self.fc2.weight)
110
+ nn.init.normal_(self.fc1.bias, std=1e-6)
111
+ nn.init.normal_(self.fc2.bias, std=1e-6)
112
+
113
+ def forward(self, x):
114
+ x = self.fc1(x)
115
+ x = self.act_fn(x)
116
+ x = self.dropout(x)
117
+ x = self.fc2(x)
118
+ x = self.dropout(x)
119
+ return x
120
+
121
+
122
+ class Embeddings(nn.Module):
123
+ """Construct the embeddings from patch, position embeddings.
124
+ """
125
+ def __init__(self, config, img_size, in_channels=3):
126
+ super(Embeddings, self).__init__()
127
+ self.hybrid = None
128
+ self.config = config
129
+ img_size = _pair(img_size)
130
+
131
+ if config.patches.get("grid") is not None: # ResNet
132
+ grid_size = config.patches["grid"]
133
+ patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])
134
+ patch_size_real = (patch_size[0] * 16, patch_size[1] * 16)
135
+ n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1])
136
+ self.hybrid = True
137
+ else:
138
+ patch_size = _pair(config.patches["size"])
139
+ n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
140
+ self.hybrid = False
141
+
142
+ if self.hybrid:
143
+ self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor)
144
+ in_channels = self.hybrid_model.width * 16
145
+ self.patch_embeddings = Conv2d(in_channels=in_channels,
146
+ out_channels=config.hidden_size,
147
+ kernel_size=patch_size,
148
+ stride=patch_size)
149
+ self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size))
150
+
151
+ self.dropout = Dropout(config.transformer["dropout_rate"])
152
+
153
+
154
+ def forward(self, x):
155
+ if self.hybrid:
156
+ x, features = self.hybrid_model(x)
157
+ else:
158
+ features = None
159
+ x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2))
160
+ x = x.flatten(2)
161
+ x = x.transpose(-1, -2) # (B, n_patches, hidden)
162
+
163
+ embeddings = x + self.position_embeddings
164
+ embeddings = self.dropout(embeddings)
165
+ return embeddings, features
166
+
167
+
168
+ class Block(nn.Module):
169
+ def __init__(self, config, vis):
170
+ super(Block, self).__init__()
171
+ self.hidden_size = config.hidden_size
172
+ self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
173
+ self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
174
+ self.ffn = Mlp(config)
175
+ self.attn = Attention(config, vis)
176
+
177
+ def forward(self, x):
178
+ h = x
179
+ x = self.attention_norm(x)
180
+ x, weights = self.attn(x)
181
+ x = x + h
182
+
183
+ h = x
184
+ x = self.ffn_norm(x)
185
+ x = self.ffn(x)
186
+ x = x + h
187
+ return x, weights
188
+
189
+ def load_from(self, weights, n_block):
190
+ ROOT = f"Transformer/encoderblock_{n_block}"
191
+ with torch.no_grad():
192
+ query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t()
193
+ key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t()
194
+ value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t()
195
+ out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t()
196
+
197
+ query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1)
198
+ key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1)
199
+ value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1)
200
+ out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1)
201
+
202
+ self.attn.query.weight.copy_(query_weight)
203
+ self.attn.key.weight.copy_(key_weight)
204
+ self.attn.value.weight.copy_(value_weight)
205
+ self.attn.out.weight.copy_(out_weight)
206
+ self.attn.query.bias.copy_(query_bias)
207
+ self.attn.key.bias.copy_(key_bias)
208
+ self.attn.value.bias.copy_(value_bias)
209
+ self.attn.out.bias.copy_(out_bias)
210
+
211
+ mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t()
212
+ mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t()
213
+ mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t()
214
+ mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t()
215
+
216
+ self.ffn.fc1.weight.copy_(mlp_weight_0)
217
+ self.ffn.fc2.weight.copy_(mlp_weight_1)
218
+ self.ffn.fc1.bias.copy_(mlp_bias_0)
219
+ self.ffn.fc2.bias.copy_(mlp_bias_1)
220
+
221
+ self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")]))
222
+ self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")]))
223
+ self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")]))
224
+ self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")]))
225
+
226
+
227
+ class Encoder(nn.Module):
228
+ def __init__(self, config, vis):
229
+ super(Encoder, self).__init__()
230
+ self.vis = vis
231
+ self.layer = nn.ModuleList()
232
+ self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
233
+ for _ in range(config.transformer["num_layers"]):
234
+ layer = Block(config, vis)
235
+ self.layer.append(copy.deepcopy(layer))
236
+
237
+ def forward(self, hidden_states):
238
+ attn_weights = []
239
+ for layer_block in self.layer:
240
+ hidden_states, weights = layer_block(hidden_states)
241
+ if self.vis:
242
+ attn_weights.append(weights)
243
+ encoded = self.encoder_norm(hidden_states)
244
+ return encoded, attn_weights
245
+
246
+
247
+ class Transformer(nn.Module):
248
+ def __init__(self, config, img_size, vis):
249
+ super(Transformer, self).__init__()
250
+ self.embeddings = Embeddings(config, img_size=img_size)
251
+ self.encoder = Encoder(config, vis)
252
+
253
+ def forward(self, input_ids):
254
+ embedding_output, features = self.embeddings(input_ids)
255
+ encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden)
256
+ return encoded, attn_weights, features
257
+
258
+
259
+ class Conv2dReLU(nn.Sequential):
260
+ def __init__(
261
+ self,
262
+ in_channels,
263
+ out_channels,
264
+ kernel_size,
265
+ padding=0,
266
+ stride=1,
267
+ use_batchnorm=True,
268
+ ):
269
+ conv = nn.Conv2d(
270
+ in_channels,
271
+ out_channels,
272
+ kernel_size,
273
+ stride=stride,
274
+ padding=padding,
275
+ bias=not (use_batchnorm),
276
+ )
277
+ relu = nn.ReLU(inplace=True)
278
+
279
+ bn = nn.BatchNorm2d(out_channels)
280
+
281
+ super(Conv2dReLU, self).__init__(conv, bn, relu)
282
+
283
+
284
+ class DecoderBlock(nn.Module):
285
+ def __init__(
286
+ self,
287
+ in_channels,
288
+ out_channels,
289
+ skip_channels=0,
290
+ use_batchnorm=True,
291
+ ):
292
+ super().__init__()
293
+ self.conv1 = Conv2dReLU(
294
+ in_channels + skip_channels,
295
+ out_channels,
296
+ kernel_size=3,
297
+ padding=1,
298
+ use_batchnorm=use_batchnorm,
299
+ )
300
+ self.conv2 = Conv2dReLU(
301
+ out_channels,
302
+ out_channels,
303
+ kernel_size=3,
304
+ padding=1,
305
+ use_batchnorm=use_batchnorm,
306
+ )
307
+ self.up = nn.UpsamplingBilinear2d(scale_factor=2)
308
+
309
+ def forward(self, x, skip=None):
310
+ x = self.up(x)
311
+ if skip is not None:
312
+ x = torch.cat([x, skip], dim=1)
313
+ x = self.conv1(x)
314
+ x = self.conv2(x)
315
+ return x
316
+
317
+
318
+ class SegmentationHead(nn.Sequential):
319
+
320
+ def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
321
+ conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
322
+ upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
323
+ super().__init__(conv2d, upsampling)
324
+
325
+
326
+ class DecoderCup(nn.Module):
327
+ def __init__(self, config):
328
+ super().__init__()
329
+ self.config = config
330
+ head_channels = 512
331
+ self.conv_more = Conv2dReLU(
332
+ config.hidden_size,
333
+ head_channels,
334
+ kernel_size=3,
335
+ padding=1,
336
+ use_batchnorm=True,
337
+ )
338
+ decoder_channels = config.decoder_channels
339
+ in_channels = [head_channels] + list(decoder_channels[:-1])
340
+ out_channels = decoder_channels
341
+
342
+ if self.config.n_skip != 0:
343
+ skip_channels = self.config.skip_channels
344
+ for i in range(4-self.config.n_skip): # re-select the skip channels according to n_skip
345
+ skip_channels[3-i]=0
346
+
347
+ else:
348
+ skip_channels=[0,0,0,0]
349
+
350
+ blocks = [
351
+ DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels)
352
+ ]
353
+ self.blocks = nn.ModuleList(blocks)
354
+
355
+ def forward(self, hidden_states, features=None):
356
+ B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
357
+ h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
358
+ x = hidden_states.permute(0, 2, 1)
359
+ x = x.contiguous().view(B, hidden, h, w)
360
+ x = self.conv_more(x)
361
+ for i, decoder_block in enumerate(self.blocks):
362
+ if features is not None:
363
+ skip = features[i] if (i < self.config.n_skip) else None
364
+ else:
365
+ skip = None
366
+ x = decoder_block(x, skip=skip)
367
+ return x
368
+
369
+
370
+ class VisionTransformer(nn.Module):
371
+ def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):
372
+ super(VisionTransformer, self).__init__()
373
+ self.num_classes = num_classes
374
+ self.zero_head = zero_head
375
+ self.classifier = config.classifier
376
+ self.transformer = Transformer(config, img_size, vis)
377
+ self.decoder = DecoderCup(config)
378
+ self.segmentation_head = SegmentationHead(
379
+ in_channels=config['decoder_channels'][-1],
380
+ out_channels=config['n_classes'],
381
+ kernel_size=3,
382
+ )
383
+ self.config = config
384
+
385
+ def forward(self, x):
386
+ if x.size()[1] == 1:
387
+ x = x.repeat(1,3,1,1)
388
+ x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden)
389
+ x = self.decoder(x, features)
390
+ logits = self.segmentation_head(x)
391
+ return logits
392
+
393
+ def load_from(self, weights):
394
+ with torch.no_grad():
395
+
396
+ res_weight = weights
397
+ self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True))
398
+ self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"]))
399
+
400
+ self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"]))
401
+ self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"]))
402
+
403
+ posemb = np2th(weights["Transformer/posembed_input/pos_embedding"])
404
+
405
+ posemb_new = self.transformer.embeddings.position_embeddings
406
+ if posemb.size() == posemb_new.size():
407
+ self.transformer.embeddings.position_embeddings.copy_(posemb)
408
+ elif posemb.size()[1]-1 == posemb_new.size()[1]:
409
+ posemb = posemb[:, 1:]
410
+ self.transformer.embeddings.position_embeddings.copy_(posemb)
411
+ else:
412
+ logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size()))
413
+ ntok_new = posemb_new.size(1)
414
+ if self.classifier == "seg":
415
+ _, posemb_grid = posemb[:, :1], posemb[0, 1:]
416
+ gs_old = int(np.sqrt(len(posemb_grid)))
417
+ gs_new = int(np.sqrt(ntok_new))
418
+ print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new))
419
+ posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)
420
+ zoom = (gs_new / gs_old, gs_new / gs_old, 1)
421
+ posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) # th2np
422
+ posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)
423
+ posemb = posemb_grid
424
+ self.transformer.embeddings.position_embeddings.copy_(np2th(posemb))
425
+
426
+ # Encoder whole
427
+ for bname, block in self.transformer.encoder.named_children():
428
+ for uname, unit in block.named_children():
429
+ unit.load_from(weights, n_block=uname)
430
+
431
+ if self.transformer.embeddings.hybrid:
432
+ self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(res_weight["conv_root/kernel"], conv=True))
433
+ gn_weight = np2th(res_weight["gn_root/scale"]).view(-1)
434
+ gn_bias = np2th(res_weight["gn_root/bias"]).view(-1)
435
+ self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight)
436
+ self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias)
437
+
438
+ for bname, block in self.transformer.embeddings.hybrid_model.body.named_children():
439
+ for uname, unit in block.named_children():
440
+ unit.load_from(res_weight, n_block=bname, n_unit=uname)
441
+
442
+ CONFIGS = {
443
+ 'ViT-B_16': configs.get_b16_config(),
444
+ 'ViT-B_32': configs.get_b32_config(),
445
+ 'ViT-L_16': configs.get_l16_config(),
446
+ 'ViT-L_32': configs.get_l32_config(),
447
+ 'ViT-H_14': configs.get_h14_config(),
448
+ 'R50-ViT-B_16': configs.get_r50_b16_config(),
449
+ 'R50-ViT-L_16': configs.get_r50_l16_config(),
450
+ 'testing': configs.get_testing(),
451
+ }
452
+
453
+
models/_transunet/vit_seg_modeling_c4.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ from __future__ import absolute_import
3
+ from __future__ import division
4
+ from __future__ import print_function
5
+
6
+ import copy
7
+ import logging
8
+ import math
9
+
10
+ from os.path import join as pjoin
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import numpy as np
15
+
16
+ from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
17
+ from torch.nn.modules.utils import _pair
18
+ from scipy import ndimage
19
+ from . import vit_seg_configs as configs
20
+ from .vit_seg_modeling_resnet_skip_c4 import ResNetV2
21
+
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
27
+ ATTENTION_K = "MultiHeadDotProductAttention_1/key"
28
+ ATTENTION_V = "MultiHeadDotProductAttention_1/value"
29
+ ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
30
+ FC_0 = "MlpBlock_3/Dense_0"
31
+ FC_1 = "MlpBlock_3/Dense_1"
32
+ ATTENTION_NORM = "LayerNorm_0"
33
+ MLP_NORM = "LayerNorm_2"
34
+
35
+
36
+ def np2th(weights, conv=False):
37
+ """Possibly convert HWIO to OIHW."""
38
+ if conv:
39
+ weights = weights.transpose([3, 2, 0, 1])
40
+ return torch.from_numpy(weights)
41
+
42
+
43
+ def swish(x):
44
+ return x * torch.sigmoid(x)
45
+
46
+
47
+ ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}
48
+
49
+
50
+ class Attention(nn.Module):
51
+ def __init__(self, config, vis):
52
+ super(Attention, self).__init__()
53
+ self.vis = vis
54
+ self.num_attention_heads = config.transformer["num_heads"]
55
+ self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
56
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
57
+
58
+ self.query = Linear(config.hidden_size, self.all_head_size)
59
+ self.key = Linear(config.hidden_size, self.all_head_size)
60
+ self.value = Linear(config.hidden_size, self.all_head_size)
61
+
62
+ self.out = Linear(config.hidden_size, config.hidden_size)
63
+ self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
64
+ self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])
65
+
66
+ self.softmax = Softmax(dim=-1)
67
+
68
+ def transpose_for_scores(self, x):
69
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
70
+ x = x.view(*new_x_shape)
71
+ return x.permute(0, 2, 1, 3)
72
+
73
+ def forward(self, hidden_states):
74
+ mixed_query_layer = self.query(hidden_states)
75
+ mixed_key_layer = self.key(hidden_states)
76
+ mixed_value_layer = self.value(hidden_states)
77
+
78
+ query_layer = self.transpose_for_scores(mixed_query_layer)
79
+ key_layer = self.transpose_for_scores(mixed_key_layer)
80
+ value_layer = self.transpose_for_scores(mixed_value_layer)
81
+
82
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
83
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
84
+ attention_probs = self.softmax(attention_scores)
85
+ weights = attention_probs if self.vis else None
86
+ attention_probs = self.attn_dropout(attention_probs)
87
+
88
+ context_layer = torch.matmul(attention_probs, value_layer)
89
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
90
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
91
+ context_layer = context_layer.view(*new_context_layer_shape)
92
+ attention_output = self.out(context_layer)
93
+ attention_output = self.proj_dropout(attention_output)
94
+ return attention_output, weights
95
+
96
+
97
+ class Mlp(nn.Module):
98
+ def __init__(self, config):
99
+ super(Mlp, self).__init__()
100
+ self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])
101
+ self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)
102
+ self.act_fn = ACT2FN["gelu"]
103
+ self.dropout = Dropout(config.transformer["dropout_rate"])
104
+
105
+ self._init_weights()
106
+
107
+ def _init_weights(self):
108
+ nn.init.xavier_uniform_(self.fc1.weight)
109
+ nn.init.xavier_uniform_(self.fc2.weight)
110
+ nn.init.normal_(self.fc1.bias, std=1e-6)
111
+ nn.init.normal_(self.fc2.bias, std=1e-6)
112
+
113
+ def forward(self, x):
114
+ x = self.fc1(x)
115
+ x = self.act_fn(x)
116
+ x = self.dropout(x)
117
+ x = self.fc2(x)
118
+ x = self.dropout(x)
119
+ return x
120
+
121
+
122
+ class Embeddings(nn.Module):
123
+ """Construct the embeddings from patch, position embeddings.
124
+ """
125
+ def __init__(self, config, img_size, in_channels=4):
126
+ super(Embeddings, self).__init__()
127
+ self.hybrid = None
128
+ self.config = config
129
+ img_size = _pair(img_size)
130
+
131
+ if config.patches.get("grid") is not None: # ResNet
132
+ grid_size = config.patches["grid"]
133
+ patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])
134
+ patch_size_real = (patch_size[0] * 16, patch_size[1] * 16)
135
+ n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1])
136
+ self.hybrid = True
137
+ else:
138
+ patch_size = _pair(config.patches["size"])
139
+ n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
140
+ self.hybrid = False
141
+
142
+ if self.hybrid:
143
+ self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor)
144
+ in_channels = self.hybrid_model.width * 16
145
+ self.patch_embeddings = Conv2d(in_channels=in_channels,
146
+ out_channels=config.hidden_size,
147
+ kernel_size=patch_size,
148
+ stride=patch_size)
149
+ self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size))
150
+
151
+ self.dropout = Dropout(config.transformer["dropout_rate"])
152
+
153
+
154
+ def forward(self, x):
155
+ if self.hybrid:
156
+ x, features = self.hybrid_model(x)
157
+ else:
158
+ features = None
159
+ x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2))
160
+ x = x.flatten(2)
161
+ x = x.transpose(-1, -2) # (B, n_patches, hidden)
162
+
163
+ embeddings = x + self.position_embeddings
164
+ embeddings = self.dropout(embeddings)
165
+ return embeddings, features
166
+
167
+
168
+ class Block(nn.Module):
169
+ def __init__(self, config, vis):
170
+ super(Block, self).__init__()
171
+ self.hidden_size = config.hidden_size
172
+ self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
173
+ self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
174
+ self.ffn = Mlp(config)
175
+ self.attn = Attention(config, vis)
176
+
177
+ def forward(self, x):
178
+ h = x
179
+ x = self.attention_norm(x)
180
+ x, weights = self.attn(x)
181
+ x = x + h
182
+
183
+ h = x
184
+ x = self.ffn_norm(x)
185
+ x = self.ffn(x)
186
+ x = x + h
187
+ return x, weights
188
+
189
+ def load_from(self, weights, n_block):
190
+ ROOT = f"Transformer/encoderblock_{n_block}"
191
+ with torch.no_grad():
192
+ query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t()
193
+ key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t()
194
+ value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t()
195
+ out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t()
196
+
197
+ query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1)
198
+ key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1)
199
+ value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1)
200
+ out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1)
201
+
202
+ self.attn.query.weight.copy_(query_weight)
203
+ self.attn.key.weight.copy_(key_weight)
204
+ self.attn.value.weight.copy_(value_weight)
205
+ self.attn.out.weight.copy_(out_weight)
206
+ self.attn.query.bias.copy_(query_bias)
207
+ self.attn.key.bias.copy_(key_bias)
208
+ self.attn.value.bias.copy_(value_bias)
209
+ self.attn.out.bias.copy_(out_bias)
210
+
211
+ mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t()
212
+ mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t()
213
+ mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t()
214
+ mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t()
215
+
216
+ self.ffn.fc1.weight.copy_(mlp_weight_0)
217
+ self.ffn.fc2.weight.copy_(mlp_weight_1)
218
+ self.ffn.fc1.bias.copy_(mlp_bias_0)
219
+ self.ffn.fc2.bias.copy_(mlp_bias_1)
220
+
221
+ self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")]))
222
+ self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")]))
223
+ self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")]))
224
+ self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")]))
225
+
226
+
227
+ class Encoder(nn.Module):
228
+ def __init__(self, config, vis):
229
+ super(Encoder, self).__init__()
230
+ self.vis = vis
231
+ self.layer = nn.ModuleList()
232
+ self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
233
+ for _ in range(config.transformer["num_layers"]):
234
+ layer = Block(config, vis)
235
+ self.layer.append(copy.deepcopy(layer))
236
+
237
+ def forward(self, hidden_states):
238
+ attn_weights = []
239
+ for layer_block in self.layer:
240
+ hidden_states, weights = layer_block(hidden_states)
241
+ if self.vis:
242
+ attn_weights.append(weights)
243
+ encoded = self.encoder_norm(hidden_states)
244
+ return encoded, attn_weights
245
+
246
+
247
+ class Transformer(nn.Module):
248
+ def __init__(self, config, img_size, vis):
249
+ super(Transformer, self).__init__()
250
+ self.embeddings = Embeddings(config, img_size=img_size)
251
+ self.encoder = Encoder(config, vis)
252
+
253
+ def forward(self, input_ids):
254
+ embedding_output, features = self.embeddings(input_ids)
255
+ encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden)
256
+ return encoded, attn_weights, features
257
+
258
+
259
+ class Conv2dReLU(nn.Sequential):
260
+ def __init__(
261
+ self,
262
+ in_channels,
263
+ out_channels,
264
+ kernel_size,
265
+ padding=0,
266
+ stride=1,
267
+ use_batchnorm=True,
268
+ ):
269
+ conv = nn.Conv2d(
270
+ in_channels,
271
+ out_channels,
272
+ kernel_size,
273
+ stride=stride,
274
+ padding=padding,
275
+ bias=not (use_batchnorm),
276
+ )
277
+ relu = nn.ReLU(inplace=True)
278
+
279
+ bn = nn.BatchNorm2d(out_channels)
280
+
281
+ super(Conv2dReLU, self).__init__(conv, bn, relu)
282
+
283
+
284
+ class DecoderBlock(nn.Module):
285
+ def __init__(
286
+ self,
287
+ in_channels,
288
+ out_channels,
289
+ skip_channels=0,
290
+ use_batchnorm=True,
291
+ ):
292
+ super().__init__()
293
+ self.conv1 = Conv2dReLU(
294
+ in_channels + skip_channels,
295
+ out_channels,
296
+ kernel_size=3,
297
+ padding=1,
298
+ use_batchnorm=use_batchnorm,
299
+ )
300
+ self.conv2 = Conv2dReLU(
301
+ out_channels,
302
+ out_channels,
303
+ kernel_size=3,
304
+ padding=1,
305
+ use_batchnorm=use_batchnorm,
306
+ )
307
+ self.up = nn.UpsamplingBilinear2d(scale_factor=2)
308
+
309
+ def forward(self, x, skip=None):
310
+ x = self.up(x)
311
+ if skip is not None:
312
+ x = torch.cat([x, skip], dim=1)
313
+ x = self.conv1(x)
314
+ x = self.conv2(x)
315
+ return x
316
+
317
+
318
+ class SegmentationHead(nn.Sequential):
319
+
320
+ def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
321
+ conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
322
+ upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
323
+ super().__init__(conv2d, upsampling)
324
+
325
+
326
+ class DecoderCup(nn.Module):
327
+ def __init__(self, config):
328
+ super().__init__()
329
+ self.config = config
330
+ head_channels = 512
331
+ self.conv_more = Conv2dReLU(
332
+ config.hidden_size,
333
+ head_channels,
334
+ kernel_size=3,
335
+ padding=1,
336
+ use_batchnorm=True,
337
+ )
338
+ decoder_channels = config.decoder_channels
339
+ in_channels = [head_channels] + list(decoder_channels[:-1])
340
+ out_channels = decoder_channels
341
+
342
+ if self.config.n_skip != 0:
343
+ skip_channels = self.config.skip_channels
344
+ for i in range(4-self.config.n_skip): # re-select the skip channels according to n_skip
345
+ skip_channels[3-i]=0
346
+
347
+ else:
348
+ skip_channels=[0,0,0,0]
349
+
350
+ blocks = [
351
+ DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels)
352
+ ]
353
+ self.blocks = nn.ModuleList(blocks)
354
+
355
+ def forward(self, hidden_states, features=None):
356
+ B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
357
+ h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
358
+ x = hidden_states.permute(0, 2, 1)
359
+ x = x.contiguous().view(B, hidden, h, w)
360
+ x = self.conv_more(x)
361
+ for i, decoder_block in enumerate(self.blocks):
362
+ if features is not None:
363
+ skip = features[i] if (i < self.config.n_skip) else None
364
+ else:
365
+ skip = None
366
+ x = decoder_block(x, skip=skip)
367
+ return x
368
+
369
+
370
+ class VisionTransformer(nn.Module):
371
+ def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):
372
+ super(VisionTransformer, self).__init__()
373
+ self.num_classes = num_classes
374
+ self.zero_head = zero_head
375
+ self.classifier = config.classifier
376
+ self.transformer = Transformer(config, img_size, vis)
377
+ self.decoder = DecoderCup(config)
378
+ self.segmentation_head = SegmentationHead(
379
+ in_channels=config['decoder_channels'][-1],
380
+ out_channels=config['n_classes'],
381
+ kernel_size=3,
382
+ )
383
+ self.config = config
384
+
385
+ def forward(self, x):
386
+ if x.size()[1] == 1:
387
+ x = x.repeat(1,4,1,1)
388
+ x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden)
389
+ x = self.decoder(x, features)
390
+ logits = self.segmentation_head(x)
391
+ return logits
392
+
393
+ def load_from(self, weights):
394
+ with torch.no_grad():
395
+
396
+ res_weight = weights
397
+ self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True))
398
+ self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"]))
399
+
400
+ self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"]))
401
+ self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"]))
402
+
403
+ posemb = np2th(weights["Transformer/posembed_input/pos_embedding"])
404
+
405
+ posemb_new = self.transformer.embeddings.position_embeddings
406
+ if posemb.size() == posemb_new.size():
407
+ self.transformer.embeddings.position_embeddings.copy_(posemb)
408
+ elif posemb.size()[1]-1 == posemb_new.size()[1]:
409
+ posemb = posemb[:, 1:]
410
+ self.transformer.embeddings.position_embeddings.copy_(posemb)
411
+ else:
412
+ logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size()))
413
+ ntok_new = posemb_new.size(1)
414
+ if self.classifier == "seg":
415
+ _, posemb_grid = posemb[:, :1], posemb[0, 1:]
416
+ gs_old = int(np.sqrt(len(posemb_grid)))
417
+ gs_new = int(np.sqrt(ntok_new))
418
+ print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new))
419
+ posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)
420
+ zoom = (gs_new / gs_old, gs_new / gs_old, 1)
421
+ posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) # th2np
422
+ posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)
423
+ posemb = posemb_grid
424
+ self.transformer.embeddings.position_embeddings.copy_(np2th(posemb))
425
+
426
+ # Encoder whole
427
+ for bname, block in self.transformer.encoder.named_children():
428
+ for uname, unit in block.named_children():
429
+ unit.load_from(weights, n_block=uname)
430
+
431
+ if self.transformer.embeddings.hybrid:
432
+ self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(res_weight["conv_root/kernel"], conv=True))
433
+ gn_weight = np2th(res_weight["gn_root/scale"]).view(-1)
434
+ gn_bias = np2th(res_weight["gn_root/bias"]).view(-1)
435
+ self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight)
436
+ self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias)
437
+
438
+ for bname, block in self.transformer.embeddings.hybrid_model.body.named_children():
439
+ for uname, unit in block.named_children():
440
+ unit.load_from(res_weight, n_block=bname, n_unit=uname)
441
+
442
+ CONFIGS = {
443
+ 'ViT-B_16': configs.get_b16_config(),
444
+ 'ViT-B_32': configs.get_b32_config(),
445
+ 'ViT-L_16': configs.get_l16_config(),
446
+ 'ViT-L_32': configs.get_l32_config(),
447
+ 'ViT-H_14': configs.get_h14_config(),
448
+ 'R50-ViT-B_16': configs.get_r50_b16_config(),
449
+ 'R50-ViT-L_16': configs.get_r50_l16_config(),
450
+ 'testing': configs.get_testing(),
451
+ }
452
+
453
+
models/_transunet/vit_seg_modeling_resnet_skip.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ from os.path import join as pjoin
4
+ from collections import OrderedDict
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+
11
+ def np2th(weights, conv=False):
12
+ """Possibly convert HWIO to OIHW."""
13
+ if conv:
14
+ weights = weights.transpose([3, 2, 0, 1])
15
+ return torch.from_numpy(weights)
16
+
17
+
18
+ class StdConv2d(nn.Conv2d):
19
+
20
+ def forward(self, x):
21
+ w = self.weight
22
+ v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
23
+ w = (w - m) / torch.sqrt(v + 1e-5)
24
+ return F.conv2d(x, w, self.bias, self.stride, self.padding,
25
+ self.dilation, self.groups)
26
+
27
+
28
+ def conv3x3(cin, cout, stride=1, groups=1, bias=False):
29
+ return StdConv2d(cin, cout, kernel_size=3, stride=stride,
30
+ padding=1, bias=bias, groups=groups)
31
+
32
+
33
+ def conv1x1(cin, cout, stride=1, bias=False):
34
+ return StdConv2d(cin, cout, kernel_size=1, stride=stride,
35
+ padding=0, bias=bias)
36
+
37
+
38
+ class PreActBottleneck(nn.Module):
39
+ """Pre-activation (v2) bottleneck block.
40
+ """
41
+
42
+ def __init__(self, cin, cout=None, cmid=None, stride=1):
43
+ super().__init__()
44
+ cout = cout or cin
45
+ cmid = cmid or cout//4
46
+
47
+ self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6)
48
+ self.conv1 = conv1x1(cin, cmid, bias=False)
49
+ self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6)
50
+ self.conv2 = conv3x3(cmid, cmid, stride, bias=False) # Original code has it on conv1!!
51
+ self.gn3 = nn.GroupNorm(32, cout, eps=1e-6)
52
+ self.conv3 = conv1x1(cmid, cout, bias=False)
53
+ self.relu = nn.ReLU(inplace=True)
54
+
55
+ if (stride != 1 or cin != cout):
56
+ # Projection also with pre-activation according to paper.
57
+ self.downsample = conv1x1(cin, cout, stride, bias=False)
58
+ self.gn_proj = nn.GroupNorm(cout, cout)
59
+
60
+ def forward(self, x):
61
+
62
+ # Residual branch
63
+ residual = x
64
+ if hasattr(self, 'downsample'):
65
+ residual = self.downsample(x)
66
+ residual = self.gn_proj(residual)
67
+
68
+ # Unit's branch
69
+ y = self.relu(self.gn1(self.conv1(x)))
70
+ y = self.relu(self.gn2(self.conv2(y)))
71
+ y = self.gn3(self.conv3(y))
72
+
73
+ y = self.relu(residual + y)
74
+ return y
75
+
76
+ def load_from(self, weights, n_block, n_unit):
77
+ conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True)
78
+ conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True)
79
+ conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True)
80
+
81
+ gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")])
82
+ gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")])
83
+
84
+ gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")])
85
+ gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")])
86
+
87
+ gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")])
88
+ gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")])
89
+
90
+ self.conv1.weight.copy_(conv1_weight)
91
+ self.conv2.weight.copy_(conv2_weight)
92
+ self.conv3.weight.copy_(conv3_weight)
93
+
94
+ self.gn1.weight.copy_(gn1_weight.view(-1))
95
+ self.gn1.bias.copy_(gn1_bias.view(-1))
96
+
97
+ self.gn2.weight.copy_(gn2_weight.view(-1))
98
+ self.gn2.bias.copy_(gn2_bias.view(-1))
99
+
100
+ self.gn3.weight.copy_(gn3_weight.view(-1))
101
+ self.gn3.bias.copy_(gn3_bias.view(-1))
102
+
103
+ if hasattr(self, 'downsample'):
104
+ proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True)
105
+ proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")])
106
+ proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")])
107
+
108
+ self.downsample.weight.copy_(proj_conv_weight)
109
+ self.gn_proj.weight.copy_(proj_gn_weight.view(-1))
110
+ self.gn_proj.bias.copy_(proj_gn_bias.view(-1))
111
+
112
+ class ResNetV2(nn.Module):
113
+ """Implementation of Pre-activation (v2) ResNet mode."""
114
+
115
+ def __init__(self, block_units, width_factor):
116
+ super().__init__()
117
+ width = int(64 * width_factor)
118
+ self.width = width
119
+
120
+ self.root = nn.Sequential(OrderedDict([
121
+ ('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)),
122
+ ('gn', nn.GroupNorm(32, width, eps=1e-6)),
123
+ ('relu', nn.ReLU(inplace=True)),
124
+ # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0))
125
+ ]))
126
+
127
+ self.body = nn.Sequential(OrderedDict([
128
+ ('block1', nn.Sequential(OrderedDict(
129
+ [('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] +
130
+ [(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)],
131
+ ))),
132
+ ('block2', nn.Sequential(OrderedDict(
133
+ [('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] +
134
+ [(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)],
135
+ ))),
136
+ ('block3', nn.Sequential(OrderedDict(
137
+ [('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] +
138
+ [(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)],
139
+ ))),
140
+ ]))
141
+
142
+ def forward(self, x):
143
+ features = []
144
+ b, c, in_size, _ = x.size()
145
+ x = self.root(x)
146
+ features.append(x)
147
+ x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x)
148
+ for i in range(len(self.body)-1):
149
+ x = self.body[i](x)
150
+ right_size = int(in_size / 4 / (i+1))
151
+ if x.size()[2] != right_size:
152
+ pad = right_size - x.size()[2]
153
+ assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), right_size)
154
+ feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device)
155
+ feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:]
156
+ else:
157
+ feat = x
158
+ features.append(feat)
159
+ x = self.body[-1](x)
160
+ return x, features[::-1]
models/_transunet/vit_seg_modeling_resnet_skip_c4.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ from os.path import join as pjoin
4
+ from collections import OrderedDict
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+
11
+ def np2th(weights, conv=False):
12
+ """Possibly convert HWIO to OIHW."""
13
+ if conv:
14
+ weights = weights.transpose([3, 2, 0, 1])
15
+ return torch.from_numpy(weights)
16
+
17
+
18
+ class StdConv2d(nn.Conv2d):
19
+
20
+ def forward(self, x):
21
+ w = self.weight
22
+ v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
23
+ w = (w - m) / torch.sqrt(v + 1e-5)
24
+ return F.conv2d(x, w, self.bias, self.stride, self.padding,
25
+ self.dilation, self.groups)
26
+
27
+
28
+ def conv3x3(cin, cout, stride=1, groups=1, bias=False):
29
+ return StdConv2d(cin, cout, kernel_size=3, stride=stride,
30
+ padding=1, bias=bias, groups=groups)
31
+
32
+
33
+ def conv1x1(cin, cout, stride=1, bias=False):
34
+ return StdConv2d(cin, cout, kernel_size=1, stride=stride,
35
+ padding=0, bias=bias)
36
+
37
+
38
+ class PreActBottleneck(nn.Module):
39
+ """Pre-activation (v2) bottleneck block.
40
+ """
41
+
42
+ def __init__(self, cin, cout=None, cmid=None, stride=1):
43
+ super().__init__()
44
+ cout = cout or cin
45
+ cmid = cmid or cout//4
46
+
47
+ self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6)
48
+ self.conv1 = conv1x1(cin, cmid, bias=False)
49
+ self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6)
50
+ self.conv2 = conv3x3(cmid, cmid, stride, bias=False) # Original code has it on conv1!!
51
+ self.gn3 = nn.GroupNorm(32, cout, eps=1e-6)
52
+ self.conv3 = conv1x1(cmid, cout, bias=False)
53
+ self.relu = nn.ReLU(inplace=True)
54
+
55
+ if (stride != 1 or cin != cout):
56
+ # Projection also with pre-activation according to paper.
57
+ self.downsample = conv1x1(cin, cout, stride, bias=False)
58
+ self.gn_proj = nn.GroupNorm(cout, cout)
59
+
60
+ def forward(self, x):
61
+
62
+ # Residual branch
63
+ residual = x
64
+ if hasattr(self, 'downsample'):
65
+ residual = self.downsample(x)
66
+ residual = self.gn_proj(residual)
67
+
68
+ # Unit's branch
69
+ y = self.relu(self.gn1(self.conv1(x)))
70
+ y = self.relu(self.gn2(self.conv2(y)))
71
+ y = self.gn3(self.conv3(y))
72
+
73
+ y = self.relu(residual + y)
74
+ return y
75
+
76
+ def load_from(self, weights, n_block, n_unit):
77
+ conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True)
78
+ conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True)
79
+ conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True)
80
+
81
+ gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")])
82
+ gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")])
83
+
84
+ gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")])
85
+ gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")])
86
+
87
+ gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")])
88
+ gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")])
89
+
90
+ self.conv1.weight.copy_(conv1_weight)
91
+ self.conv2.weight.copy_(conv2_weight)
92
+ self.conv3.weight.copy_(conv3_weight)
93
+
94
+ self.gn1.weight.copy_(gn1_weight.view(-1))
95
+ self.gn1.bias.copy_(gn1_bias.view(-1))
96
+
97
+ self.gn2.weight.copy_(gn2_weight.view(-1))
98
+ self.gn2.bias.copy_(gn2_bias.view(-1))
99
+
100
+ self.gn3.weight.copy_(gn3_weight.view(-1))
101
+ self.gn3.bias.copy_(gn3_bias.view(-1))
102
+
103
+ if hasattr(self, 'downsample'):
104
+ proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True)
105
+ proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")])
106
+ proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")])
107
+
108
+ self.downsample.weight.copy_(proj_conv_weight)
109
+ self.gn_proj.weight.copy_(proj_gn_weight.view(-1))
110
+ self.gn_proj.bias.copy_(proj_gn_bias.view(-1))
111
+
112
+ class ResNetV2(nn.Module):
113
+ """Implementation of Pre-activation (v2) ResNet mode."""
114
+
115
+ def __init__(self, block_units, width_factor):
116
+ super().__init__()
117
+ width = int(64 * width_factor)
118
+ self.width = width
119
+
120
+ self.root = nn.Sequential(OrderedDict([
121
+ ('conv', StdConv2d(4, width, kernel_size=7, stride=2, bias=False, padding=3)),
122
+ ('gn', nn.GroupNorm(32, width, eps=1e-6)),
123
+ ('relu', nn.ReLU(inplace=True)),
124
+ # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0))
125
+ ]))
126
+
127
+ self.body = nn.Sequential(OrderedDict([
128
+ ('block1', nn.Sequential(OrderedDict(
129
+ [('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] +
130
+ [(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)],
131
+ ))),
132
+ ('block2', nn.Sequential(OrderedDict(
133
+ [('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] +
134
+ [(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)],
135
+ ))),
136
+ ('block3', nn.Sequential(OrderedDict(
137
+ [('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] +
138
+ [(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)],
139
+ ))),
140
+ ]))
141
+
142
+ def forward(self, x):
143
+ features = []
144
+ b, c, in_size, _ = x.size()
145
+ x = self.root(x)
146
+ features.append(x)
147
+ x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x)
148
+ for i in range(len(self.body)-1):
149
+ x = self.body[i](x)
150
+ right_size = int(in_size / 4 / (i+1))
151
+ if x.size()[2] != right_size:
152
+ pad = right_size - x.size()[2]
153
+ assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), right_size)
154
+ feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device)
155
+ feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:]
156
+ else:
157
+ feat = x
158
+ features.append(feat)
159
+ x = self.body[-1](x)
160
+ return x, features[::-1]
models/_uctransnet/CTrans.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Author : Haonan Wang
3
+ # @File : CTrans.py
4
+ # @Software: PyCharm
5
+ # coding=utf-8
6
+ from __future__ import absolute_import
7
+ from __future__ import division
8
+ from __future__ import print_function
9
+ import copy
10
+ import logging
11
+ import math
12
+ import torch
13
+ import torch.nn as nn
14
+ import numpy as np
15
+ from torch.nn import Dropout, Softmax, Conv2d, LayerNorm
16
+ from torch.nn.modules.utils import _pair
17
+
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ class Channel_Embeddings(nn.Module):
22
+ """Construct the embeddings from patch, position embeddings.
23
+ """
24
+ def __init__(self,config, patchsize, img_size, in_channels):
25
+ super().__init__()
26
+ img_size = _pair(img_size)
27
+ patch_size = _pair(patchsize)
28
+ n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
29
+
30
+ self.patch_embeddings = Conv2d(in_channels=in_channels,
31
+ out_channels=in_channels,
32
+ kernel_size=patch_size,
33
+ stride=patch_size)
34
+ self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, in_channels))
35
+ self.dropout = Dropout(config.transformer["embeddings_dropout_rate"])
36
+
37
+ def forward(self, x):
38
+ if x is None:
39
+ return None
40
+ x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2))
41
+ x = x.flatten(2)
42
+ x = x.transpose(-1, -2) # (B, n_patches, hidden)
43
+ embeddings = x + self.position_embeddings
44
+ embeddings = self.dropout(embeddings)
45
+ return embeddings
46
+
47
+ class Reconstruct(nn.Module):
48
+ def __init__(self, in_channels, out_channels, kernel_size, scale_factor):
49
+ super(Reconstruct, self).__init__()
50
+ if kernel_size == 3:
51
+ padding = 1
52
+ else:
53
+ padding = 0
54
+ self.conv = nn.Conv2d(in_channels, out_channels,kernel_size=kernel_size, padding=padding)
55
+ self.norm = nn.BatchNorm2d(out_channels)
56
+ self.activation = nn.ReLU(inplace=True)
57
+ self.scale_factor = scale_factor
58
+
59
+ def forward(self, x):
60
+ if x is None:
61
+ return None
62
+
63
+ B, n_patch, hidden = x.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
64
+ h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
65
+ x = x.permute(0, 2, 1)
66
+ x = x.contiguous().view(B, hidden, h, w)
67
+ x = nn.Upsample(scale_factor=self.scale_factor)(x)
68
+
69
+ out = self.conv(x)
70
+ out = self.norm(out)
71
+ out = self.activation(out)
72
+ return out
73
+
74
+ class Attention_org(nn.Module):
75
+ def __init__(self, config, vis,channel_num):
76
+ super(Attention_org, self).__init__()
77
+ self.vis = vis
78
+ self.KV_size = config.KV_size
79
+ self.channel_num = channel_num
80
+ self.num_attention_heads = config.transformer["num_heads"]
81
+
82
+ self.query1 = nn.ModuleList()
83
+ self.query2 = nn.ModuleList()
84
+ self.query3 = nn.ModuleList()
85
+ self.query4 = nn.ModuleList()
86
+ self.key = nn.ModuleList()
87
+ self.value = nn.ModuleList()
88
+
89
+ for _ in range(config.transformer["num_heads"]):
90
+ query1 = nn.Linear(channel_num[0], channel_num[0], bias=False)
91
+ query2 = nn.Linear(channel_num[1], channel_num[1], bias=False)
92
+ query3 = nn.Linear(channel_num[2], channel_num[2], bias=False)
93
+ query4 = nn.Linear(channel_num[3], channel_num[3], bias=False)
94
+ key = nn.Linear( self.KV_size, self.KV_size, bias=False)
95
+ value = nn.Linear(self.KV_size, self.KV_size, bias=False)
96
+ self.query1.append(copy.deepcopy(query1))
97
+ self.query2.append(copy.deepcopy(query2))
98
+ self.query3.append(copy.deepcopy(query3))
99
+ self.query4.append(copy.deepcopy(query4))
100
+ self.key.append(copy.deepcopy(key))
101
+ self.value.append(copy.deepcopy(value))
102
+ self.psi = nn.InstanceNorm2d(self.num_attention_heads)
103
+ self.softmax = Softmax(dim=3)
104
+ self.out1 = nn.Linear(channel_num[0], channel_num[0], bias=False)
105
+ self.out2 = nn.Linear(channel_num[1], channel_num[1], bias=False)
106
+ self.out3 = nn.Linear(channel_num[2], channel_num[2], bias=False)
107
+ self.out4 = nn.Linear(channel_num[3], channel_num[3], bias=False)
108
+ self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
109
+ self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])
110
+
111
+
112
+
113
+ def forward(self, emb1,emb2,emb3,emb4, emb_all):
114
+ multi_head_Q1_list = []
115
+ multi_head_Q2_list = []
116
+ multi_head_Q3_list = []
117
+ multi_head_Q4_list = []
118
+ multi_head_K_list = []
119
+ multi_head_V_list = []
120
+ if emb1 is not None:
121
+ for query1 in self.query1:
122
+ Q1 = query1(emb1)
123
+ multi_head_Q1_list.append(Q1)
124
+ if emb2 is not None:
125
+ for query2 in self.query2:
126
+ Q2 = query2(emb2)
127
+ multi_head_Q2_list.append(Q2)
128
+ if emb3 is not None:
129
+ for query3 in self.query3:
130
+ Q3 = query3(emb3)
131
+ multi_head_Q3_list.append(Q3)
132
+ if emb4 is not None:
133
+ for query4 in self.query4:
134
+ Q4 = query4(emb4)
135
+ multi_head_Q4_list.append(Q4)
136
+ for key in self.key:
137
+ K = key(emb_all)
138
+ multi_head_K_list.append(K)
139
+ for value in self.value:
140
+ V = value(emb_all)
141
+ multi_head_V_list.append(V)
142
+ # print(len(multi_head_Q4_list))
143
+
144
+ multi_head_Q1 = torch.stack(multi_head_Q1_list, dim=1) if emb1 is not None else None
145
+ multi_head_Q2 = torch.stack(multi_head_Q2_list, dim=1) if emb2 is not None else None
146
+ multi_head_Q3 = torch.stack(multi_head_Q3_list, dim=1) if emb3 is not None else None
147
+ multi_head_Q4 = torch.stack(multi_head_Q4_list, dim=1) if emb4 is not None else None
148
+ multi_head_K = torch.stack(multi_head_K_list, dim=1)
149
+ multi_head_V = torch.stack(multi_head_V_list, dim=1)
150
+
151
+ multi_head_Q1 = multi_head_Q1.transpose(-1, -2) if emb1 is not None else None
152
+ multi_head_Q2 = multi_head_Q2.transpose(-1, -2) if emb2 is not None else None
153
+ multi_head_Q3 = multi_head_Q3.transpose(-1, -2) if emb3 is not None else None
154
+ multi_head_Q4 = multi_head_Q4.transpose(-1, -2) if emb4 is not None else None
155
+
156
+ attention_scores1 = torch.matmul(multi_head_Q1, multi_head_K) if emb1 is not None else None
157
+ attention_scores2 = torch.matmul(multi_head_Q2, multi_head_K) if emb2 is not None else None
158
+ attention_scores3 = torch.matmul(multi_head_Q3, multi_head_K) if emb3 is not None else None
159
+ attention_scores4 = torch.matmul(multi_head_Q4, multi_head_K) if emb4 is not None else None
160
+
161
+ attention_scores1 = attention_scores1 / math.sqrt(self.KV_size) if emb1 is not None else None
162
+ attention_scores2 = attention_scores2 / math.sqrt(self.KV_size) if emb2 is not None else None
163
+ attention_scores3 = attention_scores3 / math.sqrt(self.KV_size) if emb3 is not None else None
164
+ attention_scores4 = attention_scores4 / math.sqrt(self.KV_size) if emb4 is not None else None
165
+
166
+ attention_probs1 = self.softmax(self.psi(attention_scores1)) if emb1 is not None else None
167
+ attention_probs2 = self.softmax(self.psi(attention_scores2)) if emb2 is not None else None
168
+ attention_probs3 = self.softmax(self.psi(attention_scores3)) if emb3 is not None else None
169
+ attention_probs4 = self.softmax(self.psi(attention_scores4)) if emb4 is not None else None
170
+ # print(attention_probs4.size())
171
+
172
+ if self.vis:
173
+ weights = []
174
+ weights.append(attention_probs1.mean(1))
175
+ weights.append(attention_probs2.mean(1))
176
+ weights.append(attention_probs3.mean(1))
177
+ weights.append(attention_probs4.mean(1))
178
+ else: weights=None
179
+
180
+ attention_probs1 = self.attn_dropout(attention_probs1) if emb1 is not None else None
181
+ attention_probs2 = self.attn_dropout(attention_probs2) if emb2 is not None else None
182
+ attention_probs3 = self.attn_dropout(attention_probs3) if emb3 is not None else None
183
+ attention_probs4 = self.attn_dropout(attention_probs4) if emb4 is not None else None
184
+
185
+ multi_head_V = multi_head_V.transpose(-1, -2)
186
+ context_layer1 = torch.matmul(attention_probs1, multi_head_V) if emb1 is not None else None
187
+ context_layer2 = torch.matmul(attention_probs2, multi_head_V) if emb2 is not None else None
188
+ context_layer3 = torch.matmul(attention_probs3, multi_head_V) if emb3 is not None else None
189
+ context_layer4 = torch.matmul(attention_probs4, multi_head_V) if emb4 is not None else None
190
+
191
+ context_layer1 = context_layer1.permute(0, 3, 2, 1).contiguous() if emb1 is not None else None
192
+ context_layer2 = context_layer2.permute(0, 3, 2, 1).contiguous() if emb2 is not None else None
193
+ context_layer3 = context_layer3.permute(0, 3, 2, 1).contiguous() if emb3 is not None else None
194
+ context_layer4 = context_layer4.permute(0, 3, 2, 1).contiguous() if emb4 is not None else None
195
+ context_layer1 = context_layer1.mean(dim=3) if emb1 is not None else None
196
+ context_layer2 = context_layer2.mean(dim=3) if emb2 is not None else None
197
+ context_layer3 = context_layer3.mean(dim=3) if emb3 is not None else None
198
+ context_layer4 = context_layer4.mean(dim=3) if emb4 is not None else None
199
+
200
+ O1 = self.out1(context_layer1) if emb1 is not None else None
201
+ O2 = self.out2(context_layer2) if emb2 is not None else None
202
+ O3 = self.out3(context_layer3) if emb3 is not None else None
203
+ O4 = self.out4(context_layer4) if emb4 is not None else None
204
+ O1 = self.proj_dropout(O1) if emb1 is not None else None
205
+ O2 = self.proj_dropout(O2) if emb2 is not None else None
206
+ O3 = self.proj_dropout(O3) if emb3 is not None else None
207
+ O4 = self.proj_dropout(O4) if emb4 is not None else None
208
+ return O1,O2,O3,O4, weights
209
+
210
+
211
+
212
+
213
+ class Mlp(nn.Module):
214
+ def __init__(self,config, in_channel, mlp_channel):
215
+ super(Mlp, self).__init__()
216
+ self.fc1 = nn.Linear(in_channel, mlp_channel)
217
+ self.fc2 = nn.Linear(mlp_channel, in_channel)
218
+ self.act_fn = nn.GELU()
219
+ self.dropout = Dropout(config.transformer["dropout_rate"])
220
+ self._init_weights()
221
+
222
+ def _init_weights(self):
223
+ nn.init.xavier_uniform_(self.fc1.weight)
224
+ nn.init.xavier_uniform_(self.fc2.weight)
225
+ nn.init.normal_(self.fc1.bias, std=1e-6)
226
+ nn.init.normal_(self.fc2.bias, std=1e-6)
227
+
228
+ def forward(self, x):
229
+ x = self.fc1(x)
230
+ x = self.act_fn(x)
231
+ x = self.dropout(x)
232
+ x = self.fc2(x)
233
+ x = self.dropout(x)
234
+ return x
235
+
236
+ class Block_ViT(nn.Module):
237
+ def __init__(self, config, vis, channel_num):
238
+ super(Block_ViT, self).__init__()
239
+ expand_ratio = config.expand_ratio
240
+ self.attn_norm1 = LayerNorm(channel_num[0],eps=1e-6)
241
+ self.attn_norm2 = LayerNorm(channel_num[1],eps=1e-6)
242
+ self.attn_norm3 = LayerNorm(channel_num[2],eps=1e-6)
243
+ self.attn_norm4 = LayerNorm(channel_num[3],eps=1e-6)
244
+ self.attn_norm = LayerNorm(config.KV_size,eps=1e-6)
245
+ self.channel_attn = Attention_org(config, vis, channel_num)
246
+
247
+ self.ffn_norm1 = LayerNorm(channel_num[0],eps=1e-6)
248
+ self.ffn_norm2 = LayerNorm(channel_num[1],eps=1e-6)
249
+ self.ffn_norm3 = LayerNorm(channel_num[2],eps=1e-6)
250
+ self.ffn_norm4 = LayerNorm(channel_num[3],eps=1e-6)
251
+ self.ffn1 = Mlp(config,channel_num[0],channel_num[0]*expand_ratio)
252
+ self.ffn2 = Mlp(config,channel_num[1],channel_num[1]*expand_ratio)
253
+ self.ffn3 = Mlp(config,channel_num[2],channel_num[2]*expand_ratio)
254
+ self.ffn4 = Mlp(config,channel_num[3],channel_num[3]*expand_ratio)
255
+
256
+
257
+ def forward(self, emb1,emb2,emb3,emb4):
258
+ embcat = []
259
+ org1 = emb1
260
+ org2 = emb2
261
+ org3 = emb3
262
+ org4 = emb4
263
+ for i in range(4):
264
+ var_name = "emb"+str(i+1)
265
+ tmp_var = locals()[var_name]
266
+ if tmp_var is not None:
267
+ embcat.append(tmp_var)
268
+
269
+ emb_all = torch.cat(embcat,dim=2)
270
+ cx1 = self.attn_norm1(emb1) if emb1 is not None else None
271
+ cx2 = self.attn_norm2(emb2) if emb2 is not None else None
272
+ cx3 = self.attn_norm3(emb3) if emb3 is not None else None
273
+ cx4 = self.attn_norm4(emb4) if emb4 is not None else None
274
+ emb_all = self.attn_norm(emb_all)
275
+ cx1,cx2,cx3,cx4, weights = self.channel_attn(cx1,cx2,cx3,cx4,emb_all)
276
+ cx1 = org1 + cx1 if emb1 is not None else None
277
+ cx2 = org2 + cx2 if emb2 is not None else None
278
+ cx3 = org3 + cx3 if emb3 is not None else None
279
+ cx4 = org4 + cx4 if emb4 is not None else None
280
+
281
+ org1 = cx1
282
+ org2 = cx2
283
+ org3 = cx3
284
+ org4 = cx4
285
+ x1 = self.ffn_norm1(cx1) if emb1 is not None else None
286
+ x2 = self.ffn_norm2(cx2) if emb2 is not None else None
287
+ x3 = self.ffn_norm3(cx3) if emb3 is not None else None
288
+ x4 = self.ffn_norm4(cx4) if emb4 is not None else None
289
+ x1 = self.ffn1(x1) if emb1 is not None else None
290
+ x2 = self.ffn2(x2) if emb2 is not None else None
291
+ x3 = self.ffn3(x3) if emb3 is not None else None
292
+ x4 = self.ffn4(x4) if emb4 is not None else None
293
+ x1 = x1 + org1 if emb1 is not None else None
294
+ x2 = x2 + org2 if emb2 is not None else None
295
+ x3 = x3 + org3 if emb3 is not None else None
296
+ x4 = x4 + org4 if emb4 is not None else None
297
+
298
+ return x1, x2, x3, x4, weights
299
+
300
+
301
+ class Encoder(nn.Module):
302
+ def __init__(self, config, vis, channel_num):
303
+ super(Encoder, self).__init__()
304
+ self.vis = vis
305
+ self.layer = nn.ModuleList()
306
+ self.encoder_norm1 = LayerNorm(channel_num[0],eps=1e-6)
307
+ self.encoder_norm2 = LayerNorm(channel_num[1],eps=1e-6)
308
+ self.encoder_norm3 = LayerNorm(channel_num[2],eps=1e-6)
309
+ self.encoder_norm4 = LayerNorm(channel_num[3],eps=1e-6)
310
+ for _ in range(config.transformer["num_layers"]):
311
+ layer = Block_ViT(config, vis, channel_num)
312
+ self.layer.append(copy.deepcopy(layer))
313
+
314
+ def forward(self, emb1,emb2,emb3,emb4):
315
+ attn_weights = []
316
+ for layer_block in self.layer:
317
+ emb1,emb2,emb3,emb4, weights = layer_block(emb1,emb2,emb3,emb4)
318
+ if self.vis:
319
+ attn_weights.append(weights)
320
+ emb1 = self.encoder_norm1(emb1) if emb1 is not None else None
321
+ emb2 = self.encoder_norm2(emb2) if emb2 is not None else None
322
+ emb3 = self.encoder_norm3(emb3) if emb3 is not None else None
323
+ emb4 = self.encoder_norm4(emb4) if emb4 is not None else None
324
+ return emb1,emb2,emb3,emb4, attn_weights
325
+
326
+
327
+ class ChannelTransformer(nn.Module):
328
+ def __init__(self, config, vis, img_size, channel_num=[64, 128, 256, 512], patchSize=[32, 16, 8, 4]):
329
+ super().__init__()
330
+
331
+ self.patchSize_1 = patchSize[0]
332
+ self.patchSize_2 = patchSize[1]
333
+ self.patchSize_3 = patchSize[2]
334
+ self.patchSize_4 = patchSize[3]
335
+ self.embeddings_1 = Channel_Embeddings(config,self.patchSize_1, img_size=img_size, in_channels=channel_num[0])
336
+ self.embeddings_2 = Channel_Embeddings(config,self.patchSize_2, img_size=img_size//2, in_channels=channel_num[1])
337
+ self.embeddings_3 = Channel_Embeddings(config,self.patchSize_3, img_size=img_size//4, in_channels=channel_num[2])
338
+ self.embeddings_4 = Channel_Embeddings(config,self.patchSize_4, img_size=img_size//8, in_channels=channel_num[3])
339
+ self.encoder = Encoder(config, vis, channel_num)
340
+
341
+ self.reconstruct_1 = Reconstruct(channel_num[0], channel_num[0], kernel_size=1,scale_factor=(self.patchSize_1,self.patchSize_1))
342
+ self.reconstruct_2 = Reconstruct(channel_num[1], channel_num[1], kernel_size=1,scale_factor=(self.patchSize_2,self.patchSize_2))
343
+ self.reconstruct_3 = Reconstruct(channel_num[2], channel_num[2], kernel_size=1,scale_factor=(self.patchSize_3,self.patchSize_3))
344
+ self.reconstruct_4 = Reconstruct(channel_num[3], channel_num[3], kernel_size=1,scale_factor=(self.patchSize_4,self.patchSize_4))
345
+
346
+ def forward(self,en1,en2,en3,en4):
347
+
348
+ emb1 = self.embeddings_1(en1)
349
+ emb2 = self.embeddings_2(en2)
350
+ emb3 = self.embeddings_3(en3)
351
+ emb4 = self.embeddings_4(en4)
352
+
353
+ encoded1, encoded2, encoded3, encoded4, attn_weights = self.encoder(emb1,emb2,emb3,emb4) # (B, n_patch, hidden)
354
+ x1 = self.reconstruct_1(encoded1) if en1 is not None else None
355
+ x2 = self.reconstruct_2(encoded2) if en2 is not None else None
356
+ x3 = self.reconstruct_3(encoded3) if en3 is not None else None
357
+ x4 = self.reconstruct_4(encoded4) if en4 is not None else None
358
+
359
+ x1 = x1 + en1 if en1 is not None else None
360
+ x2 = x2 + en2 if en2 is not None else None
361
+ x3 = x3 + en3 if en3 is not None else None
362
+ x4 = x4 + en4 if en4 is not None else None
363
+
364
+ return x1, x2, x3, x4, attn_weights
365
+
models/_uctransnet/Config.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2021/6/19 2:44 下午
3
+ # @Author : Haonan Wang
4
+ # @File : Config.py
5
+ # @Software: PyCharm
6
+ import os
7
+ import torch
8
+ import time
9
+ import ml_collections
10
+
11
+ ## PARAMETERS OF THE MODEL
12
+ save_model = True
13
+ tensorboard = True
14
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
15
+ use_cuda = torch.cuda.is_available()
16
+ seed = 666
17
+ os.environ['PYTHONHASHSEED'] = str(seed)
18
+
19
+ cosineLR = True # whether use cosineLR or not
20
+ n_channels = 3
21
+ n_labels = 1
22
+ epochs = 2000
23
+ img_size = 224
24
+ print_frequency = 1
25
+ save_frequency = 5000
26
+ vis_frequency = 10
27
+ early_stopping_patience = 50
28
+
29
+ pretrain = False
30
+ task_name = 'MoNuSeg' # GlaS MoNuSeg
31
+ # task_name = 'GlaS'
32
+ learning_rate = 1e-3
33
+ batch_size = 4
34
+
35
+
36
+ # model_name = 'UCTransNet'
37
+ model_name = 'UCTransNet_pretrain'
38
+
39
+ train_dataset = './datasets/'+ task_name+ '/Train_Folder/'
40
+ val_dataset = './datasets/'+ task_name+ '/Val_Folder/'
41
+ test_dataset = './datasets/'+ task_name+ '/Test_Folder/'
42
+ session_name = 'Test_session' + '_' + time.strftime('%m.%d_%Hh%M')
43
+ save_path = task_name +'/'+ model_name +'/' + session_name + '/'
44
+ model_path = save_path + 'models/'
45
+ tensorboard_folder = save_path + 'tensorboard_logs/'
46
+ logger_path = save_path + session_name + ".log"
47
+ visualize_path = save_path + 'visualize_val/'
48
+
49
+
50
+ ##########################################################################
51
+ # CTrans configs
52
+ ##########################################################################
53
+ def get_CTranS_config():
54
+ config = ml_collections.ConfigDict()
55
+ config.transformer = ml_collections.ConfigDict()
56
+ config.KV_size = 960 # KV_size = Q1 + Q2 + Q3 + Q4
57
+ config.transformer.num_heads = 4
58
+ config.transformer.num_layers = 4
59
+ config.expand_ratio = 4 # MLP channel dimension expand ratio
60
+ config.transformer.embeddings_dropout_rate = 0.1
61
+ config.transformer.attention_dropout_rate = 0.1
62
+ config.transformer.dropout_rate = 0
63
+ config.patch_sizes = [16,8,4,2]
64
+ config.base_channel = 64 # base channel of U-Net
65
+ config.n_classes = 1
66
+ return config
67
+
68
+
69
+
70
+
71
+ # used in testing phase, copy the session name in training phase
72
+ test_session = "Test_session_07.03_20h39"
models/_uctransnet/UCTransNet.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2021/7/8 8:59 上午
3
+ # @File : UCTransNet.py
4
+ # @Software: PyCharm
5
+ import torch.nn as nn
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from .CTrans import ChannelTransformer
9
+
10
+ def get_activation(activation_type):
11
+ activation_type = activation_type.lower()
12
+ if hasattr(nn, activation_type):
13
+ return getattr(nn, activation_type)()
14
+ else:
15
+ return nn.ReLU()
16
+
17
+ def _make_nConv(in_channels, out_channels, nb_Conv, activation='ReLU'):
18
+ layers = []
19
+ layers.append(ConvBatchNorm(in_channels, out_channels, activation))
20
+
21
+ for _ in range(nb_Conv - 1):
22
+ layers.append(ConvBatchNorm(out_channels, out_channels, activation))
23
+ return nn.Sequential(*layers)
24
+
25
+ class ConvBatchNorm(nn.Module):
26
+ """(convolution => [BN] => ReLU)"""
27
+
28
+ def __init__(self, in_channels, out_channels, activation='ReLU'):
29
+ super(ConvBatchNorm, self).__init__()
30
+ self.conv = nn.Conv2d(in_channels, out_channels,
31
+ kernel_size=3, padding=1)
32
+ self.norm = nn.BatchNorm2d(out_channels)
33
+ self.activation = get_activation(activation)
34
+
35
+ def forward(self, x):
36
+ out = self.conv(x)
37
+ out = self.norm(out)
38
+ return self.activation(out)
39
+
40
+ class DownBlock(nn.Module):
41
+ """Downscaling with maxpool convolution"""
42
+ def __init__(self, in_channels, out_channels, nb_Conv, activation='ReLU'):
43
+ super(DownBlock, self).__init__()
44
+ self.maxpool = nn.MaxPool2d(2)
45
+ self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv, activation)
46
+
47
+ def forward(self, x):
48
+ out = self.maxpool(x)
49
+ return self.nConvs(out)
50
+
51
+ class Flatten(nn.Module):
52
+ def forward(self, x):
53
+ return x.view(x.size(0), -1)
54
+
55
+ class CCA(nn.Module):
56
+ """
57
+ CCA Block
58
+ """
59
+ def __init__(self, F_g, F_x):
60
+ super().__init__()
61
+ self.mlp_x = nn.Sequential(
62
+ Flatten(),
63
+ nn.Linear(F_x, F_x))
64
+ self.mlp_g = nn.Sequential(
65
+ Flatten(),
66
+ nn.Linear(F_g, F_x))
67
+ self.relu = nn.ReLU(inplace=True)
68
+
69
+ def forward(self, g, x):
70
+ # channel-wise attention
71
+ avg_pool_x = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
72
+ channel_att_x = self.mlp_x(avg_pool_x)
73
+ avg_pool_g = F.avg_pool2d( g, (g.size(2), g.size(3)), stride=(g.size(2), g.size(3)))
74
+ channel_att_g = self.mlp_g(avg_pool_g)
75
+ channel_att_sum = (channel_att_x + channel_att_g)/2.0
76
+ scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
77
+ x_after_channel = x * scale
78
+ out = self.relu(x_after_channel)
79
+ return out
80
+
81
+ class UpBlock_attention(nn.Module):
82
+ def __init__(self, in_channels, out_channels, nb_Conv, activation='ReLU'):
83
+ super().__init__()
84
+ self.up = nn.Upsample(scale_factor=2)
85
+ self.coatt = CCA(F_g=in_channels//2, F_x=in_channels//2)
86
+ self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv, activation)
87
+
88
+ def forward(self, x, skip_x):
89
+ up = self.up(x)
90
+ skip_x_att = self.coatt(g=up, x=skip_x)
91
+ x = torch.cat([skip_x_att, up], dim=1) # dim 1 is the channel dimension
92
+ return self.nConvs(x)
93
+
94
+ class UCTransNet(nn.Module):
95
+ def __init__(self, config,n_channels=3, n_classes=1,img_size=224,vis=False):
96
+ super().__init__()
97
+ self.vis = vis
98
+ self.n_channels = n_channels
99
+ self.n_classes = n_classes
100
+ in_channels = config.base_channel
101
+ self.inc = ConvBatchNorm(n_channels, in_channels)
102
+ self.down1 = DownBlock(in_channels, in_channels*2, nb_Conv=2)
103
+ self.down2 = DownBlock(in_channels*2, in_channels*4, nb_Conv=2)
104
+ self.down3 = DownBlock(in_channels*4, in_channels*8, nb_Conv=2)
105
+ self.down4 = DownBlock(in_channels*8, in_channels*8, nb_Conv=2)
106
+ self.mtc = ChannelTransformer(config, vis, img_size,
107
+ channel_num=[in_channels, in_channels*2, in_channels*4, in_channels*8],
108
+ patchSize=config.patch_sizes)
109
+ self.up4 = UpBlock_attention(in_channels*16, in_channels*4, nb_Conv=2)
110
+ self.up3 = UpBlock_attention(in_channels*8, in_channels*2, nb_Conv=2)
111
+ self.up2 = UpBlock_attention(in_channels*4, in_channels, nb_Conv=2)
112
+ self.up1 = UpBlock_attention(in_channels*2, in_channels, nb_Conv=2)
113
+ self.outc = nn.Conv2d(in_channels, n_classes, kernel_size=(1,1), stride=(1,1))
114
+ self.last_activation = nn.Sigmoid() # if using BCELoss
115
+
116
+ def forward(self, x):
117
+ x = x.float()
118
+ x1 = self.inc(x)
119
+ x2 = self.down1(x1)
120
+ x3 = self.down2(x2)
121
+ x4 = self.down3(x3)
122
+ x5 = self.down4(x4)
123
+ x1,x2,x3,x4,att_weights = self.mtc(x1,x2,x3,x4)
124
+ x = self.up4(x5, x4)
125
+ x = self.up3(x, x3)
126
+ x = self.up2(x, x2)
127
+ x = self.up1(x, x1)
128
+ if self.n_classes ==1:
129
+ logits = self.last_activation(self.outc(x))
130
+ else:
131
+ logits = self.outc(x) # if nusing BCEWithLogitsLoss or class>1
132
+ if self.vis: # visualize the attention maps
133
+ return logits, att_weights
134
+ else:
135
+ return logits
136
+
137
+
138
+
139
+
models/_uctransnet/UNet.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+ def get_activation(activation_type):
5
+ activation_type = activation_type.lower()
6
+ if hasattr(nn, activation_type):
7
+ return getattr(nn, activation_type)()
8
+ else:
9
+ return nn.ReLU()
10
+
11
+ def _make_nConv(in_channels, out_channels, nb_Conv, activation='ReLU'):
12
+ layers = []
13
+ layers.append(ConvBatchNorm(in_channels, out_channels, activation))
14
+
15
+ for _ in range(nb_Conv - 1):
16
+ layers.append(ConvBatchNorm(out_channels, out_channels, activation))
17
+ return nn.Sequential(*layers)
18
+
19
+ class ConvBatchNorm(nn.Module):
20
+ """(convolution => [BN] => ReLU)"""
21
+
22
+ def __init__(self, in_channels, out_channels, activation='ReLU'):
23
+ super(ConvBatchNorm, self).__init__()
24
+ self.conv = nn.Conv2d(in_channels, out_channels,
25
+ kernel_size=3, padding=1)
26
+ self.norm = nn.BatchNorm2d(out_channels)
27
+ self.activation = get_activation(activation)
28
+
29
+ def forward(self, x):
30
+ out = self.conv(x)
31
+ out = self.norm(out)
32
+ return self.activation(out)
33
+
34
+ class DownBlock(nn.Module):
35
+ """Downscaling with maxpool convolution"""
36
+
37
+ def __init__(self, in_channels, out_channels, nb_Conv, activation='ReLU'):
38
+ super(DownBlock, self).__init__()
39
+ self.maxpool = nn.MaxPool2d(2)
40
+ self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv, activation)
41
+
42
+ def forward(self, x):
43
+ out = self.maxpool(x)
44
+ return self.nConvs(out)
45
+
46
+ class UpBlock(nn.Module):
47
+ """Upscaling then conv"""
48
+
49
+ def __init__(self, in_channels, out_channels, nb_Conv, activation='ReLU'):
50
+ super(UpBlock, self).__init__()
51
+
52
+ # self.up = nn.Upsample(scale_factor=2)
53
+ self.up = nn.ConvTranspose2d(in_channels//2,in_channels//2,(2,2),2)
54
+ self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv, activation)
55
+
56
+ def forward(self, x, skip_x):
57
+ out = self.up(x)
58
+ x = torch.cat([out, skip_x], dim=1) # dim 1 is the channel dimension
59
+ return self.nConvs(x)
60
+
61
+ class UNet(nn.Module):
62
+ def __init__(self, n_channels=3, n_classes=9):
63
+ '''
64
+ n_channels : number of channels of the input.
65
+ By default 3, because we have RGB images
66
+ n_labels : number of channels of the ouput.
67
+ By default 3 (2 labels + 1 for the background)
68
+ '''
69
+ super().__init__()
70
+ self.n_channels = n_channels
71
+ self.n_classes = n_classes
72
+ # Question here
73
+ in_channels = 64
74
+ self.inc = ConvBatchNorm(n_channels, in_channels)
75
+ self.down1 = DownBlock(in_channels, in_channels*2, nb_Conv=2)
76
+ self.down2 = DownBlock(in_channels*2, in_channels*4, nb_Conv=2)
77
+ self.down3 = DownBlock(in_channels*4, in_channels*8, nb_Conv=2)
78
+ self.down4 = DownBlock(in_channels*8, in_channels*8, nb_Conv=2)
79
+ self.up4 = UpBlock(in_channels*16, in_channels*4, nb_Conv=2)
80
+ self.up3 = UpBlock(in_channels*8, in_channels*2, nb_Conv=2)
81
+ self.up2 = UpBlock(in_channels*4, in_channels, nb_Conv=2)
82
+ self.up1 = UpBlock(in_channels*2, in_channels, nb_Conv=2)
83
+ self.outc = nn.Conv2d(in_channels, n_classes, kernel_size=(1,1))
84
+ if n_classes == 1:
85
+ self.last_activation = nn.Sigmoid()
86
+ else:
87
+ self.last_activation = None
88
+
89
+ def forward(self, x):
90
+ # Question here
91
+ x = x.float()
92
+ x1 = self.inc(x)
93
+ x2 = self.down1(x1)
94
+ x3 = self.down2(x2)
95
+ x4 = self.down3(x3)
96
+ x5 = self.down4(x4)
97
+ x = self.up4(x5, x4)
98
+ x = self.up3(x, x3)
99
+ x = self.up2(x, x2)
100
+ x = self.up1(x, x1)
101
+ if self.last_activation is not None:
102
+ logits = self.last_activation(self.outc(x))
103
+ # print("111")
104
+ else:
105
+ logits = self.outc(x)
106
+ # print("222")
107
+ # logits = self.outc(x) # if using BCEWithLogitsLoss
108
+ # print(logits.size())
109
+ return logits
110
+
111
+
models/attunet.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/LeeJunHyun/Image_Segmentation
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.nn import init
7
+
8
+ def init_weights(net, init_type='normal', gain=0.02):
9
+ def init_func(m):
10
+ classname = m.__class__.__name__
11
+ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
12
+ if init_type == 'normal':
13
+ init.normal_(m.weight.data, 0.0, gain)
14
+ elif init_type == 'xavier':
15
+ init.xavier_normal_(m.weight.data, gain=gain)
16
+ elif init_type == 'kaiming':
17
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
18
+ elif init_type == 'orthogonal':
19
+ init.orthogonal_(m.weight.data, gain=gain)
20
+ else:
21
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
22
+ if hasattr(m, 'bias') and m.bias is not None:
23
+ init.constant_(m.bias.data, 0.0)
24
+ elif classname.find('BatchNorm2d') != -1:
25
+ init.normal_(m.weight.data, 1.0, gain)
26
+ init.constant_(m.bias.data, 0.0)
27
+
28
+ print('initialize network with %s' % init_type)
29
+ net.apply(init_func)
30
+
31
+ class conv_block(nn.Module):
32
+ def __init__(self,ch_in,ch_out):
33
+ super(conv_block,self).__init__()
34
+ self.conv = nn.Sequential(
35
+ nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
36
+ nn.BatchNorm2d(ch_out),
37
+ nn.ReLU(inplace=True),
38
+ nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
39
+ nn.BatchNorm2d(ch_out),
40
+ nn.ReLU(inplace=True)
41
+ )
42
+
43
+
44
+ def forward(self,x):
45
+ x = self.conv(x)
46
+ return x
47
+
48
+ class up_conv(nn.Module):
49
+ def __init__(self,ch_in,ch_out):
50
+ super(up_conv,self).__init__()
51
+ self.up = nn.Sequential(
52
+ nn.Upsample(scale_factor=2),
53
+ nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
54
+ nn.BatchNorm2d(ch_out),
55
+ nn.ReLU(inplace=True)
56
+ )
57
+
58
+ def forward(self,x):
59
+ x = self.up(x)
60
+ return x
61
+
62
+ class Recurrent_block(nn.Module):
63
+ def __init__(self,ch_out,t=2):
64
+ super(Recurrent_block,self).__init__()
65
+ self.t = t
66
+ self.ch_out = ch_out
67
+ self.conv = nn.Sequential(
68
+ nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
69
+ nn.BatchNorm2d(ch_out),
70
+ nn.ReLU(inplace=True)
71
+ )
72
+
73
+ def forward(self,x):
74
+ for i in range(self.t):
75
+
76
+ if i==0:
77
+ x1 = self.conv(x)
78
+
79
+ x1 = self.conv(x+x1)
80
+ return x1
81
+
82
+ class RRCNN_block(nn.Module):
83
+ def __init__(self,ch_in,ch_out,t=2):
84
+ super(RRCNN_block,self).__init__()
85
+ self.RCNN = nn.Sequential(
86
+ Recurrent_block(ch_out,t=t),
87
+ Recurrent_block(ch_out,t=t)
88
+ )
89
+ self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding=0)
90
+
91
+ def forward(self,x):
92
+ x = self.Conv_1x1(x)
93
+ x1 = self.RCNN(x)
94
+ return x+x1
95
+
96
+
97
+ class single_conv(nn.Module):
98
+ def __init__(self,ch_in,ch_out):
99
+ super(single_conv,self).__init__()
100
+ self.conv = nn.Sequential(
101
+ nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
102
+ nn.BatchNorm2d(ch_out),
103
+ nn.ReLU(inplace=True)
104
+ )
105
+
106
+ def forward(self,x):
107
+ x = self.conv(x)
108
+ return x
109
+
110
+ class Attention_block(nn.Module):
111
+ def __init__(self,F_g,F_l,F_int):
112
+ super(Attention_block,self).__init__()
113
+ self.W_g = nn.Sequential(
114
+ nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
115
+ nn.BatchNorm2d(F_int)
116
+ )
117
+
118
+ self.W_x = nn.Sequential(
119
+ nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
120
+ nn.BatchNorm2d(F_int)
121
+ )
122
+
123
+ self.psi = nn.Sequential(
124
+ nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
125
+ nn.BatchNorm2d(1),
126
+ nn.Sigmoid()
127
+ )
128
+
129
+ self.relu = nn.ReLU(inplace=True)
130
+
131
+ def forward(self,g,x):
132
+ g1 = self.W_g(g)
133
+ x1 = self.W_x(x)
134
+ psi = self.relu(g1+x1)
135
+ psi = self.psi(psi)
136
+
137
+ return x*psi
138
+
139
+
140
+ class U_Net(nn.Module):
141
+ def __init__(self,img_ch=3,output_ch=1):
142
+ super(U_Net,self).__init__()
143
+
144
+ self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
145
+
146
+ self.Conv1 = conv_block(ch_in=img_ch,ch_out=64)
147
+ self.Conv2 = conv_block(ch_in=64,ch_out=128)
148
+ self.Conv3 = conv_block(ch_in=128,ch_out=256)
149
+ self.Conv4 = conv_block(ch_in=256,ch_out=512)
150
+ self.Conv5 = conv_block(ch_in=512,ch_out=1024)
151
+
152
+ self.Up5 = up_conv(ch_in=1024,ch_out=512)
153
+ self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)
154
+
155
+ self.Up4 = up_conv(ch_in=512,ch_out=256)
156
+ self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
157
+
158
+ self.Up3 = up_conv(ch_in=256,ch_out=128)
159
+ self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
160
+
161
+ self.Up2 = up_conv(ch_in=128,ch_out=64)
162
+ self.Up_conv2 = conv_block(ch_in=128, ch_out=64)
163
+
164
+ self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)
165
+
166
+
167
+ def forward(self,x):
168
+ # encoding path
169
+ x1 = self.Conv1(x)
170
+
171
+ x2 = self.Maxpool(x1)
172
+ x2 = self.Conv2(x2)
173
+
174
+ x3 = self.Maxpool(x2)
175
+ x3 = self.Conv3(x3)
176
+
177
+ x4 = self.Maxpool(x3)
178
+ x4 = self.Conv4(x4)
179
+
180
+ x5 = self.Maxpool(x4)
181
+ x5 = self.Conv5(x5)
182
+
183
+ # decoding + concat path
184
+ d5 = self.Up5(x5)
185
+ d5 = torch.cat((x4,d5),dim=1)
186
+
187
+ d5 = self.Up_conv5(d5)
188
+
189
+ d4 = self.Up4(d5)
190
+ d4 = torch.cat((x3,d4),dim=1)
191
+ d4 = self.Up_conv4(d4)
192
+
193
+ d3 = self.Up3(d4)
194
+ d3 = torch.cat((x2,d3),dim=1)
195
+ d3 = self.Up_conv3(d3)
196
+
197
+ d2 = self.Up2(d3)
198
+ d2 = torch.cat((x1,d2),dim=1)
199
+ d2 = self.Up_conv2(d2)
200
+
201
+ d1 = self.Conv_1x1(d2)
202
+
203
+ return d1
204
+
205
+
206
+ class R2U_Net(nn.Module):
207
+ def __init__(self,img_ch=3,output_ch=1,t=2):
208
+ super(R2U_Net,self).__init__()
209
+
210
+ self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
211
+ self.Upsample = nn.Upsample(scale_factor=2)
212
+
213
+ self.RRCNN1 = RRCNN_block(ch_in=img_ch,ch_out=64,t=t)
214
+
215
+ self.RRCNN2 = RRCNN_block(ch_in=64,ch_out=128,t=t)
216
+
217
+ self.RRCNN3 = RRCNN_block(ch_in=128,ch_out=256,t=t)
218
+
219
+ self.RRCNN4 = RRCNN_block(ch_in=256,ch_out=512,t=t)
220
+
221
+ self.RRCNN5 = RRCNN_block(ch_in=512,ch_out=1024,t=t)
222
+
223
+
224
+ self.Up5 = up_conv(ch_in=1024,ch_out=512)
225
+ self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512,t=t)
226
+
227
+ self.Up4 = up_conv(ch_in=512,ch_out=256)
228
+ self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256,t=t)
229
+
230
+ self.Up3 = up_conv(ch_in=256,ch_out=128)
231
+ self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128,t=t)
232
+
233
+ self.Up2 = up_conv(ch_in=128,ch_out=64)
234
+ self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64,t=t)
235
+
236
+ self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)
237
+
238
+
239
+ def forward(self,x):
240
+ # encoding path
241
+ x1 = self.RRCNN1(x)
242
+
243
+ x2 = self.Maxpool(x1)
244
+ x2 = self.RRCNN2(x2)
245
+
246
+ x3 = self.Maxpool(x2)
247
+ x3 = self.RRCNN3(x3)
248
+
249
+ x4 = self.Maxpool(x3)
250
+ x4 = self.RRCNN4(x4)
251
+
252
+ x5 = self.Maxpool(x4)
253
+ x5 = self.RRCNN5(x5)
254
+
255
+ # decoding + concat path
256
+ d5 = self.Up5(x5)
257
+ d5 = torch.cat((x4,d5),dim=1)
258
+ d5 = self.Up_RRCNN5(d5)
259
+
260
+ d4 = self.Up4(d5)
261
+ d4 = torch.cat((x3,d4),dim=1)
262
+ d4 = self.Up_RRCNN4(d4)
263
+
264
+ d3 = self.Up3(d4)
265
+ d3 = torch.cat((x2,d3),dim=1)
266
+ d3 = self.Up_RRCNN3(d3)
267
+
268
+ d2 = self.Up2(d3)
269
+ d2 = torch.cat((x1,d2),dim=1)
270
+ d2 = self.Up_RRCNN2(d2)
271
+
272
+ d1 = self.Conv_1x1(d2)
273
+
274
+ return d1
275
+
276
+
277
+
278
+ class AttU_Net(nn.Module):
279
+ def __init__(self,img_ch=3,output_ch=1):
280
+ super(AttU_Net,self).__init__()
281
+
282
+ self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
283
+
284
+ self.Conv1 = conv_block(ch_in=img_ch,ch_out=64)
285
+ self.Conv2 = conv_block(ch_in=64,ch_out=128)
286
+ self.Conv3 = conv_block(ch_in=128,ch_out=256)
287
+ self.Conv4 = conv_block(ch_in=256,ch_out=512)
288
+ self.Conv5 = conv_block(ch_in=512,ch_out=1024)
289
+
290
+ self.Up5 = up_conv(ch_in=1024,ch_out=512)
291
+ self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256)
292
+ self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)
293
+
294
+ self.Up4 = up_conv(ch_in=512,ch_out=256)
295
+ self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128)
296
+ self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
297
+
298
+ self.Up3 = up_conv(ch_in=256,ch_out=128)
299
+ self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64)
300
+ self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
301
+
302
+ self.Up2 = up_conv(ch_in=128,ch_out=64)
303
+ self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32)
304
+ self.Up_conv2 = conv_block(ch_in=128, ch_out=64)
305
+
306
+ self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)
307
+
308
+
309
+ def forward(self,x):
310
+ # encoding path
311
+ x1 = self.Conv1(x)
312
+
313
+ x2 = self.Maxpool(x1)
314
+ x2 = self.Conv2(x2)
315
+
316
+ x3 = self.Maxpool(x2)
317
+ x3 = self.Conv3(x3)
318
+
319
+ x4 = self.Maxpool(x3)
320
+ x4 = self.Conv4(x4)
321
+
322
+ x5 = self.Maxpool(x4)
323
+ x5 = self.Conv5(x5)
324
+
325
+ # decoding + concat path
326
+ d5 = self.Up5(x5)
327
+ x4 = self.Att5(g=d5,x=x4)
328
+ d5 = torch.cat((x4,d5),dim=1)
329
+ d5 = self.Up_conv5(d5)
330
+
331
+ d4 = self.Up4(d5)
332
+ x3 = self.Att4(g=d4,x=x3)
333
+ d4 = torch.cat((x3,d4),dim=1)
334
+ d4 = self.Up_conv4(d4)
335
+
336
+ d3 = self.Up3(d4)
337
+ x2 = self.Att3(g=d3,x=x2)
338
+ d3 = torch.cat((x2,d3),dim=1)
339
+ d3 = self.Up_conv3(d3)
340
+
341
+ d2 = self.Up2(d3)
342
+ x1 = self.Att2(g=d2,x=x1)
343
+ d2 = torch.cat((x1,d2),dim=1)
344
+ d2 = self.Up_conv2(d2)
345
+
346
+ d1 = self.Conv_1x1(d2)
347
+
348
+ return d1
349
+
350
+
351
+ class R2AttU_Net(nn.Module):
352
+ def __init__(self,img_ch=3,output_ch=1,t=2):
353
+ super(R2AttU_Net,self).__init__()
354
+
355
+ self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
356
+ self.Upsample = nn.Upsample(scale_factor=2)
357
+
358
+ self.RRCNN1 = RRCNN_block(ch_in=img_ch,ch_out=64,t=t)
359
+
360
+ self.RRCNN2 = RRCNN_block(ch_in=64,ch_out=128,t=t)
361
+
362
+ self.RRCNN3 = RRCNN_block(ch_in=128,ch_out=256,t=t)
363
+
364
+ self.RRCNN4 = RRCNN_block(ch_in=256,ch_out=512,t=t)
365
+
366
+ self.RRCNN5 = RRCNN_block(ch_in=512,ch_out=1024,t=t)
367
+
368
+
369
+ self.Up5 = up_conv(ch_in=1024,ch_out=512)
370
+ self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256)
371
+ self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512,t=t)
372
+
373
+ self.Up4 = up_conv(ch_in=512,ch_out=256)
374
+ self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128)
375
+ self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256,t=t)
376
+
377
+ self.Up3 = up_conv(ch_in=256,ch_out=128)
378
+ self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64)
379
+ self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128,t=t)
380
+
381
+ self.Up2 = up_conv(ch_in=128,ch_out=64)
382
+ self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32)
383
+ self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64,t=t)
384
+
385
+ self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)
386
+
387
+
388
+ def forward(self,x):
389
+ # encoding path
390
+ x1 = self.RRCNN1(x)
391
+
392
+ x2 = self.Maxpool(x1)
393
+ x2 = self.RRCNN2(x2)
394
+
395
+ x3 = self.Maxpool(x2)
396
+ x3 = self.RRCNN3(x3)
397
+
398
+ x4 = self.Maxpool(x3)
399
+ x4 = self.RRCNN4(x4)
400
+
401
+ x5 = self.Maxpool(x4)
402
+ x5 = self.RRCNN5(x5)
403
+
404
+ # decoding + concat path
405
+ d5 = self.Up5(x5)
406
+ x4 = self.Att5(g=d5,x=x4)
407
+ d5 = torch.cat((x4,d5),dim=1)
408
+ d5 = self.Up_RRCNN5(d5)
409
+
410
+ d4 = self.Up4(d5)
411
+ x3 = self.Att4(g=d4,x=x3)
412
+ d4 = torch.cat((x3,d4),dim=1)
413
+ d4 = self.Up_RRCNN4(d4)
414
+
415
+ d3 = self.Up3(d4)
416
+ x2 = self.Att3(g=d3,x=x2)
417
+ d3 = torch.cat((x2,d3),dim=1)
418
+ d3 = self.Up_RRCNN3(d3)
419
+
420
+ d2 = self.Up2(d3)
421
+ x1 = self.Att2(g=d2,x=x1)
422
+ d2 = torch.cat((x1,d2),dim=1)
423
+ d2 = self.Up_RRCNN2(d2)
424
+
425
+ d1 = self.Conv_1x1(d2)
426
+
427
+ return d1
models/multiresunet.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/j-sripad/mulitresunet-pytorch/blob/main/multiresunet.py
2
+
3
+ from typing import Tuple, Dict
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torch
7
+
8
+
9
+ class Multiresblock(nn.Module):
10
+ def __init__(self,input_features : int, corresponding_unet_filters : int ,alpha : float =1.67)->None:
11
+ """
12
+ MultiResblock
13
+ Arguments:
14
+ x - input layer
15
+ corresponding_unet_filters - Unet filters for the same stage
16
+ alpha - 1.67 - factor used in the paper to dervie number of filters for multiresunet filters from Unet filters
17
+ Returns - None
18
+ """
19
+ super().__init__()
20
+ self.corresponding_unet_filters = corresponding_unet_filters
21
+ self.alpha = alpha
22
+ self.W = corresponding_unet_filters * alpha
23
+ self.conv2d_bn_1x1 = Conv2d_batchnorm(input_features=input_features,num_of_filters = int(self.W*0.167)+int(self.W*0.333)+int(self.W*0.5),
24
+ kernel_size = (1,1),activation='None',padding = 0)
25
+
26
+ self.conv2d_bn_3x3 = Conv2d_batchnorm(input_features=input_features,num_of_filters = int(self.W*0.167),
27
+ kernel_size = (3,3),activation='relu',padding = 1)
28
+ self.conv2d_bn_5x5 = Conv2d_batchnorm(input_features=int(self.W*0.167),num_of_filters = int(self.W*0.333),
29
+ kernel_size = (3,3),activation='relu',padding = 1)
30
+ self.conv2d_bn_7x7 = Conv2d_batchnorm(input_features=int(self.W*0.333),num_of_filters = int(self.W*0.5),
31
+ kernel_size = (3,3),activation='relu',padding = 1)
32
+ self.batch_norm1 = nn.BatchNorm2d(int(self.W*0.5)+int(self.W*0.167)+int(self.W*0.333) ,affine=False)
33
+
34
+ def forward(self,x: torch.Tensor)->torch.Tensor:
35
+
36
+ temp = self.conv2d_bn_1x1(x)
37
+ a = self.conv2d_bn_3x3(x)
38
+ b = self.conv2d_bn_5x5(a)
39
+ c = self.conv2d_bn_7x7(b)
40
+ x = torch.cat([a,b,c],axis=1)
41
+ x = self.batch_norm1(x)
42
+ x = x + temp
43
+ x = self.batch_norm1(x)
44
+ return x
45
+
46
+ class Conv2d_batchnorm(nn.Module):
47
+ def __init__(self,input_features : int,num_of_filters : int ,kernel_size : Tuple = (2,2),stride : Tuple = (1,1), activation : str = 'relu',padding : int= 0)->None:
48
+ """
49
+ Arguments:
50
+ x - input layer
51
+ num_of_filters - no. of filter outputs
52
+ filters - shape of the filters to be used
53
+ stride - stride dimension
54
+ activation -activation function to be used
55
+ Returns - None
56
+ """
57
+ super().__init__()
58
+ self.activation = activation
59
+ self.conv1 = nn.Conv2d(in_channels=input_features,out_channels=num_of_filters,kernel_size=kernel_size,stride=stride,padding = padding)
60
+ self.batchnorm = nn.BatchNorm2d(num_of_filters,affine=False)
61
+
62
+ def forward(self,x : torch.Tensor)->torch.Tensor:
63
+ x = self.conv1(x)
64
+ x = self.batchnorm(x)
65
+ if self.activation == 'relu':
66
+ return F.relu(x)
67
+ else:
68
+ return x
69
+
70
+
71
+ class Respath(nn.Module):
72
+ def __init__(self,input_features : int,filters : int,respath_length : int)->None:
73
+ """
74
+ Arguments:
75
+ input_features - input layer filters
76
+ filters - output channels
77
+ respath_length - length of the Respath
78
+
79
+ Returns - None
80
+ """
81
+ super().__init__()
82
+ self.filters = filters
83
+ self.respath_length = respath_length
84
+ self.conv2d_bn_1x1 = Conv2d_batchnorm(input_features=input_features,num_of_filters = self.filters,
85
+ kernel_size = (1,1),activation='None',padding = 0)
86
+ self.conv2d_bn_3x3 = Conv2d_batchnorm(input_features=input_features,num_of_filters = self.filters,
87
+ kernel_size = (3,3),activation='relu',padding = 1)
88
+ self.conv2d_bn_1x1_common = Conv2d_batchnorm(input_features=self.filters,num_of_filters = self.filters,
89
+ kernel_size = (1,1),activation='None',padding = 0)
90
+ self.conv2d_bn_3x3_common = Conv2d_batchnorm(input_features=self.filters,num_of_filters = self.filters,
91
+ kernel_size = (3,3),activation='relu',padding = 1)
92
+ self.batch_norm1 = nn.BatchNorm2d(filters,affine=False)
93
+
94
+ def forward(self,x : torch.Tensor)->torch.Tensor:
95
+ shortcut = self.conv2d_bn_1x1(x)
96
+ x = self.conv2d_bn_3x3(x)
97
+ x = x + shortcut
98
+ x = F.relu(x)
99
+ x = self.batch_norm1(x)
100
+ if self.respath_length>1:
101
+ for i in range(self.respath_length):
102
+ shortcut = self.conv2d_bn_1x1_common(x)
103
+ x = self.conv2d_bn_3x3_common(x)
104
+ x = x + shortcut
105
+ x = F.relu(x)
106
+ x = self.batch_norm1(x)
107
+ return x
108
+ else:
109
+ return x
110
+
111
+ class MultiResUnet(nn.Module):
112
+ def __init__(self,channels : int,filters : int =32,nclasses : int =1)->None:
113
+
114
+ """
115
+ Arguments:
116
+ channels - input image channels
117
+ filters - filters to begin with (Unet)
118
+ nclasses - number of classes
119
+ Returns - None
120
+ """
121
+ super().__init__()
122
+ self.alpha = 1.67
123
+ self.filters = filters
124
+ self.nclasses = nclasses
125
+ self.multiresblock1 = Multiresblock(input_features=channels,corresponding_unet_filters=self.filters)
126
+ self.pool1 = nn.MaxPool2d(2,stride= 2)
127
+ self.in_filters1 = int(self.filters*self.alpha* 0.5)+int(self.filters*self.alpha*0.167)+int(self.filters*self.alpha*0.333)
128
+ self.respath1 = Respath(input_features=self.in_filters1 ,filters=self.filters,respath_length=4)
129
+ self.multiresblock2 = Multiresblock(input_features= self.in_filters1,corresponding_unet_filters=self.filters*2)
130
+ self.pool2 = nn.MaxPool2d(2, 2)
131
+ self.in_filters2 = int(self.filters*2*self.alpha* 0.5)+int(self.filters*2*self.alpha*0.167)+int(self.filters*2*self.alpha*0.333)
132
+ self.respath2 = Respath(input_features=self.in_filters2,filters=self.filters*2,respath_length=3)
133
+ self.multiresblock3 = Multiresblock(input_features= self.in_filters2,corresponding_unet_filters=self.filters*4)
134
+ self.pool3 = nn.MaxPool2d(2, 2)
135
+ self.in_filters3 = int(self.filters*4*self.alpha* 0.5)+int(self.filters*4*self.alpha*0.167)+int(self.filters*4*self.alpha*0.333)
136
+ self.respath3 = Respath(input_features=self.in_filters3,filters=self.filters*4,respath_length=2)
137
+ self.multiresblock4 = Multiresblock(input_features= self.in_filters3,corresponding_unet_filters=self.filters*8)
138
+ self.pool4 = nn.MaxPool2d(2, 2)
139
+ self.in_filters4 = int(self.filters*8*self.alpha* 0.5)+int(self.filters*8*self.alpha*0.167)+int(self.filters*8*self.alpha*0.333)
140
+ self.respath4 = Respath(input_features=self.in_filters4,filters=self.filters*8,respath_length=1)
141
+ self.multiresblock5 = Multiresblock(input_features= self.in_filters4,corresponding_unet_filters=self.filters*16)
142
+ self.in_filters5 = int(self.filters*16*self.alpha* 0.5)+int(self.filters*16*self.alpha*0.167)+int(self.filters*16*self.alpha*0.333)
143
+
144
+ #Decoder path
145
+ self.upsample6 = nn.ConvTranspose2d(in_channels=self.in_filters5,out_channels=self.filters*8,kernel_size=(2,2),stride=(2,2),padding = 0)
146
+ self.concat_filters1 = self.filters*8+self.filters*8
147
+ self.multiresblock6 = Multiresblock(input_features=self.concat_filters1,corresponding_unet_filters=self.filters*8)
148
+ self.in_filters6 = int(self.filters*8*self.alpha* 0.5)+int(self.filters*8*self.alpha*0.167)+int(self.filters*8*self.alpha*0.333)
149
+ self.upsample7 = nn.ConvTranspose2d(in_channels=self.in_filters6,out_channels=self.filters*4,kernel_size=(2,2),stride=(2,2),padding = 0)
150
+ self.concat_filters2 = self.filters*4+self.filters*4
151
+ self.multiresblock7 = Multiresblock(input_features=self.concat_filters2,corresponding_unet_filters=self.filters*4)
152
+ self.in_filters7 = int(self.filters*4*self.alpha* 0.5)+int(self.filters*4*self.alpha*0.167)+int(self.filters*4*self.alpha*0.333)
153
+ self.upsample8 = nn.ConvTranspose2d(in_channels=self.in_filters7,out_channels=self.filters*2,kernel_size=(2,2),stride=(2,2),padding = 0)
154
+ self.concat_filters3 = self.filters*2+self.filters*2
155
+ self.multiresblock8 = Multiresblock(input_features=self.concat_filters3,corresponding_unet_filters=self.filters*2)
156
+ self.in_filters8 = int(self.filters*2*self.alpha* 0.5)+int(self.filters*2*self.alpha*0.167)+int(self.filters*2*self.alpha*0.333)
157
+ self.upsample9 = nn.ConvTranspose2d(in_channels=self.in_filters8,out_channels=self.filters,kernel_size=(2,2),stride=(2,2),padding = 0)
158
+ self.concat_filters4 = self.filters+self.filters
159
+ self.multiresblock9 = Multiresblock(input_features=self.concat_filters4,corresponding_unet_filters=self.filters)
160
+ self.in_filters9 = int(self.filters*self.alpha* 0.5)+int(self.filters*self.alpha*0.167)+int(self.filters*self.alpha*0.333)
161
+ self.conv_final = Conv2d_batchnorm(input_features=self.in_filters9,num_of_filters = self.nclasses,
162
+ kernel_size = (1,1),activation='None')
163
+
164
+ def forward(self,x : torch.Tensor)->torch.Tensor:
165
+ x_multires1 = self.multiresblock1(x)
166
+ x_pool1 = self.pool1(x_multires1)
167
+ x_multires1 = self.respath1(x_multires1)
168
+ x_multires2 = self.multiresblock2(x_pool1)
169
+ x_pool2 = self.pool2(x_multires2)
170
+ x_multires2 = self.respath2(x_multires2)
171
+ x_multires3 = self.multiresblock3(x_pool2)
172
+ x_pool3 = self.pool3(x_multires3)
173
+ x_multires3 = self.respath3(x_multires3)
174
+ x_multires4 = self.multiresblock4(x_pool3)
175
+ x_pool4 = self.pool4(x_multires4)
176
+ x_multires4 = self.respath4(x_multires4)
177
+ x_multires5 = self.multiresblock5(x_pool4)
178
+ up6 = torch.cat([self.upsample6(x_multires5),x_multires4],axis=1)
179
+ x_multires6 = self.multiresblock6(up6)
180
+ up7 = torch.cat([self.upsample7(x_multires6),x_multires3],axis=1)
181
+ x_multires7 = self.multiresblock7(up7)
182
+ up8 = torch.cat([self.upsample8(x_multires7),x_multires2],axis=1)
183
+ x_multires8 = self.multiresblock8(up8)
184
+ up9 = torch.cat([self.upsample9(x_multires8),x_multires1],axis=1)
185
+ x_multires9 = self.multiresblock9(up9)
186
+ if self.nclasses > 1:
187
+ conv_final_layer = self.conv_final(x_multires9)
188
+ else:
189
+ conv_final_layer = torch.sigmoid(self.conv_final(x_multires9))
190
+ return conv_final_layer
models/unet.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+
6
+ class DoubleConv(nn.Module):
7
+ def __init__(self, in_channels, out_channels, with_bn=False):
8
+ super().__init__()
9
+ if with_bn:
10
+ self.step = nn.Sequential(
11
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
12
+ nn.BatchNorm2d(out_channels),
13
+ nn.ReLU(),
14
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
15
+ nn.BatchNorm2d(out_channels),
16
+ nn.ReLU(),
17
+ )
18
+ else:
19
+ self.step = nn.Sequential(
20
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
21
+ nn.ReLU(),
22
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
23
+ nn.ReLU(),
24
+ )
25
+
26
+ def forward(self, x):
27
+ return self.step(x)
28
+
29
+
30
+ class UNet(nn.Module):
31
+ def __init__(self, in_channels, out_channels, with_bn=False):
32
+ super().__init__()
33
+ init_channels = 32
34
+ self.out_channels = out_channels
35
+
36
+ self.en_1 = DoubleConv(in_channels , init_channels , with_bn)
37
+ self.en_2 = DoubleConv(1*init_channels, 2*init_channels, with_bn)
38
+ self.en_3 = DoubleConv(2*init_channels, 4*init_channels, with_bn)
39
+ self.en_4 = DoubleConv(4*init_channels, 8*init_channels, with_bn)
40
+
41
+ self.de_1 = DoubleConv((4 + 8)*init_channels, 4*init_channels, with_bn)
42
+ self.de_2 = DoubleConv((2 + 4)*init_channels, 2*init_channels, with_bn)
43
+ self.de_3 = DoubleConv((1 + 2)*init_channels, 1*init_channels, with_bn)
44
+ self.de_4 = nn.Conv2d(init_channels, out_channels, 1)
45
+
46
+ self.maxpool = nn.MaxPool2d(kernel_size=2)
47
+ self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')
48
+
49
+ def forward(self, x):
50
+ e1 = self.en_1(x)
51
+ e2 = self.en_2(self.maxpool(e1))
52
+ e3 = self.en_3(self.maxpool(e2))
53
+ e4 = self.en_4(self.maxpool(e3))
54
+
55
+ d1 = self.de_1(torch.cat([self.upsample(e4), e3], dim=1))
56
+ d2 = self.de_2(torch.cat([self.upsample(d1), e2], dim=1))
57
+ d3 = self.de_3(torch.cat([self.upsample(d2), e1], dim=1))
58
+ d4 = self.de_4(d3)
59
+
60
+ return d4
61
+
62
+ # if self.out_channels<2:
63
+ # return torch.sigmoid(d4)
64
+ # return torch.softmax(d4, 1)
models/unetpp.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/4uiiurz1/pytorch-nested-unet/blob/master/archs.py (unetpp)
2
+
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn.functional import softmax, sigmoid
7
+
8
+
9
+ __all__ = ['UNet', 'NestedUNet']
10
+
11
+
12
+ class VGGBlock(nn.Module):
13
+ def __init__(self, in_channels, middle_channels, out_channels):
14
+ super().__init__()
15
+ self.relu = nn.ReLU(inplace=True)
16
+ self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
17
+ self.bn1 = nn.BatchNorm2d(middle_channels)
18
+ self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
19
+ self.bn2 = nn.BatchNorm2d(out_channels)
20
+
21
+ def forward(self, x):
22
+ out = self.conv1(x)
23
+ out = self.bn1(out)
24
+ out = self.relu(out)
25
+
26
+ out = self.conv2(out)
27
+ out = self.bn2(out)
28
+ out = self.relu(out)
29
+
30
+ return out
31
+
32
+
33
+ class UNet(nn.Module):
34
+ def __init__(self, num_classes, input_channels=3, **kwargs):
35
+ super().__init__()
36
+
37
+ nb_filter = [32, 64, 128, 256, 512]
38
+
39
+ self.pool = nn.MaxPool2d(2, 2)
40
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
41
+
42
+ self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
43
+ self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
44
+ self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
45
+ self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
46
+ self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])
47
+
48
+ self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])
49
+ self.conv2_2 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
50
+ self.conv1_3 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
51
+ self.conv0_4 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
52
+
53
+ self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
54
+
55
+
56
+ def forward(self, input):
57
+ x0_0 = self.conv0_0(input)
58
+ x1_0 = self.conv1_0(self.pool(x0_0))
59
+ x2_0 = self.conv2_0(self.pool(x1_0))
60
+ x3_0 = self.conv3_0(self.pool(x2_0))
61
+ x4_0 = self.conv4_0(self.pool(x3_0))
62
+
63
+ x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
64
+ x2_2 = self.conv2_2(torch.cat([x2_0, self.up(x3_1)], 1))
65
+ x1_3 = self.conv1_3(torch.cat([x1_0, self.up(x2_2)], 1))
66
+ x0_4 = self.conv0_4(torch.cat([x0_0, self.up(x1_3)], 1))
67
+
68
+ output = self.final(x0_4)
69
+ return output
70
+
71
+
72
+ class NestedUNet(nn.Module):
73
+ def __init__(self, num_classes, input_channels=3, deep_supervision=False, **kwargs):
74
+ super().__init__()
75
+
76
+ nb_filter = [32, 64, 128, 256, 512]
77
+
78
+ self.deep_supervision = deep_supervision
79
+
80
+ self.pool = nn.MaxPool2d(2, 2)
81
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
82
+
83
+ self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
84
+ self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
85
+ self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
86
+ self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
87
+ self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])
88
+
89
+ self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
90
+ self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
91
+ self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
92
+ self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])
93
+
94
+ self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0])
95
+ self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1])
96
+ self.conv2_2 = VGGBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], nb_filter[2])
97
+
98
+ self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0])
99
+ self.conv1_3 = VGGBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], nb_filter[1])
100
+
101
+ self.conv0_4 = VGGBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0], nb_filter[0])
102
+
103
+ if self.deep_supervision:
104
+ self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
105
+ self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
106
+ self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
107
+ self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
108
+ else:
109
+ self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
110
+
111
+
112
+ def forward(self, input):
113
+ x0_0 = self.conv0_0(input)
114
+ x1_0 = self.conv1_0(self.pool(x0_0))
115
+ x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))
116
+
117
+ x2_0 = self.conv2_0(self.pool(x1_0))
118
+ x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
119
+ x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))
120
+
121
+ x3_0 = self.conv3_0(self.pool(x2_0))
122
+ x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
123
+ x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
124
+ x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))
125
+
126
+ x4_0 = self.conv4_0(self.pool(x3_0))
127
+ x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
128
+ x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
129
+ x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
130
+ x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))
131
+
132
+ if self.deep_supervision:
133
+ output1 = self.final1(x0_1)
134
+ output2 = self.final2(x0_2)
135
+ output3 = self.final3(x0_3)
136
+ output4 = self.final4(x0_4)
137
+ return [output1, output2, output3, output4]
138
+
139
+ else:
140
+ output = self.final(x0_4)
141
+ return output
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio
4
+ numpy
5
+ Pillow
6
+ pyyaml
saved_models/isic2018_unet/best_model_state_dict.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8559876e4b85070ffbe84748e02926a9e8690b09bf101a12f7e5c5e590decbf0
3
+ size 7799041
saved_models/segpc2021_unet/best_model_state_dict.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a182b4f7a415d056ae7e5293aed483804494580eb3f1a3b27d04e77c55468e76
3
+ size 7800193