Kim Mạnh Hưng
commited on
Commit
·
aa04f76
1
Parent(s):
0e003ab
Add U-Net app and weights
Browse files- .gitignore +3 -0
- app.py +118 -0
- configs/isic/isic2018_attunet.yaml +50 -0
- configs/isic/isic2018_missformer.yaml +52 -0
- configs/isic/isic2018_multiresunet.yaml +51 -0
- configs/isic/isic2018_resunet.yaml +50 -0
- configs/isic/isic2018_transunet.yaml +52 -0
- configs/isic/isic2018_uctransnet.yaml +50 -0
- configs/isic/isic2018_unet.yaml +51 -0
- configs/isic/isic2018_unetpp.yaml +51 -0
- configs/segpc/segpc2021_attunet.yaml +47 -0
- configs/segpc/segpc2021_missformer.yaml +49 -0
- configs/segpc/segpc2021_multiresunet.yaml +53 -0
- configs/segpc/segpc2021_resunet.yaml +47 -0
- configs/segpc/segpc2021_transunet.yaml +52 -0
- configs/segpc/segpc2021_uctransnet.yaml +47 -0
- configs/segpc/segpc2021_unet.yaml +48 -0
- configs/segpc/segpc2021_unetpp.yaml +48 -0
- models/__init__.py +0 -0
- models/_missformer/MISSFormer.py +398 -0
- models/_missformer/__init__.py +0 -0
- models/_missformer/segformer.py +557 -0
- models/_resunet/__init__.py +0 -0
- models/_resunet/modules.py +143 -0
- models/_resunet/res_unet.py +65 -0
- models/_transunet/vit_seg_configs.py +130 -0
- models/_transunet/vit_seg_modeling.py +453 -0
- models/_transunet/vit_seg_modeling_c4.py +453 -0
- models/_transunet/vit_seg_modeling_resnet_skip.py +160 -0
- models/_transunet/vit_seg_modeling_resnet_skip_c4.py +160 -0
- models/_uctransnet/CTrans.py +365 -0
- models/_uctransnet/Config.py +72 -0
- models/_uctransnet/UCTransNet.py +139 -0
- models/_uctransnet/UNet.py +111 -0
- models/attunet.py +427 -0
- models/multiresunet.py +190 -0
- models/unet.py +64 -0
- models/unetpp.py +141 -0
- requirements.txt +6 -0
- saved_models/isic2018_unet/best_model_state_dict.pt +3 -0
- saved_models/segpc2021_unet/best_model_state_dict.pt +3 -0
.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
*.pyo
|
app.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import yaml
|
| 6 |
+
import os
|
| 7 |
+
from models.unet import UNet
|
| 8 |
+
|
| 9 |
+
# Configuration
|
| 10 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 11 |
+
# Map dataset names to config/model paths
|
| 12 |
+
CONFIG_PATHS = {
|
| 13 |
+
'isic': './configs/isic/isic2018_unet.yaml',
|
| 14 |
+
'segpc': './configs/segpc/segpc2021_unet.yaml'
|
| 15 |
+
}
|
| 16 |
+
MODEL_PATHS = {
|
| 17 |
+
'isic': './saved_models/isic2018_unet/best_model_state_dict.pt',
|
| 18 |
+
'segpc': './saved_models/segpc2021_unet/best_model_state_dict.pt'
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
def load_config(config_path):
|
| 22 |
+
with open(config_path, 'r') as f:
|
| 23 |
+
return yaml.safe_load(f)
|
| 24 |
+
|
| 25 |
+
def load_model(dataset_name):
|
| 26 |
+
config = load_config(CONFIG_PATHS[dataset_name])
|
| 27 |
+
model = UNet(
|
| 28 |
+
in_channels=config['model']['in_channels'],
|
| 29 |
+
out_channels=config['model']['out_channels']
|
| 30 |
+
)
|
| 31 |
+
model_path = MODEL_PATHS[dataset_name]
|
| 32 |
+
if os.path.exists(model_path):
|
| 33 |
+
state_dict = torch.load(model_path, map_location=DEVICE)
|
| 34 |
+
model.load_state_dict(state_dict)
|
| 35 |
+
print(f"Loaded model for {dataset_name} from {model_path}")
|
| 36 |
+
else:
|
| 37 |
+
print(f"Warning: Model weights not found for {dataset_name} at {model_path}")
|
| 38 |
+
|
| 39 |
+
model.to(DEVICE)
|
| 40 |
+
model.eval()
|
| 41 |
+
return model
|
| 42 |
+
|
| 43 |
+
# Load models once (cache them)
|
| 44 |
+
models = {}
|
| 45 |
+
for ds in ['isic', 'segpc']:
|
| 46 |
+
try:
|
| 47 |
+
models[ds] = load_model(ds)
|
| 48 |
+
except Exception as e:
|
| 49 |
+
print(f"Error loading model {ds}: {e}")
|
| 50 |
+
|
| 51 |
+
def predict(image, dataset_choice):
|
| 52 |
+
if image is None:
|
| 53 |
+
return None
|
| 54 |
+
|
| 55 |
+
if dataset_choice not in models:
|
| 56 |
+
return None
|
| 57 |
+
|
| 58 |
+
model = models[dataset_choice]
|
| 59 |
+
|
| 60 |
+
# Preprocess
|
| 61 |
+
# Resize to 224x224 as per config
|
| 62 |
+
img_resized = image.resize((224, 224))
|
| 63 |
+
img_np = np.array(img_resized).astype(np.float32) / 255.0
|
| 64 |
+
|
| 65 |
+
# Handle channels
|
| 66 |
+
if dataset_choice == 'isic':
|
| 67 |
+
# ISIC: 3 channels (RGB)
|
| 68 |
+
if img_np.shape[-1] == 4:
|
| 69 |
+
img_np = img_np[:, :, :3]
|
| 70 |
+
img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).float()
|
| 71 |
+
else:
|
| 72 |
+
# SegPC: 4 channels (BMP input often loaded as RGB, need to assume/check)
|
| 73 |
+
if img_np.shape[-1] == 3:
|
| 74 |
+
# Create fake 4th channel
|
| 75 |
+
padding = np.zeros((224, 224, 1), dtype=np.float32)
|
| 76 |
+
img_np = np.concatenate([img_np, padding], axis=-1)
|
| 77 |
+
|
| 78 |
+
img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).float()
|
| 79 |
+
|
| 80 |
+
img_tensor = img_tensor.to(DEVICE)
|
| 81 |
+
|
| 82 |
+
with torch.no_grad():
|
| 83 |
+
output = model(img_tensor)
|
| 84 |
+
probs = torch.sigmoid(output)
|
| 85 |
+
pred_mask = (probs > 0.5).float().cpu().numpy()[0, 0]
|
| 86 |
+
|
| 87 |
+
# Post-process for visualization
|
| 88 |
+
# Create an overlay
|
| 89 |
+
base_img = np.array(img_resized)
|
| 90 |
+
overlay = base_img.copy()
|
| 91 |
+
|
| 92 |
+
# Green mask
|
| 93 |
+
mask_bool = pred_mask > 0
|
| 94 |
+
overlay[mask_bool] = [0, 255, 0] # Make Green
|
| 95 |
+
|
| 96 |
+
# Blend
|
| 97 |
+
final_img = (0.6 * base_img + 0.4 * overlay).astype(np.uint8)
|
| 98 |
+
|
| 99 |
+
return final_img
|
| 100 |
+
|
| 101 |
+
# Interface
|
| 102 |
+
iface = gr.Interface(
|
| 103 |
+
fn=predict,
|
| 104 |
+
inputs=[
|
| 105 |
+
gr.Image(type="pil", label="Input Image"),
|
| 106 |
+
gr.Radio(["isic", "segpc"], label="Dataset Model", value="isic")
|
| 107 |
+
],
|
| 108 |
+
outputs=gr.Image(type="numpy", label="Prediction Overlay"),
|
| 109 |
+
title="Medical Image Segmentation (Awesome-U-Net)",
|
| 110 |
+
description="Upload an image to segment skin lesions (ISIC) or cells (SegPC).",
|
| 111 |
+
examples=[
|
| 112 |
+
# Add example paths if available
|
| 113 |
+
# ["dataset_examples/isic_sample.jpg", "isic"]
|
| 114 |
+
]
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
if __name__ == "__main__":
|
| 118 |
+
iface.launch()
|
configs/isic/isic2018_attunet.yaml
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
run:
|
| 2 |
+
mode: 'train'
|
| 3 |
+
device: 'gpu'
|
| 4 |
+
transforms: none
|
| 5 |
+
dataset:
|
| 6 |
+
class_name: "ISIC2018Dataset"
|
| 7 |
+
input_size: 224
|
| 8 |
+
training:
|
| 9 |
+
params:
|
| 10 |
+
data_dir: "/path/to/datasets/ISIC2018"
|
| 11 |
+
validation:
|
| 12 |
+
params:
|
| 13 |
+
data_dir: "/path/to/datasets/ISIC2018"
|
| 14 |
+
number_classes: 2
|
| 15 |
+
data_loader:
|
| 16 |
+
train:
|
| 17 |
+
batch_size: 16
|
| 18 |
+
shuffle: true
|
| 19 |
+
num_workers: 8
|
| 20 |
+
pin_memory: true
|
| 21 |
+
validation:
|
| 22 |
+
batch_size: 16
|
| 23 |
+
shuffle: false
|
| 24 |
+
num_workers: 8
|
| 25 |
+
pin_memory: true
|
| 26 |
+
test:
|
| 27 |
+
batch_size: 16
|
| 28 |
+
shuffle: false
|
| 29 |
+
num_workers: 4
|
| 30 |
+
pin_memory: false
|
| 31 |
+
training:
|
| 32 |
+
optimizer:
|
| 33 |
+
name: 'Adam'
|
| 34 |
+
params:
|
| 35 |
+
lr: 0.0001
|
| 36 |
+
criterion:
|
| 37 |
+
name: "DiceLoss"
|
| 38 |
+
params: {}
|
| 39 |
+
scheduler:
|
| 40 |
+
factor: 0.5
|
| 41 |
+
patience: 10
|
| 42 |
+
epochs: 100
|
| 43 |
+
model:
|
| 44 |
+
save_dir: '../../saved_models/isic2018_attunet'
|
| 45 |
+
load_weights: false
|
| 46 |
+
name: 'AttU_Net'
|
| 47 |
+
params:
|
| 48 |
+
img_ch: 3
|
| 49 |
+
output_ch: 2
|
| 50 |
+
# preprocess:
|
configs/isic/isic2018_missformer.yaml
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
run:
|
| 2 |
+
mode: 'train'
|
| 3 |
+
device: 'gpu'
|
| 4 |
+
transforms: none
|
| 5 |
+
dataset:
|
| 6 |
+
class_name: "ISIC2018Dataset"
|
| 7 |
+
input_size: 224
|
| 8 |
+
training:
|
| 9 |
+
params:
|
| 10 |
+
data_dir: "/path/to/datasets/ISIC2018"
|
| 11 |
+
validation:
|
| 12 |
+
params:
|
| 13 |
+
data_dir: "/path/to/datasets/ISIC2018"
|
| 14 |
+
number_classes: 2
|
| 15 |
+
data_loader:
|
| 16 |
+
train:
|
| 17 |
+
batch_size: 16
|
| 18 |
+
shuffle: true
|
| 19 |
+
num_workers: 8
|
| 20 |
+
pin_memory: true
|
| 21 |
+
validation:
|
| 22 |
+
batch_size: 16
|
| 23 |
+
shuffle: false
|
| 24 |
+
num_workers: 8
|
| 25 |
+
pin_memory: true
|
| 26 |
+
test:
|
| 27 |
+
batch_size: 16
|
| 28 |
+
shuffle: false
|
| 29 |
+
num_workers: 4
|
| 30 |
+
pin_memory: false
|
| 31 |
+
training:
|
| 32 |
+
optimizer:
|
| 33 |
+
name: 'SGD'
|
| 34 |
+
params:
|
| 35 |
+
lr: 0.0001
|
| 36 |
+
momentum: 0.9
|
| 37 |
+
weight_decay: 0.0001
|
| 38 |
+
criterion:
|
| 39 |
+
name: "DiceLoss"
|
| 40 |
+
params: {}
|
| 41 |
+
scheduler:
|
| 42 |
+
factor: 0.5
|
| 43 |
+
patience: 10
|
| 44 |
+
epochs: 300
|
| 45 |
+
model:
|
| 46 |
+
save_dir: '../../saved_models/isic2018_missformer'
|
| 47 |
+
load_weights: false
|
| 48 |
+
name: "MISSFormer"
|
| 49 |
+
params:
|
| 50 |
+
in_ch: 3
|
| 51 |
+
num_classes: 2
|
| 52 |
+
# preprocess:
|
configs/isic/isic2018_multiresunet.yaml
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
run:
|
| 2 |
+
mode: 'train'
|
| 3 |
+
device: 'gpu'
|
| 4 |
+
transforms: none
|
| 5 |
+
dataset:
|
| 6 |
+
class_name: "ISIC2018Dataset"
|
| 7 |
+
input_size: 224
|
| 8 |
+
training:
|
| 9 |
+
params:
|
| 10 |
+
data_dir: "/path/to/datasets/ISIC2018"
|
| 11 |
+
validation:
|
| 12 |
+
params:
|
| 13 |
+
data_dir: "/path/to/datasets/ISIC2018"
|
| 14 |
+
number_classes: 2
|
| 15 |
+
data_loader:
|
| 16 |
+
train:
|
| 17 |
+
batch_size: 16
|
| 18 |
+
shuffle: true
|
| 19 |
+
num_workers: 2
|
| 20 |
+
pin_memory: true
|
| 21 |
+
validation:
|
| 22 |
+
batch_size: 16
|
| 23 |
+
shuffle: false
|
| 24 |
+
num_workers: 2
|
| 25 |
+
pin_memory: true
|
| 26 |
+
test:
|
| 27 |
+
batch_size: 16
|
| 28 |
+
shuffle: false
|
| 29 |
+
num_workers: 2
|
| 30 |
+
pin_memory: false
|
| 31 |
+
training:
|
| 32 |
+
optimizer:
|
| 33 |
+
name: 'Adam'
|
| 34 |
+
params:
|
| 35 |
+
lr: 0.0005
|
| 36 |
+
criterion:
|
| 37 |
+
name: "DiceLoss"
|
| 38 |
+
params: {}
|
| 39 |
+
scheduler:
|
| 40 |
+
factor: 0.5
|
| 41 |
+
patience: 10
|
| 42 |
+
epochs: 100
|
| 43 |
+
model:
|
| 44 |
+
save_dir: '../../saved_models/isic2018_multiresunet'
|
| 45 |
+
load_weights: false
|
| 46 |
+
name: 'MultiResUnet'
|
| 47 |
+
params:
|
| 48 |
+
channels: 3
|
| 49 |
+
filters: 32
|
| 50 |
+
nclasses: 2
|
| 51 |
+
# preprocess:
|
configs/isic/isic2018_resunet.yaml
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
run:
|
| 2 |
+
mode: 'train'
|
| 3 |
+
device: 'gpu'
|
| 4 |
+
transforms: none
|
| 5 |
+
dataset:
|
| 6 |
+
class_name: "ISIC2018Dataset"
|
| 7 |
+
input_size: 224
|
| 8 |
+
training:
|
| 9 |
+
params:
|
| 10 |
+
data_dir: "/path/to/datasets/ISIC2018"
|
| 11 |
+
validation:
|
| 12 |
+
params:
|
| 13 |
+
data_dir: "/path/to/datasets/ISIC2018"
|
| 14 |
+
number_classes: 2
|
| 15 |
+
data_loader:
|
| 16 |
+
train:
|
| 17 |
+
batch_size: 16
|
| 18 |
+
shuffle: true
|
| 19 |
+
num_workers: 8
|
| 20 |
+
pin_memory: true
|
| 21 |
+
validation:
|
| 22 |
+
batch_size: 16
|
| 23 |
+
shuffle: false
|
| 24 |
+
num_workers: 8
|
| 25 |
+
pin_memory: true
|
| 26 |
+
test:
|
| 27 |
+
batch_size: 16
|
| 28 |
+
shuffle: false
|
| 29 |
+
num_workers: 4
|
| 30 |
+
pin_memory: false
|
| 31 |
+
training:
|
| 32 |
+
optimizer:
|
| 33 |
+
name: 'Adam'
|
| 34 |
+
params:
|
| 35 |
+
lr: 0.0001
|
| 36 |
+
criterion:
|
| 37 |
+
name: "DiceLoss"
|
| 38 |
+
params: {}
|
| 39 |
+
scheduler:
|
| 40 |
+
factor: 0.5
|
| 41 |
+
patience: 10
|
| 42 |
+
epochs: 100
|
| 43 |
+
model:
|
| 44 |
+
save_dir: '../../saved_models/isic2018_resunet'
|
| 45 |
+
load_weights: false
|
| 46 |
+
name: 'ResUnet'
|
| 47 |
+
params:
|
| 48 |
+
in_ch: 3
|
| 49 |
+
out_ch: 2
|
| 50 |
+
# preprocess:
|
configs/isic/isic2018_transunet.yaml
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
run:
|
| 2 |
+
mode: 'train'
|
| 3 |
+
device: 'gpu'
|
| 4 |
+
transforms: none
|
| 5 |
+
dataset:
|
| 6 |
+
class_name: "ISIC2018Dataset"
|
| 7 |
+
input_size: 224
|
| 8 |
+
training:
|
| 9 |
+
params:
|
| 10 |
+
data_dir: "/path/to/datasets/ISIC2018"
|
| 11 |
+
validation:
|
| 12 |
+
params:
|
| 13 |
+
data_dir: "/path/to/datasets/ISIC2018"
|
| 14 |
+
number_classes: 2
|
| 15 |
+
data_loader:
|
| 16 |
+
train:
|
| 17 |
+
batch_size: 16
|
| 18 |
+
shuffle: true
|
| 19 |
+
num_workers: 8
|
| 20 |
+
pin_memory: true
|
| 21 |
+
validation:
|
| 22 |
+
batch_size: 16
|
| 23 |
+
shuffle: false
|
| 24 |
+
num_workers: 8
|
| 25 |
+
pin_memory: true
|
| 26 |
+
test:
|
| 27 |
+
batch_size: 16
|
| 28 |
+
shuffle: false
|
| 29 |
+
num_workers: 4
|
| 30 |
+
pin_memory: false
|
| 31 |
+
training:
|
| 32 |
+
optimizer:
|
| 33 |
+
name: 'SGD'
|
| 34 |
+
params:
|
| 35 |
+
lr: 0.0001
|
| 36 |
+
momentum: 0.9
|
| 37 |
+
weight_decay: 0.0001
|
| 38 |
+
criterion:
|
| 39 |
+
name: "DiceLoss"
|
| 40 |
+
params: {}
|
| 41 |
+
scheduler:
|
| 42 |
+
factor: 0.5
|
| 43 |
+
patience: 10
|
| 44 |
+
epochs: 100
|
| 45 |
+
model:
|
| 46 |
+
save_dir: '../../saved_models/isic2018_transunet'
|
| 47 |
+
load_weights: false
|
| 48 |
+
name: 'VisionTransformer'
|
| 49 |
+
params:
|
| 50 |
+
img_size: 224
|
| 51 |
+
num_classes: 2
|
| 52 |
+
# preprocess:
|
configs/isic/isic2018_uctransnet.yaml
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
run:
|
| 2 |
+
mode: 'train'
|
| 3 |
+
device: 'gpu'
|
| 4 |
+
transforms: none
|
| 5 |
+
dataset:
|
| 6 |
+
class_name: "ISIC2018Dataset"
|
| 7 |
+
input_size: 224
|
| 8 |
+
training:
|
| 9 |
+
params:
|
| 10 |
+
data_dir: "/path/to/datasets/ISIC2018"
|
| 11 |
+
validation:
|
| 12 |
+
params:
|
| 13 |
+
data_dir: "/path/to/datasets/ISIC2018"
|
| 14 |
+
number_classes: 2
|
| 15 |
+
data_loader:
|
| 16 |
+
train:
|
| 17 |
+
batch_size: 16
|
| 18 |
+
shuffle: true
|
| 19 |
+
num_workers: 8
|
| 20 |
+
pin_memory: true
|
| 21 |
+
validation:
|
| 22 |
+
batch_size: 16
|
| 23 |
+
shuffle: false
|
| 24 |
+
num_workers: 8
|
| 25 |
+
pin_memory: true
|
| 26 |
+
test:
|
| 27 |
+
batch_size: 16
|
| 28 |
+
shuffle: false
|
| 29 |
+
num_workers: 4
|
| 30 |
+
pin_memory: false
|
| 31 |
+
training:
|
| 32 |
+
optimizer:
|
| 33 |
+
name: 'Adam'
|
| 34 |
+
params:
|
| 35 |
+
lr: 0.0001
|
| 36 |
+
criterion:
|
| 37 |
+
name: "DiceLoss"
|
| 38 |
+
params: {}
|
| 39 |
+
scheduler:
|
| 40 |
+
factor: 0.5
|
| 41 |
+
patience: 10
|
| 42 |
+
epochs: 100
|
| 43 |
+
model:
|
| 44 |
+
save_dir: '../../saved_models/isic2018_uctransnet'
|
| 45 |
+
load_weights: false
|
| 46 |
+
name: "UCTransNet"
|
| 47 |
+
params:
|
| 48 |
+
n_channels: 3
|
| 49 |
+
n_classes: 2
|
| 50 |
+
# preprocess:
|
configs/isic/isic2018_unet.yaml
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
run:
|
| 2 |
+
mode: 'train'
|
| 3 |
+
device: 'gpu'
|
| 4 |
+
transforms: none
|
| 5 |
+
dataset:
|
| 6 |
+
class_name: "ISIC2018Dataset"
|
| 7 |
+
input_size: 224
|
| 8 |
+
training:
|
| 9 |
+
params:
|
| 10 |
+
data_dir: "./datasets/ISIC2018"
|
| 11 |
+
validation:
|
| 12 |
+
params:
|
| 13 |
+
data_dir: "./datasets/ISIC2018"
|
| 14 |
+
number_classes: 2
|
| 15 |
+
data_loader:
|
| 16 |
+
train:
|
| 17 |
+
batch_size: 16
|
| 18 |
+
shuffle: true
|
| 19 |
+
num_workers: 0
|
| 20 |
+
pin_memory: true
|
| 21 |
+
validation:
|
| 22 |
+
batch_size: 16
|
| 23 |
+
shuffle: false
|
| 24 |
+
num_workers: 0
|
| 25 |
+
pin_memory: true
|
| 26 |
+
test:
|
| 27 |
+
batch_size: 16
|
| 28 |
+
shuffle: false
|
| 29 |
+
num_workers: 0
|
| 30 |
+
pin_memory: false
|
| 31 |
+
training:
|
| 32 |
+
optimizer:
|
| 33 |
+
name: 'Adam'
|
| 34 |
+
params:
|
| 35 |
+
lr: 0.0001
|
| 36 |
+
criterion:
|
| 37 |
+
name: "DiceLoss"
|
| 38 |
+
params: {}
|
| 39 |
+
scheduler:
|
| 40 |
+
factor: 0.5
|
| 41 |
+
patience: 10
|
| 42 |
+
epochs: 2
|
| 43 |
+
model:
|
| 44 |
+
save_dir: './saved_models/isic2018_unet'
|
| 45 |
+
load_weights: false
|
| 46 |
+
name: 'UNet'
|
| 47 |
+
params:
|
| 48 |
+
in_channels: 3
|
| 49 |
+
out_channels: 2
|
| 50 |
+
with_bn: false
|
| 51 |
+
# preprocess:
|
configs/isic/isic2018_unetpp.yaml
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
run:
|
| 2 |
+
mode: 'train'
|
| 3 |
+
device: 'gpu'
|
| 4 |
+
transforms: none
|
| 5 |
+
dataset:
|
| 6 |
+
class_name: "ISIC2018Dataset"
|
| 7 |
+
input_size: 224
|
| 8 |
+
training:
|
| 9 |
+
params:
|
| 10 |
+
data_dir: "/path/to/datasets/ISIC2018"
|
| 11 |
+
validation:
|
| 12 |
+
params:
|
| 13 |
+
data_dir: "/path/to/datasets/ISIC2018"
|
| 14 |
+
number_classes: 2
|
| 15 |
+
data_loader:
|
| 16 |
+
train:
|
| 17 |
+
batch_size: 16
|
| 18 |
+
shuffle: true
|
| 19 |
+
num_workers: 8
|
| 20 |
+
pin_memory: true
|
| 21 |
+
validation:
|
| 22 |
+
batch_size: 16
|
| 23 |
+
shuffle: false
|
| 24 |
+
num_workers: 8
|
| 25 |
+
pin_memory: true
|
| 26 |
+
test:
|
| 27 |
+
batch_size: 16
|
| 28 |
+
shuffle: false
|
| 29 |
+
num_workers: 4
|
| 30 |
+
pin_memory: false
|
| 31 |
+
training:
|
| 32 |
+
optimizer:
|
| 33 |
+
name: 'Adam'
|
| 34 |
+
params:
|
| 35 |
+
lr: 0.0001
|
| 36 |
+
criterion:
|
| 37 |
+
name: "DiceLoss"
|
| 38 |
+
params: {}
|
| 39 |
+
scheduler:
|
| 40 |
+
factor: 0.5
|
| 41 |
+
patience: 10
|
| 42 |
+
epochs: 100
|
| 43 |
+
model:
|
| 44 |
+
save_dir: '../../saved_models/isic2018_unetpp'
|
| 45 |
+
load_weights: false
|
| 46 |
+
name: 'NestedUNet'
|
| 47 |
+
params:
|
| 48 |
+
num_classes: 2
|
| 49 |
+
input_channels: 3
|
| 50 |
+
deep_supervision: false
|
| 51 |
+
# preprocess:
|
configs/segpc/segpc2021_attunet.yaml
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
run:
|
| 2 |
+
mode: 'train'
|
| 3 |
+
device: 'gpu'
|
| 4 |
+
transforms: none
|
| 5 |
+
dataset:
|
| 6 |
+
class_name: "SegPC2021Dataset"
|
| 7 |
+
input_size: 224
|
| 8 |
+
scale: 2.5
|
| 9 |
+
data_dir: "/path/to/datasets/segpc/np"
|
| 10 |
+
dataset_dir: "/path/to/datasets/segpc/TCIA_SegPC_dataset"
|
| 11 |
+
number_classes: 2
|
| 12 |
+
data_loader:
|
| 13 |
+
train:
|
| 14 |
+
batch_size: 16
|
| 15 |
+
shuffle: true
|
| 16 |
+
num_workers: 4
|
| 17 |
+
pin_memory: true
|
| 18 |
+
validation:
|
| 19 |
+
batch_size: 16
|
| 20 |
+
shuffle: false
|
| 21 |
+
num_workers: 4
|
| 22 |
+
pin_memory: true
|
| 23 |
+
test:
|
| 24 |
+
batch_size: 16
|
| 25 |
+
shuffle: false
|
| 26 |
+
num_workers: 4
|
| 27 |
+
pin_memory: false
|
| 28 |
+
training:
|
| 29 |
+
optimizer:
|
| 30 |
+
name: 'Adam'
|
| 31 |
+
params:
|
| 32 |
+
lr: 0.0001
|
| 33 |
+
criterion:
|
| 34 |
+
name: "DiceLoss"
|
| 35 |
+
params: {}
|
| 36 |
+
scheduler:
|
| 37 |
+
factor: 0.5
|
| 38 |
+
patience: 10
|
| 39 |
+
epochs: 100
|
| 40 |
+
model:
|
| 41 |
+
save_dir: '../../saved_models/segpc2021_attunet'
|
| 42 |
+
load_weights: false
|
| 43 |
+
name: 'AttU_Net'
|
| 44 |
+
params:
|
| 45 |
+
img_ch: 4
|
| 46 |
+
output_ch: 2
|
| 47 |
+
# preprocess:
|
configs/segpc/segpc2021_missformer.yaml
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
run:
|
| 2 |
+
mode: 'train'
|
| 3 |
+
device: 'gpu'
|
| 4 |
+
transforms: none
|
| 5 |
+
dataset:
|
| 6 |
+
class_name: "SegPC2021Dataset"
|
| 7 |
+
input_size: 224
|
| 8 |
+
scale: 2.5
|
| 9 |
+
data_dir: "/path/to/datasets/segpc/np"
|
| 10 |
+
dataset_dir: "/path/to/datasets/segpc/TCIA_SegPC_dataset"
|
| 11 |
+
number_classes: 2
|
| 12 |
+
data_loader:
|
| 13 |
+
train:
|
| 14 |
+
batch_size: 16
|
| 15 |
+
shuffle: true
|
| 16 |
+
num_workers: 4
|
| 17 |
+
pin_memory: true
|
| 18 |
+
validation:
|
| 19 |
+
batch_size: 16
|
| 20 |
+
shuffle: false
|
| 21 |
+
num_workers: 4
|
| 22 |
+
pin_memory: true
|
| 23 |
+
test:
|
| 24 |
+
batch_size: 16
|
| 25 |
+
shuffle: false
|
| 26 |
+
num_workers: 4
|
| 27 |
+
pin_memory: false
|
| 28 |
+
training:
|
| 29 |
+
optimizer:
|
| 30 |
+
name: 'SGD'
|
| 31 |
+
params:
|
| 32 |
+
lr: 0.0001
|
| 33 |
+
momentum: 0.9
|
| 34 |
+
weight_decay: 0.0001
|
| 35 |
+
criterion:
|
| 36 |
+
name: "DiceLoss"
|
| 37 |
+
params: {}
|
| 38 |
+
scheduler:
|
| 39 |
+
factor: 0.5
|
| 40 |
+
patience: 10
|
| 41 |
+
epochs: 500
|
| 42 |
+
model:
|
| 43 |
+
save_dir: '../../saved_models/segpc2021_missformer'
|
| 44 |
+
load_weights: false
|
| 45 |
+
name: 'MISSFormer'
|
| 46 |
+
params:
|
| 47 |
+
in_ch: 4
|
| 48 |
+
num_classes: 2
|
| 49 |
+
# preprocess:
|
configs/segpc/segpc2021_multiresunet.yaml
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
run:
|
| 2 |
+
mode: 'train'
|
| 3 |
+
device: 'gpu'
|
| 4 |
+
transforms: none
|
| 5 |
+
dataset:
|
| 6 |
+
class_name: "SegPC2021Dataset"
|
| 7 |
+
input_size: 224
|
| 8 |
+
scale: 2.5
|
| 9 |
+
data_dir: "/path/to/datasets/segpc/np"
|
| 10 |
+
dataset_dir: "/path/to/datasets/segpc/TCIA_SegPC_dataset"
|
| 11 |
+
number_classes: 2
|
| 12 |
+
data_loader:
|
| 13 |
+
train:
|
| 14 |
+
batch_size: 16
|
| 15 |
+
shuffle: true
|
| 16 |
+
num_workers: 4
|
| 17 |
+
pin_memory: true
|
| 18 |
+
validation:
|
| 19 |
+
batch_size: 16
|
| 20 |
+
shuffle: false
|
| 21 |
+
num_workers: 4
|
| 22 |
+
pin_memory: true
|
| 23 |
+
test:
|
| 24 |
+
batch_size: 16
|
| 25 |
+
shuffle: false
|
| 26 |
+
num_workers: 4
|
| 27 |
+
pin_memory: false
|
| 28 |
+
training:
|
| 29 |
+
optimizer:
|
| 30 |
+
name: 'Adam'
|
| 31 |
+
params:
|
| 32 |
+
lr: 0.0001
|
| 33 |
+
# name: "SGD"
|
| 34 |
+
# params:
|
| 35 |
+
# lr: 0.0001
|
| 36 |
+
# momentum: 0.9
|
| 37 |
+
# weight_decay: 0.0001
|
| 38 |
+
criterion:
|
| 39 |
+
name: "DiceLoss"
|
| 40 |
+
params: {}
|
| 41 |
+
scheduler:
|
| 42 |
+
factor: 0.5
|
| 43 |
+
patience: 10
|
| 44 |
+
epochs: 100
|
| 45 |
+
model:
|
| 46 |
+
save_dir: '../../saved_models/segpc2021_multiresunet'
|
| 47 |
+
load_weights: false
|
| 48 |
+
name: 'MultiResUnet'
|
| 49 |
+
params:
|
| 50 |
+
channels: 4
|
| 51 |
+
filters: 32
|
| 52 |
+
nclasses: 2
|
| 53 |
+
# preprocess:
|
configs/segpc/segpc2021_resunet.yaml
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
run:
|
| 2 |
+
mode: 'train'
|
| 3 |
+
device: 'gpu'
|
| 4 |
+
transforms: none
|
| 5 |
+
dataset:
|
| 6 |
+
class_name: "SegPC2021Dataset"
|
| 7 |
+
input_size: 224
|
| 8 |
+
scale: 2.5
|
| 9 |
+
data_dir: "/path/to/datasets/segpc/np"
|
| 10 |
+
dataset_dir: "/path/to/datasets/segpc/TCIA_SegPC_dataset"
|
| 11 |
+
number_classes: 2
|
| 12 |
+
data_loader:
|
| 13 |
+
train:
|
| 14 |
+
batch_size: 16
|
| 15 |
+
shuffle: true
|
| 16 |
+
num_workers: 4
|
| 17 |
+
pin_memory: true
|
| 18 |
+
validation:
|
| 19 |
+
batch_size: 16
|
| 20 |
+
shuffle: false
|
| 21 |
+
num_workers: 4
|
| 22 |
+
pin_memory: true
|
| 23 |
+
test:
|
| 24 |
+
batch_size: 16
|
| 25 |
+
shuffle: false
|
| 26 |
+
num_workers: 4
|
| 27 |
+
pin_memory: false
|
| 28 |
+
training:
|
| 29 |
+
optimizer:
|
| 30 |
+
name: 'Adam'
|
| 31 |
+
params:
|
| 32 |
+
lr: 0.0001
|
| 33 |
+
criterion:
|
| 34 |
+
name: "DiceLoss"
|
| 35 |
+
params: {}
|
| 36 |
+
scheduler:
|
| 37 |
+
factor: 0.5
|
| 38 |
+
patience: 10
|
| 39 |
+
epochs: 100
|
| 40 |
+
model:
|
| 41 |
+
save_dir: '../../saved_models/segpc2021_resunet'
|
| 42 |
+
load_weights: false
|
| 43 |
+
name: 'ResUnet'
|
| 44 |
+
params:
|
| 45 |
+
in_ch: 4
|
| 46 |
+
out_ch: 2
|
| 47 |
+
# preprocess:
|
configs/segpc/segpc2021_transunet.yaml
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
run:
|
| 2 |
+
mode: 'train'
|
| 3 |
+
device: 'gpu'
|
| 4 |
+
transforms: none
|
| 5 |
+
dataset:
|
| 6 |
+
class_name: "SegPC2021Dataset"
|
| 7 |
+
input_size: 224
|
| 8 |
+
scale: 2.5
|
| 9 |
+
data_dir: "/path/to/datasets/segpc/np"
|
| 10 |
+
dataset_dir: "/path/to/datasets/segpc/TCIA_SegPC_dataset"
|
| 11 |
+
number_classes: 2
|
| 12 |
+
data_loader:
|
| 13 |
+
train:
|
| 14 |
+
batch_size: 16
|
| 15 |
+
shuffle: true
|
| 16 |
+
num_workers: 4
|
| 17 |
+
pin_memory: true
|
| 18 |
+
validation:
|
| 19 |
+
batch_size: 16
|
| 20 |
+
shuffle: false
|
| 21 |
+
num_workers: 4
|
| 22 |
+
pin_memory: true
|
| 23 |
+
test:
|
| 24 |
+
batch_size: 16
|
| 25 |
+
shuffle: false
|
| 26 |
+
num_workers: 4
|
| 27 |
+
pin_memory: false
|
| 28 |
+
training:
|
| 29 |
+
optimizer:
|
| 30 |
+
# name: 'Adam'
|
| 31 |
+
# params:
|
| 32 |
+
# lr: 0.0001
|
| 33 |
+
name: "SGD"
|
| 34 |
+
params:
|
| 35 |
+
lr: 0.0001
|
| 36 |
+
momentum: 0.9
|
| 37 |
+
weight_decay: 0.0001
|
| 38 |
+
criterion:
|
| 39 |
+
name: "DiceLoss"
|
| 40 |
+
params: {}
|
| 41 |
+
scheduler:
|
| 42 |
+
factor: 0.5
|
| 43 |
+
patience: 10
|
| 44 |
+
epochs: 100
|
| 45 |
+
model:
|
| 46 |
+
save_dir: '../../saved_models/segpc2021_transunet'
|
| 47 |
+
load_weights: false
|
| 48 |
+
name: 'VisionTransformer'
|
| 49 |
+
params:
|
| 50 |
+
img_size: 224
|
| 51 |
+
num_classes: 2
|
| 52 |
+
# preprocess:
|
configs/segpc/segpc2021_uctransnet.yaml
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
run:
|
| 2 |
+
mode: 'train'
|
| 3 |
+
device: 'gpu'
|
| 4 |
+
transforms: none
|
| 5 |
+
dataset:
|
| 6 |
+
class_name: "SegPC2021Dataset"
|
| 7 |
+
input_size: 224
|
| 8 |
+
scale: 2.5
|
| 9 |
+
data_dir: "/path/to/datasets/segpc/np"
|
| 10 |
+
dataset_dir: "/path/to/datasets/segpc/TCIA_SegPC_dataset"
|
| 11 |
+
number_classes: 2
|
| 12 |
+
data_loader:
|
| 13 |
+
train:
|
| 14 |
+
batch_size: 16
|
| 15 |
+
shuffle: true
|
| 16 |
+
num_workers: 4
|
| 17 |
+
pin_memory: true
|
| 18 |
+
validation:
|
| 19 |
+
batch_size: 16
|
| 20 |
+
shuffle: false
|
| 21 |
+
num_workers: 4
|
| 22 |
+
pin_memory: true
|
| 23 |
+
test:
|
| 24 |
+
batch_size: 16
|
| 25 |
+
shuffle: false
|
| 26 |
+
num_workers: 4
|
| 27 |
+
pin_memory: false
|
| 28 |
+
training:
|
| 29 |
+
optimizer:
|
| 30 |
+
name: 'Adam'
|
| 31 |
+
params:
|
| 32 |
+
lr: 0.0001
|
| 33 |
+
criterion:
|
| 34 |
+
name: "DiceLoss"
|
| 35 |
+
params: {}
|
| 36 |
+
scheduler:
|
| 37 |
+
factor: 0.5
|
| 38 |
+
patience: 10
|
| 39 |
+
epochs: 100
|
| 40 |
+
model:
|
| 41 |
+
save_dir: '../../saved_models/segpc2021_uctransnet'
|
| 42 |
+
load_weights: false
|
| 43 |
+
name: 'UCTransNet'
|
| 44 |
+
params:
|
| 45 |
+
n_channels: 4
|
| 46 |
+
n_classes: 2
|
| 47 |
+
# preprocess:
|
configs/segpc/segpc2021_unet.yaml
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
run:
|
| 2 |
+
mode: 'train'
|
| 3 |
+
device: 'gpu'
|
| 4 |
+
transforms: none
|
| 5 |
+
dataset:
|
| 6 |
+
class_name: "SegPC2021Dataset"
|
| 7 |
+
input_size: 224
|
| 8 |
+
scale: 2.5
|
| 9 |
+
data_dir: "./datasets/SegPC2021/np"
|
| 10 |
+
dataset_dir: "./datasets/SegPC2021/TCIA_SegPC_dataset"
|
| 11 |
+
number_classes: 2
|
| 12 |
+
data_loader:
|
| 13 |
+
train:
|
| 14 |
+
batch_size: 8
|
| 15 |
+
shuffle: true
|
| 16 |
+
num_workers: 0
|
| 17 |
+
pin_memory: true
|
| 18 |
+
validation:
|
| 19 |
+
batch_size: 8
|
| 20 |
+
shuffle: false
|
| 21 |
+
num_workers: 0
|
| 22 |
+
pin_memory: true
|
| 23 |
+
test:
|
| 24 |
+
batch_size: 8
|
| 25 |
+
shuffle: false
|
| 26 |
+
num_workers: 0
|
| 27 |
+
pin_memory: false
|
| 28 |
+
training:
|
| 29 |
+
optimizer:
|
| 30 |
+
name: 'Adam'
|
| 31 |
+
params:
|
| 32 |
+
lr: 0.0001
|
| 33 |
+
criterion:
|
| 34 |
+
name: "DiceLoss"
|
| 35 |
+
params: {}
|
| 36 |
+
scheduler:
|
| 37 |
+
factor: 0.5
|
| 38 |
+
patience: 10
|
| 39 |
+
epochs: 2
|
| 40 |
+
model:
|
| 41 |
+
save_dir: './saved_models/segpc2021_unet'
|
| 42 |
+
load_weights: false
|
| 43 |
+
name: 'UNet'
|
| 44 |
+
params:
|
| 45 |
+
in_channels: 4
|
| 46 |
+
out_channels: 2
|
| 47 |
+
with_bn: false
|
| 48 |
+
# preprocess:
|
configs/segpc/segpc2021_unetpp.yaml
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
run:
|
| 2 |
+
mode: 'train'
|
| 3 |
+
device: 'gpu'
|
| 4 |
+
transforms: none
|
| 5 |
+
dataset:
|
| 6 |
+
class_name: "SegPC2021Dataset"
|
| 7 |
+
input_size: 224
|
| 8 |
+
scale: 2.5
|
| 9 |
+
data_dir: "/path/to/datasets/segpc/np"
|
| 10 |
+
dataset_dir: "/path/to/datasets/segpc/TCIA_SegPC_dataset"
|
| 11 |
+
number_classes: 2
|
| 12 |
+
data_loader:
|
| 13 |
+
train:
|
| 14 |
+
batch_size: 16
|
| 15 |
+
shuffle: true
|
| 16 |
+
num_workers: 4
|
| 17 |
+
pin_memory: true
|
| 18 |
+
validation:
|
| 19 |
+
batch_size: 16
|
| 20 |
+
shuffle: false
|
| 21 |
+
num_workers: 4
|
| 22 |
+
pin_memory: true
|
| 23 |
+
test:
|
| 24 |
+
batch_size: 16
|
| 25 |
+
shuffle: false
|
| 26 |
+
num_workers: 4
|
| 27 |
+
pin_memory: false
|
| 28 |
+
training:
|
| 29 |
+
optimizer:
|
| 30 |
+
name: 'Adam'
|
| 31 |
+
params:
|
| 32 |
+
lr: 0.0001
|
| 33 |
+
criterion:
|
| 34 |
+
name: "DiceLoss"
|
| 35 |
+
params: {}
|
| 36 |
+
scheduler:
|
| 37 |
+
factor: 0.5
|
| 38 |
+
patience: 10
|
| 39 |
+
epochs: 100
|
| 40 |
+
model:
|
| 41 |
+
save_dir: '../../saved_models/segpc2021_unetpp'
|
| 42 |
+
load_weights: false
|
| 43 |
+
name: 'NestedUNet'
|
| 44 |
+
params:
|
| 45 |
+
num_classes: 2
|
| 46 |
+
input_channels: 4
|
| 47 |
+
deep_supervision: false
|
| 48 |
+
# preprocess:
|
models/__init__.py
ADDED
|
File without changes
|
models/_missformer/MISSFormer.py
ADDED
|
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from .segformer import *
|
| 4 |
+
from typing import Tuple
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
|
| 7 |
+
class PatchExpand(nn.Module):
|
| 8 |
+
def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.input_resolution = input_resolution
|
| 11 |
+
self.dim = dim
|
| 12 |
+
self.expand = nn.Linear(dim, 2*dim, bias=False) if dim_scale==2 else nn.Identity()
|
| 13 |
+
self.norm = norm_layer(dim // dim_scale)
|
| 14 |
+
|
| 15 |
+
def forward(self, x):
|
| 16 |
+
"""
|
| 17 |
+
x: B, H*W, C
|
| 18 |
+
"""
|
| 19 |
+
# print("x_shape-----",x.shape)
|
| 20 |
+
H, W = self.input_resolution
|
| 21 |
+
x = self.expand(x)
|
| 22 |
+
|
| 23 |
+
B, L, C = x.shape
|
| 24 |
+
# print(x.shape)
|
| 25 |
+
assert L == H * W, "input feature has wrong size"
|
| 26 |
+
|
| 27 |
+
x = x.view(B, H, W, C)
|
| 28 |
+
x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4)
|
| 29 |
+
x = x.view(B,-1,C//4)
|
| 30 |
+
x= self.norm(x.clone())
|
| 31 |
+
|
| 32 |
+
return x
|
| 33 |
+
|
| 34 |
+
class FinalPatchExpand_X4(nn.Module):
|
| 35 |
+
def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.input_resolution = input_resolution
|
| 38 |
+
self.dim = dim
|
| 39 |
+
self.dim_scale = dim_scale
|
| 40 |
+
self.expand = nn.Linear(dim, 16*dim, bias=False)
|
| 41 |
+
self.output_dim = dim
|
| 42 |
+
self.norm = norm_layer(self.output_dim)
|
| 43 |
+
|
| 44 |
+
def forward(self, x):
|
| 45 |
+
"""
|
| 46 |
+
x: B, H*W, C
|
| 47 |
+
"""
|
| 48 |
+
H, W = self.input_resolution
|
| 49 |
+
x = self.expand(x)
|
| 50 |
+
B, L, C = x.shape
|
| 51 |
+
assert L == H * W, "input feature has wrong size"
|
| 52 |
+
|
| 53 |
+
x = x.view(B, H, W, C)
|
| 54 |
+
x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//(self.dim_scale**2))
|
| 55 |
+
x = x.view(B,-1,self.output_dim)
|
| 56 |
+
x= self.norm(x.clone())
|
| 57 |
+
|
| 58 |
+
return x
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class SegU_decoder(nn.Module):
|
| 62 |
+
def __init__(self, input_size, in_out_chan, heads, reduction_ratios, n_class=9, norm_layer=nn.LayerNorm, is_last=False):
|
| 63 |
+
super().__init__()
|
| 64 |
+
dims = in_out_chan[0]
|
| 65 |
+
out_dim = in_out_chan[1]
|
| 66 |
+
if not is_last:
|
| 67 |
+
self.concat_linear = nn.Linear(dims*2, out_dim)
|
| 68 |
+
# transformer decoder
|
| 69 |
+
self.layer_up = PatchExpand(input_resolution=input_size, dim=out_dim, dim_scale=2, norm_layer=norm_layer)
|
| 70 |
+
self.last_layer = None
|
| 71 |
+
else:
|
| 72 |
+
self.concat_linear = nn.Linear(dims*4, out_dim)
|
| 73 |
+
# transformer decoder
|
| 74 |
+
self.layer_up = FinalPatchExpand_X4(input_resolution=input_size, dim=out_dim, dim_scale=4, norm_layer=norm_layer)
|
| 75 |
+
# self.last_layer = nn.Linear(out_dim, n_class)
|
| 76 |
+
self.last_layer = nn.Conv2d(out_dim, n_class,1)
|
| 77 |
+
# self.last_layer = None
|
| 78 |
+
|
| 79 |
+
self.layer_former_1 = TransformerBlock(out_dim, heads, reduction_ratios)
|
| 80 |
+
self.layer_former_2 = TransformerBlock(out_dim, heads, reduction_ratios)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def init_weights(self):
|
| 84 |
+
for m in self.modules():
|
| 85 |
+
if isinstance(m, nn.Linear):
|
| 86 |
+
nn.init.xavier_uniform_(m.weight)
|
| 87 |
+
if m.bias is not None:
|
| 88 |
+
nn.init.zeros_(m.bias)
|
| 89 |
+
elif isinstance(m, nn.LayerNorm):
|
| 90 |
+
nn.init.ones_(m.weight)
|
| 91 |
+
nn.init.zeros_(m.bias)
|
| 92 |
+
elif isinstance(m, nn.Conv2d):
|
| 93 |
+
nn.init.xavier_uniform_(m.weight)
|
| 94 |
+
if m.bias is not None:
|
| 95 |
+
nn.init.zeros_(m.bias)
|
| 96 |
+
|
| 97 |
+
init_weights(self)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def forward(self, x1, x2=None):
|
| 102 |
+
if x2 is not None:
|
| 103 |
+
b, h, w, c = x2.shape
|
| 104 |
+
x2 = x2.view(b, -1, c)
|
| 105 |
+
# print("------",x1.shape, x2.shape)
|
| 106 |
+
cat_x = torch.cat([x1, x2], dim=-1)
|
| 107 |
+
# print("-----catx shape", cat_x.shape)
|
| 108 |
+
cat_linear_x = self.concat_linear(cat_x)
|
| 109 |
+
tran_layer_1 = self.layer_former_1(cat_linear_x, h, w)
|
| 110 |
+
tran_layer_2 = self.layer_former_2(tran_layer_1, h, w)
|
| 111 |
+
|
| 112 |
+
if self.last_layer:
|
| 113 |
+
out = self.last_layer(self.layer_up(tran_layer_2).view(b, 4*h, 4*w, -1).permute(0,3,1,2))
|
| 114 |
+
else:
|
| 115 |
+
out = self.layer_up(tran_layer_2)
|
| 116 |
+
else:
|
| 117 |
+
# if len(x1.shape)>3:
|
| 118 |
+
# x1 = x1.permute(0,2,3,1)
|
| 119 |
+
# b, h, w, c = x1.shape
|
| 120 |
+
# x1 = x1.view(b, -1, c)
|
| 121 |
+
out = self.layer_up(x1)
|
| 122 |
+
return out
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class BridgeLayer_4(nn.Module):
|
| 126 |
+
def __init__(self, dims, head, reduction_ratios):
|
| 127 |
+
super().__init__()
|
| 128 |
+
|
| 129 |
+
self.norm1 = nn.LayerNorm(dims)
|
| 130 |
+
self.attn = M_EfficientSelfAtten(dims, head, reduction_ratios)
|
| 131 |
+
self.norm2 = nn.LayerNorm(dims)
|
| 132 |
+
self.mixffn1 = MixFFN_skip(dims,dims*4)
|
| 133 |
+
self.mixffn2 = MixFFN_skip(dims*2,dims*8)
|
| 134 |
+
self.mixffn3 = MixFFN_skip(dims*5,dims*20)
|
| 135 |
+
self.mixffn4 = MixFFN_skip(dims*8,dims*32)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def forward(self, inputs):
|
| 139 |
+
B = inputs[0].shape[0]
|
| 140 |
+
C = 64
|
| 141 |
+
if (type(inputs) == list):
|
| 142 |
+
# print("-----1-----")
|
| 143 |
+
c1, c2, c3, c4 = inputs
|
| 144 |
+
B, C, _, _= c1.shape
|
| 145 |
+
c1f = c1.permute(0, 2, 3, 1).reshape(B, -1, C) # 3136*64
|
| 146 |
+
c2f = c2.permute(0, 2, 3, 1).reshape(B, -1, C) # 1568*64
|
| 147 |
+
c3f = c3.permute(0, 2, 3, 1).reshape(B, -1, C) # 980*64
|
| 148 |
+
c4f = c4.permute(0, 2, 3, 1).reshape(B, -1, C) # 392*64
|
| 149 |
+
|
| 150 |
+
# print(c1f.shape, c2f.shape, c3f.shape, c4f.shape)
|
| 151 |
+
inputs = torch.cat([c1f, c2f, c3f, c4f], -2)
|
| 152 |
+
else:
|
| 153 |
+
B,_,C = inputs.shape
|
| 154 |
+
|
| 155 |
+
tx1 = inputs + self.attn(self.norm1(inputs))
|
| 156 |
+
tx = self.norm2(tx1)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
tem1 = tx[:,:3136,:].reshape(B, -1, C)
|
| 160 |
+
tem2 = tx[:,3136:4704,:].reshape(B, -1, C*2)
|
| 161 |
+
tem3 = tx[:,4704:5684,:].reshape(B, -1, C*5)
|
| 162 |
+
tem4 = tx[:,5684:6076,:].reshape(B, -1, C*8)
|
| 163 |
+
|
| 164 |
+
m1f = self.mixffn1(tem1, 56, 56).reshape(B, -1, C)
|
| 165 |
+
m2f = self.mixffn2(tem2, 28, 28).reshape(B, -1, C)
|
| 166 |
+
m3f = self.mixffn3(tem3, 14, 14).reshape(B, -1, C)
|
| 167 |
+
m4f = self.mixffn4(tem4, 7, 7).reshape(B, -1, C)
|
| 168 |
+
|
| 169 |
+
t1 = torch.cat([m1f, m2f, m3f, m4f], -2)
|
| 170 |
+
|
| 171 |
+
tx2 = tx1 + t1
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
return tx2
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class BridgeLayer_3(nn.Module):
|
| 178 |
+
def __init__(self, dims, head, reduction_ratios):
|
| 179 |
+
super().__init__()
|
| 180 |
+
|
| 181 |
+
self.norm1 = nn.LayerNorm(dims)
|
| 182 |
+
self.attn = M_EfficientSelfAtten(dims, head, reduction_ratios)
|
| 183 |
+
self.norm2 = nn.LayerNorm(dims)
|
| 184 |
+
# self.mixffn1 = MixFFN(dims,dims*4)
|
| 185 |
+
self.mixffn2 = MixFFN(dims*2,dims*8)
|
| 186 |
+
self.mixffn3 = MixFFN(dims*5,dims*20)
|
| 187 |
+
self.mixffn4 = MixFFN(dims*8,dims*32)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def forward(self, inputs: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) -> torch.Tensor:
|
| 191 |
+
B = inputs[0].shape[0]
|
| 192 |
+
C = 64
|
| 193 |
+
if (type(inputs) == list):
|
| 194 |
+
# print("-----1-----")
|
| 195 |
+
c1, c2, c3, c4 = inputs
|
| 196 |
+
B, C, _, _= c1.shape
|
| 197 |
+
c1f = c1.permute(0, 2, 3, 1).reshape(B, -1, C) # 3136*64
|
| 198 |
+
c2f = c2.permute(0, 2, 3, 1).reshape(B, -1, C) # 1568*64
|
| 199 |
+
c3f = c3.permute(0, 2, 3, 1).reshape(B, -1, C) # 980*64
|
| 200 |
+
c4f = c4.permute(0, 2, 3, 1).reshape(B, -1, C) # 392*64
|
| 201 |
+
|
| 202 |
+
# print(c1f.shape, c2f.shape, c3f.shape, c4f.shape)
|
| 203 |
+
inputs = torch.cat([c2f, c3f, c4f], -2)
|
| 204 |
+
else:
|
| 205 |
+
B,_,C = inputs.shape
|
| 206 |
+
|
| 207 |
+
tx1 = inputs + self.attn(self.norm1(inputs))
|
| 208 |
+
tx = self.norm2(tx1)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
# tem1 = tx[:,:3136,:].reshape(B, -1, C)
|
| 212 |
+
tem2 = tx[:,:1568,:].reshape(B, -1, C*2)
|
| 213 |
+
tem3 = tx[:,1568:2548,:].reshape(B, -1, C*5)
|
| 214 |
+
tem4 = tx[:,2548:2940,:].reshape(B, -1, C*8)
|
| 215 |
+
|
| 216 |
+
# m1f = self.mixffn1(tem1, 56, 56).reshape(B, -1, C)
|
| 217 |
+
m2f = self.mixffn2(tem2, 28, 28).reshape(B, -1, C)
|
| 218 |
+
m3f = self.mixffn3(tem3, 14, 14).reshape(B, -1, C)
|
| 219 |
+
m4f = self.mixffn4(tem4, 7, 7).reshape(B, -1, C)
|
| 220 |
+
|
| 221 |
+
t1 = torch.cat([m2f, m3f, m4f], -2)
|
| 222 |
+
|
| 223 |
+
tx2 = tx1 + t1
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
return tx2
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
class BridegeBlock_4(nn.Module):
|
| 231 |
+
def __init__(self, dims, head, reduction_ratios):
|
| 232 |
+
super().__init__()
|
| 233 |
+
self.bridge_layer1 = BridgeLayer_4(dims, head, reduction_ratios)
|
| 234 |
+
self.bridge_layer2 = BridgeLayer_4(dims, head, reduction_ratios)
|
| 235 |
+
self.bridge_layer3 = BridgeLayer_4(dims, head, reduction_ratios)
|
| 236 |
+
self.bridge_layer4 = BridgeLayer_4(dims, head, reduction_ratios)
|
| 237 |
+
|
| 238 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 239 |
+
bridge1 = self.bridge_layer1(x)
|
| 240 |
+
bridge2 = self.bridge_layer2(bridge1)
|
| 241 |
+
bridge3 = self.bridge_layer3(bridge2)
|
| 242 |
+
bridge4 = self.bridge_layer4(bridge3)
|
| 243 |
+
|
| 244 |
+
B,_,C = bridge4.shape
|
| 245 |
+
outs = []
|
| 246 |
+
|
| 247 |
+
sk1 = bridge4[:,:3136,:].reshape(B, 56, 56, C).permute(0,3,1,2)
|
| 248 |
+
sk2 = bridge4[:,3136:4704,:].reshape(B, 28, 28, C*2).permute(0,3,1,2)
|
| 249 |
+
sk3 = bridge4[:,4704:5684,:].reshape(B, 14, 14, C*5).permute(0,3,1,2)
|
| 250 |
+
sk4 = bridge4[:,5684:6076,:].reshape(B, 7, 7, C*8).permute(0,3,1,2)
|
| 251 |
+
|
| 252 |
+
outs.append(sk1)
|
| 253 |
+
outs.append(sk2)
|
| 254 |
+
outs.append(sk3)
|
| 255 |
+
outs.append(sk4)
|
| 256 |
+
|
| 257 |
+
return outs
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
class BridegeBlock_3(nn.Module):
|
| 261 |
+
def __init__(self, dims, head, reduction_ratios):
|
| 262 |
+
super().__init__()
|
| 263 |
+
self.bridge_layer1 = BridgeLayer_3(dims, head, reduction_ratios)
|
| 264 |
+
self.bridge_layer2 = BridgeLayer_3(dims, head, reduction_ratios)
|
| 265 |
+
self.bridge_layer3 = BridgeLayer_3(dims, head, reduction_ratios)
|
| 266 |
+
self.bridge_layer4 = BridgeLayer_3(dims, head, reduction_ratios)
|
| 267 |
+
|
| 268 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 269 |
+
outs = []
|
| 270 |
+
if (type(x) == list):
|
| 271 |
+
# print("-----1-----")
|
| 272 |
+
outs.append(x[0])
|
| 273 |
+
bridge1 = self.bridge_layer1(x)
|
| 274 |
+
bridge2 = self.bridge_layer2(bridge1)
|
| 275 |
+
bridge3 = self.bridge_layer3(bridge2)
|
| 276 |
+
bridge4 = self.bridge_layer4(bridge3)
|
| 277 |
+
|
| 278 |
+
B,_,C = bridge4.shape
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
# sk1 = bridge2[:,:3136,:].reshape(B, 56, 56, C).permute(0,3,1,2)
|
| 282 |
+
sk2 = bridge4[:,:1568,:].reshape(B, 28, 28, C*2).permute(0,3,1,2)
|
| 283 |
+
sk3 = bridge4[:,1568:2548,:].reshape(B, 14, 14, C*5).permute(0,3,1,2)
|
| 284 |
+
sk4 = bridge4[:,2548:2940,:].reshape(B, 7, 7, C*8).permute(0,3,1,2)
|
| 285 |
+
|
| 286 |
+
# outs.append(sk1)
|
| 287 |
+
outs.append(sk2)
|
| 288 |
+
outs.append(sk3)
|
| 289 |
+
outs.append(sk4)
|
| 290 |
+
|
| 291 |
+
return outs
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
class MyDecoderLayer(nn.Module):
|
| 295 |
+
def __init__(self, input_size, in_out_chan, heads, reduction_ratios,token_mlp_mode, n_class=9, norm_layer=nn.LayerNorm, is_last=False):
|
| 296 |
+
super().__init__()
|
| 297 |
+
dims = in_out_chan[0]
|
| 298 |
+
out_dim = in_out_chan[1]
|
| 299 |
+
if not is_last:
|
| 300 |
+
self.concat_linear = nn.Linear(dims*2, out_dim)
|
| 301 |
+
# transformer decoder
|
| 302 |
+
self.layer_up = PatchExpand(input_resolution=input_size, dim=out_dim, dim_scale=2, norm_layer=norm_layer)
|
| 303 |
+
self.last_layer = None
|
| 304 |
+
else:
|
| 305 |
+
self.concat_linear = nn.Linear(dims*4, out_dim)
|
| 306 |
+
# transformer decoder
|
| 307 |
+
self.layer_up = FinalPatchExpand_X4(input_resolution=input_size, dim=out_dim, dim_scale=4, norm_layer=norm_layer)
|
| 308 |
+
# self.last_layer = nn.Linear(out_dim, n_class)
|
| 309 |
+
self.last_layer = nn.Conv2d(out_dim, n_class,1)
|
| 310 |
+
# self.last_layer = None
|
| 311 |
+
|
| 312 |
+
self.layer_former_1 = TransformerBlock(out_dim, heads, reduction_ratios, token_mlp_mode)
|
| 313 |
+
self.layer_former_2 = TransformerBlock(out_dim, heads, reduction_ratios, token_mlp_mode)
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def init_weights(self):
|
| 317 |
+
for m in self.modules():
|
| 318 |
+
if isinstance(m, nn.Linear):
|
| 319 |
+
nn.init.xavier_uniform_(m.weight)
|
| 320 |
+
if m.bias is not None:
|
| 321 |
+
nn.init.zeros_(m.bias)
|
| 322 |
+
elif isinstance(m, nn.LayerNorm):
|
| 323 |
+
nn.init.ones_(m.weight)
|
| 324 |
+
nn.init.zeros_(m.bias)
|
| 325 |
+
elif isinstance(m, nn.Conv2d):
|
| 326 |
+
nn.init.xavier_uniform_(m.weight)
|
| 327 |
+
if m.bias is not None:
|
| 328 |
+
nn.init.zeros_(m.bias)
|
| 329 |
+
|
| 330 |
+
init_weights(self)
|
| 331 |
+
|
| 332 |
+
def forward(self, x1, x2=None):
|
| 333 |
+
if x2 is not None:
|
| 334 |
+
b, h, w, c = x2.shape
|
| 335 |
+
x2 = x2.view(b, -1, c)
|
| 336 |
+
# print("------",x1.shape, x2.shape)
|
| 337 |
+
cat_x = torch.cat([x1, x2], dim=-1)
|
| 338 |
+
# print("-----catx shape", cat_x.shape)
|
| 339 |
+
cat_linear_x = self.concat_linear(cat_x)
|
| 340 |
+
tran_layer_1 = self.layer_former_1(cat_linear_x, h, w)
|
| 341 |
+
tran_layer_2 = self.layer_former_2(tran_layer_1, h, w)
|
| 342 |
+
|
| 343 |
+
if self.last_layer:
|
| 344 |
+
out = self.last_layer(self.layer_up(tran_layer_2).view(b, 4*h, 4*w, -1).permute(0,3,1,2))
|
| 345 |
+
else:
|
| 346 |
+
out = self.layer_up(tran_layer_2)
|
| 347 |
+
else:
|
| 348 |
+
# if len(x1.shape)>3:
|
| 349 |
+
# x1 = x1.permute(0,2,3,1)
|
| 350 |
+
# b, h, w, c = x1.shape
|
| 351 |
+
# x1 = x1.view(b, -1, c)
|
| 352 |
+
out = self.layer_up(x1)
|
| 353 |
+
return out
|
| 354 |
+
|
| 355 |
+
class MISSFormer(nn.Module):
|
| 356 |
+
def __init__(self, num_classes=9, in_ch=3, token_mlp_mode="mix_skip", encoder_pretrained=True):
|
| 357 |
+
super().__init__()
|
| 358 |
+
|
| 359 |
+
reduction_ratios = [8, 4, 2, 1]
|
| 360 |
+
heads = [1, 2, 5, 8]
|
| 361 |
+
d_base_feat_size = 7 #16 for 512 inputsize 7for 224
|
| 362 |
+
in_out_chan = [[32, 64],[144, 128],[288, 320],[512, 512]]
|
| 363 |
+
|
| 364 |
+
dims, layers = [[64, 128, 320, 512], [2, 2, 2, 2]]
|
| 365 |
+
self.backbone = MiT(224, dims, layers,in_ch, token_mlp_mode)
|
| 366 |
+
|
| 367 |
+
self.reduction_ratios = [1, 2, 4, 8]
|
| 368 |
+
self.bridge = BridegeBlock_4(64, 1, self.reduction_ratios)
|
| 369 |
+
|
| 370 |
+
self.decoder_3= MyDecoderLayer((d_base_feat_size,d_base_feat_size), in_out_chan[3], heads[3], reduction_ratios[3],token_mlp_mode, n_class=num_classes)
|
| 371 |
+
self.decoder_2= MyDecoderLayer((d_base_feat_size*2,d_base_feat_size*2),in_out_chan[2], heads[2], reduction_ratios[2], token_mlp_mode, n_class=num_classes)
|
| 372 |
+
self.decoder_1= MyDecoderLayer((d_base_feat_size*4,d_base_feat_size*4), in_out_chan[1], heads[1], reduction_ratios[1], token_mlp_mode, n_class=num_classes)
|
| 373 |
+
self.decoder_0= MyDecoderLayer((d_base_feat_size*8,d_base_feat_size*8), in_out_chan[0], heads[0], reduction_ratios[0], token_mlp_mode, n_class=num_classes, is_last=True)
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def forward(self, x):
|
| 377 |
+
#---------------Encoder-------------------------
|
| 378 |
+
if x.size()[1] == 1:
|
| 379 |
+
x = x.repeat(1,3,1,1)
|
| 380 |
+
|
| 381 |
+
encoder = self.backbone(x)
|
| 382 |
+
bridge = self.bridge(encoder) #list
|
| 383 |
+
|
| 384 |
+
b,c,_,_ = bridge[3].shape
|
| 385 |
+
# print(bridge[3].shape, bridge[2].shape,bridge[1].shape, bridge[0].shape)
|
| 386 |
+
#---------------Decoder-------------------------
|
| 387 |
+
# print("stage3-----")
|
| 388 |
+
tmp_3 = self.decoder_3(bridge[3].permute(0,2,3,1).view(b,-1,c))
|
| 389 |
+
# print("stage2-----")
|
| 390 |
+
tmp_2 = self.decoder_2(tmp_3, bridge[2].permute(0,2,3,1))
|
| 391 |
+
# print("stage1-----")
|
| 392 |
+
tmp_1 = self.decoder_1(tmp_2, bridge[1].permute(0,2,3,1))
|
| 393 |
+
# print("stage0-----")
|
| 394 |
+
tmp_0 = self.decoder_0(tmp_1, bridge[0].permute(0,2,3,1))
|
| 395 |
+
|
| 396 |
+
return tmp_0
|
| 397 |
+
|
| 398 |
+
|
models/_missformer/__init__.py
ADDED
|
File without changes
|
models/_missformer/segformer.py
ADDED
|
@@ -0,0 +1,557 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
from typing import Tuple
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class EfficientSelfAtten(nn.Module):
|
| 8 |
+
def __init__(self, dim, head, reduction_ratio):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.head = head
|
| 11 |
+
self.reduction_ratio = reduction_ratio
|
| 12 |
+
self.scale = (dim // head) ** -0.5
|
| 13 |
+
self.q = nn.Linear(dim, dim, bias=True)
|
| 14 |
+
self.kv = nn.Linear(dim, dim*2, bias=True)
|
| 15 |
+
self.proj = nn.Linear(dim, dim)
|
| 16 |
+
|
| 17 |
+
if reduction_ratio > 1:
|
| 18 |
+
self.sr = nn.Conv2d(dim, dim, reduction_ratio, reduction_ratio)
|
| 19 |
+
self.norm = nn.LayerNorm(dim)
|
| 20 |
+
|
| 21 |
+
def forward(self, x: torch.Tensor, H, W) -> torch.Tensor:
|
| 22 |
+
B, N, C = x.shape
|
| 23 |
+
q = self.q(x).reshape(B, N, self.head, C // self.head).permute(0, 2, 1, 3)
|
| 24 |
+
|
| 25 |
+
if self.reduction_ratio > 1:
|
| 26 |
+
p_x = x.clone().permute(0, 2, 1).reshape(B, C, H, W)
|
| 27 |
+
sp_x = self.sr(p_x).reshape(B, C, -1).permute(0, 2, 1)
|
| 28 |
+
x = self.norm(sp_x)
|
| 29 |
+
|
| 30 |
+
kv = self.kv(x).reshape(B, -1, 2, self.head, C // self.head).permute(2, 0, 3, 1, 4)
|
| 31 |
+
k, v = kv[0], kv[1]
|
| 32 |
+
|
| 33 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 34 |
+
attn_score = attn.softmax(dim=-1)
|
| 35 |
+
|
| 36 |
+
x_atten = (attn_score @ v).transpose(1, 2).reshape(B, N, C)
|
| 37 |
+
out = self.proj(x_atten)
|
| 38 |
+
|
| 39 |
+
return out
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class SelfAtten(nn.Module):
|
| 43 |
+
def __init__(self, dim, head):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.head = head
|
| 46 |
+
self.scale = (dim // head) ** -0.5
|
| 47 |
+
self.q = nn.Linear(dim, dim, bias=True)
|
| 48 |
+
self.kv = nn.Linear(dim, dim*2, bias=True)
|
| 49 |
+
self.proj = nn.Linear(dim, dim)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 53 |
+
B, N, C = x.shape
|
| 54 |
+
q = self.q(x).reshape(B, N, self.head, C // self.head).permute(0, 2, 1, 3)
|
| 55 |
+
|
| 56 |
+
kv = self.kv(x).reshape(B, -1, 2, self.head, C // self.head).permute(2, 0, 3, 1, 4)
|
| 57 |
+
k, v = kv[0], kv[1]
|
| 58 |
+
|
| 59 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 60 |
+
attn_score = attn.softmax(dim=-1)
|
| 61 |
+
|
| 62 |
+
x_atten = (attn_score @ v).transpose(1, 2).reshape(B, N, C)
|
| 63 |
+
out = self.proj(x_atten)
|
| 64 |
+
|
| 65 |
+
return out
|
| 66 |
+
|
| 67 |
+
class Scale_reduce(nn.Module):
|
| 68 |
+
def __init__(self, dim, reduction_ratio):
|
| 69 |
+
super().__init__()
|
| 70 |
+
self.dim = dim
|
| 71 |
+
self.reduction_ratio = reduction_ratio
|
| 72 |
+
if(len(self.reduction_ratio)==4):
|
| 73 |
+
self.sr0 = nn.Conv2d(dim, dim, reduction_ratio[3], reduction_ratio[3])
|
| 74 |
+
self.sr1 = nn.Conv2d(dim*2, dim*2, reduction_ratio[2], reduction_ratio[2])
|
| 75 |
+
self.sr2 = nn.Conv2d(dim*5, dim*5, reduction_ratio[1], reduction_ratio[1])
|
| 76 |
+
|
| 77 |
+
elif(len(self.reduction_ratio)==3):
|
| 78 |
+
self.sr0 = nn.Conv2d(dim*2, dim*2, reduction_ratio[2], reduction_ratio[2])
|
| 79 |
+
self.sr1 = nn.Conv2d(dim*5, dim*5, reduction_ratio[1], reduction_ratio[1])
|
| 80 |
+
|
| 81 |
+
self.norm = nn.LayerNorm(dim)
|
| 82 |
+
|
| 83 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 84 |
+
B, N, C = x.shape
|
| 85 |
+
if(len(self.reduction_ratio)==4):
|
| 86 |
+
tem0 = x[:,:3136,:].reshape(B, 56, 56, C).permute(0, 3, 1, 2)
|
| 87 |
+
tem1 = x[:,3136:4704,:].reshape(B, 28, 28, C*2).permute(0, 3, 1, 2)
|
| 88 |
+
tem2 = x[:,4704:5684,:].reshape(B, 14, 14, C*5).permute(0, 3, 1, 2)
|
| 89 |
+
tem3 = x[:,5684:6076,:]
|
| 90 |
+
|
| 91 |
+
sr_0 = self.sr0(tem0).reshape(B, C, -1).permute(0, 2, 1)
|
| 92 |
+
sr_1 = self.sr1(tem1).reshape(B, C, -1).permute(0, 2, 1)
|
| 93 |
+
sr_2 = self.sr2(tem2).reshape(B, C, -1).permute(0, 2, 1)
|
| 94 |
+
|
| 95 |
+
reduce_out = self.norm(torch.cat([sr_0, sr_1, sr_2, tem3], -2))
|
| 96 |
+
|
| 97 |
+
if(len(self.reduction_ratio)==3):
|
| 98 |
+
tem0 = x[:,:1568,:].reshape(B, 28, 28, C*2).permute(0, 3, 1, 2)
|
| 99 |
+
tem1 = x[:,1568:2548,:].reshape(B, 14, 14, C*5).permute(0, 3, 1, 2)
|
| 100 |
+
tem2 = x[:,2548:2940,:]
|
| 101 |
+
|
| 102 |
+
sr_0 = self.sr0(tem0).reshape(B, C, -1).permute(0, 2, 1)
|
| 103 |
+
sr_1 = self.sr1(tem1).reshape(B, C, -1).permute(0, 2, 1)
|
| 104 |
+
|
| 105 |
+
reduce_out = self.norm(torch.cat([sr_0, sr_1, tem2], -2))
|
| 106 |
+
|
| 107 |
+
return reduce_out
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class M_EfficientSelfAtten(nn.Module):
|
| 114 |
+
def __init__(self, dim, head, reduction_ratio):
|
| 115 |
+
super().__init__()
|
| 116 |
+
self.head = head
|
| 117 |
+
self.reduction_ratio = reduction_ratio # list[1 2 4 8]
|
| 118 |
+
self.scale = (dim // head) ** -0.5
|
| 119 |
+
self.q = nn.Linear(dim, dim, bias=True)
|
| 120 |
+
self.kv = nn.Linear(dim, dim*2, bias=True)
|
| 121 |
+
self.proj = nn.Linear(dim, dim)
|
| 122 |
+
|
| 123 |
+
if reduction_ratio is not None:
|
| 124 |
+
self.scale_reduce = Scale_reduce(dim,reduction_ratio)
|
| 125 |
+
|
| 126 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 127 |
+
B, N, C = x.shape
|
| 128 |
+
q = self.q(x).reshape(B, N, self.head, C // self.head).permute(0, 2, 1, 3)
|
| 129 |
+
|
| 130 |
+
if self.reduction_ratio is not None:
|
| 131 |
+
x = self.scale_reduce(x)
|
| 132 |
+
|
| 133 |
+
kv = self.kv(x).reshape(B, -1, 2, self.head, C // self.head).permute(2, 0, 3, 1, 4)
|
| 134 |
+
k, v = kv[0], kv[1]
|
| 135 |
+
|
| 136 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 137 |
+
attn_score = attn.softmax(dim=-1)
|
| 138 |
+
|
| 139 |
+
x_atten = (attn_score @ v).transpose(1, 2).reshape(B, N, C)
|
| 140 |
+
out = self.proj(x_atten)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
return out
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class LocalEnhance_EfficientSelfAtten(nn.Module):
|
| 147 |
+
def __init__(self, dim, head, reduction_ratio):
|
| 148 |
+
super().__init__()
|
| 149 |
+
self.head = head
|
| 150 |
+
self.reduction_ratio = reduction_ratio
|
| 151 |
+
self.scale = (dim // head) ** -0.5
|
| 152 |
+
self.q = nn.Linear(dim, dim, bias=True)
|
| 153 |
+
self.kv = nn.Linear(dim, dim*2, bias=True)
|
| 154 |
+
self.proj = nn.Linear(dim, dim)
|
| 155 |
+
self.local_pos = DWConv(dim)
|
| 156 |
+
|
| 157 |
+
if reduction_ratio > 1:
|
| 158 |
+
self.sr = nn.Conv2d(dim, dim, reduction_ratio, reduction_ratio)
|
| 159 |
+
self.norm = nn.LayerNorm(dim)
|
| 160 |
+
|
| 161 |
+
def forward(self, x: torch.Tensor, H, W) -> torch.Tensor:
|
| 162 |
+
B, N, C = x.shape
|
| 163 |
+
q = self.q(x).reshape(B, N, self.head, C // self.head).permute(0, 2, 1, 3)
|
| 164 |
+
|
| 165 |
+
if self.reduction_ratio > 1:
|
| 166 |
+
p_x = x.clone().permute(0, 2, 1).reshape(B, C, H, W)
|
| 167 |
+
sp_x = self.sr(p_x).reshape(B, C, -1).permute(0, 2, 1)
|
| 168 |
+
x = self.norm(sp_x)
|
| 169 |
+
|
| 170 |
+
kv = self.kv(x).reshape(B, -1, 2, self.head, C // self.head).permute(2, 0, 3, 1, 4)
|
| 171 |
+
k, v = kv[0], kv[1]
|
| 172 |
+
|
| 173 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 174 |
+
attn_score = attn.softmax(dim=-1)
|
| 175 |
+
local_v = v.permute(0, 2, 1, 3).reshape(B, N, C)
|
| 176 |
+
local_pos = self.local_pos(local_v).reshape(B, -1, self.head, C//self.head).permute(0, 2, 1, 3)
|
| 177 |
+
x_atten = ((attn_score @ v) + local_pos).transpose(1, 2).reshape(B, N, C)
|
| 178 |
+
out = self.proj(x_atten)
|
| 179 |
+
|
| 180 |
+
return out
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class DWConv(nn.Module):
|
| 184 |
+
def __init__(self, dim):
|
| 185 |
+
super().__init__()
|
| 186 |
+
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim)
|
| 187 |
+
|
| 188 |
+
def forward(self, x: torch.Tensor, H, W) -> torch.Tensor:
|
| 189 |
+
B, N, C = x.shape
|
| 190 |
+
tx = x.transpose(1, 2).view(B, C, H, W)
|
| 191 |
+
conv_x = self.dwconv(tx)
|
| 192 |
+
return conv_x.flatten(2).transpose(1, 2)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class MixFFN(nn.Module):
|
| 196 |
+
def __init__(self, c1, c2):
|
| 197 |
+
super().__init__()
|
| 198 |
+
self.fc1 = nn.Linear(c1, c2)
|
| 199 |
+
self.dwconv = DWConv(c2)
|
| 200 |
+
self.act = nn.GELU()
|
| 201 |
+
self.fc2 = nn.Linear(c2, c1)
|
| 202 |
+
|
| 203 |
+
def forward(self, x: torch.Tensor, H, W) -> torch.Tensor:
|
| 204 |
+
ax = self.act(self.dwconv(self.fc1(x), H, W))
|
| 205 |
+
out = self.fc2(ax)
|
| 206 |
+
return out
|
| 207 |
+
|
| 208 |
+
class MixFFN_skip(nn.Module):
|
| 209 |
+
def __init__(self, c1, c2):
|
| 210 |
+
super().__init__()
|
| 211 |
+
self.fc1 = nn.Linear(c1, c2)
|
| 212 |
+
self.dwconv = DWConv(c2)
|
| 213 |
+
self.act = nn.GELU()
|
| 214 |
+
self.fc2 = nn.Linear(c2, c1)
|
| 215 |
+
self.norm1 = nn.LayerNorm(c2)
|
| 216 |
+
self.norm2 = nn.LayerNorm(c2)
|
| 217 |
+
self.norm3 = nn.LayerNorm(c2)
|
| 218 |
+
def forward(self, x: torch.Tensor, H, W) -> torch.Tensor:
|
| 219 |
+
ax = self.act(self.norm1(self.dwconv(self.fc1(x), H, W)+self.fc1(x)))
|
| 220 |
+
out = self.fc2(ax)
|
| 221 |
+
return out
|
| 222 |
+
|
| 223 |
+
class MLP_FFN(nn.Module):
|
| 224 |
+
def __init__(self, c1, c2):
|
| 225 |
+
super().__init__()
|
| 226 |
+
self.fc1 = nn.Linear(c1, c2)
|
| 227 |
+
self.act = nn.GELU()
|
| 228 |
+
self.fc2 = nn.Linear(c2, c1)
|
| 229 |
+
|
| 230 |
+
def forward(self, x):
|
| 231 |
+
x = self.fc1(x)
|
| 232 |
+
x = self.act(x)
|
| 233 |
+
x = self.fc2(x)
|
| 234 |
+
return x
|
| 235 |
+
|
| 236 |
+
class MixD_FFN(nn.Module):
|
| 237 |
+
def __init__(self, c1, c2, fuse_mode = "add"):
|
| 238 |
+
super().__init__()
|
| 239 |
+
self.fc1 = nn.Linear(c1, c2)
|
| 240 |
+
self.dwconv = DWConv(c2)
|
| 241 |
+
self.act = nn.GELU()
|
| 242 |
+
self.fc2 = nn.Linear(c2, c1) if fuse_mode=="add" else nn.Linear(c2*2, c1)
|
| 243 |
+
self.fuse_mode = fuse_mode
|
| 244 |
+
|
| 245 |
+
def forward(self, x):
|
| 246 |
+
ax = self.dwconv(self.fc1(x), H, W)
|
| 247 |
+
fuse = self.act(ax+self.fc1(x)) if self.fuse_mode=="add" else self.act(torch.cat([ax, self.fc1(x)],2))
|
| 248 |
+
out = self.fc2(ax)
|
| 249 |
+
return out
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class OverlapPatchEmbeddings(nn.Module):
|
| 253 |
+
def __init__(self, img_size=224, patch_size=7, stride=4, padding=1, in_ch=3, dim=768):
|
| 254 |
+
super().__init__()
|
| 255 |
+
self.num_patches = (img_size // patch_size) ** 2
|
| 256 |
+
self.proj = nn.Conv2d(in_ch, dim, patch_size, stride, padding)
|
| 257 |
+
self.norm = nn.LayerNorm(dim)
|
| 258 |
+
|
| 259 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 260 |
+
px = self.proj(x)
|
| 261 |
+
_, _, H, W = px.shape
|
| 262 |
+
fx = px.flatten(2).transpose(1, 2)
|
| 263 |
+
nfx = self.norm(fx)
|
| 264 |
+
return nfx, H, W
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class TransformerBlock(nn.Module):
|
| 269 |
+
def __init__(self, dim, head, reduction_ratio=1, token_mlp='mix'):
|
| 270 |
+
super().__init__()
|
| 271 |
+
self.norm1 = nn.LayerNorm(dim)
|
| 272 |
+
self.attn = EfficientSelfAtten(dim, head, reduction_ratio)
|
| 273 |
+
self.norm2 = nn.LayerNorm(dim)
|
| 274 |
+
if token_mlp=='mix':
|
| 275 |
+
self.mlp = MixFFN(dim, int(dim*4))
|
| 276 |
+
elif token_mlp=='mix_skip':
|
| 277 |
+
self.mlp = MixFFN_skip(dim, int(dim*4))
|
| 278 |
+
else:
|
| 279 |
+
self.mlp = MLP_FFN(dim, int(dim*4))
|
| 280 |
+
|
| 281 |
+
def forward(self, x: torch.Tensor, H, W) -> torch.Tensor:
|
| 282 |
+
tx = x + self.attn(self.norm1(x), H, W)
|
| 283 |
+
mx = tx + self.mlp(self.norm2(tx), H, W)
|
| 284 |
+
return mx
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class FuseTransformerBlock(nn.Module):
|
| 288 |
+
def __init__(self, dim, head, reduction_ratio=1, fuse_mode = "add"):
|
| 289 |
+
super().__init__()
|
| 290 |
+
self.norm1 = nn.LayerNorm(dim)
|
| 291 |
+
self.attn = EfficientSelfAtten(dim, head, reduction_ratio)
|
| 292 |
+
self.norm2 = nn.LayerNorm(dim)
|
| 293 |
+
self.mlp = MixD_FFN(dim, int(dim*4), fuse_mode)
|
| 294 |
+
|
| 295 |
+
def forward(self, x: torch.Tensor, H, W) -> torch.Tensor:
|
| 296 |
+
tx = x + self.attn(self.norm1(x), H, W)
|
| 297 |
+
mx = tx + self.mlp(self.norm2(tx), H, W)
|
| 298 |
+
return mx
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
class MLP(nn.Module):
|
| 302 |
+
def __init__(self, dim, embed_dim):
|
| 303 |
+
super().__init__()
|
| 304 |
+
self.proj = nn.Linear(dim, embed_dim)
|
| 305 |
+
|
| 306 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 307 |
+
x = x.flatten(2).transpose(1, 2)
|
| 308 |
+
return self.proj(x)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
class ConvModule(nn.Module):
|
| 312 |
+
def __init__(self, c1, c2, k):
|
| 313 |
+
super().__init__()
|
| 314 |
+
self.conv = nn.Conv2d(c1, c2, k, bias=False)
|
| 315 |
+
self.bn = nn.BatchNorm2d(c2)
|
| 316 |
+
self.activate = nn.ReLU(True)
|
| 317 |
+
|
| 318 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 319 |
+
return self.activate(self.bn(self.conv(x)))
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
class MiT(nn.Module):
|
| 323 |
+
def __init__(self, image_size, dims, layers, in_ch=3, token_mlp='mix_skip'):
|
| 324 |
+
super().__init__()
|
| 325 |
+
patch_sizes = [7, 3, 3, 3]
|
| 326 |
+
strides = [4, 2, 2, 2]
|
| 327 |
+
padding_sizes = [3, 1, 1, 1]
|
| 328 |
+
reduction_ratios = [8, 4, 2, 1]
|
| 329 |
+
heads = [1, 2, 5, 8]
|
| 330 |
+
|
| 331 |
+
# patch_embed
|
| 332 |
+
self.patch_embed1 = OverlapPatchEmbeddings(image_size, patch_sizes[0], strides[0], padding_sizes[0], in_ch, dims[0])
|
| 333 |
+
self.patch_embed2 = OverlapPatchEmbeddings(image_size//4, patch_sizes[1], strides[1], padding_sizes[1],dims[0], dims[1])
|
| 334 |
+
self.patch_embed3 = OverlapPatchEmbeddings(image_size//8, patch_sizes[2], strides[2], padding_sizes[2],dims[1], dims[2])
|
| 335 |
+
self.patch_embed4 = OverlapPatchEmbeddings(image_size//16, patch_sizes[3], strides[3], padding_sizes[3],dims[2], dims[3])
|
| 336 |
+
|
| 337 |
+
# transformer encoder
|
| 338 |
+
self.block1 = nn.ModuleList([
|
| 339 |
+
TransformerBlock(dims[0], heads[0], reduction_ratios[0],token_mlp)
|
| 340 |
+
for _ in range(layers[0])])
|
| 341 |
+
self.norm1 = nn.LayerNorm(dims[0])
|
| 342 |
+
|
| 343 |
+
self.block2 = nn.ModuleList([
|
| 344 |
+
TransformerBlock(dims[1], heads[1], reduction_ratios[1],token_mlp)
|
| 345 |
+
for _ in range(layers[1])])
|
| 346 |
+
self.norm2 = nn.LayerNorm(dims[1])
|
| 347 |
+
|
| 348 |
+
self.block3 = nn.ModuleList([
|
| 349 |
+
TransformerBlock(dims[2], heads[2], reduction_ratios[2], token_mlp)
|
| 350 |
+
for _ in range(layers[2])])
|
| 351 |
+
self.norm3 = nn.LayerNorm(dims[2])
|
| 352 |
+
|
| 353 |
+
self.block4 = nn.ModuleList([
|
| 354 |
+
TransformerBlock(dims[3], heads[3], reduction_ratios[3], token_mlp)
|
| 355 |
+
for _ in range(layers[3])])
|
| 356 |
+
self.norm4 = nn.LayerNorm(dims[3])
|
| 357 |
+
|
| 358 |
+
# self.head = nn.Linear(dims[3], num_classes)
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 362 |
+
B = x.shape[0]
|
| 363 |
+
outs = []
|
| 364 |
+
|
| 365 |
+
# stage 1
|
| 366 |
+
x, H, W = self.patch_embed1(x)
|
| 367 |
+
for blk in self.block1:
|
| 368 |
+
x = blk(x, H, W)
|
| 369 |
+
x = self.norm1(x)
|
| 370 |
+
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
| 371 |
+
outs.append(x)
|
| 372 |
+
|
| 373 |
+
# stage 2
|
| 374 |
+
x, H, W = self.patch_embed2(x)
|
| 375 |
+
for blk in self.block2:
|
| 376 |
+
x = blk(x, H, W)
|
| 377 |
+
x = self.norm2(x)
|
| 378 |
+
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
| 379 |
+
outs.append(x)
|
| 380 |
+
|
| 381 |
+
# stage 3
|
| 382 |
+
x, H, W = self.patch_embed3(x)
|
| 383 |
+
for blk in self.block3:
|
| 384 |
+
x = blk(x, H, W)
|
| 385 |
+
x = self.norm3(x)
|
| 386 |
+
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
| 387 |
+
outs.append(x)
|
| 388 |
+
|
| 389 |
+
# stage 4
|
| 390 |
+
x, H, W = self.patch_embed4(x)
|
| 391 |
+
for blk in self.block4:
|
| 392 |
+
x = blk(x, H, W)
|
| 393 |
+
x = self.norm4(x)
|
| 394 |
+
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
| 395 |
+
outs.append(x)
|
| 396 |
+
|
| 397 |
+
return outs
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
class FuseMiT(nn.Module):
|
| 401 |
+
def __init__(self, image_size, dims, layers, fuse_mode='add'):
|
| 402 |
+
super().__init__()
|
| 403 |
+
patch_sizes = [7, 3, 3, 3]
|
| 404 |
+
strides = [4, 2, 2, 2]
|
| 405 |
+
padding_sizes = [3, 1, 1, 1]
|
| 406 |
+
reduction_ratios = [8, 4, 2, 1]
|
| 407 |
+
heads = [1, 2, 5, 8]
|
| 408 |
+
|
| 409 |
+
# patch_embed
|
| 410 |
+
self.patch_embed1 = OverlapPatchEmbeddings(image_size, patch_sizes[0], strides[0], padding_sizes[0], 3, dims[0])
|
| 411 |
+
self.patch_embed2 = OverlapPatchEmbeddings(image_size//4, patch_sizes[1], strides[1], padding_sizes[1],dims[0], dims[1])
|
| 412 |
+
self.patch_embed3 = OverlapPatchEmbeddings(image_size//8, patch_sizes[2], strides[2], padding_sizes[2],dims[1], dims[2])
|
| 413 |
+
self.patch_embed4 = OverlapPatchEmbeddings(image_size//16, patch_sizes[3], strides[3], padding_sizes[3],dims[2], dims[3])
|
| 414 |
+
|
| 415 |
+
# transformer encoder
|
| 416 |
+
self.block1 = nn.ModuleList([
|
| 417 |
+
FuseTransformerBlock(dims[0], heads[0], reduction_ratios[0],fuse_mode)
|
| 418 |
+
for _ in range(layers[0])])
|
| 419 |
+
self.norm1 = nn.LayerNorm(dims[0])
|
| 420 |
+
|
| 421 |
+
self.block2 = nn.ModuleList([
|
| 422 |
+
FuseTransformerBlock(dims[1], heads[1], reduction_ratios[1],fuse_mode)
|
| 423 |
+
for _ in range(layers[1])])
|
| 424 |
+
self.norm2 = nn.LayerNorm(dims[1])
|
| 425 |
+
|
| 426 |
+
self.block3 = nn.ModuleList([
|
| 427 |
+
FuseTransformerBlock(dims[2], heads[2], reduction_ratios[2], fuse_mode)
|
| 428 |
+
for _ in range(layers[2])])
|
| 429 |
+
self.norm3 = nn.LayerNorm(dims[2])
|
| 430 |
+
|
| 431 |
+
self.block4 = nn.ModuleList([
|
| 432 |
+
FuseTransformerBlock(dims[3], heads[3], reduction_ratios[3], fuse_mode)
|
| 433 |
+
for _ in range(layers[3])])
|
| 434 |
+
self.norm4 = nn.LayerNorm(dims[3])
|
| 435 |
+
|
| 436 |
+
# self.head = nn.Linear(dims[3], num_classes)
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 440 |
+
B = x.shape[0]
|
| 441 |
+
outs = []
|
| 442 |
+
|
| 443 |
+
# stage 1
|
| 444 |
+
x, H, W = self.patch_embed1(x)
|
| 445 |
+
for blk in self.block1:
|
| 446 |
+
x = blk(x, H, W)
|
| 447 |
+
x = self.norm1(x)
|
| 448 |
+
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
| 449 |
+
outs.append(x)
|
| 450 |
+
|
| 451 |
+
# stage 2
|
| 452 |
+
x, H, W = self.patch_embed2(x)
|
| 453 |
+
for blk in self.block2:
|
| 454 |
+
x = blk(x, H, W)
|
| 455 |
+
x = self.norm2(x)
|
| 456 |
+
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
| 457 |
+
outs.append(x)
|
| 458 |
+
|
| 459 |
+
# stage 3
|
| 460 |
+
x, H, W = self.patch_embed3(x)
|
| 461 |
+
for blk in self.block3:
|
| 462 |
+
x = blk(x, H, W)
|
| 463 |
+
x = self.norm3(x)
|
| 464 |
+
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
| 465 |
+
outs.append(x)
|
| 466 |
+
|
| 467 |
+
# stage 4
|
| 468 |
+
x, H, W = self.patch_embed4(x)
|
| 469 |
+
for blk in self.block4:
|
| 470 |
+
x = blk(x, H, W)
|
| 471 |
+
x = self.norm4(x)
|
| 472 |
+
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
| 473 |
+
outs.append(x)
|
| 474 |
+
|
| 475 |
+
return outs
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
class Decoder(nn.Module):
|
| 481 |
+
def __init__(self, dims, embed_dim, num_classes):
|
| 482 |
+
super().__init__()
|
| 483 |
+
|
| 484 |
+
self.linear_c1 = MLP(dims[0], embed_dim)
|
| 485 |
+
self.linear_c2 = MLP(dims[1], embed_dim)
|
| 486 |
+
self.linear_c3 = MLP(dims[2], embed_dim)
|
| 487 |
+
self.linear_c4 = MLP(dims[3], embed_dim)
|
| 488 |
+
|
| 489 |
+
self.linear_fuse = ConvModule(embed_dim*4, embed_dim, 1)
|
| 490 |
+
self.linear_pred = nn.Conv2d(embed_dim, num_classes, 1)
|
| 491 |
+
|
| 492 |
+
self.conv_seg = nn.Conv2d(128, num_classes, 1)
|
| 493 |
+
|
| 494 |
+
self.dropout = nn.Dropout2d(0.1)
|
| 495 |
+
|
| 496 |
+
def forward(self, inputs: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) -> torch.Tensor:
|
| 497 |
+
c1, c2, c3, c4 = inputs
|
| 498 |
+
n = c1.shape[0]
|
| 499 |
+
c1f = self.linear_c1(c1).permute(0, 2, 1).reshape(n, -1, c1.shape[2], c1.shape[3])
|
| 500 |
+
|
| 501 |
+
c2f = self.linear_c2(c2).permute(0, 2, 1).reshape(n, -1, c2.shape[2], c2.shape[3])
|
| 502 |
+
c2f = F.interpolate(c2f, size=c1.shape[2:], mode='bilinear', align_corners=False)
|
| 503 |
+
|
| 504 |
+
c3f = self.linear_c3(c3).permute(0, 2, 1).reshape(n, -1, c3.shape[2], c3.shape[3])
|
| 505 |
+
c3f = F.interpolate(c3f, size=c1.shape[2:], mode='bilinear', align_corners=False)
|
| 506 |
+
|
| 507 |
+
c4f = self.linear_c4(c4).permute(0, 2, 1).reshape(n, -1, c4.shape[2], c4.shape[3])
|
| 508 |
+
c4f = F.interpolate(c4f, size=c1.shape[2:], mode='bilinear', align_corners=False)
|
| 509 |
+
|
| 510 |
+
c = self.linear_fuse(torch.cat([c4f, c3f, c2f, c1f], dim=1))
|
| 511 |
+
c = self.dropout(c)
|
| 512 |
+
return self.linear_pred(c)
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
segformer_settings = {
|
| 516 |
+
'B0': [[32, 64, 160, 256], [2, 2, 2, 2], 256], # [channel dimensions, num encoder layers, embed dim]
|
| 517 |
+
'B1': [[64, 128, 320, 512], [2, 2, 2, 2], 256],
|
| 518 |
+
'B2': [[64, 128, 320, 512], [3, 4, 6, 3], 768],
|
| 519 |
+
'B3': [[64, 128, 320, 512], [3, 4, 18, 3], 768],
|
| 520 |
+
'B4': [[64, 128, 320, 512], [3, 8, 27, 3], 768],
|
| 521 |
+
'B5': [[64, 128, 320, 512], [3, 6, 40, 3], 768]
|
| 522 |
+
}
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
class SegFormer(nn.Module):
|
| 526 |
+
def __init__(self, model_name: str = 'B0', num_classes: int = 19, image_size: int = 224) -> None:
|
| 527 |
+
super().__init__()
|
| 528 |
+
assert model_name in segformer_settings.keys(), f"SegFormer model name should be in {list(segformer_settings.keys())}"
|
| 529 |
+
dims, layers, embed_dim = segformer_settings[model_name]
|
| 530 |
+
|
| 531 |
+
self.backbone = MiT(image_size, dims, layers)
|
| 532 |
+
self.decode_head = Decoder(dims, embed_dim, num_classes)
|
| 533 |
+
|
| 534 |
+
def init_weights(self, pretrained: str = None) -> None:
|
| 535 |
+
if pretrained:
|
| 536 |
+
self.backbone.load_state_dict(torch.load(pretrained, map_location='cpu'), strict=False)
|
| 537 |
+
else:
|
| 538 |
+
for m in self.modules():
|
| 539 |
+
if isinstance(m, nn.Linear):
|
| 540 |
+
nn.init.xavier_uniform_(m.weight)
|
| 541 |
+
if m.bias is not None:
|
| 542 |
+
nn.init.zeros_(m.bias)
|
| 543 |
+
elif isinstance(m, nn.LayerNorm):
|
| 544 |
+
nn.init.ones_(m.weight)
|
| 545 |
+
nn.init.zeros_(m.bias)
|
| 546 |
+
elif isinstance(m, nn.Conv2d):
|
| 547 |
+
nn.init.xavier_uniform_(m.weight)
|
| 548 |
+
if m.bias is not None:
|
| 549 |
+
nn.init.zeros_(m.bias)
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 553 |
+
if x.size()[1] == 1:
|
| 554 |
+
x = x.repeat(1,3,1,1)
|
| 555 |
+
encoder_outs = self.backbone(x)
|
| 556 |
+
return self.decode_head(encoder_outs)
|
| 557 |
+
|
models/_resunet/__init__.py
ADDED
|
File without changes
|
models/_resunet/modules.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://github.com/rishikksh20/ResUnet/blob/master/core/modules.py
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ResidualConv(nn.Module):
|
| 9 |
+
def __init__(self, input_dim, output_dim, stride, padding):
|
| 10 |
+
super(ResidualConv, self).__init__()
|
| 11 |
+
|
| 12 |
+
self.conv_block = nn.Sequential(
|
| 13 |
+
nn.BatchNorm2d(input_dim),
|
| 14 |
+
nn.ReLU(),
|
| 15 |
+
nn.Conv2d(
|
| 16 |
+
input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
|
| 17 |
+
),
|
| 18 |
+
nn.BatchNorm2d(output_dim),
|
| 19 |
+
nn.ReLU(),
|
| 20 |
+
nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
|
| 21 |
+
)
|
| 22 |
+
self.conv_skip = nn.Sequential(
|
| 23 |
+
nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
|
| 24 |
+
nn.BatchNorm2d(output_dim),
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
def forward(self, x):
|
| 28 |
+
|
| 29 |
+
return self.conv_block(x) + self.conv_skip(x)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class Upsample(nn.Module):
|
| 33 |
+
def __init__(self, input_dim, output_dim, kernel, stride):
|
| 34 |
+
super(Upsample, self).__init__()
|
| 35 |
+
|
| 36 |
+
self.upsample = nn.ConvTranspose2d(
|
| 37 |
+
input_dim, output_dim, kernel_size=kernel, stride=stride
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
def forward(self, x):
|
| 41 |
+
return self.upsample(x)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class Squeeze_Excite_Block(nn.Module):
|
| 45 |
+
def __init__(self, channel, reduction=16):
|
| 46 |
+
super(Squeeze_Excite_Block, self).__init__()
|
| 47 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
| 48 |
+
self.fc = nn.Sequential(
|
| 49 |
+
nn.Linear(channel, channel // reduction, bias=False),
|
| 50 |
+
nn.ReLU(inplace=True),
|
| 51 |
+
nn.Linear(channel // reduction, channel, bias=False),
|
| 52 |
+
nn.Sigmoid(),
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
b, c, _, _ = x.size()
|
| 57 |
+
y = self.avg_pool(x).view(b, c)
|
| 58 |
+
y = self.fc(y).view(b, c, 1, 1)
|
| 59 |
+
return x * y.expand_as(x)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class ASPP(nn.Module):
|
| 63 |
+
def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):
|
| 64 |
+
super(ASPP, self).__init__()
|
| 65 |
+
|
| 66 |
+
self.aspp_block1 = nn.Sequential(
|
| 67 |
+
nn.Conv2d(
|
| 68 |
+
in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0]
|
| 69 |
+
),
|
| 70 |
+
nn.ReLU(inplace=True),
|
| 71 |
+
nn.BatchNorm2d(out_dims),
|
| 72 |
+
)
|
| 73 |
+
self.aspp_block2 = nn.Sequential(
|
| 74 |
+
nn.Conv2d(
|
| 75 |
+
in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1]
|
| 76 |
+
),
|
| 77 |
+
nn.ReLU(inplace=True),
|
| 78 |
+
nn.BatchNorm2d(out_dims),
|
| 79 |
+
)
|
| 80 |
+
self.aspp_block3 = nn.Sequential(
|
| 81 |
+
nn.Conv2d(
|
| 82 |
+
in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2]
|
| 83 |
+
),
|
| 84 |
+
nn.ReLU(inplace=True),
|
| 85 |
+
nn.BatchNorm2d(out_dims),
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1)
|
| 89 |
+
self._init_weights()
|
| 90 |
+
|
| 91 |
+
def forward(self, x):
|
| 92 |
+
x1 = self.aspp_block1(x)
|
| 93 |
+
x2 = self.aspp_block2(x)
|
| 94 |
+
x3 = self.aspp_block3(x)
|
| 95 |
+
out = torch.cat([x1, x2, x3], dim=1)
|
| 96 |
+
return self.output(out)
|
| 97 |
+
|
| 98 |
+
def _init_weights(self):
|
| 99 |
+
for m in self.modules():
|
| 100 |
+
if isinstance(m, nn.Conv2d):
|
| 101 |
+
nn.init.kaiming_normal_(m.weight)
|
| 102 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 103 |
+
m.weight.data.fill_(1)
|
| 104 |
+
m.bias.data.zero_()
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class Upsample_(nn.Module):
|
| 108 |
+
def __init__(self, scale=2):
|
| 109 |
+
super(Upsample_, self).__init__()
|
| 110 |
+
|
| 111 |
+
self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale)
|
| 112 |
+
|
| 113 |
+
def forward(self, x):
|
| 114 |
+
return self.upsample(x)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class AttentionBlock(nn.Module):
|
| 118 |
+
def __init__(self, input_encoder, input_decoder, output_dim):
|
| 119 |
+
super(AttentionBlock, self).__init__()
|
| 120 |
+
|
| 121 |
+
self.conv_encoder = nn.Sequential(
|
| 122 |
+
nn.BatchNorm2d(input_encoder),
|
| 123 |
+
nn.ReLU(),
|
| 124 |
+
nn.Conv2d(input_encoder, output_dim, 3, padding=1),
|
| 125 |
+
nn.MaxPool2d(2, 2),
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
self.conv_decoder = nn.Sequential(
|
| 129 |
+
nn.BatchNorm2d(input_decoder),
|
| 130 |
+
nn.ReLU(),
|
| 131 |
+
nn.Conv2d(input_decoder, output_dim, 3, padding=1),
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
self.conv_attn = nn.Sequential(
|
| 135 |
+
nn.BatchNorm2d(output_dim),
|
| 136 |
+
nn.ReLU(),
|
| 137 |
+
nn.Conv2d(output_dim, 1, 1),
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
def forward(self, x1, x2):
|
| 141 |
+
out = self.conv_encoder(x1) + self.conv_decoder(x2)
|
| 142 |
+
out = self.conv_attn(out)
|
| 143 |
+
return out * x2
|
models/_resunet/res_unet.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://github.com/rishikksh20/ResUnet
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from .modules import ResidualConv, Upsample
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ResUnet(nn.Module):
|
| 9 |
+
def __init__(self, in_ch, out_ch, filters=[64, 128, 256, 512]):
|
| 10 |
+
super(ResUnet, self).__init__()
|
| 11 |
+
|
| 12 |
+
self.input_layer = nn.Sequential(
|
| 13 |
+
nn.Conv2d(in_ch, filters[0], kernel_size=3, padding=1),
|
| 14 |
+
nn.BatchNorm2d(filters[0]),
|
| 15 |
+
nn.ReLU(),
|
| 16 |
+
nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
|
| 17 |
+
)
|
| 18 |
+
self.input_skip = nn.Sequential(
|
| 19 |
+
nn.Conv2d(in_ch, filters[0], kernel_size=3, padding=1)
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
self.residual_conv_1 = ResidualConv(filters[0], filters[1], 2, 1)
|
| 23 |
+
self.residual_conv_2 = ResidualConv(filters[1], filters[2], 2, 1)
|
| 24 |
+
|
| 25 |
+
self.bridge = ResidualConv(filters[2], filters[3], 2, 1)
|
| 26 |
+
|
| 27 |
+
self.upsample_1 = Upsample(filters[3], filters[3], 2, 2)
|
| 28 |
+
self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], 1, 1)
|
| 29 |
+
|
| 30 |
+
self.upsample_2 = Upsample(filters[2], filters[2], 2, 2)
|
| 31 |
+
self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], 1, 1)
|
| 32 |
+
|
| 33 |
+
self.upsample_3 = Upsample(filters[1], filters[1], 2, 2)
|
| 34 |
+
self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], 1, 1)
|
| 35 |
+
|
| 36 |
+
self.output_layer = nn.Sequential(
|
| 37 |
+
nn.Conv2d(filters[0], out_ch, 1, 1),
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
def forward(self, x):
|
| 41 |
+
# Encode
|
| 42 |
+
x1 = self.input_layer(x) + self.input_skip(x)
|
| 43 |
+
x2 = self.residual_conv_1(x1)
|
| 44 |
+
x3 = self.residual_conv_2(x2)
|
| 45 |
+
# Bridge
|
| 46 |
+
x4 = self.bridge(x3)
|
| 47 |
+
# Decode
|
| 48 |
+
x4 = self.upsample_1(x4)
|
| 49 |
+
x5 = torch.cat([x4, x3], dim=1)
|
| 50 |
+
|
| 51 |
+
x6 = self.up_residual_conv1(x5)
|
| 52 |
+
|
| 53 |
+
x6 = self.upsample_2(x6)
|
| 54 |
+
x7 = torch.cat([x6, x2], dim=1)
|
| 55 |
+
|
| 56 |
+
x8 = self.up_residual_conv2(x7)
|
| 57 |
+
|
| 58 |
+
x8 = self.upsample_3(x8)
|
| 59 |
+
x9 = torch.cat([x8, x1], dim=1)
|
| 60 |
+
|
| 61 |
+
x10 = self.up_residual_conv3(x9)
|
| 62 |
+
|
| 63 |
+
output = self.output_layer(x10)
|
| 64 |
+
|
| 65 |
+
return output
|
models/_transunet/vit_seg_configs.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ml_collections
|
| 2 |
+
|
| 3 |
+
def get_b16_config():
|
| 4 |
+
"""Returns the ViT-B/16 configuration."""
|
| 5 |
+
config = ml_collections.ConfigDict()
|
| 6 |
+
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
|
| 7 |
+
config.hidden_size = 768
|
| 8 |
+
config.transformer = ml_collections.ConfigDict()
|
| 9 |
+
config.transformer.mlp_dim = 3072
|
| 10 |
+
config.transformer.num_heads = 12
|
| 11 |
+
config.transformer.num_layers = 12
|
| 12 |
+
config.transformer.attention_dropout_rate = 0.0
|
| 13 |
+
config.transformer.dropout_rate = 0.1
|
| 14 |
+
|
| 15 |
+
config.classifier = 'seg'
|
| 16 |
+
config.representation_size = None
|
| 17 |
+
config.resnet_pretrained_path = None
|
| 18 |
+
config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_16.npz'
|
| 19 |
+
config.patch_size = 16
|
| 20 |
+
|
| 21 |
+
config.decoder_channels = (256, 128, 64, 16)
|
| 22 |
+
config.n_classes = 2
|
| 23 |
+
config.activation = 'softmax'
|
| 24 |
+
return config
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_testing():
|
| 28 |
+
"""Returns a minimal configuration for testing."""
|
| 29 |
+
config = ml_collections.ConfigDict()
|
| 30 |
+
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
|
| 31 |
+
config.hidden_size = 1
|
| 32 |
+
config.transformer = ml_collections.ConfigDict()
|
| 33 |
+
config.transformer.mlp_dim = 1
|
| 34 |
+
config.transformer.num_heads = 1
|
| 35 |
+
config.transformer.num_layers = 1
|
| 36 |
+
config.transformer.attention_dropout_rate = 0.0
|
| 37 |
+
config.transformer.dropout_rate = 0.1
|
| 38 |
+
config.classifier = 'token'
|
| 39 |
+
config.representation_size = None
|
| 40 |
+
return config
|
| 41 |
+
|
| 42 |
+
def get_r50_b16_config():
|
| 43 |
+
"""Returns the Resnet50 + ViT-B/16 configuration."""
|
| 44 |
+
config = get_b16_config()
|
| 45 |
+
config.patches.grid = (16, 16)
|
| 46 |
+
config.resnet = ml_collections.ConfigDict()
|
| 47 |
+
config.resnet.num_layers = (3, 4, 9)
|
| 48 |
+
config.resnet.width_factor = 1
|
| 49 |
+
|
| 50 |
+
config.classifier = 'seg'
|
| 51 |
+
config.pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz'
|
| 52 |
+
config.decoder_channels = (256, 128, 64, 16)
|
| 53 |
+
config.skip_channels = [512, 256, 64, 16]
|
| 54 |
+
config.n_classes = 2
|
| 55 |
+
config.n_skip = 3
|
| 56 |
+
config.activation = 'softmax'
|
| 57 |
+
|
| 58 |
+
return config
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def get_b32_config():
|
| 62 |
+
"""Returns the ViT-B/32 configuration."""
|
| 63 |
+
config = get_b16_config()
|
| 64 |
+
config.patches.size = (32, 32)
|
| 65 |
+
config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_32.npz'
|
| 66 |
+
return config
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def get_l16_config():
|
| 70 |
+
"""Returns the ViT-L/16 configuration."""
|
| 71 |
+
config = ml_collections.ConfigDict()
|
| 72 |
+
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
|
| 73 |
+
config.hidden_size = 1024
|
| 74 |
+
config.transformer = ml_collections.ConfigDict()
|
| 75 |
+
config.transformer.mlp_dim = 4096
|
| 76 |
+
config.transformer.num_heads = 16
|
| 77 |
+
config.transformer.num_layers = 24
|
| 78 |
+
config.transformer.attention_dropout_rate = 0.0
|
| 79 |
+
config.transformer.dropout_rate = 0.1
|
| 80 |
+
config.representation_size = None
|
| 81 |
+
|
| 82 |
+
# custom
|
| 83 |
+
config.classifier = 'seg'
|
| 84 |
+
config.resnet_pretrained_path = None
|
| 85 |
+
config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-L_16.npz'
|
| 86 |
+
config.decoder_channels = (256, 128, 64, 16)
|
| 87 |
+
config.n_classes = 2
|
| 88 |
+
config.activation = 'softmax'
|
| 89 |
+
return config
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def get_r50_l16_config():
|
| 93 |
+
"""Returns the Resnet50 + ViT-L/16 configuration. customized """
|
| 94 |
+
config = get_l16_config()
|
| 95 |
+
config.patches.grid = (16, 16)
|
| 96 |
+
config.resnet = ml_collections.ConfigDict()
|
| 97 |
+
config.resnet.num_layers = (3, 4, 9)
|
| 98 |
+
config.resnet.width_factor = 1
|
| 99 |
+
|
| 100 |
+
config.classifier = 'seg'
|
| 101 |
+
config.resnet_pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz'
|
| 102 |
+
config.decoder_channels = (256, 128, 64, 16)
|
| 103 |
+
config.skip_channels = [512, 256, 64, 16]
|
| 104 |
+
config.n_classes = 2
|
| 105 |
+
config.activation = 'softmax'
|
| 106 |
+
return config
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def get_l32_config():
|
| 110 |
+
"""Returns the ViT-L/32 configuration."""
|
| 111 |
+
config = get_l16_config()
|
| 112 |
+
config.patches.size = (32, 32)
|
| 113 |
+
return config
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def get_h14_config():
|
| 117 |
+
"""Returns the ViT-L/16 configuration."""
|
| 118 |
+
config = ml_collections.ConfigDict()
|
| 119 |
+
config.patches = ml_collections.ConfigDict({'size': (14, 14)})
|
| 120 |
+
config.hidden_size = 1280
|
| 121 |
+
config.transformer = ml_collections.ConfigDict()
|
| 122 |
+
config.transformer.mlp_dim = 5120
|
| 123 |
+
config.transformer.num_heads = 16
|
| 124 |
+
config.transformer.num_layers = 32
|
| 125 |
+
config.transformer.attention_dropout_rate = 0.0
|
| 126 |
+
config.transformer.dropout_rate = 0.1
|
| 127 |
+
config.classifier = 'token'
|
| 128 |
+
config.representation_size = None
|
| 129 |
+
|
| 130 |
+
return config
|
models/_transunet/vit_seg_modeling.py
ADDED
|
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
from __future__ import absolute_import
|
| 3 |
+
from __future__ import division
|
| 4 |
+
from __future__ import print_function
|
| 5 |
+
|
| 6 |
+
import copy
|
| 7 |
+
import logging
|
| 8 |
+
import math
|
| 9 |
+
|
| 10 |
+
from os.path import join as pjoin
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
|
| 17 |
+
from torch.nn.modules.utils import _pair
|
| 18 |
+
from scipy import ndimage
|
| 19 |
+
from . import vit_seg_configs as configs
|
| 20 |
+
from .vit_seg_modeling_resnet_skip import ResNetV2
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
|
| 27 |
+
ATTENTION_K = "MultiHeadDotProductAttention_1/key"
|
| 28 |
+
ATTENTION_V = "MultiHeadDotProductAttention_1/value"
|
| 29 |
+
ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
|
| 30 |
+
FC_0 = "MlpBlock_3/Dense_0"
|
| 31 |
+
FC_1 = "MlpBlock_3/Dense_1"
|
| 32 |
+
ATTENTION_NORM = "LayerNorm_0"
|
| 33 |
+
MLP_NORM = "LayerNorm_2"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def np2th(weights, conv=False):
|
| 37 |
+
"""Possibly convert HWIO to OIHW."""
|
| 38 |
+
if conv:
|
| 39 |
+
weights = weights.transpose([3, 2, 0, 1])
|
| 40 |
+
return torch.from_numpy(weights)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def swish(x):
|
| 44 |
+
return x * torch.sigmoid(x)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class Attention(nn.Module):
|
| 51 |
+
def __init__(self, config, vis):
|
| 52 |
+
super(Attention, self).__init__()
|
| 53 |
+
self.vis = vis
|
| 54 |
+
self.num_attention_heads = config.transformer["num_heads"]
|
| 55 |
+
self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
|
| 56 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 57 |
+
|
| 58 |
+
self.query = Linear(config.hidden_size, self.all_head_size)
|
| 59 |
+
self.key = Linear(config.hidden_size, self.all_head_size)
|
| 60 |
+
self.value = Linear(config.hidden_size, self.all_head_size)
|
| 61 |
+
|
| 62 |
+
self.out = Linear(config.hidden_size, config.hidden_size)
|
| 63 |
+
self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
|
| 64 |
+
self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])
|
| 65 |
+
|
| 66 |
+
self.softmax = Softmax(dim=-1)
|
| 67 |
+
|
| 68 |
+
def transpose_for_scores(self, x):
|
| 69 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 70 |
+
x = x.view(*new_x_shape)
|
| 71 |
+
return x.permute(0, 2, 1, 3)
|
| 72 |
+
|
| 73 |
+
def forward(self, hidden_states):
|
| 74 |
+
mixed_query_layer = self.query(hidden_states)
|
| 75 |
+
mixed_key_layer = self.key(hidden_states)
|
| 76 |
+
mixed_value_layer = self.value(hidden_states)
|
| 77 |
+
|
| 78 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
| 79 |
+
key_layer = self.transpose_for_scores(mixed_key_layer)
|
| 80 |
+
value_layer = self.transpose_for_scores(mixed_value_layer)
|
| 81 |
+
|
| 82 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 83 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 84 |
+
attention_probs = self.softmax(attention_scores)
|
| 85 |
+
weights = attention_probs if self.vis else None
|
| 86 |
+
attention_probs = self.attn_dropout(attention_probs)
|
| 87 |
+
|
| 88 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
| 89 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 90 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 91 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
| 92 |
+
attention_output = self.out(context_layer)
|
| 93 |
+
attention_output = self.proj_dropout(attention_output)
|
| 94 |
+
return attention_output, weights
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class Mlp(nn.Module):
|
| 98 |
+
def __init__(self, config):
|
| 99 |
+
super(Mlp, self).__init__()
|
| 100 |
+
self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])
|
| 101 |
+
self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)
|
| 102 |
+
self.act_fn = ACT2FN["gelu"]
|
| 103 |
+
self.dropout = Dropout(config.transformer["dropout_rate"])
|
| 104 |
+
|
| 105 |
+
self._init_weights()
|
| 106 |
+
|
| 107 |
+
def _init_weights(self):
|
| 108 |
+
nn.init.xavier_uniform_(self.fc1.weight)
|
| 109 |
+
nn.init.xavier_uniform_(self.fc2.weight)
|
| 110 |
+
nn.init.normal_(self.fc1.bias, std=1e-6)
|
| 111 |
+
nn.init.normal_(self.fc2.bias, std=1e-6)
|
| 112 |
+
|
| 113 |
+
def forward(self, x):
|
| 114 |
+
x = self.fc1(x)
|
| 115 |
+
x = self.act_fn(x)
|
| 116 |
+
x = self.dropout(x)
|
| 117 |
+
x = self.fc2(x)
|
| 118 |
+
x = self.dropout(x)
|
| 119 |
+
return x
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class Embeddings(nn.Module):
|
| 123 |
+
"""Construct the embeddings from patch, position embeddings.
|
| 124 |
+
"""
|
| 125 |
+
def __init__(self, config, img_size, in_channels=3):
|
| 126 |
+
super(Embeddings, self).__init__()
|
| 127 |
+
self.hybrid = None
|
| 128 |
+
self.config = config
|
| 129 |
+
img_size = _pair(img_size)
|
| 130 |
+
|
| 131 |
+
if config.patches.get("grid") is not None: # ResNet
|
| 132 |
+
grid_size = config.patches["grid"]
|
| 133 |
+
patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])
|
| 134 |
+
patch_size_real = (patch_size[0] * 16, patch_size[1] * 16)
|
| 135 |
+
n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1])
|
| 136 |
+
self.hybrid = True
|
| 137 |
+
else:
|
| 138 |
+
patch_size = _pair(config.patches["size"])
|
| 139 |
+
n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
|
| 140 |
+
self.hybrid = False
|
| 141 |
+
|
| 142 |
+
if self.hybrid:
|
| 143 |
+
self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor)
|
| 144 |
+
in_channels = self.hybrid_model.width * 16
|
| 145 |
+
self.patch_embeddings = Conv2d(in_channels=in_channels,
|
| 146 |
+
out_channels=config.hidden_size,
|
| 147 |
+
kernel_size=patch_size,
|
| 148 |
+
stride=patch_size)
|
| 149 |
+
self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size))
|
| 150 |
+
|
| 151 |
+
self.dropout = Dropout(config.transformer["dropout_rate"])
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def forward(self, x):
|
| 155 |
+
if self.hybrid:
|
| 156 |
+
x, features = self.hybrid_model(x)
|
| 157 |
+
else:
|
| 158 |
+
features = None
|
| 159 |
+
x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2))
|
| 160 |
+
x = x.flatten(2)
|
| 161 |
+
x = x.transpose(-1, -2) # (B, n_patches, hidden)
|
| 162 |
+
|
| 163 |
+
embeddings = x + self.position_embeddings
|
| 164 |
+
embeddings = self.dropout(embeddings)
|
| 165 |
+
return embeddings, features
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class Block(nn.Module):
|
| 169 |
+
def __init__(self, config, vis):
|
| 170 |
+
super(Block, self).__init__()
|
| 171 |
+
self.hidden_size = config.hidden_size
|
| 172 |
+
self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
|
| 173 |
+
self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
|
| 174 |
+
self.ffn = Mlp(config)
|
| 175 |
+
self.attn = Attention(config, vis)
|
| 176 |
+
|
| 177 |
+
def forward(self, x):
|
| 178 |
+
h = x
|
| 179 |
+
x = self.attention_norm(x)
|
| 180 |
+
x, weights = self.attn(x)
|
| 181 |
+
x = x + h
|
| 182 |
+
|
| 183 |
+
h = x
|
| 184 |
+
x = self.ffn_norm(x)
|
| 185 |
+
x = self.ffn(x)
|
| 186 |
+
x = x + h
|
| 187 |
+
return x, weights
|
| 188 |
+
|
| 189 |
+
def load_from(self, weights, n_block):
|
| 190 |
+
ROOT = f"Transformer/encoderblock_{n_block}"
|
| 191 |
+
with torch.no_grad():
|
| 192 |
+
query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t()
|
| 193 |
+
key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t()
|
| 194 |
+
value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t()
|
| 195 |
+
out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t()
|
| 196 |
+
|
| 197 |
+
query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1)
|
| 198 |
+
key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1)
|
| 199 |
+
value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1)
|
| 200 |
+
out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1)
|
| 201 |
+
|
| 202 |
+
self.attn.query.weight.copy_(query_weight)
|
| 203 |
+
self.attn.key.weight.copy_(key_weight)
|
| 204 |
+
self.attn.value.weight.copy_(value_weight)
|
| 205 |
+
self.attn.out.weight.copy_(out_weight)
|
| 206 |
+
self.attn.query.bias.copy_(query_bias)
|
| 207 |
+
self.attn.key.bias.copy_(key_bias)
|
| 208 |
+
self.attn.value.bias.copy_(value_bias)
|
| 209 |
+
self.attn.out.bias.copy_(out_bias)
|
| 210 |
+
|
| 211 |
+
mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t()
|
| 212 |
+
mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t()
|
| 213 |
+
mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t()
|
| 214 |
+
mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t()
|
| 215 |
+
|
| 216 |
+
self.ffn.fc1.weight.copy_(mlp_weight_0)
|
| 217 |
+
self.ffn.fc2.weight.copy_(mlp_weight_1)
|
| 218 |
+
self.ffn.fc1.bias.copy_(mlp_bias_0)
|
| 219 |
+
self.ffn.fc2.bias.copy_(mlp_bias_1)
|
| 220 |
+
|
| 221 |
+
self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")]))
|
| 222 |
+
self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")]))
|
| 223 |
+
self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")]))
|
| 224 |
+
self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")]))
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class Encoder(nn.Module):
|
| 228 |
+
def __init__(self, config, vis):
|
| 229 |
+
super(Encoder, self).__init__()
|
| 230 |
+
self.vis = vis
|
| 231 |
+
self.layer = nn.ModuleList()
|
| 232 |
+
self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
|
| 233 |
+
for _ in range(config.transformer["num_layers"]):
|
| 234 |
+
layer = Block(config, vis)
|
| 235 |
+
self.layer.append(copy.deepcopy(layer))
|
| 236 |
+
|
| 237 |
+
def forward(self, hidden_states):
|
| 238 |
+
attn_weights = []
|
| 239 |
+
for layer_block in self.layer:
|
| 240 |
+
hidden_states, weights = layer_block(hidden_states)
|
| 241 |
+
if self.vis:
|
| 242 |
+
attn_weights.append(weights)
|
| 243 |
+
encoded = self.encoder_norm(hidden_states)
|
| 244 |
+
return encoded, attn_weights
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class Transformer(nn.Module):
|
| 248 |
+
def __init__(self, config, img_size, vis):
|
| 249 |
+
super(Transformer, self).__init__()
|
| 250 |
+
self.embeddings = Embeddings(config, img_size=img_size)
|
| 251 |
+
self.encoder = Encoder(config, vis)
|
| 252 |
+
|
| 253 |
+
def forward(self, input_ids):
|
| 254 |
+
embedding_output, features = self.embeddings(input_ids)
|
| 255 |
+
encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden)
|
| 256 |
+
return encoded, attn_weights, features
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
class Conv2dReLU(nn.Sequential):
|
| 260 |
+
def __init__(
|
| 261 |
+
self,
|
| 262 |
+
in_channels,
|
| 263 |
+
out_channels,
|
| 264 |
+
kernel_size,
|
| 265 |
+
padding=0,
|
| 266 |
+
stride=1,
|
| 267 |
+
use_batchnorm=True,
|
| 268 |
+
):
|
| 269 |
+
conv = nn.Conv2d(
|
| 270 |
+
in_channels,
|
| 271 |
+
out_channels,
|
| 272 |
+
kernel_size,
|
| 273 |
+
stride=stride,
|
| 274 |
+
padding=padding,
|
| 275 |
+
bias=not (use_batchnorm),
|
| 276 |
+
)
|
| 277 |
+
relu = nn.ReLU(inplace=True)
|
| 278 |
+
|
| 279 |
+
bn = nn.BatchNorm2d(out_channels)
|
| 280 |
+
|
| 281 |
+
super(Conv2dReLU, self).__init__(conv, bn, relu)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
class DecoderBlock(nn.Module):
|
| 285 |
+
def __init__(
|
| 286 |
+
self,
|
| 287 |
+
in_channels,
|
| 288 |
+
out_channels,
|
| 289 |
+
skip_channels=0,
|
| 290 |
+
use_batchnorm=True,
|
| 291 |
+
):
|
| 292 |
+
super().__init__()
|
| 293 |
+
self.conv1 = Conv2dReLU(
|
| 294 |
+
in_channels + skip_channels,
|
| 295 |
+
out_channels,
|
| 296 |
+
kernel_size=3,
|
| 297 |
+
padding=1,
|
| 298 |
+
use_batchnorm=use_batchnorm,
|
| 299 |
+
)
|
| 300 |
+
self.conv2 = Conv2dReLU(
|
| 301 |
+
out_channels,
|
| 302 |
+
out_channels,
|
| 303 |
+
kernel_size=3,
|
| 304 |
+
padding=1,
|
| 305 |
+
use_batchnorm=use_batchnorm,
|
| 306 |
+
)
|
| 307 |
+
self.up = nn.UpsamplingBilinear2d(scale_factor=2)
|
| 308 |
+
|
| 309 |
+
def forward(self, x, skip=None):
|
| 310 |
+
x = self.up(x)
|
| 311 |
+
if skip is not None:
|
| 312 |
+
x = torch.cat([x, skip], dim=1)
|
| 313 |
+
x = self.conv1(x)
|
| 314 |
+
x = self.conv2(x)
|
| 315 |
+
return x
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
class SegmentationHead(nn.Sequential):
|
| 319 |
+
|
| 320 |
+
def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
|
| 321 |
+
conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
|
| 322 |
+
upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
|
| 323 |
+
super().__init__(conv2d, upsampling)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
class DecoderCup(nn.Module):
|
| 327 |
+
def __init__(self, config):
|
| 328 |
+
super().__init__()
|
| 329 |
+
self.config = config
|
| 330 |
+
head_channels = 512
|
| 331 |
+
self.conv_more = Conv2dReLU(
|
| 332 |
+
config.hidden_size,
|
| 333 |
+
head_channels,
|
| 334 |
+
kernel_size=3,
|
| 335 |
+
padding=1,
|
| 336 |
+
use_batchnorm=True,
|
| 337 |
+
)
|
| 338 |
+
decoder_channels = config.decoder_channels
|
| 339 |
+
in_channels = [head_channels] + list(decoder_channels[:-1])
|
| 340 |
+
out_channels = decoder_channels
|
| 341 |
+
|
| 342 |
+
if self.config.n_skip != 0:
|
| 343 |
+
skip_channels = self.config.skip_channels
|
| 344 |
+
for i in range(4-self.config.n_skip): # re-select the skip channels according to n_skip
|
| 345 |
+
skip_channels[3-i]=0
|
| 346 |
+
|
| 347 |
+
else:
|
| 348 |
+
skip_channels=[0,0,0,0]
|
| 349 |
+
|
| 350 |
+
blocks = [
|
| 351 |
+
DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels)
|
| 352 |
+
]
|
| 353 |
+
self.blocks = nn.ModuleList(blocks)
|
| 354 |
+
|
| 355 |
+
def forward(self, hidden_states, features=None):
|
| 356 |
+
B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
|
| 357 |
+
h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
|
| 358 |
+
x = hidden_states.permute(0, 2, 1)
|
| 359 |
+
x = x.contiguous().view(B, hidden, h, w)
|
| 360 |
+
x = self.conv_more(x)
|
| 361 |
+
for i, decoder_block in enumerate(self.blocks):
|
| 362 |
+
if features is not None:
|
| 363 |
+
skip = features[i] if (i < self.config.n_skip) else None
|
| 364 |
+
else:
|
| 365 |
+
skip = None
|
| 366 |
+
x = decoder_block(x, skip=skip)
|
| 367 |
+
return x
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
class VisionTransformer(nn.Module):
|
| 371 |
+
def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):
|
| 372 |
+
super(VisionTransformer, self).__init__()
|
| 373 |
+
self.num_classes = num_classes
|
| 374 |
+
self.zero_head = zero_head
|
| 375 |
+
self.classifier = config.classifier
|
| 376 |
+
self.transformer = Transformer(config, img_size, vis)
|
| 377 |
+
self.decoder = DecoderCup(config)
|
| 378 |
+
self.segmentation_head = SegmentationHead(
|
| 379 |
+
in_channels=config['decoder_channels'][-1],
|
| 380 |
+
out_channels=config['n_classes'],
|
| 381 |
+
kernel_size=3,
|
| 382 |
+
)
|
| 383 |
+
self.config = config
|
| 384 |
+
|
| 385 |
+
def forward(self, x):
|
| 386 |
+
if x.size()[1] == 1:
|
| 387 |
+
x = x.repeat(1,3,1,1)
|
| 388 |
+
x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden)
|
| 389 |
+
x = self.decoder(x, features)
|
| 390 |
+
logits = self.segmentation_head(x)
|
| 391 |
+
return logits
|
| 392 |
+
|
| 393 |
+
def load_from(self, weights):
|
| 394 |
+
with torch.no_grad():
|
| 395 |
+
|
| 396 |
+
res_weight = weights
|
| 397 |
+
self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True))
|
| 398 |
+
self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"]))
|
| 399 |
+
|
| 400 |
+
self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"]))
|
| 401 |
+
self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"]))
|
| 402 |
+
|
| 403 |
+
posemb = np2th(weights["Transformer/posembed_input/pos_embedding"])
|
| 404 |
+
|
| 405 |
+
posemb_new = self.transformer.embeddings.position_embeddings
|
| 406 |
+
if posemb.size() == posemb_new.size():
|
| 407 |
+
self.transformer.embeddings.position_embeddings.copy_(posemb)
|
| 408 |
+
elif posemb.size()[1]-1 == posemb_new.size()[1]:
|
| 409 |
+
posemb = posemb[:, 1:]
|
| 410 |
+
self.transformer.embeddings.position_embeddings.copy_(posemb)
|
| 411 |
+
else:
|
| 412 |
+
logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size()))
|
| 413 |
+
ntok_new = posemb_new.size(1)
|
| 414 |
+
if self.classifier == "seg":
|
| 415 |
+
_, posemb_grid = posemb[:, :1], posemb[0, 1:]
|
| 416 |
+
gs_old = int(np.sqrt(len(posemb_grid)))
|
| 417 |
+
gs_new = int(np.sqrt(ntok_new))
|
| 418 |
+
print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new))
|
| 419 |
+
posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)
|
| 420 |
+
zoom = (gs_new / gs_old, gs_new / gs_old, 1)
|
| 421 |
+
posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) # th2np
|
| 422 |
+
posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)
|
| 423 |
+
posemb = posemb_grid
|
| 424 |
+
self.transformer.embeddings.position_embeddings.copy_(np2th(posemb))
|
| 425 |
+
|
| 426 |
+
# Encoder whole
|
| 427 |
+
for bname, block in self.transformer.encoder.named_children():
|
| 428 |
+
for uname, unit in block.named_children():
|
| 429 |
+
unit.load_from(weights, n_block=uname)
|
| 430 |
+
|
| 431 |
+
if self.transformer.embeddings.hybrid:
|
| 432 |
+
self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(res_weight["conv_root/kernel"], conv=True))
|
| 433 |
+
gn_weight = np2th(res_weight["gn_root/scale"]).view(-1)
|
| 434 |
+
gn_bias = np2th(res_weight["gn_root/bias"]).view(-1)
|
| 435 |
+
self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight)
|
| 436 |
+
self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias)
|
| 437 |
+
|
| 438 |
+
for bname, block in self.transformer.embeddings.hybrid_model.body.named_children():
|
| 439 |
+
for uname, unit in block.named_children():
|
| 440 |
+
unit.load_from(res_weight, n_block=bname, n_unit=uname)
|
| 441 |
+
|
| 442 |
+
CONFIGS = {
|
| 443 |
+
'ViT-B_16': configs.get_b16_config(),
|
| 444 |
+
'ViT-B_32': configs.get_b32_config(),
|
| 445 |
+
'ViT-L_16': configs.get_l16_config(),
|
| 446 |
+
'ViT-L_32': configs.get_l32_config(),
|
| 447 |
+
'ViT-H_14': configs.get_h14_config(),
|
| 448 |
+
'R50-ViT-B_16': configs.get_r50_b16_config(),
|
| 449 |
+
'R50-ViT-L_16': configs.get_r50_l16_config(),
|
| 450 |
+
'testing': configs.get_testing(),
|
| 451 |
+
}
|
| 452 |
+
|
| 453 |
+
|
models/_transunet/vit_seg_modeling_c4.py
ADDED
|
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
from __future__ import absolute_import
|
| 3 |
+
from __future__ import division
|
| 4 |
+
from __future__ import print_function
|
| 5 |
+
|
| 6 |
+
import copy
|
| 7 |
+
import logging
|
| 8 |
+
import math
|
| 9 |
+
|
| 10 |
+
from os.path import join as pjoin
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
|
| 17 |
+
from torch.nn.modules.utils import _pair
|
| 18 |
+
from scipy import ndimage
|
| 19 |
+
from . import vit_seg_configs as configs
|
| 20 |
+
from .vit_seg_modeling_resnet_skip_c4 import ResNetV2
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
|
| 27 |
+
ATTENTION_K = "MultiHeadDotProductAttention_1/key"
|
| 28 |
+
ATTENTION_V = "MultiHeadDotProductAttention_1/value"
|
| 29 |
+
ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
|
| 30 |
+
FC_0 = "MlpBlock_3/Dense_0"
|
| 31 |
+
FC_1 = "MlpBlock_3/Dense_1"
|
| 32 |
+
ATTENTION_NORM = "LayerNorm_0"
|
| 33 |
+
MLP_NORM = "LayerNorm_2"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def np2th(weights, conv=False):
|
| 37 |
+
"""Possibly convert HWIO to OIHW."""
|
| 38 |
+
if conv:
|
| 39 |
+
weights = weights.transpose([3, 2, 0, 1])
|
| 40 |
+
return torch.from_numpy(weights)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def swish(x):
|
| 44 |
+
return x * torch.sigmoid(x)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class Attention(nn.Module):
|
| 51 |
+
def __init__(self, config, vis):
|
| 52 |
+
super(Attention, self).__init__()
|
| 53 |
+
self.vis = vis
|
| 54 |
+
self.num_attention_heads = config.transformer["num_heads"]
|
| 55 |
+
self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
|
| 56 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 57 |
+
|
| 58 |
+
self.query = Linear(config.hidden_size, self.all_head_size)
|
| 59 |
+
self.key = Linear(config.hidden_size, self.all_head_size)
|
| 60 |
+
self.value = Linear(config.hidden_size, self.all_head_size)
|
| 61 |
+
|
| 62 |
+
self.out = Linear(config.hidden_size, config.hidden_size)
|
| 63 |
+
self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
|
| 64 |
+
self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])
|
| 65 |
+
|
| 66 |
+
self.softmax = Softmax(dim=-1)
|
| 67 |
+
|
| 68 |
+
def transpose_for_scores(self, x):
|
| 69 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 70 |
+
x = x.view(*new_x_shape)
|
| 71 |
+
return x.permute(0, 2, 1, 3)
|
| 72 |
+
|
| 73 |
+
def forward(self, hidden_states):
|
| 74 |
+
mixed_query_layer = self.query(hidden_states)
|
| 75 |
+
mixed_key_layer = self.key(hidden_states)
|
| 76 |
+
mixed_value_layer = self.value(hidden_states)
|
| 77 |
+
|
| 78 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
| 79 |
+
key_layer = self.transpose_for_scores(mixed_key_layer)
|
| 80 |
+
value_layer = self.transpose_for_scores(mixed_value_layer)
|
| 81 |
+
|
| 82 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 83 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 84 |
+
attention_probs = self.softmax(attention_scores)
|
| 85 |
+
weights = attention_probs if self.vis else None
|
| 86 |
+
attention_probs = self.attn_dropout(attention_probs)
|
| 87 |
+
|
| 88 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
| 89 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 90 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 91 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
| 92 |
+
attention_output = self.out(context_layer)
|
| 93 |
+
attention_output = self.proj_dropout(attention_output)
|
| 94 |
+
return attention_output, weights
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class Mlp(nn.Module):
|
| 98 |
+
def __init__(self, config):
|
| 99 |
+
super(Mlp, self).__init__()
|
| 100 |
+
self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])
|
| 101 |
+
self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)
|
| 102 |
+
self.act_fn = ACT2FN["gelu"]
|
| 103 |
+
self.dropout = Dropout(config.transformer["dropout_rate"])
|
| 104 |
+
|
| 105 |
+
self._init_weights()
|
| 106 |
+
|
| 107 |
+
def _init_weights(self):
|
| 108 |
+
nn.init.xavier_uniform_(self.fc1.weight)
|
| 109 |
+
nn.init.xavier_uniform_(self.fc2.weight)
|
| 110 |
+
nn.init.normal_(self.fc1.bias, std=1e-6)
|
| 111 |
+
nn.init.normal_(self.fc2.bias, std=1e-6)
|
| 112 |
+
|
| 113 |
+
def forward(self, x):
|
| 114 |
+
x = self.fc1(x)
|
| 115 |
+
x = self.act_fn(x)
|
| 116 |
+
x = self.dropout(x)
|
| 117 |
+
x = self.fc2(x)
|
| 118 |
+
x = self.dropout(x)
|
| 119 |
+
return x
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class Embeddings(nn.Module):
|
| 123 |
+
"""Construct the embeddings from patch, position embeddings.
|
| 124 |
+
"""
|
| 125 |
+
def __init__(self, config, img_size, in_channels=4):
|
| 126 |
+
super(Embeddings, self).__init__()
|
| 127 |
+
self.hybrid = None
|
| 128 |
+
self.config = config
|
| 129 |
+
img_size = _pair(img_size)
|
| 130 |
+
|
| 131 |
+
if config.patches.get("grid") is not None: # ResNet
|
| 132 |
+
grid_size = config.patches["grid"]
|
| 133 |
+
patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])
|
| 134 |
+
patch_size_real = (patch_size[0] * 16, patch_size[1] * 16)
|
| 135 |
+
n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1])
|
| 136 |
+
self.hybrid = True
|
| 137 |
+
else:
|
| 138 |
+
patch_size = _pair(config.patches["size"])
|
| 139 |
+
n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
|
| 140 |
+
self.hybrid = False
|
| 141 |
+
|
| 142 |
+
if self.hybrid:
|
| 143 |
+
self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor)
|
| 144 |
+
in_channels = self.hybrid_model.width * 16
|
| 145 |
+
self.patch_embeddings = Conv2d(in_channels=in_channels,
|
| 146 |
+
out_channels=config.hidden_size,
|
| 147 |
+
kernel_size=patch_size,
|
| 148 |
+
stride=patch_size)
|
| 149 |
+
self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size))
|
| 150 |
+
|
| 151 |
+
self.dropout = Dropout(config.transformer["dropout_rate"])
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def forward(self, x):
|
| 155 |
+
if self.hybrid:
|
| 156 |
+
x, features = self.hybrid_model(x)
|
| 157 |
+
else:
|
| 158 |
+
features = None
|
| 159 |
+
x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2))
|
| 160 |
+
x = x.flatten(2)
|
| 161 |
+
x = x.transpose(-1, -2) # (B, n_patches, hidden)
|
| 162 |
+
|
| 163 |
+
embeddings = x + self.position_embeddings
|
| 164 |
+
embeddings = self.dropout(embeddings)
|
| 165 |
+
return embeddings, features
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class Block(nn.Module):
|
| 169 |
+
def __init__(self, config, vis):
|
| 170 |
+
super(Block, self).__init__()
|
| 171 |
+
self.hidden_size = config.hidden_size
|
| 172 |
+
self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
|
| 173 |
+
self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
|
| 174 |
+
self.ffn = Mlp(config)
|
| 175 |
+
self.attn = Attention(config, vis)
|
| 176 |
+
|
| 177 |
+
def forward(self, x):
|
| 178 |
+
h = x
|
| 179 |
+
x = self.attention_norm(x)
|
| 180 |
+
x, weights = self.attn(x)
|
| 181 |
+
x = x + h
|
| 182 |
+
|
| 183 |
+
h = x
|
| 184 |
+
x = self.ffn_norm(x)
|
| 185 |
+
x = self.ffn(x)
|
| 186 |
+
x = x + h
|
| 187 |
+
return x, weights
|
| 188 |
+
|
| 189 |
+
def load_from(self, weights, n_block):
|
| 190 |
+
ROOT = f"Transformer/encoderblock_{n_block}"
|
| 191 |
+
with torch.no_grad():
|
| 192 |
+
query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t()
|
| 193 |
+
key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t()
|
| 194 |
+
value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t()
|
| 195 |
+
out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t()
|
| 196 |
+
|
| 197 |
+
query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1)
|
| 198 |
+
key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1)
|
| 199 |
+
value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1)
|
| 200 |
+
out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1)
|
| 201 |
+
|
| 202 |
+
self.attn.query.weight.copy_(query_weight)
|
| 203 |
+
self.attn.key.weight.copy_(key_weight)
|
| 204 |
+
self.attn.value.weight.copy_(value_weight)
|
| 205 |
+
self.attn.out.weight.copy_(out_weight)
|
| 206 |
+
self.attn.query.bias.copy_(query_bias)
|
| 207 |
+
self.attn.key.bias.copy_(key_bias)
|
| 208 |
+
self.attn.value.bias.copy_(value_bias)
|
| 209 |
+
self.attn.out.bias.copy_(out_bias)
|
| 210 |
+
|
| 211 |
+
mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t()
|
| 212 |
+
mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t()
|
| 213 |
+
mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t()
|
| 214 |
+
mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t()
|
| 215 |
+
|
| 216 |
+
self.ffn.fc1.weight.copy_(mlp_weight_0)
|
| 217 |
+
self.ffn.fc2.weight.copy_(mlp_weight_1)
|
| 218 |
+
self.ffn.fc1.bias.copy_(mlp_bias_0)
|
| 219 |
+
self.ffn.fc2.bias.copy_(mlp_bias_1)
|
| 220 |
+
|
| 221 |
+
self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")]))
|
| 222 |
+
self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")]))
|
| 223 |
+
self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")]))
|
| 224 |
+
self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")]))
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class Encoder(nn.Module):
|
| 228 |
+
def __init__(self, config, vis):
|
| 229 |
+
super(Encoder, self).__init__()
|
| 230 |
+
self.vis = vis
|
| 231 |
+
self.layer = nn.ModuleList()
|
| 232 |
+
self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
|
| 233 |
+
for _ in range(config.transformer["num_layers"]):
|
| 234 |
+
layer = Block(config, vis)
|
| 235 |
+
self.layer.append(copy.deepcopy(layer))
|
| 236 |
+
|
| 237 |
+
def forward(self, hidden_states):
|
| 238 |
+
attn_weights = []
|
| 239 |
+
for layer_block in self.layer:
|
| 240 |
+
hidden_states, weights = layer_block(hidden_states)
|
| 241 |
+
if self.vis:
|
| 242 |
+
attn_weights.append(weights)
|
| 243 |
+
encoded = self.encoder_norm(hidden_states)
|
| 244 |
+
return encoded, attn_weights
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class Transformer(nn.Module):
|
| 248 |
+
def __init__(self, config, img_size, vis):
|
| 249 |
+
super(Transformer, self).__init__()
|
| 250 |
+
self.embeddings = Embeddings(config, img_size=img_size)
|
| 251 |
+
self.encoder = Encoder(config, vis)
|
| 252 |
+
|
| 253 |
+
def forward(self, input_ids):
|
| 254 |
+
embedding_output, features = self.embeddings(input_ids)
|
| 255 |
+
encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden)
|
| 256 |
+
return encoded, attn_weights, features
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
class Conv2dReLU(nn.Sequential):
|
| 260 |
+
def __init__(
|
| 261 |
+
self,
|
| 262 |
+
in_channels,
|
| 263 |
+
out_channels,
|
| 264 |
+
kernel_size,
|
| 265 |
+
padding=0,
|
| 266 |
+
stride=1,
|
| 267 |
+
use_batchnorm=True,
|
| 268 |
+
):
|
| 269 |
+
conv = nn.Conv2d(
|
| 270 |
+
in_channels,
|
| 271 |
+
out_channels,
|
| 272 |
+
kernel_size,
|
| 273 |
+
stride=stride,
|
| 274 |
+
padding=padding,
|
| 275 |
+
bias=not (use_batchnorm),
|
| 276 |
+
)
|
| 277 |
+
relu = nn.ReLU(inplace=True)
|
| 278 |
+
|
| 279 |
+
bn = nn.BatchNorm2d(out_channels)
|
| 280 |
+
|
| 281 |
+
super(Conv2dReLU, self).__init__(conv, bn, relu)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
class DecoderBlock(nn.Module):
|
| 285 |
+
def __init__(
|
| 286 |
+
self,
|
| 287 |
+
in_channels,
|
| 288 |
+
out_channels,
|
| 289 |
+
skip_channels=0,
|
| 290 |
+
use_batchnorm=True,
|
| 291 |
+
):
|
| 292 |
+
super().__init__()
|
| 293 |
+
self.conv1 = Conv2dReLU(
|
| 294 |
+
in_channels + skip_channels,
|
| 295 |
+
out_channels,
|
| 296 |
+
kernel_size=3,
|
| 297 |
+
padding=1,
|
| 298 |
+
use_batchnorm=use_batchnorm,
|
| 299 |
+
)
|
| 300 |
+
self.conv2 = Conv2dReLU(
|
| 301 |
+
out_channels,
|
| 302 |
+
out_channels,
|
| 303 |
+
kernel_size=3,
|
| 304 |
+
padding=1,
|
| 305 |
+
use_batchnorm=use_batchnorm,
|
| 306 |
+
)
|
| 307 |
+
self.up = nn.UpsamplingBilinear2d(scale_factor=2)
|
| 308 |
+
|
| 309 |
+
def forward(self, x, skip=None):
|
| 310 |
+
x = self.up(x)
|
| 311 |
+
if skip is not None:
|
| 312 |
+
x = torch.cat([x, skip], dim=1)
|
| 313 |
+
x = self.conv1(x)
|
| 314 |
+
x = self.conv2(x)
|
| 315 |
+
return x
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
class SegmentationHead(nn.Sequential):
|
| 319 |
+
|
| 320 |
+
def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
|
| 321 |
+
conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
|
| 322 |
+
upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
|
| 323 |
+
super().__init__(conv2d, upsampling)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
class DecoderCup(nn.Module):
|
| 327 |
+
def __init__(self, config):
|
| 328 |
+
super().__init__()
|
| 329 |
+
self.config = config
|
| 330 |
+
head_channels = 512
|
| 331 |
+
self.conv_more = Conv2dReLU(
|
| 332 |
+
config.hidden_size,
|
| 333 |
+
head_channels,
|
| 334 |
+
kernel_size=3,
|
| 335 |
+
padding=1,
|
| 336 |
+
use_batchnorm=True,
|
| 337 |
+
)
|
| 338 |
+
decoder_channels = config.decoder_channels
|
| 339 |
+
in_channels = [head_channels] + list(decoder_channels[:-1])
|
| 340 |
+
out_channels = decoder_channels
|
| 341 |
+
|
| 342 |
+
if self.config.n_skip != 0:
|
| 343 |
+
skip_channels = self.config.skip_channels
|
| 344 |
+
for i in range(4-self.config.n_skip): # re-select the skip channels according to n_skip
|
| 345 |
+
skip_channels[3-i]=0
|
| 346 |
+
|
| 347 |
+
else:
|
| 348 |
+
skip_channels=[0,0,0,0]
|
| 349 |
+
|
| 350 |
+
blocks = [
|
| 351 |
+
DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels)
|
| 352 |
+
]
|
| 353 |
+
self.blocks = nn.ModuleList(blocks)
|
| 354 |
+
|
| 355 |
+
def forward(self, hidden_states, features=None):
|
| 356 |
+
B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
|
| 357 |
+
h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
|
| 358 |
+
x = hidden_states.permute(0, 2, 1)
|
| 359 |
+
x = x.contiguous().view(B, hidden, h, w)
|
| 360 |
+
x = self.conv_more(x)
|
| 361 |
+
for i, decoder_block in enumerate(self.blocks):
|
| 362 |
+
if features is not None:
|
| 363 |
+
skip = features[i] if (i < self.config.n_skip) else None
|
| 364 |
+
else:
|
| 365 |
+
skip = None
|
| 366 |
+
x = decoder_block(x, skip=skip)
|
| 367 |
+
return x
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
class VisionTransformer(nn.Module):
|
| 371 |
+
def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):
|
| 372 |
+
super(VisionTransformer, self).__init__()
|
| 373 |
+
self.num_classes = num_classes
|
| 374 |
+
self.zero_head = zero_head
|
| 375 |
+
self.classifier = config.classifier
|
| 376 |
+
self.transformer = Transformer(config, img_size, vis)
|
| 377 |
+
self.decoder = DecoderCup(config)
|
| 378 |
+
self.segmentation_head = SegmentationHead(
|
| 379 |
+
in_channels=config['decoder_channels'][-1],
|
| 380 |
+
out_channels=config['n_classes'],
|
| 381 |
+
kernel_size=3,
|
| 382 |
+
)
|
| 383 |
+
self.config = config
|
| 384 |
+
|
| 385 |
+
def forward(self, x):
|
| 386 |
+
if x.size()[1] == 1:
|
| 387 |
+
x = x.repeat(1,4,1,1)
|
| 388 |
+
x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden)
|
| 389 |
+
x = self.decoder(x, features)
|
| 390 |
+
logits = self.segmentation_head(x)
|
| 391 |
+
return logits
|
| 392 |
+
|
| 393 |
+
def load_from(self, weights):
|
| 394 |
+
with torch.no_grad():
|
| 395 |
+
|
| 396 |
+
res_weight = weights
|
| 397 |
+
self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True))
|
| 398 |
+
self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"]))
|
| 399 |
+
|
| 400 |
+
self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"]))
|
| 401 |
+
self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"]))
|
| 402 |
+
|
| 403 |
+
posemb = np2th(weights["Transformer/posembed_input/pos_embedding"])
|
| 404 |
+
|
| 405 |
+
posemb_new = self.transformer.embeddings.position_embeddings
|
| 406 |
+
if posemb.size() == posemb_new.size():
|
| 407 |
+
self.transformer.embeddings.position_embeddings.copy_(posemb)
|
| 408 |
+
elif posemb.size()[1]-1 == posemb_new.size()[1]:
|
| 409 |
+
posemb = posemb[:, 1:]
|
| 410 |
+
self.transformer.embeddings.position_embeddings.copy_(posemb)
|
| 411 |
+
else:
|
| 412 |
+
logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size()))
|
| 413 |
+
ntok_new = posemb_new.size(1)
|
| 414 |
+
if self.classifier == "seg":
|
| 415 |
+
_, posemb_grid = posemb[:, :1], posemb[0, 1:]
|
| 416 |
+
gs_old = int(np.sqrt(len(posemb_grid)))
|
| 417 |
+
gs_new = int(np.sqrt(ntok_new))
|
| 418 |
+
print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new))
|
| 419 |
+
posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)
|
| 420 |
+
zoom = (gs_new / gs_old, gs_new / gs_old, 1)
|
| 421 |
+
posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) # th2np
|
| 422 |
+
posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)
|
| 423 |
+
posemb = posemb_grid
|
| 424 |
+
self.transformer.embeddings.position_embeddings.copy_(np2th(posemb))
|
| 425 |
+
|
| 426 |
+
# Encoder whole
|
| 427 |
+
for bname, block in self.transformer.encoder.named_children():
|
| 428 |
+
for uname, unit in block.named_children():
|
| 429 |
+
unit.load_from(weights, n_block=uname)
|
| 430 |
+
|
| 431 |
+
if self.transformer.embeddings.hybrid:
|
| 432 |
+
self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(res_weight["conv_root/kernel"], conv=True))
|
| 433 |
+
gn_weight = np2th(res_weight["gn_root/scale"]).view(-1)
|
| 434 |
+
gn_bias = np2th(res_weight["gn_root/bias"]).view(-1)
|
| 435 |
+
self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight)
|
| 436 |
+
self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias)
|
| 437 |
+
|
| 438 |
+
for bname, block in self.transformer.embeddings.hybrid_model.body.named_children():
|
| 439 |
+
for uname, unit in block.named_children():
|
| 440 |
+
unit.load_from(res_weight, n_block=bname, n_unit=uname)
|
| 441 |
+
|
| 442 |
+
CONFIGS = {
|
| 443 |
+
'ViT-B_16': configs.get_b16_config(),
|
| 444 |
+
'ViT-B_32': configs.get_b32_config(),
|
| 445 |
+
'ViT-L_16': configs.get_l16_config(),
|
| 446 |
+
'ViT-L_32': configs.get_l32_config(),
|
| 447 |
+
'ViT-H_14': configs.get_h14_config(),
|
| 448 |
+
'R50-ViT-B_16': configs.get_r50_b16_config(),
|
| 449 |
+
'R50-ViT-L_16': configs.get_r50_l16_config(),
|
| 450 |
+
'testing': configs.get_testing(),
|
| 451 |
+
}
|
| 452 |
+
|
| 453 |
+
|
models/_transunet/vit_seg_modeling_resnet_skip.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
from os.path import join as pjoin
|
| 4 |
+
from collections import OrderedDict
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def np2th(weights, conv=False):
|
| 12 |
+
"""Possibly convert HWIO to OIHW."""
|
| 13 |
+
if conv:
|
| 14 |
+
weights = weights.transpose([3, 2, 0, 1])
|
| 15 |
+
return torch.from_numpy(weights)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class StdConv2d(nn.Conv2d):
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
w = self.weight
|
| 22 |
+
v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
|
| 23 |
+
w = (w - m) / torch.sqrt(v + 1e-5)
|
| 24 |
+
return F.conv2d(x, w, self.bias, self.stride, self.padding,
|
| 25 |
+
self.dilation, self.groups)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def conv3x3(cin, cout, stride=1, groups=1, bias=False):
|
| 29 |
+
return StdConv2d(cin, cout, kernel_size=3, stride=stride,
|
| 30 |
+
padding=1, bias=bias, groups=groups)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def conv1x1(cin, cout, stride=1, bias=False):
|
| 34 |
+
return StdConv2d(cin, cout, kernel_size=1, stride=stride,
|
| 35 |
+
padding=0, bias=bias)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class PreActBottleneck(nn.Module):
|
| 39 |
+
"""Pre-activation (v2) bottleneck block.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(self, cin, cout=None, cmid=None, stride=1):
|
| 43 |
+
super().__init__()
|
| 44 |
+
cout = cout or cin
|
| 45 |
+
cmid = cmid or cout//4
|
| 46 |
+
|
| 47 |
+
self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6)
|
| 48 |
+
self.conv1 = conv1x1(cin, cmid, bias=False)
|
| 49 |
+
self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6)
|
| 50 |
+
self.conv2 = conv3x3(cmid, cmid, stride, bias=False) # Original code has it on conv1!!
|
| 51 |
+
self.gn3 = nn.GroupNorm(32, cout, eps=1e-6)
|
| 52 |
+
self.conv3 = conv1x1(cmid, cout, bias=False)
|
| 53 |
+
self.relu = nn.ReLU(inplace=True)
|
| 54 |
+
|
| 55 |
+
if (stride != 1 or cin != cout):
|
| 56 |
+
# Projection also with pre-activation according to paper.
|
| 57 |
+
self.downsample = conv1x1(cin, cout, stride, bias=False)
|
| 58 |
+
self.gn_proj = nn.GroupNorm(cout, cout)
|
| 59 |
+
|
| 60 |
+
def forward(self, x):
|
| 61 |
+
|
| 62 |
+
# Residual branch
|
| 63 |
+
residual = x
|
| 64 |
+
if hasattr(self, 'downsample'):
|
| 65 |
+
residual = self.downsample(x)
|
| 66 |
+
residual = self.gn_proj(residual)
|
| 67 |
+
|
| 68 |
+
# Unit's branch
|
| 69 |
+
y = self.relu(self.gn1(self.conv1(x)))
|
| 70 |
+
y = self.relu(self.gn2(self.conv2(y)))
|
| 71 |
+
y = self.gn3(self.conv3(y))
|
| 72 |
+
|
| 73 |
+
y = self.relu(residual + y)
|
| 74 |
+
return y
|
| 75 |
+
|
| 76 |
+
def load_from(self, weights, n_block, n_unit):
|
| 77 |
+
conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True)
|
| 78 |
+
conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True)
|
| 79 |
+
conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True)
|
| 80 |
+
|
| 81 |
+
gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")])
|
| 82 |
+
gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")])
|
| 83 |
+
|
| 84 |
+
gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")])
|
| 85 |
+
gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")])
|
| 86 |
+
|
| 87 |
+
gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")])
|
| 88 |
+
gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")])
|
| 89 |
+
|
| 90 |
+
self.conv1.weight.copy_(conv1_weight)
|
| 91 |
+
self.conv2.weight.copy_(conv2_weight)
|
| 92 |
+
self.conv3.weight.copy_(conv3_weight)
|
| 93 |
+
|
| 94 |
+
self.gn1.weight.copy_(gn1_weight.view(-1))
|
| 95 |
+
self.gn1.bias.copy_(gn1_bias.view(-1))
|
| 96 |
+
|
| 97 |
+
self.gn2.weight.copy_(gn2_weight.view(-1))
|
| 98 |
+
self.gn2.bias.copy_(gn2_bias.view(-1))
|
| 99 |
+
|
| 100 |
+
self.gn3.weight.copy_(gn3_weight.view(-1))
|
| 101 |
+
self.gn3.bias.copy_(gn3_bias.view(-1))
|
| 102 |
+
|
| 103 |
+
if hasattr(self, 'downsample'):
|
| 104 |
+
proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True)
|
| 105 |
+
proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")])
|
| 106 |
+
proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")])
|
| 107 |
+
|
| 108 |
+
self.downsample.weight.copy_(proj_conv_weight)
|
| 109 |
+
self.gn_proj.weight.copy_(proj_gn_weight.view(-1))
|
| 110 |
+
self.gn_proj.bias.copy_(proj_gn_bias.view(-1))
|
| 111 |
+
|
| 112 |
+
class ResNetV2(nn.Module):
|
| 113 |
+
"""Implementation of Pre-activation (v2) ResNet mode."""
|
| 114 |
+
|
| 115 |
+
def __init__(self, block_units, width_factor):
|
| 116 |
+
super().__init__()
|
| 117 |
+
width = int(64 * width_factor)
|
| 118 |
+
self.width = width
|
| 119 |
+
|
| 120 |
+
self.root = nn.Sequential(OrderedDict([
|
| 121 |
+
('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)),
|
| 122 |
+
('gn', nn.GroupNorm(32, width, eps=1e-6)),
|
| 123 |
+
('relu', nn.ReLU(inplace=True)),
|
| 124 |
+
# ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0))
|
| 125 |
+
]))
|
| 126 |
+
|
| 127 |
+
self.body = nn.Sequential(OrderedDict([
|
| 128 |
+
('block1', nn.Sequential(OrderedDict(
|
| 129 |
+
[('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] +
|
| 130 |
+
[(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)],
|
| 131 |
+
))),
|
| 132 |
+
('block2', nn.Sequential(OrderedDict(
|
| 133 |
+
[('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] +
|
| 134 |
+
[(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)],
|
| 135 |
+
))),
|
| 136 |
+
('block3', nn.Sequential(OrderedDict(
|
| 137 |
+
[('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] +
|
| 138 |
+
[(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)],
|
| 139 |
+
))),
|
| 140 |
+
]))
|
| 141 |
+
|
| 142 |
+
def forward(self, x):
|
| 143 |
+
features = []
|
| 144 |
+
b, c, in_size, _ = x.size()
|
| 145 |
+
x = self.root(x)
|
| 146 |
+
features.append(x)
|
| 147 |
+
x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x)
|
| 148 |
+
for i in range(len(self.body)-1):
|
| 149 |
+
x = self.body[i](x)
|
| 150 |
+
right_size = int(in_size / 4 / (i+1))
|
| 151 |
+
if x.size()[2] != right_size:
|
| 152 |
+
pad = right_size - x.size()[2]
|
| 153 |
+
assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), right_size)
|
| 154 |
+
feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device)
|
| 155 |
+
feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:]
|
| 156 |
+
else:
|
| 157 |
+
feat = x
|
| 158 |
+
features.append(feat)
|
| 159 |
+
x = self.body[-1](x)
|
| 160 |
+
return x, features[::-1]
|
models/_transunet/vit_seg_modeling_resnet_skip_c4.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
from os.path import join as pjoin
|
| 4 |
+
from collections import OrderedDict
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def np2th(weights, conv=False):
|
| 12 |
+
"""Possibly convert HWIO to OIHW."""
|
| 13 |
+
if conv:
|
| 14 |
+
weights = weights.transpose([3, 2, 0, 1])
|
| 15 |
+
return torch.from_numpy(weights)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class StdConv2d(nn.Conv2d):
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
w = self.weight
|
| 22 |
+
v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
|
| 23 |
+
w = (w - m) / torch.sqrt(v + 1e-5)
|
| 24 |
+
return F.conv2d(x, w, self.bias, self.stride, self.padding,
|
| 25 |
+
self.dilation, self.groups)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def conv3x3(cin, cout, stride=1, groups=1, bias=False):
|
| 29 |
+
return StdConv2d(cin, cout, kernel_size=3, stride=stride,
|
| 30 |
+
padding=1, bias=bias, groups=groups)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def conv1x1(cin, cout, stride=1, bias=False):
|
| 34 |
+
return StdConv2d(cin, cout, kernel_size=1, stride=stride,
|
| 35 |
+
padding=0, bias=bias)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class PreActBottleneck(nn.Module):
|
| 39 |
+
"""Pre-activation (v2) bottleneck block.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(self, cin, cout=None, cmid=None, stride=1):
|
| 43 |
+
super().__init__()
|
| 44 |
+
cout = cout or cin
|
| 45 |
+
cmid = cmid or cout//4
|
| 46 |
+
|
| 47 |
+
self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6)
|
| 48 |
+
self.conv1 = conv1x1(cin, cmid, bias=False)
|
| 49 |
+
self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6)
|
| 50 |
+
self.conv2 = conv3x3(cmid, cmid, stride, bias=False) # Original code has it on conv1!!
|
| 51 |
+
self.gn3 = nn.GroupNorm(32, cout, eps=1e-6)
|
| 52 |
+
self.conv3 = conv1x1(cmid, cout, bias=False)
|
| 53 |
+
self.relu = nn.ReLU(inplace=True)
|
| 54 |
+
|
| 55 |
+
if (stride != 1 or cin != cout):
|
| 56 |
+
# Projection also with pre-activation according to paper.
|
| 57 |
+
self.downsample = conv1x1(cin, cout, stride, bias=False)
|
| 58 |
+
self.gn_proj = nn.GroupNorm(cout, cout)
|
| 59 |
+
|
| 60 |
+
def forward(self, x):
|
| 61 |
+
|
| 62 |
+
# Residual branch
|
| 63 |
+
residual = x
|
| 64 |
+
if hasattr(self, 'downsample'):
|
| 65 |
+
residual = self.downsample(x)
|
| 66 |
+
residual = self.gn_proj(residual)
|
| 67 |
+
|
| 68 |
+
# Unit's branch
|
| 69 |
+
y = self.relu(self.gn1(self.conv1(x)))
|
| 70 |
+
y = self.relu(self.gn2(self.conv2(y)))
|
| 71 |
+
y = self.gn3(self.conv3(y))
|
| 72 |
+
|
| 73 |
+
y = self.relu(residual + y)
|
| 74 |
+
return y
|
| 75 |
+
|
| 76 |
+
def load_from(self, weights, n_block, n_unit):
|
| 77 |
+
conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True)
|
| 78 |
+
conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True)
|
| 79 |
+
conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True)
|
| 80 |
+
|
| 81 |
+
gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")])
|
| 82 |
+
gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")])
|
| 83 |
+
|
| 84 |
+
gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")])
|
| 85 |
+
gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")])
|
| 86 |
+
|
| 87 |
+
gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")])
|
| 88 |
+
gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")])
|
| 89 |
+
|
| 90 |
+
self.conv1.weight.copy_(conv1_weight)
|
| 91 |
+
self.conv2.weight.copy_(conv2_weight)
|
| 92 |
+
self.conv3.weight.copy_(conv3_weight)
|
| 93 |
+
|
| 94 |
+
self.gn1.weight.copy_(gn1_weight.view(-1))
|
| 95 |
+
self.gn1.bias.copy_(gn1_bias.view(-1))
|
| 96 |
+
|
| 97 |
+
self.gn2.weight.copy_(gn2_weight.view(-1))
|
| 98 |
+
self.gn2.bias.copy_(gn2_bias.view(-1))
|
| 99 |
+
|
| 100 |
+
self.gn3.weight.copy_(gn3_weight.view(-1))
|
| 101 |
+
self.gn3.bias.copy_(gn3_bias.view(-1))
|
| 102 |
+
|
| 103 |
+
if hasattr(self, 'downsample'):
|
| 104 |
+
proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True)
|
| 105 |
+
proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")])
|
| 106 |
+
proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")])
|
| 107 |
+
|
| 108 |
+
self.downsample.weight.copy_(proj_conv_weight)
|
| 109 |
+
self.gn_proj.weight.copy_(proj_gn_weight.view(-1))
|
| 110 |
+
self.gn_proj.bias.copy_(proj_gn_bias.view(-1))
|
| 111 |
+
|
| 112 |
+
class ResNetV2(nn.Module):
|
| 113 |
+
"""Implementation of Pre-activation (v2) ResNet mode."""
|
| 114 |
+
|
| 115 |
+
def __init__(self, block_units, width_factor):
|
| 116 |
+
super().__init__()
|
| 117 |
+
width = int(64 * width_factor)
|
| 118 |
+
self.width = width
|
| 119 |
+
|
| 120 |
+
self.root = nn.Sequential(OrderedDict([
|
| 121 |
+
('conv', StdConv2d(4, width, kernel_size=7, stride=2, bias=False, padding=3)),
|
| 122 |
+
('gn', nn.GroupNorm(32, width, eps=1e-6)),
|
| 123 |
+
('relu', nn.ReLU(inplace=True)),
|
| 124 |
+
# ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0))
|
| 125 |
+
]))
|
| 126 |
+
|
| 127 |
+
self.body = nn.Sequential(OrderedDict([
|
| 128 |
+
('block1', nn.Sequential(OrderedDict(
|
| 129 |
+
[('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] +
|
| 130 |
+
[(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)],
|
| 131 |
+
))),
|
| 132 |
+
('block2', nn.Sequential(OrderedDict(
|
| 133 |
+
[('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] +
|
| 134 |
+
[(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)],
|
| 135 |
+
))),
|
| 136 |
+
('block3', nn.Sequential(OrderedDict(
|
| 137 |
+
[('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] +
|
| 138 |
+
[(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)],
|
| 139 |
+
))),
|
| 140 |
+
]))
|
| 141 |
+
|
| 142 |
+
def forward(self, x):
|
| 143 |
+
features = []
|
| 144 |
+
b, c, in_size, _ = x.size()
|
| 145 |
+
x = self.root(x)
|
| 146 |
+
features.append(x)
|
| 147 |
+
x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x)
|
| 148 |
+
for i in range(len(self.body)-1):
|
| 149 |
+
x = self.body[i](x)
|
| 150 |
+
right_size = int(in_size / 4 / (i+1))
|
| 151 |
+
if x.size()[2] != right_size:
|
| 152 |
+
pad = right_size - x.size()[2]
|
| 153 |
+
assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), right_size)
|
| 154 |
+
feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device)
|
| 155 |
+
feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:]
|
| 156 |
+
else:
|
| 157 |
+
feat = x
|
| 158 |
+
features.append(feat)
|
| 159 |
+
x = self.body[-1](x)
|
| 160 |
+
return x, features[::-1]
|
models/_uctransnet/CTrans.py
ADDED
|
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# @Author : Haonan Wang
|
| 3 |
+
# @File : CTrans.py
|
| 4 |
+
# @Software: PyCharm
|
| 5 |
+
# coding=utf-8
|
| 6 |
+
from __future__ import absolute_import
|
| 7 |
+
from __future__ import division
|
| 8 |
+
from __future__ import print_function
|
| 9 |
+
import copy
|
| 10 |
+
import logging
|
| 11 |
+
import math
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import numpy as np
|
| 15 |
+
from torch.nn import Dropout, Softmax, Conv2d, LayerNorm
|
| 16 |
+
from torch.nn.modules.utils import _pair
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
class Channel_Embeddings(nn.Module):
|
| 22 |
+
"""Construct the embeddings from patch, position embeddings.
|
| 23 |
+
"""
|
| 24 |
+
def __init__(self,config, patchsize, img_size, in_channels):
|
| 25 |
+
super().__init__()
|
| 26 |
+
img_size = _pair(img_size)
|
| 27 |
+
patch_size = _pair(patchsize)
|
| 28 |
+
n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
|
| 29 |
+
|
| 30 |
+
self.patch_embeddings = Conv2d(in_channels=in_channels,
|
| 31 |
+
out_channels=in_channels,
|
| 32 |
+
kernel_size=patch_size,
|
| 33 |
+
stride=patch_size)
|
| 34 |
+
self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, in_channels))
|
| 35 |
+
self.dropout = Dropout(config.transformer["embeddings_dropout_rate"])
|
| 36 |
+
|
| 37 |
+
def forward(self, x):
|
| 38 |
+
if x is None:
|
| 39 |
+
return None
|
| 40 |
+
x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2))
|
| 41 |
+
x = x.flatten(2)
|
| 42 |
+
x = x.transpose(-1, -2) # (B, n_patches, hidden)
|
| 43 |
+
embeddings = x + self.position_embeddings
|
| 44 |
+
embeddings = self.dropout(embeddings)
|
| 45 |
+
return embeddings
|
| 46 |
+
|
| 47 |
+
class Reconstruct(nn.Module):
|
| 48 |
+
def __init__(self, in_channels, out_channels, kernel_size, scale_factor):
|
| 49 |
+
super(Reconstruct, self).__init__()
|
| 50 |
+
if kernel_size == 3:
|
| 51 |
+
padding = 1
|
| 52 |
+
else:
|
| 53 |
+
padding = 0
|
| 54 |
+
self.conv = nn.Conv2d(in_channels, out_channels,kernel_size=kernel_size, padding=padding)
|
| 55 |
+
self.norm = nn.BatchNorm2d(out_channels)
|
| 56 |
+
self.activation = nn.ReLU(inplace=True)
|
| 57 |
+
self.scale_factor = scale_factor
|
| 58 |
+
|
| 59 |
+
def forward(self, x):
|
| 60 |
+
if x is None:
|
| 61 |
+
return None
|
| 62 |
+
|
| 63 |
+
B, n_patch, hidden = x.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
|
| 64 |
+
h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
|
| 65 |
+
x = x.permute(0, 2, 1)
|
| 66 |
+
x = x.contiguous().view(B, hidden, h, w)
|
| 67 |
+
x = nn.Upsample(scale_factor=self.scale_factor)(x)
|
| 68 |
+
|
| 69 |
+
out = self.conv(x)
|
| 70 |
+
out = self.norm(out)
|
| 71 |
+
out = self.activation(out)
|
| 72 |
+
return out
|
| 73 |
+
|
| 74 |
+
class Attention_org(nn.Module):
|
| 75 |
+
def __init__(self, config, vis,channel_num):
|
| 76 |
+
super(Attention_org, self).__init__()
|
| 77 |
+
self.vis = vis
|
| 78 |
+
self.KV_size = config.KV_size
|
| 79 |
+
self.channel_num = channel_num
|
| 80 |
+
self.num_attention_heads = config.transformer["num_heads"]
|
| 81 |
+
|
| 82 |
+
self.query1 = nn.ModuleList()
|
| 83 |
+
self.query2 = nn.ModuleList()
|
| 84 |
+
self.query3 = nn.ModuleList()
|
| 85 |
+
self.query4 = nn.ModuleList()
|
| 86 |
+
self.key = nn.ModuleList()
|
| 87 |
+
self.value = nn.ModuleList()
|
| 88 |
+
|
| 89 |
+
for _ in range(config.transformer["num_heads"]):
|
| 90 |
+
query1 = nn.Linear(channel_num[0], channel_num[0], bias=False)
|
| 91 |
+
query2 = nn.Linear(channel_num[1], channel_num[1], bias=False)
|
| 92 |
+
query3 = nn.Linear(channel_num[2], channel_num[2], bias=False)
|
| 93 |
+
query4 = nn.Linear(channel_num[3], channel_num[3], bias=False)
|
| 94 |
+
key = nn.Linear( self.KV_size, self.KV_size, bias=False)
|
| 95 |
+
value = nn.Linear(self.KV_size, self.KV_size, bias=False)
|
| 96 |
+
self.query1.append(copy.deepcopy(query1))
|
| 97 |
+
self.query2.append(copy.deepcopy(query2))
|
| 98 |
+
self.query3.append(copy.deepcopy(query3))
|
| 99 |
+
self.query4.append(copy.deepcopy(query4))
|
| 100 |
+
self.key.append(copy.deepcopy(key))
|
| 101 |
+
self.value.append(copy.deepcopy(value))
|
| 102 |
+
self.psi = nn.InstanceNorm2d(self.num_attention_heads)
|
| 103 |
+
self.softmax = Softmax(dim=3)
|
| 104 |
+
self.out1 = nn.Linear(channel_num[0], channel_num[0], bias=False)
|
| 105 |
+
self.out2 = nn.Linear(channel_num[1], channel_num[1], bias=False)
|
| 106 |
+
self.out3 = nn.Linear(channel_num[2], channel_num[2], bias=False)
|
| 107 |
+
self.out4 = nn.Linear(channel_num[3], channel_num[3], bias=False)
|
| 108 |
+
self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
|
| 109 |
+
self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def forward(self, emb1,emb2,emb3,emb4, emb_all):
|
| 114 |
+
multi_head_Q1_list = []
|
| 115 |
+
multi_head_Q2_list = []
|
| 116 |
+
multi_head_Q3_list = []
|
| 117 |
+
multi_head_Q4_list = []
|
| 118 |
+
multi_head_K_list = []
|
| 119 |
+
multi_head_V_list = []
|
| 120 |
+
if emb1 is not None:
|
| 121 |
+
for query1 in self.query1:
|
| 122 |
+
Q1 = query1(emb1)
|
| 123 |
+
multi_head_Q1_list.append(Q1)
|
| 124 |
+
if emb2 is not None:
|
| 125 |
+
for query2 in self.query2:
|
| 126 |
+
Q2 = query2(emb2)
|
| 127 |
+
multi_head_Q2_list.append(Q2)
|
| 128 |
+
if emb3 is not None:
|
| 129 |
+
for query3 in self.query3:
|
| 130 |
+
Q3 = query3(emb3)
|
| 131 |
+
multi_head_Q3_list.append(Q3)
|
| 132 |
+
if emb4 is not None:
|
| 133 |
+
for query4 in self.query4:
|
| 134 |
+
Q4 = query4(emb4)
|
| 135 |
+
multi_head_Q4_list.append(Q4)
|
| 136 |
+
for key in self.key:
|
| 137 |
+
K = key(emb_all)
|
| 138 |
+
multi_head_K_list.append(K)
|
| 139 |
+
for value in self.value:
|
| 140 |
+
V = value(emb_all)
|
| 141 |
+
multi_head_V_list.append(V)
|
| 142 |
+
# print(len(multi_head_Q4_list))
|
| 143 |
+
|
| 144 |
+
multi_head_Q1 = torch.stack(multi_head_Q1_list, dim=1) if emb1 is not None else None
|
| 145 |
+
multi_head_Q2 = torch.stack(multi_head_Q2_list, dim=1) if emb2 is not None else None
|
| 146 |
+
multi_head_Q3 = torch.stack(multi_head_Q3_list, dim=1) if emb3 is not None else None
|
| 147 |
+
multi_head_Q4 = torch.stack(multi_head_Q4_list, dim=1) if emb4 is not None else None
|
| 148 |
+
multi_head_K = torch.stack(multi_head_K_list, dim=1)
|
| 149 |
+
multi_head_V = torch.stack(multi_head_V_list, dim=1)
|
| 150 |
+
|
| 151 |
+
multi_head_Q1 = multi_head_Q1.transpose(-1, -2) if emb1 is not None else None
|
| 152 |
+
multi_head_Q2 = multi_head_Q2.transpose(-1, -2) if emb2 is not None else None
|
| 153 |
+
multi_head_Q3 = multi_head_Q3.transpose(-1, -2) if emb3 is not None else None
|
| 154 |
+
multi_head_Q4 = multi_head_Q4.transpose(-1, -2) if emb4 is not None else None
|
| 155 |
+
|
| 156 |
+
attention_scores1 = torch.matmul(multi_head_Q1, multi_head_K) if emb1 is not None else None
|
| 157 |
+
attention_scores2 = torch.matmul(multi_head_Q2, multi_head_K) if emb2 is not None else None
|
| 158 |
+
attention_scores3 = torch.matmul(multi_head_Q3, multi_head_K) if emb3 is not None else None
|
| 159 |
+
attention_scores4 = torch.matmul(multi_head_Q4, multi_head_K) if emb4 is not None else None
|
| 160 |
+
|
| 161 |
+
attention_scores1 = attention_scores1 / math.sqrt(self.KV_size) if emb1 is not None else None
|
| 162 |
+
attention_scores2 = attention_scores2 / math.sqrt(self.KV_size) if emb2 is not None else None
|
| 163 |
+
attention_scores3 = attention_scores3 / math.sqrt(self.KV_size) if emb3 is not None else None
|
| 164 |
+
attention_scores4 = attention_scores4 / math.sqrt(self.KV_size) if emb4 is not None else None
|
| 165 |
+
|
| 166 |
+
attention_probs1 = self.softmax(self.psi(attention_scores1)) if emb1 is not None else None
|
| 167 |
+
attention_probs2 = self.softmax(self.psi(attention_scores2)) if emb2 is not None else None
|
| 168 |
+
attention_probs3 = self.softmax(self.psi(attention_scores3)) if emb3 is not None else None
|
| 169 |
+
attention_probs4 = self.softmax(self.psi(attention_scores4)) if emb4 is not None else None
|
| 170 |
+
# print(attention_probs4.size())
|
| 171 |
+
|
| 172 |
+
if self.vis:
|
| 173 |
+
weights = []
|
| 174 |
+
weights.append(attention_probs1.mean(1))
|
| 175 |
+
weights.append(attention_probs2.mean(1))
|
| 176 |
+
weights.append(attention_probs3.mean(1))
|
| 177 |
+
weights.append(attention_probs4.mean(1))
|
| 178 |
+
else: weights=None
|
| 179 |
+
|
| 180 |
+
attention_probs1 = self.attn_dropout(attention_probs1) if emb1 is not None else None
|
| 181 |
+
attention_probs2 = self.attn_dropout(attention_probs2) if emb2 is not None else None
|
| 182 |
+
attention_probs3 = self.attn_dropout(attention_probs3) if emb3 is not None else None
|
| 183 |
+
attention_probs4 = self.attn_dropout(attention_probs4) if emb4 is not None else None
|
| 184 |
+
|
| 185 |
+
multi_head_V = multi_head_V.transpose(-1, -2)
|
| 186 |
+
context_layer1 = torch.matmul(attention_probs1, multi_head_V) if emb1 is not None else None
|
| 187 |
+
context_layer2 = torch.matmul(attention_probs2, multi_head_V) if emb2 is not None else None
|
| 188 |
+
context_layer3 = torch.matmul(attention_probs3, multi_head_V) if emb3 is not None else None
|
| 189 |
+
context_layer4 = torch.matmul(attention_probs4, multi_head_V) if emb4 is not None else None
|
| 190 |
+
|
| 191 |
+
context_layer1 = context_layer1.permute(0, 3, 2, 1).contiguous() if emb1 is not None else None
|
| 192 |
+
context_layer2 = context_layer2.permute(0, 3, 2, 1).contiguous() if emb2 is not None else None
|
| 193 |
+
context_layer3 = context_layer3.permute(0, 3, 2, 1).contiguous() if emb3 is not None else None
|
| 194 |
+
context_layer4 = context_layer4.permute(0, 3, 2, 1).contiguous() if emb4 is not None else None
|
| 195 |
+
context_layer1 = context_layer1.mean(dim=3) if emb1 is not None else None
|
| 196 |
+
context_layer2 = context_layer2.mean(dim=3) if emb2 is not None else None
|
| 197 |
+
context_layer3 = context_layer3.mean(dim=3) if emb3 is not None else None
|
| 198 |
+
context_layer4 = context_layer4.mean(dim=3) if emb4 is not None else None
|
| 199 |
+
|
| 200 |
+
O1 = self.out1(context_layer1) if emb1 is not None else None
|
| 201 |
+
O2 = self.out2(context_layer2) if emb2 is not None else None
|
| 202 |
+
O3 = self.out3(context_layer3) if emb3 is not None else None
|
| 203 |
+
O4 = self.out4(context_layer4) if emb4 is not None else None
|
| 204 |
+
O1 = self.proj_dropout(O1) if emb1 is not None else None
|
| 205 |
+
O2 = self.proj_dropout(O2) if emb2 is not None else None
|
| 206 |
+
O3 = self.proj_dropout(O3) if emb3 is not None else None
|
| 207 |
+
O4 = self.proj_dropout(O4) if emb4 is not None else None
|
| 208 |
+
return O1,O2,O3,O4, weights
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class Mlp(nn.Module):
|
| 214 |
+
def __init__(self,config, in_channel, mlp_channel):
|
| 215 |
+
super(Mlp, self).__init__()
|
| 216 |
+
self.fc1 = nn.Linear(in_channel, mlp_channel)
|
| 217 |
+
self.fc2 = nn.Linear(mlp_channel, in_channel)
|
| 218 |
+
self.act_fn = nn.GELU()
|
| 219 |
+
self.dropout = Dropout(config.transformer["dropout_rate"])
|
| 220 |
+
self._init_weights()
|
| 221 |
+
|
| 222 |
+
def _init_weights(self):
|
| 223 |
+
nn.init.xavier_uniform_(self.fc1.weight)
|
| 224 |
+
nn.init.xavier_uniform_(self.fc2.weight)
|
| 225 |
+
nn.init.normal_(self.fc1.bias, std=1e-6)
|
| 226 |
+
nn.init.normal_(self.fc2.bias, std=1e-6)
|
| 227 |
+
|
| 228 |
+
def forward(self, x):
|
| 229 |
+
x = self.fc1(x)
|
| 230 |
+
x = self.act_fn(x)
|
| 231 |
+
x = self.dropout(x)
|
| 232 |
+
x = self.fc2(x)
|
| 233 |
+
x = self.dropout(x)
|
| 234 |
+
return x
|
| 235 |
+
|
| 236 |
+
class Block_ViT(nn.Module):
|
| 237 |
+
def __init__(self, config, vis, channel_num):
|
| 238 |
+
super(Block_ViT, self).__init__()
|
| 239 |
+
expand_ratio = config.expand_ratio
|
| 240 |
+
self.attn_norm1 = LayerNorm(channel_num[0],eps=1e-6)
|
| 241 |
+
self.attn_norm2 = LayerNorm(channel_num[1],eps=1e-6)
|
| 242 |
+
self.attn_norm3 = LayerNorm(channel_num[2],eps=1e-6)
|
| 243 |
+
self.attn_norm4 = LayerNorm(channel_num[3],eps=1e-6)
|
| 244 |
+
self.attn_norm = LayerNorm(config.KV_size,eps=1e-6)
|
| 245 |
+
self.channel_attn = Attention_org(config, vis, channel_num)
|
| 246 |
+
|
| 247 |
+
self.ffn_norm1 = LayerNorm(channel_num[0],eps=1e-6)
|
| 248 |
+
self.ffn_norm2 = LayerNorm(channel_num[1],eps=1e-6)
|
| 249 |
+
self.ffn_norm3 = LayerNorm(channel_num[2],eps=1e-6)
|
| 250 |
+
self.ffn_norm4 = LayerNorm(channel_num[3],eps=1e-6)
|
| 251 |
+
self.ffn1 = Mlp(config,channel_num[0],channel_num[0]*expand_ratio)
|
| 252 |
+
self.ffn2 = Mlp(config,channel_num[1],channel_num[1]*expand_ratio)
|
| 253 |
+
self.ffn3 = Mlp(config,channel_num[2],channel_num[2]*expand_ratio)
|
| 254 |
+
self.ffn4 = Mlp(config,channel_num[3],channel_num[3]*expand_ratio)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def forward(self, emb1,emb2,emb3,emb4):
|
| 258 |
+
embcat = []
|
| 259 |
+
org1 = emb1
|
| 260 |
+
org2 = emb2
|
| 261 |
+
org3 = emb3
|
| 262 |
+
org4 = emb4
|
| 263 |
+
for i in range(4):
|
| 264 |
+
var_name = "emb"+str(i+1)
|
| 265 |
+
tmp_var = locals()[var_name]
|
| 266 |
+
if tmp_var is not None:
|
| 267 |
+
embcat.append(tmp_var)
|
| 268 |
+
|
| 269 |
+
emb_all = torch.cat(embcat,dim=2)
|
| 270 |
+
cx1 = self.attn_norm1(emb1) if emb1 is not None else None
|
| 271 |
+
cx2 = self.attn_norm2(emb2) if emb2 is not None else None
|
| 272 |
+
cx3 = self.attn_norm3(emb3) if emb3 is not None else None
|
| 273 |
+
cx4 = self.attn_norm4(emb4) if emb4 is not None else None
|
| 274 |
+
emb_all = self.attn_norm(emb_all)
|
| 275 |
+
cx1,cx2,cx3,cx4, weights = self.channel_attn(cx1,cx2,cx3,cx4,emb_all)
|
| 276 |
+
cx1 = org1 + cx1 if emb1 is not None else None
|
| 277 |
+
cx2 = org2 + cx2 if emb2 is not None else None
|
| 278 |
+
cx3 = org3 + cx3 if emb3 is not None else None
|
| 279 |
+
cx4 = org4 + cx4 if emb4 is not None else None
|
| 280 |
+
|
| 281 |
+
org1 = cx1
|
| 282 |
+
org2 = cx2
|
| 283 |
+
org3 = cx3
|
| 284 |
+
org4 = cx4
|
| 285 |
+
x1 = self.ffn_norm1(cx1) if emb1 is not None else None
|
| 286 |
+
x2 = self.ffn_norm2(cx2) if emb2 is not None else None
|
| 287 |
+
x3 = self.ffn_norm3(cx3) if emb3 is not None else None
|
| 288 |
+
x4 = self.ffn_norm4(cx4) if emb4 is not None else None
|
| 289 |
+
x1 = self.ffn1(x1) if emb1 is not None else None
|
| 290 |
+
x2 = self.ffn2(x2) if emb2 is not None else None
|
| 291 |
+
x3 = self.ffn3(x3) if emb3 is not None else None
|
| 292 |
+
x4 = self.ffn4(x4) if emb4 is not None else None
|
| 293 |
+
x1 = x1 + org1 if emb1 is not None else None
|
| 294 |
+
x2 = x2 + org2 if emb2 is not None else None
|
| 295 |
+
x3 = x3 + org3 if emb3 is not None else None
|
| 296 |
+
x4 = x4 + org4 if emb4 is not None else None
|
| 297 |
+
|
| 298 |
+
return x1, x2, x3, x4, weights
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
class Encoder(nn.Module):
|
| 302 |
+
def __init__(self, config, vis, channel_num):
|
| 303 |
+
super(Encoder, self).__init__()
|
| 304 |
+
self.vis = vis
|
| 305 |
+
self.layer = nn.ModuleList()
|
| 306 |
+
self.encoder_norm1 = LayerNorm(channel_num[0],eps=1e-6)
|
| 307 |
+
self.encoder_norm2 = LayerNorm(channel_num[1],eps=1e-6)
|
| 308 |
+
self.encoder_norm3 = LayerNorm(channel_num[2],eps=1e-6)
|
| 309 |
+
self.encoder_norm4 = LayerNorm(channel_num[3],eps=1e-6)
|
| 310 |
+
for _ in range(config.transformer["num_layers"]):
|
| 311 |
+
layer = Block_ViT(config, vis, channel_num)
|
| 312 |
+
self.layer.append(copy.deepcopy(layer))
|
| 313 |
+
|
| 314 |
+
def forward(self, emb1,emb2,emb3,emb4):
|
| 315 |
+
attn_weights = []
|
| 316 |
+
for layer_block in self.layer:
|
| 317 |
+
emb1,emb2,emb3,emb4, weights = layer_block(emb1,emb2,emb3,emb4)
|
| 318 |
+
if self.vis:
|
| 319 |
+
attn_weights.append(weights)
|
| 320 |
+
emb1 = self.encoder_norm1(emb1) if emb1 is not None else None
|
| 321 |
+
emb2 = self.encoder_norm2(emb2) if emb2 is not None else None
|
| 322 |
+
emb3 = self.encoder_norm3(emb3) if emb3 is not None else None
|
| 323 |
+
emb4 = self.encoder_norm4(emb4) if emb4 is not None else None
|
| 324 |
+
return emb1,emb2,emb3,emb4, attn_weights
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
class ChannelTransformer(nn.Module):
|
| 328 |
+
def __init__(self, config, vis, img_size, channel_num=[64, 128, 256, 512], patchSize=[32, 16, 8, 4]):
|
| 329 |
+
super().__init__()
|
| 330 |
+
|
| 331 |
+
self.patchSize_1 = patchSize[0]
|
| 332 |
+
self.patchSize_2 = patchSize[1]
|
| 333 |
+
self.patchSize_3 = patchSize[2]
|
| 334 |
+
self.patchSize_4 = patchSize[3]
|
| 335 |
+
self.embeddings_1 = Channel_Embeddings(config,self.patchSize_1, img_size=img_size, in_channels=channel_num[0])
|
| 336 |
+
self.embeddings_2 = Channel_Embeddings(config,self.patchSize_2, img_size=img_size//2, in_channels=channel_num[1])
|
| 337 |
+
self.embeddings_3 = Channel_Embeddings(config,self.patchSize_3, img_size=img_size//4, in_channels=channel_num[2])
|
| 338 |
+
self.embeddings_4 = Channel_Embeddings(config,self.patchSize_4, img_size=img_size//8, in_channels=channel_num[3])
|
| 339 |
+
self.encoder = Encoder(config, vis, channel_num)
|
| 340 |
+
|
| 341 |
+
self.reconstruct_1 = Reconstruct(channel_num[0], channel_num[0], kernel_size=1,scale_factor=(self.patchSize_1,self.patchSize_1))
|
| 342 |
+
self.reconstruct_2 = Reconstruct(channel_num[1], channel_num[1], kernel_size=1,scale_factor=(self.patchSize_2,self.patchSize_2))
|
| 343 |
+
self.reconstruct_3 = Reconstruct(channel_num[2], channel_num[2], kernel_size=1,scale_factor=(self.patchSize_3,self.patchSize_3))
|
| 344 |
+
self.reconstruct_4 = Reconstruct(channel_num[3], channel_num[3], kernel_size=1,scale_factor=(self.patchSize_4,self.patchSize_4))
|
| 345 |
+
|
| 346 |
+
def forward(self,en1,en2,en3,en4):
|
| 347 |
+
|
| 348 |
+
emb1 = self.embeddings_1(en1)
|
| 349 |
+
emb2 = self.embeddings_2(en2)
|
| 350 |
+
emb3 = self.embeddings_3(en3)
|
| 351 |
+
emb4 = self.embeddings_4(en4)
|
| 352 |
+
|
| 353 |
+
encoded1, encoded2, encoded3, encoded4, attn_weights = self.encoder(emb1,emb2,emb3,emb4) # (B, n_patch, hidden)
|
| 354 |
+
x1 = self.reconstruct_1(encoded1) if en1 is not None else None
|
| 355 |
+
x2 = self.reconstruct_2(encoded2) if en2 is not None else None
|
| 356 |
+
x3 = self.reconstruct_3(encoded3) if en3 is not None else None
|
| 357 |
+
x4 = self.reconstruct_4(encoded4) if en4 is not None else None
|
| 358 |
+
|
| 359 |
+
x1 = x1 + en1 if en1 is not None else None
|
| 360 |
+
x2 = x2 + en2 if en2 is not None else None
|
| 361 |
+
x3 = x3 + en3 if en3 is not None else None
|
| 362 |
+
x4 = x4 + en4 if en4 is not None else None
|
| 363 |
+
|
| 364 |
+
return x1, x2, x3, x4, attn_weights
|
| 365 |
+
|
models/_uctransnet/Config.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# @Time : 2021/6/19 2:44 下午
|
| 3 |
+
# @Author : Haonan Wang
|
| 4 |
+
# @File : Config.py
|
| 5 |
+
# @Software: PyCharm
|
| 6 |
+
import os
|
| 7 |
+
import torch
|
| 8 |
+
import time
|
| 9 |
+
import ml_collections
|
| 10 |
+
|
| 11 |
+
## PARAMETERS OF THE MODEL
|
| 12 |
+
save_model = True
|
| 13 |
+
tensorboard = True
|
| 14 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
| 15 |
+
use_cuda = torch.cuda.is_available()
|
| 16 |
+
seed = 666
|
| 17 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
| 18 |
+
|
| 19 |
+
cosineLR = True # whether use cosineLR or not
|
| 20 |
+
n_channels = 3
|
| 21 |
+
n_labels = 1
|
| 22 |
+
epochs = 2000
|
| 23 |
+
img_size = 224
|
| 24 |
+
print_frequency = 1
|
| 25 |
+
save_frequency = 5000
|
| 26 |
+
vis_frequency = 10
|
| 27 |
+
early_stopping_patience = 50
|
| 28 |
+
|
| 29 |
+
pretrain = False
|
| 30 |
+
task_name = 'MoNuSeg' # GlaS MoNuSeg
|
| 31 |
+
# task_name = 'GlaS'
|
| 32 |
+
learning_rate = 1e-3
|
| 33 |
+
batch_size = 4
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# model_name = 'UCTransNet'
|
| 37 |
+
model_name = 'UCTransNet_pretrain'
|
| 38 |
+
|
| 39 |
+
train_dataset = './datasets/'+ task_name+ '/Train_Folder/'
|
| 40 |
+
val_dataset = './datasets/'+ task_name+ '/Val_Folder/'
|
| 41 |
+
test_dataset = './datasets/'+ task_name+ '/Test_Folder/'
|
| 42 |
+
session_name = 'Test_session' + '_' + time.strftime('%m.%d_%Hh%M')
|
| 43 |
+
save_path = task_name +'/'+ model_name +'/' + session_name + '/'
|
| 44 |
+
model_path = save_path + 'models/'
|
| 45 |
+
tensorboard_folder = save_path + 'tensorboard_logs/'
|
| 46 |
+
logger_path = save_path + session_name + ".log"
|
| 47 |
+
visualize_path = save_path + 'visualize_val/'
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
##########################################################################
|
| 51 |
+
# CTrans configs
|
| 52 |
+
##########################################################################
|
| 53 |
+
def get_CTranS_config():
|
| 54 |
+
config = ml_collections.ConfigDict()
|
| 55 |
+
config.transformer = ml_collections.ConfigDict()
|
| 56 |
+
config.KV_size = 960 # KV_size = Q1 + Q2 + Q3 + Q4
|
| 57 |
+
config.transformer.num_heads = 4
|
| 58 |
+
config.transformer.num_layers = 4
|
| 59 |
+
config.expand_ratio = 4 # MLP channel dimension expand ratio
|
| 60 |
+
config.transformer.embeddings_dropout_rate = 0.1
|
| 61 |
+
config.transformer.attention_dropout_rate = 0.1
|
| 62 |
+
config.transformer.dropout_rate = 0
|
| 63 |
+
config.patch_sizes = [16,8,4,2]
|
| 64 |
+
config.base_channel = 64 # base channel of U-Net
|
| 65 |
+
config.n_classes = 1
|
| 66 |
+
return config
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# used in testing phase, copy the session name in training phase
|
| 72 |
+
test_session = "Test_session_07.03_20h39"
|
models/_uctransnet/UCTransNet.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# @Time : 2021/7/8 8:59 上午
|
| 3 |
+
# @File : UCTransNet.py
|
| 4 |
+
# @Software: PyCharm
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from .CTrans import ChannelTransformer
|
| 9 |
+
|
| 10 |
+
def get_activation(activation_type):
|
| 11 |
+
activation_type = activation_type.lower()
|
| 12 |
+
if hasattr(nn, activation_type):
|
| 13 |
+
return getattr(nn, activation_type)()
|
| 14 |
+
else:
|
| 15 |
+
return nn.ReLU()
|
| 16 |
+
|
| 17 |
+
def _make_nConv(in_channels, out_channels, nb_Conv, activation='ReLU'):
|
| 18 |
+
layers = []
|
| 19 |
+
layers.append(ConvBatchNorm(in_channels, out_channels, activation))
|
| 20 |
+
|
| 21 |
+
for _ in range(nb_Conv - 1):
|
| 22 |
+
layers.append(ConvBatchNorm(out_channels, out_channels, activation))
|
| 23 |
+
return nn.Sequential(*layers)
|
| 24 |
+
|
| 25 |
+
class ConvBatchNorm(nn.Module):
|
| 26 |
+
"""(convolution => [BN] => ReLU)"""
|
| 27 |
+
|
| 28 |
+
def __init__(self, in_channels, out_channels, activation='ReLU'):
|
| 29 |
+
super(ConvBatchNorm, self).__init__()
|
| 30 |
+
self.conv = nn.Conv2d(in_channels, out_channels,
|
| 31 |
+
kernel_size=3, padding=1)
|
| 32 |
+
self.norm = nn.BatchNorm2d(out_channels)
|
| 33 |
+
self.activation = get_activation(activation)
|
| 34 |
+
|
| 35 |
+
def forward(self, x):
|
| 36 |
+
out = self.conv(x)
|
| 37 |
+
out = self.norm(out)
|
| 38 |
+
return self.activation(out)
|
| 39 |
+
|
| 40 |
+
class DownBlock(nn.Module):
|
| 41 |
+
"""Downscaling with maxpool convolution"""
|
| 42 |
+
def __init__(self, in_channels, out_channels, nb_Conv, activation='ReLU'):
|
| 43 |
+
super(DownBlock, self).__init__()
|
| 44 |
+
self.maxpool = nn.MaxPool2d(2)
|
| 45 |
+
self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv, activation)
|
| 46 |
+
|
| 47 |
+
def forward(self, x):
|
| 48 |
+
out = self.maxpool(x)
|
| 49 |
+
return self.nConvs(out)
|
| 50 |
+
|
| 51 |
+
class Flatten(nn.Module):
|
| 52 |
+
def forward(self, x):
|
| 53 |
+
return x.view(x.size(0), -1)
|
| 54 |
+
|
| 55 |
+
class CCA(nn.Module):
|
| 56 |
+
"""
|
| 57 |
+
CCA Block
|
| 58 |
+
"""
|
| 59 |
+
def __init__(self, F_g, F_x):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.mlp_x = nn.Sequential(
|
| 62 |
+
Flatten(),
|
| 63 |
+
nn.Linear(F_x, F_x))
|
| 64 |
+
self.mlp_g = nn.Sequential(
|
| 65 |
+
Flatten(),
|
| 66 |
+
nn.Linear(F_g, F_x))
|
| 67 |
+
self.relu = nn.ReLU(inplace=True)
|
| 68 |
+
|
| 69 |
+
def forward(self, g, x):
|
| 70 |
+
# channel-wise attention
|
| 71 |
+
avg_pool_x = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
|
| 72 |
+
channel_att_x = self.mlp_x(avg_pool_x)
|
| 73 |
+
avg_pool_g = F.avg_pool2d( g, (g.size(2), g.size(3)), stride=(g.size(2), g.size(3)))
|
| 74 |
+
channel_att_g = self.mlp_g(avg_pool_g)
|
| 75 |
+
channel_att_sum = (channel_att_x + channel_att_g)/2.0
|
| 76 |
+
scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
|
| 77 |
+
x_after_channel = x * scale
|
| 78 |
+
out = self.relu(x_after_channel)
|
| 79 |
+
return out
|
| 80 |
+
|
| 81 |
+
class UpBlock_attention(nn.Module):
|
| 82 |
+
def __init__(self, in_channels, out_channels, nb_Conv, activation='ReLU'):
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.up = nn.Upsample(scale_factor=2)
|
| 85 |
+
self.coatt = CCA(F_g=in_channels//2, F_x=in_channels//2)
|
| 86 |
+
self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv, activation)
|
| 87 |
+
|
| 88 |
+
def forward(self, x, skip_x):
|
| 89 |
+
up = self.up(x)
|
| 90 |
+
skip_x_att = self.coatt(g=up, x=skip_x)
|
| 91 |
+
x = torch.cat([skip_x_att, up], dim=1) # dim 1 is the channel dimension
|
| 92 |
+
return self.nConvs(x)
|
| 93 |
+
|
| 94 |
+
class UCTransNet(nn.Module):
|
| 95 |
+
def __init__(self, config,n_channels=3, n_classes=1,img_size=224,vis=False):
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.vis = vis
|
| 98 |
+
self.n_channels = n_channels
|
| 99 |
+
self.n_classes = n_classes
|
| 100 |
+
in_channels = config.base_channel
|
| 101 |
+
self.inc = ConvBatchNorm(n_channels, in_channels)
|
| 102 |
+
self.down1 = DownBlock(in_channels, in_channels*2, nb_Conv=2)
|
| 103 |
+
self.down2 = DownBlock(in_channels*2, in_channels*4, nb_Conv=2)
|
| 104 |
+
self.down3 = DownBlock(in_channels*4, in_channels*8, nb_Conv=2)
|
| 105 |
+
self.down4 = DownBlock(in_channels*8, in_channels*8, nb_Conv=2)
|
| 106 |
+
self.mtc = ChannelTransformer(config, vis, img_size,
|
| 107 |
+
channel_num=[in_channels, in_channels*2, in_channels*4, in_channels*8],
|
| 108 |
+
patchSize=config.patch_sizes)
|
| 109 |
+
self.up4 = UpBlock_attention(in_channels*16, in_channels*4, nb_Conv=2)
|
| 110 |
+
self.up3 = UpBlock_attention(in_channels*8, in_channels*2, nb_Conv=2)
|
| 111 |
+
self.up2 = UpBlock_attention(in_channels*4, in_channels, nb_Conv=2)
|
| 112 |
+
self.up1 = UpBlock_attention(in_channels*2, in_channels, nb_Conv=2)
|
| 113 |
+
self.outc = nn.Conv2d(in_channels, n_classes, kernel_size=(1,1), stride=(1,1))
|
| 114 |
+
self.last_activation = nn.Sigmoid() # if using BCELoss
|
| 115 |
+
|
| 116 |
+
def forward(self, x):
|
| 117 |
+
x = x.float()
|
| 118 |
+
x1 = self.inc(x)
|
| 119 |
+
x2 = self.down1(x1)
|
| 120 |
+
x3 = self.down2(x2)
|
| 121 |
+
x4 = self.down3(x3)
|
| 122 |
+
x5 = self.down4(x4)
|
| 123 |
+
x1,x2,x3,x4,att_weights = self.mtc(x1,x2,x3,x4)
|
| 124 |
+
x = self.up4(x5, x4)
|
| 125 |
+
x = self.up3(x, x3)
|
| 126 |
+
x = self.up2(x, x2)
|
| 127 |
+
x = self.up1(x, x1)
|
| 128 |
+
if self.n_classes ==1:
|
| 129 |
+
logits = self.last_activation(self.outc(x))
|
| 130 |
+
else:
|
| 131 |
+
logits = self.outc(x) # if nusing BCEWithLogitsLoss or class>1
|
| 132 |
+
if self.vis: # visualize the attention maps
|
| 133 |
+
return logits, att_weights
|
| 134 |
+
else:
|
| 135 |
+
return logits
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
|
models/_uctransnet/UNet.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
def get_activation(activation_type):
|
| 5 |
+
activation_type = activation_type.lower()
|
| 6 |
+
if hasattr(nn, activation_type):
|
| 7 |
+
return getattr(nn, activation_type)()
|
| 8 |
+
else:
|
| 9 |
+
return nn.ReLU()
|
| 10 |
+
|
| 11 |
+
def _make_nConv(in_channels, out_channels, nb_Conv, activation='ReLU'):
|
| 12 |
+
layers = []
|
| 13 |
+
layers.append(ConvBatchNorm(in_channels, out_channels, activation))
|
| 14 |
+
|
| 15 |
+
for _ in range(nb_Conv - 1):
|
| 16 |
+
layers.append(ConvBatchNorm(out_channels, out_channels, activation))
|
| 17 |
+
return nn.Sequential(*layers)
|
| 18 |
+
|
| 19 |
+
class ConvBatchNorm(nn.Module):
|
| 20 |
+
"""(convolution => [BN] => ReLU)"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, in_channels, out_channels, activation='ReLU'):
|
| 23 |
+
super(ConvBatchNorm, self).__init__()
|
| 24 |
+
self.conv = nn.Conv2d(in_channels, out_channels,
|
| 25 |
+
kernel_size=3, padding=1)
|
| 26 |
+
self.norm = nn.BatchNorm2d(out_channels)
|
| 27 |
+
self.activation = get_activation(activation)
|
| 28 |
+
|
| 29 |
+
def forward(self, x):
|
| 30 |
+
out = self.conv(x)
|
| 31 |
+
out = self.norm(out)
|
| 32 |
+
return self.activation(out)
|
| 33 |
+
|
| 34 |
+
class DownBlock(nn.Module):
|
| 35 |
+
"""Downscaling with maxpool convolution"""
|
| 36 |
+
|
| 37 |
+
def __init__(self, in_channels, out_channels, nb_Conv, activation='ReLU'):
|
| 38 |
+
super(DownBlock, self).__init__()
|
| 39 |
+
self.maxpool = nn.MaxPool2d(2)
|
| 40 |
+
self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv, activation)
|
| 41 |
+
|
| 42 |
+
def forward(self, x):
|
| 43 |
+
out = self.maxpool(x)
|
| 44 |
+
return self.nConvs(out)
|
| 45 |
+
|
| 46 |
+
class UpBlock(nn.Module):
|
| 47 |
+
"""Upscaling then conv"""
|
| 48 |
+
|
| 49 |
+
def __init__(self, in_channels, out_channels, nb_Conv, activation='ReLU'):
|
| 50 |
+
super(UpBlock, self).__init__()
|
| 51 |
+
|
| 52 |
+
# self.up = nn.Upsample(scale_factor=2)
|
| 53 |
+
self.up = nn.ConvTranspose2d(in_channels//2,in_channels//2,(2,2),2)
|
| 54 |
+
self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv, activation)
|
| 55 |
+
|
| 56 |
+
def forward(self, x, skip_x):
|
| 57 |
+
out = self.up(x)
|
| 58 |
+
x = torch.cat([out, skip_x], dim=1) # dim 1 is the channel dimension
|
| 59 |
+
return self.nConvs(x)
|
| 60 |
+
|
| 61 |
+
class UNet(nn.Module):
|
| 62 |
+
def __init__(self, n_channels=3, n_classes=9):
|
| 63 |
+
'''
|
| 64 |
+
n_channels : number of channels of the input.
|
| 65 |
+
By default 3, because we have RGB images
|
| 66 |
+
n_labels : number of channels of the ouput.
|
| 67 |
+
By default 3 (2 labels + 1 for the background)
|
| 68 |
+
'''
|
| 69 |
+
super().__init__()
|
| 70 |
+
self.n_channels = n_channels
|
| 71 |
+
self.n_classes = n_classes
|
| 72 |
+
# Question here
|
| 73 |
+
in_channels = 64
|
| 74 |
+
self.inc = ConvBatchNorm(n_channels, in_channels)
|
| 75 |
+
self.down1 = DownBlock(in_channels, in_channels*2, nb_Conv=2)
|
| 76 |
+
self.down2 = DownBlock(in_channels*2, in_channels*4, nb_Conv=2)
|
| 77 |
+
self.down3 = DownBlock(in_channels*4, in_channels*8, nb_Conv=2)
|
| 78 |
+
self.down4 = DownBlock(in_channels*8, in_channels*8, nb_Conv=2)
|
| 79 |
+
self.up4 = UpBlock(in_channels*16, in_channels*4, nb_Conv=2)
|
| 80 |
+
self.up3 = UpBlock(in_channels*8, in_channels*2, nb_Conv=2)
|
| 81 |
+
self.up2 = UpBlock(in_channels*4, in_channels, nb_Conv=2)
|
| 82 |
+
self.up1 = UpBlock(in_channels*2, in_channels, nb_Conv=2)
|
| 83 |
+
self.outc = nn.Conv2d(in_channels, n_classes, kernel_size=(1,1))
|
| 84 |
+
if n_classes == 1:
|
| 85 |
+
self.last_activation = nn.Sigmoid()
|
| 86 |
+
else:
|
| 87 |
+
self.last_activation = None
|
| 88 |
+
|
| 89 |
+
def forward(self, x):
|
| 90 |
+
# Question here
|
| 91 |
+
x = x.float()
|
| 92 |
+
x1 = self.inc(x)
|
| 93 |
+
x2 = self.down1(x1)
|
| 94 |
+
x3 = self.down2(x2)
|
| 95 |
+
x4 = self.down3(x3)
|
| 96 |
+
x5 = self.down4(x4)
|
| 97 |
+
x = self.up4(x5, x4)
|
| 98 |
+
x = self.up3(x, x3)
|
| 99 |
+
x = self.up2(x, x2)
|
| 100 |
+
x = self.up1(x, x1)
|
| 101 |
+
if self.last_activation is not None:
|
| 102 |
+
logits = self.last_activation(self.outc(x))
|
| 103 |
+
# print("111")
|
| 104 |
+
else:
|
| 105 |
+
logits = self.outc(x)
|
| 106 |
+
# print("222")
|
| 107 |
+
# logits = self.outc(x) # if using BCEWithLogitsLoss
|
| 108 |
+
# print(logits.size())
|
| 109 |
+
return logits
|
| 110 |
+
|
| 111 |
+
|
models/attunet.py
ADDED
|
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://github.com/LeeJunHyun/Image_Segmentation
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch.nn import init
|
| 7 |
+
|
| 8 |
+
def init_weights(net, init_type='normal', gain=0.02):
|
| 9 |
+
def init_func(m):
|
| 10 |
+
classname = m.__class__.__name__
|
| 11 |
+
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
| 12 |
+
if init_type == 'normal':
|
| 13 |
+
init.normal_(m.weight.data, 0.0, gain)
|
| 14 |
+
elif init_type == 'xavier':
|
| 15 |
+
init.xavier_normal_(m.weight.data, gain=gain)
|
| 16 |
+
elif init_type == 'kaiming':
|
| 17 |
+
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
| 18 |
+
elif init_type == 'orthogonal':
|
| 19 |
+
init.orthogonal_(m.weight.data, gain=gain)
|
| 20 |
+
else:
|
| 21 |
+
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
| 22 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
| 23 |
+
init.constant_(m.bias.data, 0.0)
|
| 24 |
+
elif classname.find('BatchNorm2d') != -1:
|
| 25 |
+
init.normal_(m.weight.data, 1.0, gain)
|
| 26 |
+
init.constant_(m.bias.data, 0.0)
|
| 27 |
+
|
| 28 |
+
print('initialize network with %s' % init_type)
|
| 29 |
+
net.apply(init_func)
|
| 30 |
+
|
| 31 |
+
class conv_block(nn.Module):
|
| 32 |
+
def __init__(self,ch_in,ch_out):
|
| 33 |
+
super(conv_block,self).__init__()
|
| 34 |
+
self.conv = nn.Sequential(
|
| 35 |
+
nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
|
| 36 |
+
nn.BatchNorm2d(ch_out),
|
| 37 |
+
nn.ReLU(inplace=True),
|
| 38 |
+
nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
|
| 39 |
+
nn.BatchNorm2d(ch_out),
|
| 40 |
+
nn.ReLU(inplace=True)
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def forward(self,x):
|
| 45 |
+
x = self.conv(x)
|
| 46 |
+
return x
|
| 47 |
+
|
| 48 |
+
class up_conv(nn.Module):
|
| 49 |
+
def __init__(self,ch_in,ch_out):
|
| 50 |
+
super(up_conv,self).__init__()
|
| 51 |
+
self.up = nn.Sequential(
|
| 52 |
+
nn.Upsample(scale_factor=2),
|
| 53 |
+
nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
|
| 54 |
+
nn.BatchNorm2d(ch_out),
|
| 55 |
+
nn.ReLU(inplace=True)
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
def forward(self,x):
|
| 59 |
+
x = self.up(x)
|
| 60 |
+
return x
|
| 61 |
+
|
| 62 |
+
class Recurrent_block(nn.Module):
|
| 63 |
+
def __init__(self,ch_out,t=2):
|
| 64 |
+
super(Recurrent_block,self).__init__()
|
| 65 |
+
self.t = t
|
| 66 |
+
self.ch_out = ch_out
|
| 67 |
+
self.conv = nn.Sequential(
|
| 68 |
+
nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
|
| 69 |
+
nn.BatchNorm2d(ch_out),
|
| 70 |
+
nn.ReLU(inplace=True)
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
def forward(self,x):
|
| 74 |
+
for i in range(self.t):
|
| 75 |
+
|
| 76 |
+
if i==0:
|
| 77 |
+
x1 = self.conv(x)
|
| 78 |
+
|
| 79 |
+
x1 = self.conv(x+x1)
|
| 80 |
+
return x1
|
| 81 |
+
|
| 82 |
+
class RRCNN_block(nn.Module):
|
| 83 |
+
def __init__(self,ch_in,ch_out,t=2):
|
| 84 |
+
super(RRCNN_block,self).__init__()
|
| 85 |
+
self.RCNN = nn.Sequential(
|
| 86 |
+
Recurrent_block(ch_out,t=t),
|
| 87 |
+
Recurrent_block(ch_out,t=t)
|
| 88 |
+
)
|
| 89 |
+
self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding=0)
|
| 90 |
+
|
| 91 |
+
def forward(self,x):
|
| 92 |
+
x = self.Conv_1x1(x)
|
| 93 |
+
x1 = self.RCNN(x)
|
| 94 |
+
return x+x1
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class single_conv(nn.Module):
|
| 98 |
+
def __init__(self,ch_in,ch_out):
|
| 99 |
+
super(single_conv,self).__init__()
|
| 100 |
+
self.conv = nn.Sequential(
|
| 101 |
+
nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
|
| 102 |
+
nn.BatchNorm2d(ch_out),
|
| 103 |
+
nn.ReLU(inplace=True)
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
def forward(self,x):
|
| 107 |
+
x = self.conv(x)
|
| 108 |
+
return x
|
| 109 |
+
|
| 110 |
+
class Attention_block(nn.Module):
|
| 111 |
+
def __init__(self,F_g,F_l,F_int):
|
| 112 |
+
super(Attention_block,self).__init__()
|
| 113 |
+
self.W_g = nn.Sequential(
|
| 114 |
+
nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
|
| 115 |
+
nn.BatchNorm2d(F_int)
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
self.W_x = nn.Sequential(
|
| 119 |
+
nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
|
| 120 |
+
nn.BatchNorm2d(F_int)
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
self.psi = nn.Sequential(
|
| 124 |
+
nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
|
| 125 |
+
nn.BatchNorm2d(1),
|
| 126 |
+
nn.Sigmoid()
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
self.relu = nn.ReLU(inplace=True)
|
| 130 |
+
|
| 131 |
+
def forward(self,g,x):
|
| 132 |
+
g1 = self.W_g(g)
|
| 133 |
+
x1 = self.W_x(x)
|
| 134 |
+
psi = self.relu(g1+x1)
|
| 135 |
+
psi = self.psi(psi)
|
| 136 |
+
|
| 137 |
+
return x*psi
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class U_Net(nn.Module):
|
| 141 |
+
def __init__(self,img_ch=3,output_ch=1):
|
| 142 |
+
super(U_Net,self).__init__()
|
| 143 |
+
|
| 144 |
+
self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
|
| 145 |
+
|
| 146 |
+
self.Conv1 = conv_block(ch_in=img_ch,ch_out=64)
|
| 147 |
+
self.Conv2 = conv_block(ch_in=64,ch_out=128)
|
| 148 |
+
self.Conv3 = conv_block(ch_in=128,ch_out=256)
|
| 149 |
+
self.Conv4 = conv_block(ch_in=256,ch_out=512)
|
| 150 |
+
self.Conv5 = conv_block(ch_in=512,ch_out=1024)
|
| 151 |
+
|
| 152 |
+
self.Up5 = up_conv(ch_in=1024,ch_out=512)
|
| 153 |
+
self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)
|
| 154 |
+
|
| 155 |
+
self.Up4 = up_conv(ch_in=512,ch_out=256)
|
| 156 |
+
self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
|
| 157 |
+
|
| 158 |
+
self.Up3 = up_conv(ch_in=256,ch_out=128)
|
| 159 |
+
self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
|
| 160 |
+
|
| 161 |
+
self.Up2 = up_conv(ch_in=128,ch_out=64)
|
| 162 |
+
self.Up_conv2 = conv_block(ch_in=128, ch_out=64)
|
| 163 |
+
|
| 164 |
+
self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def forward(self,x):
|
| 168 |
+
# encoding path
|
| 169 |
+
x1 = self.Conv1(x)
|
| 170 |
+
|
| 171 |
+
x2 = self.Maxpool(x1)
|
| 172 |
+
x2 = self.Conv2(x2)
|
| 173 |
+
|
| 174 |
+
x3 = self.Maxpool(x2)
|
| 175 |
+
x3 = self.Conv3(x3)
|
| 176 |
+
|
| 177 |
+
x4 = self.Maxpool(x3)
|
| 178 |
+
x4 = self.Conv4(x4)
|
| 179 |
+
|
| 180 |
+
x5 = self.Maxpool(x4)
|
| 181 |
+
x5 = self.Conv5(x5)
|
| 182 |
+
|
| 183 |
+
# decoding + concat path
|
| 184 |
+
d5 = self.Up5(x5)
|
| 185 |
+
d5 = torch.cat((x4,d5),dim=1)
|
| 186 |
+
|
| 187 |
+
d5 = self.Up_conv5(d5)
|
| 188 |
+
|
| 189 |
+
d4 = self.Up4(d5)
|
| 190 |
+
d4 = torch.cat((x3,d4),dim=1)
|
| 191 |
+
d4 = self.Up_conv4(d4)
|
| 192 |
+
|
| 193 |
+
d3 = self.Up3(d4)
|
| 194 |
+
d3 = torch.cat((x2,d3),dim=1)
|
| 195 |
+
d3 = self.Up_conv3(d3)
|
| 196 |
+
|
| 197 |
+
d2 = self.Up2(d3)
|
| 198 |
+
d2 = torch.cat((x1,d2),dim=1)
|
| 199 |
+
d2 = self.Up_conv2(d2)
|
| 200 |
+
|
| 201 |
+
d1 = self.Conv_1x1(d2)
|
| 202 |
+
|
| 203 |
+
return d1
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class R2U_Net(nn.Module):
|
| 207 |
+
def __init__(self,img_ch=3,output_ch=1,t=2):
|
| 208 |
+
super(R2U_Net,self).__init__()
|
| 209 |
+
|
| 210 |
+
self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
|
| 211 |
+
self.Upsample = nn.Upsample(scale_factor=2)
|
| 212 |
+
|
| 213 |
+
self.RRCNN1 = RRCNN_block(ch_in=img_ch,ch_out=64,t=t)
|
| 214 |
+
|
| 215 |
+
self.RRCNN2 = RRCNN_block(ch_in=64,ch_out=128,t=t)
|
| 216 |
+
|
| 217 |
+
self.RRCNN3 = RRCNN_block(ch_in=128,ch_out=256,t=t)
|
| 218 |
+
|
| 219 |
+
self.RRCNN4 = RRCNN_block(ch_in=256,ch_out=512,t=t)
|
| 220 |
+
|
| 221 |
+
self.RRCNN5 = RRCNN_block(ch_in=512,ch_out=1024,t=t)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
self.Up5 = up_conv(ch_in=1024,ch_out=512)
|
| 225 |
+
self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512,t=t)
|
| 226 |
+
|
| 227 |
+
self.Up4 = up_conv(ch_in=512,ch_out=256)
|
| 228 |
+
self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256,t=t)
|
| 229 |
+
|
| 230 |
+
self.Up3 = up_conv(ch_in=256,ch_out=128)
|
| 231 |
+
self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128,t=t)
|
| 232 |
+
|
| 233 |
+
self.Up2 = up_conv(ch_in=128,ch_out=64)
|
| 234 |
+
self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64,t=t)
|
| 235 |
+
|
| 236 |
+
self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def forward(self,x):
|
| 240 |
+
# encoding path
|
| 241 |
+
x1 = self.RRCNN1(x)
|
| 242 |
+
|
| 243 |
+
x2 = self.Maxpool(x1)
|
| 244 |
+
x2 = self.RRCNN2(x2)
|
| 245 |
+
|
| 246 |
+
x3 = self.Maxpool(x2)
|
| 247 |
+
x3 = self.RRCNN3(x3)
|
| 248 |
+
|
| 249 |
+
x4 = self.Maxpool(x3)
|
| 250 |
+
x4 = self.RRCNN4(x4)
|
| 251 |
+
|
| 252 |
+
x5 = self.Maxpool(x4)
|
| 253 |
+
x5 = self.RRCNN5(x5)
|
| 254 |
+
|
| 255 |
+
# decoding + concat path
|
| 256 |
+
d5 = self.Up5(x5)
|
| 257 |
+
d5 = torch.cat((x4,d5),dim=1)
|
| 258 |
+
d5 = self.Up_RRCNN5(d5)
|
| 259 |
+
|
| 260 |
+
d4 = self.Up4(d5)
|
| 261 |
+
d4 = torch.cat((x3,d4),dim=1)
|
| 262 |
+
d4 = self.Up_RRCNN4(d4)
|
| 263 |
+
|
| 264 |
+
d3 = self.Up3(d4)
|
| 265 |
+
d3 = torch.cat((x2,d3),dim=1)
|
| 266 |
+
d3 = self.Up_RRCNN3(d3)
|
| 267 |
+
|
| 268 |
+
d2 = self.Up2(d3)
|
| 269 |
+
d2 = torch.cat((x1,d2),dim=1)
|
| 270 |
+
d2 = self.Up_RRCNN2(d2)
|
| 271 |
+
|
| 272 |
+
d1 = self.Conv_1x1(d2)
|
| 273 |
+
|
| 274 |
+
return d1
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
class AttU_Net(nn.Module):
|
| 279 |
+
def __init__(self,img_ch=3,output_ch=1):
|
| 280 |
+
super(AttU_Net,self).__init__()
|
| 281 |
+
|
| 282 |
+
self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
|
| 283 |
+
|
| 284 |
+
self.Conv1 = conv_block(ch_in=img_ch,ch_out=64)
|
| 285 |
+
self.Conv2 = conv_block(ch_in=64,ch_out=128)
|
| 286 |
+
self.Conv3 = conv_block(ch_in=128,ch_out=256)
|
| 287 |
+
self.Conv4 = conv_block(ch_in=256,ch_out=512)
|
| 288 |
+
self.Conv5 = conv_block(ch_in=512,ch_out=1024)
|
| 289 |
+
|
| 290 |
+
self.Up5 = up_conv(ch_in=1024,ch_out=512)
|
| 291 |
+
self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256)
|
| 292 |
+
self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)
|
| 293 |
+
|
| 294 |
+
self.Up4 = up_conv(ch_in=512,ch_out=256)
|
| 295 |
+
self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128)
|
| 296 |
+
self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
|
| 297 |
+
|
| 298 |
+
self.Up3 = up_conv(ch_in=256,ch_out=128)
|
| 299 |
+
self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64)
|
| 300 |
+
self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
|
| 301 |
+
|
| 302 |
+
self.Up2 = up_conv(ch_in=128,ch_out=64)
|
| 303 |
+
self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32)
|
| 304 |
+
self.Up_conv2 = conv_block(ch_in=128, ch_out=64)
|
| 305 |
+
|
| 306 |
+
self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def forward(self,x):
|
| 310 |
+
# encoding path
|
| 311 |
+
x1 = self.Conv1(x)
|
| 312 |
+
|
| 313 |
+
x2 = self.Maxpool(x1)
|
| 314 |
+
x2 = self.Conv2(x2)
|
| 315 |
+
|
| 316 |
+
x3 = self.Maxpool(x2)
|
| 317 |
+
x3 = self.Conv3(x3)
|
| 318 |
+
|
| 319 |
+
x4 = self.Maxpool(x3)
|
| 320 |
+
x4 = self.Conv4(x4)
|
| 321 |
+
|
| 322 |
+
x5 = self.Maxpool(x4)
|
| 323 |
+
x5 = self.Conv5(x5)
|
| 324 |
+
|
| 325 |
+
# decoding + concat path
|
| 326 |
+
d5 = self.Up5(x5)
|
| 327 |
+
x4 = self.Att5(g=d5,x=x4)
|
| 328 |
+
d5 = torch.cat((x4,d5),dim=1)
|
| 329 |
+
d5 = self.Up_conv5(d5)
|
| 330 |
+
|
| 331 |
+
d4 = self.Up4(d5)
|
| 332 |
+
x3 = self.Att4(g=d4,x=x3)
|
| 333 |
+
d4 = torch.cat((x3,d4),dim=1)
|
| 334 |
+
d4 = self.Up_conv4(d4)
|
| 335 |
+
|
| 336 |
+
d3 = self.Up3(d4)
|
| 337 |
+
x2 = self.Att3(g=d3,x=x2)
|
| 338 |
+
d3 = torch.cat((x2,d3),dim=1)
|
| 339 |
+
d3 = self.Up_conv3(d3)
|
| 340 |
+
|
| 341 |
+
d2 = self.Up2(d3)
|
| 342 |
+
x1 = self.Att2(g=d2,x=x1)
|
| 343 |
+
d2 = torch.cat((x1,d2),dim=1)
|
| 344 |
+
d2 = self.Up_conv2(d2)
|
| 345 |
+
|
| 346 |
+
d1 = self.Conv_1x1(d2)
|
| 347 |
+
|
| 348 |
+
return d1
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
class R2AttU_Net(nn.Module):
|
| 352 |
+
def __init__(self,img_ch=3,output_ch=1,t=2):
|
| 353 |
+
super(R2AttU_Net,self).__init__()
|
| 354 |
+
|
| 355 |
+
self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
|
| 356 |
+
self.Upsample = nn.Upsample(scale_factor=2)
|
| 357 |
+
|
| 358 |
+
self.RRCNN1 = RRCNN_block(ch_in=img_ch,ch_out=64,t=t)
|
| 359 |
+
|
| 360 |
+
self.RRCNN2 = RRCNN_block(ch_in=64,ch_out=128,t=t)
|
| 361 |
+
|
| 362 |
+
self.RRCNN3 = RRCNN_block(ch_in=128,ch_out=256,t=t)
|
| 363 |
+
|
| 364 |
+
self.RRCNN4 = RRCNN_block(ch_in=256,ch_out=512,t=t)
|
| 365 |
+
|
| 366 |
+
self.RRCNN5 = RRCNN_block(ch_in=512,ch_out=1024,t=t)
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
self.Up5 = up_conv(ch_in=1024,ch_out=512)
|
| 370 |
+
self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256)
|
| 371 |
+
self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512,t=t)
|
| 372 |
+
|
| 373 |
+
self.Up4 = up_conv(ch_in=512,ch_out=256)
|
| 374 |
+
self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128)
|
| 375 |
+
self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256,t=t)
|
| 376 |
+
|
| 377 |
+
self.Up3 = up_conv(ch_in=256,ch_out=128)
|
| 378 |
+
self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64)
|
| 379 |
+
self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128,t=t)
|
| 380 |
+
|
| 381 |
+
self.Up2 = up_conv(ch_in=128,ch_out=64)
|
| 382 |
+
self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32)
|
| 383 |
+
self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64,t=t)
|
| 384 |
+
|
| 385 |
+
self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def forward(self,x):
|
| 389 |
+
# encoding path
|
| 390 |
+
x1 = self.RRCNN1(x)
|
| 391 |
+
|
| 392 |
+
x2 = self.Maxpool(x1)
|
| 393 |
+
x2 = self.RRCNN2(x2)
|
| 394 |
+
|
| 395 |
+
x3 = self.Maxpool(x2)
|
| 396 |
+
x3 = self.RRCNN3(x3)
|
| 397 |
+
|
| 398 |
+
x4 = self.Maxpool(x3)
|
| 399 |
+
x4 = self.RRCNN4(x4)
|
| 400 |
+
|
| 401 |
+
x5 = self.Maxpool(x4)
|
| 402 |
+
x5 = self.RRCNN5(x5)
|
| 403 |
+
|
| 404 |
+
# decoding + concat path
|
| 405 |
+
d5 = self.Up5(x5)
|
| 406 |
+
x4 = self.Att5(g=d5,x=x4)
|
| 407 |
+
d5 = torch.cat((x4,d5),dim=1)
|
| 408 |
+
d5 = self.Up_RRCNN5(d5)
|
| 409 |
+
|
| 410 |
+
d4 = self.Up4(d5)
|
| 411 |
+
x3 = self.Att4(g=d4,x=x3)
|
| 412 |
+
d4 = torch.cat((x3,d4),dim=1)
|
| 413 |
+
d4 = self.Up_RRCNN4(d4)
|
| 414 |
+
|
| 415 |
+
d3 = self.Up3(d4)
|
| 416 |
+
x2 = self.Att3(g=d3,x=x2)
|
| 417 |
+
d3 = torch.cat((x2,d3),dim=1)
|
| 418 |
+
d3 = self.Up_RRCNN3(d3)
|
| 419 |
+
|
| 420 |
+
d2 = self.Up2(d3)
|
| 421 |
+
x1 = self.Att2(g=d2,x=x1)
|
| 422 |
+
d2 = torch.cat((x1,d2),dim=1)
|
| 423 |
+
d2 = self.Up_RRCNN2(d2)
|
| 424 |
+
|
| 425 |
+
d1 = self.Conv_1x1(d2)
|
| 426 |
+
|
| 427 |
+
return d1
|
models/multiresunet.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://github.com/j-sripad/mulitresunet-pytorch/blob/main/multiresunet.py
|
| 2 |
+
|
| 3 |
+
from typing import Tuple, Dict
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Multiresblock(nn.Module):
|
| 10 |
+
def __init__(self,input_features : int, corresponding_unet_filters : int ,alpha : float =1.67)->None:
|
| 11 |
+
"""
|
| 12 |
+
MultiResblock
|
| 13 |
+
Arguments:
|
| 14 |
+
x - input layer
|
| 15 |
+
corresponding_unet_filters - Unet filters for the same stage
|
| 16 |
+
alpha - 1.67 - factor used in the paper to dervie number of filters for multiresunet filters from Unet filters
|
| 17 |
+
Returns - None
|
| 18 |
+
"""
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.corresponding_unet_filters = corresponding_unet_filters
|
| 21 |
+
self.alpha = alpha
|
| 22 |
+
self.W = corresponding_unet_filters * alpha
|
| 23 |
+
self.conv2d_bn_1x1 = Conv2d_batchnorm(input_features=input_features,num_of_filters = int(self.W*0.167)+int(self.W*0.333)+int(self.W*0.5),
|
| 24 |
+
kernel_size = (1,1),activation='None',padding = 0)
|
| 25 |
+
|
| 26 |
+
self.conv2d_bn_3x3 = Conv2d_batchnorm(input_features=input_features,num_of_filters = int(self.W*0.167),
|
| 27 |
+
kernel_size = (3,3),activation='relu',padding = 1)
|
| 28 |
+
self.conv2d_bn_5x5 = Conv2d_batchnorm(input_features=int(self.W*0.167),num_of_filters = int(self.W*0.333),
|
| 29 |
+
kernel_size = (3,3),activation='relu',padding = 1)
|
| 30 |
+
self.conv2d_bn_7x7 = Conv2d_batchnorm(input_features=int(self.W*0.333),num_of_filters = int(self.W*0.5),
|
| 31 |
+
kernel_size = (3,3),activation='relu',padding = 1)
|
| 32 |
+
self.batch_norm1 = nn.BatchNorm2d(int(self.W*0.5)+int(self.W*0.167)+int(self.W*0.333) ,affine=False)
|
| 33 |
+
|
| 34 |
+
def forward(self,x: torch.Tensor)->torch.Tensor:
|
| 35 |
+
|
| 36 |
+
temp = self.conv2d_bn_1x1(x)
|
| 37 |
+
a = self.conv2d_bn_3x3(x)
|
| 38 |
+
b = self.conv2d_bn_5x5(a)
|
| 39 |
+
c = self.conv2d_bn_7x7(b)
|
| 40 |
+
x = torch.cat([a,b,c],axis=1)
|
| 41 |
+
x = self.batch_norm1(x)
|
| 42 |
+
x = x + temp
|
| 43 |
+
x = self.batch_norm1(x)
|
| 44 |
+
return x
|
| 45 |
+
|
| 46 |
+
class Conv2d_batchnorm(nn.Module):
|
| 47 |
+
def __init__(self,input_features : int,num_of_filters : int ,kernel_size : Tuple = (2,2),stride : Tuple = (1,1), activation : str = 'relu',padding : int= 0)->None:
|
| 48 |
+
"""
|
| 49 |
+
Arguments:
|
| 50 |
+
x - input layer
|
| 51 |
+
num_of_filters - no. of filter outputs
|
| 52 |
+
filters - shape of the filters to be used
|
| 53 |
+
stride - stride dimension
|
| 54 |
+
activation -activation function to be used
|
| 55 |
+
Returns - None
|
| 56 |
+
"""
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.activation = activation
|
| 59 |
+
self.conv1 = nn.Conv2d(in_channels=input_features,out_channels=num_of_filters,kernel_size=kernel_size,stride=stride,padding = padding)
|
| 60 |
+
self.batchnorm = nn.BatchNorm2d(num_of_filters,affine=False)
|
| 61 |
+
|
| 62 |
+
def forward(self,x : torch.Tensor)->torch.Tensor:
|
| 63 |
+
x = self.conv1(x)
|
| 64 |
+
x = self.batchnorm(x)
|
| 65 |
+
if self.activation == 'relu':
|
| 66 |
+
return F.relu(x)
|
| 67 |
+
else:
|
| 68 |
+
return x
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class Respath(nn.Module):
|
| 72 |
+
def __init__(self,input_features : int,filters : int,respath_length : int)->None:
|
| 73 |
+
"""
|
| 74 |
+
Arguments:
|
| 75 |
+
input_features - input layer filters
|
| 76 |
+
filters - output channels
|
| 77 |
+
respath_length - length of the Respath
|
| 78 |
+
|
| 79 |
+
Returns - None
|
| 80 |
+
"""
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.filters = filters
|
| 83 |
+
self.respath_length = respath_length
|
| 84 |
+
self.conv2d_bn_1x1 = Conv2d_batchnorm(input_features=input_features,num_of_filters = self.filters,
|
| 85 |
+
kernel_size = (1,1),activation='None',padding = 0)
|
| 86 |
+
self.conv2d_bn_3x3 = Conv2d_batchnorm(input_features=input_features,num_of_filters = self.filters,
|
| 87 |
+
kernel_size = (3,3),activation='relu',padding = 1)
|
| 88 |
+
self.conv2d_bn_1x1_common = Conv2d_batchnorm(input_features=self.filters,num_of_filters = self.filters,
|
| 89 |
+
kernel_size = (1,1),activation='None',padding = 0)
|
| 90 |
+
self.conv2d_bn_3x3_common = Conv2d_batchnorm(input_features=self.filters,num_of_filters = self.filters,
|
| 91 |
+
kernel_size = (3,3),activation='relu',padding = 1)
|
| 92 |
+
self.batch_norm1 = nn.BatchNorm2d(filters,affine=False)
|
| 93 |
+
|
| 94 |
+
def forward(self,x : torch.Tensor)->torch.Tensor:
|
| 95 |
+
shortcut = self.conv2d_bn_1x1(x)
|
| 96 |
+
x = self.conv2d_bn_3x3(x)
|
| 97 |
+
x = x + shortcut
|
| 98 |
+
x = F.relu(x)
|
| 99 |
+
x = self.batch_norm1(x)
|
| 100 |
+
if self.respath_length>1:
|
| 101 |
+
for i in range(self.respath_length):
|
| 102 |
+
shortcut = self.conv2d_bn_1x1_common(x)
|
| 103 |
+
x = self.conv2d_bn_3x3_common(x)
|
| 104 |
+
x = x + shortcut
|
| 105 |
+
x = F.relu(x)
|
| 106 |
+
x = self.batch_norm1(x)
|
| 107 |
+
return x
|
| 108 |
+
else:
|
| 109 |
+
return x
|
| 110 |
+
|
| 111 |
+
class MultiResUnet(nn.Module):
|
| 112 |
+
def __init__(self,channels : int,filters : int =32,nclasses : int =1)->None:
|
| 113 |
+
|
| 114 |
+
"""
|
| 115 |
+
Arguments:
|
| 116 |
+
channels - input image channels
|
| 117 |
+
filters - filters to begin with (Unet)
|
| 118 |
+
nclasses - number of classes
|
| 119 |
+
Returns - None
|
| 120 |
+
"""
|
| 121 |
+
super().__init__()
|
| 122 |
+
self.alpha = 1.67
|
| 123 |
+
self.filters = filters
|
| 124 |
+
self.nclasses = nclasses
|
| 125 |
+
self.multiresblock1 = Multiresblock(input_features=channels,corresponding_unet_filters=self.filters)
|
| 126 |
+
self.pool1 = nn.MaxPool2d(2,stride= 2)
|
| 127 |
+
self.in_filters1 = int(self.filters*self.alpha* 0.5)+int(self.filters*self.alpha*0.167)+int(self.filters*self.alpha*0.333)
|
| 128 |
+
self.respath1 = Respath(input_features=self.in_filters1 ,filters=self.filters,respath_length=4)
|
| 129 |
+
self.multiresblock2 = Multiresblock(input_features= self.in_filters1,corresponding_unet_filters=self.filters*2)
|
| 130 |
+
self.pool2 = nn.MaxPool2d(2, 2)
|
| 131 |
+
self.in_filters2 = int(self.filters*2*self.alpha* 0.5)+int(self.filters*2*self.alpha*0.167)+int(self.filters*2*self.alpha*0.333)
|
| 132 |
+
self.respath2 = Respath(input_features=self.in_filters2,filters=self.filters*2,respath_length=3)
|
| 133 |
+
self.multiresblock3 = Multiresblock(input_features= self.in_filters2,corresponding_unet_filters=self.filters*4)
|
| 134 |
+
self.pool3 = nn.MaxPool2d(2, 2)
|
| 135 |
+
self.in_filters3 = int(self.filters*4*self.alpha* 0.5)+int(self.filters*4*self.alpha*0.167)+int(self.filters*4*self.alpha*0.333)
|
| 136 |
+
self.respath3 = Respath(input_features=self.in_filters3,filters=self.filters*4,respath_length=2)
|
| 137 |
+
self.multiresblock4 = Multiresblock(input_features= self.in_filters3,corresponding_unet_filters=self.filters*8)
|
| 138 |
+
self.pool4 = nn.MaxPool2d(2, 2)
|
| 139 |
+
self.in_filters4 = int(self.filters*8*self.alpha* 0.5)+int(self.filters*8*self.alpha*0.167)+int(self.filters*8*self.alpha*0.333)
|
| 140 |
+
self.respath4 = Respath(input_features=self.in_filters4,filters=self.filters*8,respath_length=1)
|
| 141 |
+
self.multiresblock5 = Multiresblock(input_features= self.in_filters4,corresponding_unet_filters=self.filters*16)
|
| 142 |
+
self.in_filters5 = int(self.filters*16*self.alpha* 0.5)+int(self.filters*16*self.alpha*0.167)+int(self.filters*16*self.alpha*0.333)
|
| 143 |
+
|
| 144 |
+
#Decoder path
|
| 145 |
+
self.upsample6 = nn.ConvTranspose2d(in_channels=self.in_filters5,out_channels=self.filters*8,kernel_size=(2,2),stride=(2,2),padding = 0)
|
| 146 |
+
self.concat_filters1 = self.filters*8+self.filters*8
|
| 147 |
+
self.multiresblock6 = Multiresblock(input_features=self.concat_filters1,corresponding_unet_filters=self.filters*8)
|
| 148 |
+
self.in_filters6 = int(self.filters*8*self.alpha* 0.5)+int(self.filters*8*self.alpha*0.167)+int(self.filters*8*self.alpha*0.333)
|
| 149 |
+
self.upsample7 = nn.ConvTranspose2d(in_channels=self.in_filters6,out_channels=self.filters*4,kernel_size=(2,2),stride=(2,2),padding = 0)
|
| 150 |
+
self.concat_filters2 = self.filters*4+self.filters*4
|
| 151 |
+
self.multiresblock7 = Multiresblock(input_features=self.concat_filters2,corresponding_unet_filters=self.filters*4)
|
| 152 |
+
self.in_filters7 = int(self.filters*4*self.alpha* 0.5)+int(self.filters*4*self.alpha*0.167)+int(self.filters*4*self.alpha*0.333)
|
| 153 |
+
self.upsample8 = nn.ConvTranspose2d(in_channels=self.in_filters7,out_channels=self.filters*2,kernel_size=(2,2),stride=(2,2),padding = 0)
|
| 154 |
+
self.concat_filters3 = self.filters*2+self.filters*2
|
| 155 |
+
self.multiresblock8 = Multiresblock(input_features=self.concat_filters3,corresponding_unet_filters=self.filters*2)
|
| 156 |
+
self.in_filters8 = int(self.filters*2*self.alpha* 0.5)+int(self.filters*2*self.alpha*0.167)+int(self.filters*2*self.alpha*0.333)
|
| 157 |
+
self.upsample9 = nn.ConvTranspose2d(in_channels=self.in_filters8,out_channels=self.filters,kernel_size=(2,2),stride=(2,2),padding = 0)
|
| 158 |
+
self.concat_filters4 = self.filters+self.filters
|
| 159 |
+
self.multiresblock9 = Multiresblock(input_features=self.concat_filters4,corresponding_unet_filters=self.filters)
|
| 160 |
+
self.in_filters9 = int(self.filters*self.alpha* 0.5)+int(self.filters*self.alpha*0.167)+int(self.filters*self.alpha*0.333)
|
| 161 |
+
self.conv_final = Conv2d_batchnorm(input_features=self.in_filters9,num_of_filters = self.nclasses,
|
| 162 |
+
kernel_size = (1,1),activation='None')
|
| 163 |
+
|
| 164 |
+
def forward(self,x : torch.Tensor)->torch.Tensor:
|
| 165 |
+
x_multires1 = self.multiresblock1(x)
|
| 166 |
+
x_pool1 = self.pool1(x_multires1)
|
| 167 |
+
x_multires1 = self.respath1(x_multires1)
|
| 168 |
+
x_multires2 = self.multiresblock2(x_pool1)
|
| 169 |
+
x_pool2 = self.pool2(x_multires2)
|
| 170 |
+
x_multires2 = self.respath2(x_multires2)
|
| 171 |
+
x_multires3 = self.multiresblock3(x_pool2)
|
| 172 |
+
x_pool3 = self.pool3(x_multires3)
|
| 173 |
+
x_multires3 = self.respath3(x_multires3)
|
| 174 |
+
x_multires4 = self.multiresblock4(x_pool3)
|
| 175 |
+
x_pool4 = self.pool4(x_multires4)
|
| 176 |
+
x_multires4 = self.respath4(x_multires4)
|
| 177 |
+
x_multires5 = self.multiresblock5(x_pool4)
|
| 178 |
+
up6 = torch.cat([self.upsample6(x_multires5),x_multires4],axis=1)
|
| 179 |
+
x_multires6 = self.multiresblock6(up6)
|
| 180 |
+
up7 = torch.cat([self.upsample7(x_multires6),x_multires3],axis=1)
|
| 181 |
+
x_multires7 = self.multiresblock7(up7)
|
| 182 |
+
up8 = torch.cat([self.upsample8(x_multires7),x_multires2],axis=1)
|
| 183 |
+
x_multires8 = self.multiresblock8(up8)
|
| 184 |
+
up9 = torch.cat([self.upsample9(x_multires8),x_multires1],axis=1)
|
| 185 |
+
x_multires9 = self.multiresblock9(up9)
|
| 186 |
+
if self.nclasses > 1:
|
| 187 |
+
conv_final_layer = self.conv_final(x_multires9)
|
| 188 |
+
else:
|
| 189 |
+
conv_final_layer = torch.sigmoid(self.conv_final(x_multires9))
|
| 190 |
+
return conv_final_layer
|
models/unet.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class DoubleConv(nn.Module):
|
| 7 |
+
def __init__(self, in_channels, out_channels, with_bn=False):
|
| 8 |
+
super().__init__()
|
| 9 |
+
if with_bn:
|
| 10 |
+
self.step = nn.Sequential(
|
| 11 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
| 12 |
+
nn.BatchNorm2d(out_channels),
|
| 13 |
+
nn.ReLU(),
|
| 14 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
| 15 |
+
nn.BatchNorm2d(out_channels),
|
| 16 |
+
nn.ReLU(),
|
| 17 |
+
)
|
| 18 |
+
else:
|
| 19 |
+
self.step = nn.Sequential(
|
| 20 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
| 21 |
+
nn.ReLU(),
|
| 22 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
| 23 |
+
nn.ReLU(),
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
def forward(self, x):
|
| 27 |
+
return self.step(x)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class UNet(nn.Module):
|
| 31 |
+
def __init__(self, in_channels, out_channels, with_bn=False):
|
| 32 |
+
super().__init__()
|
| 33 |
+
init_channels = 32
|
| 34 |
+
self.out_channels = out_channels
|
| 35 |
+
|
| 36 |
+
self.en_1 = DoubleConv(in_channels , init_channels , with_bn)
|
| 37 |
+
self.en_2 = DoubleConv(1*init_channels, 2*init_channels, with_bn)
|
| 38 |
+
self.en_3 = DoubleConv(2*init_channels, 4*init_channels, with_bn)
|
| 39 |
+
self.en_4 = DoubleConv(4*init_channels, 8*init_channels, with_bn)
|
| 40 |
+
|
| 41 |
+
self.de_1 = DoubleConv((4 + 8)*init_channels, 4*init_channels, with_bn)
|
| 42 |
+
self.de_2 = DoubleConv((2 + 4)*init_channels, 2*init_channels, with_bn)
|
| 43 |
+
self.de_3 = DoubleConv((1 + 2)*init_channels, 1*init_channels, with_bn)
|
| 44 |
+
self.de_4 = nn.Conv2d(init_channels, out_channels, 1)
|
| 45 |
+
|
| 46 |
+
self.maxpool = nn.MaxPool2d(kernel_size=2)
|
| 47 |
+
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')
|
| 48 |
+
|
| 49 |
+
def forward(self, x):
|
| 50 |
+
e1 = self.en_1(x)
|
| 51 |
+
e2 = self.en_2(self.maxpool(e1))
|
| 52 |
+
e3 = self.en_3(self.maxpool(e2))
|
| 53 |
+
e4 = self.en_4(self.maxpool(e3))
|
| 54 |
+
|
| 55 |
+
d1 = self.de_1(torch.cat([self.upsample(e4), e3], dim=1))
|
| 56 |
+
d2 = self.de_2(torch.cat([self.upsample(d1), e2], dim=1))
|
| 57 |
+
d3 = self.de_3(torch.cat([self.upsample(d2), e1], dim=1))
|
| 58 |
+
d4 = self.de_4(d3)
|
| 59 |
+
|
| 60 |
+
return d4
|
| 61 |
+
|
| 62 |
+
# if self.out_channels<2:
|
| 63 |
+
# return torch.sigmoid(d4)
|
| 64 |
+
# return torch.softmax(d4, 1)
|
models/unetpp.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://github.com/4uiiurz1/pytorch-nested-unet/blob/master/archs.py (unetpp)
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.nn.functional import softmax, sigmoid
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
__all__ = ['UNet', 'NestedUNet']
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class VGGBlock(nn.Module):
|
| 13 |
+
def __init__(self, in_channels, middle_channels, out_channels):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.relu = nn.ReLU(inplace=True)
|
| 16 |
+
self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
|
| 17 |
+
self.bn1 = nn.BatchNorm2d(middle_channels)
|
| 18 |
+
self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
|
| 19 |
+
self.bn2 = nn.BatchNorm2d(out_channels)
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
out = self.conv1(x)
|
| 23 |
+
out = self.bn1(out)
|
| 24 |
+
out = self.relu(out)
|
| 25 |
+
|
| 26 |
+
out = self.conv2(out)
|
| 27 |
+
out = self.bn2(out)
|
| 28 |
+
out = self.relu(out)
|
| 29 |
+
|
| 30 |
+
return out
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class UNet(nn.Module):
|
| 34 |
+
def __init__(self, num_classes, input_channels=3, **kwargs):
|
| 35 |
+
super().__init__()
|
| 36 |
+
|
| 37 |
+
nb_filter = [32, 64, 128, 256, 512]
|
| 38 |
+
|
| 39 |
+
self.pool = nn.MaxPool2d(2, 2)
|
| 40 |
+
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
| 41 |
+
|
| 42 |
+
self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
|
| 43 |
+
self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
|
| 44 |
+
self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
|
| 45 |
+
self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
|
| 46 |
+
self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])
|
| 47 |
+
|
| 48 |
+
self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])
|
| 49 |
+
self.conv2_2 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
|
| 50 |
+
self.conv1_3 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
|
| 51 |
+
self.conv0_4 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
|
| 52 |
+
|
| 53 |
+
self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def forward(self, input):
|
| 57 |
+
x0_0 = self.conv0_0(input)
|
| 58 |
+
x1_0 = self.conv1_0(self.pool(x0_0))
|
| 59 |
+
x2_0 = self.conv2_0(self.pool(x1_0))
|
| 60 |
+
x3_0 = self.conv3_0(self.pool(x2_0))
|
| 61 |
+
x4_0 = self.conv4_0(self.pool(x3_0))
|
| 62 |
+
|
| 63 |
+
x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
|
| 64 |
+
x2_2 = self.conv2_2(torch.cat([x2_0, self.up(x3_1)], 1))
|
| 65 |
+
x1_3 = self.conv1_3(torch.cat([x1_0, self.up(x2_2)], 1))
|
| 66 |
+
x0_4 = self.conv0_4(torch.cat([x0_0, self.up(x1_3)], 1))
|
| 67 |
+
|
| 68 |
+
output = self.final(x0_4)
|
| 69 |
+
return output
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class NestedUNet(nn.Module):
|
| 73 |
+
def __init__(self, num_classes, input_channels=3, deep_supervision=False, **kwargs):
|
| 74 |
+
super().__init__()
|
| 75 |
+
|
| 76 |
+
nb_filter = [32, 64, 128, 256, 512]
|
| 77 |
+
|
| 78 |
+
self.deep_supervision = deep_supervision
|
| 79 |
+
|
| 80 |
+
self.pool = nn.MaxPool2d(2, 2)
|
| 81 |
+
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
| 82 |
+
|
| 83 |
+
self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
|
| 84 |
+
self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
|
| 85 |
+
self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
|
| 86 |
+
self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
|
| 87 |
+
self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])
|
| 88 |
+
|
| 89 |
+
self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
|
| 90 |
+
self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
|
| 91 |
+
self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
|
| 92 |
+
self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])
|
| 93 |
+
|
| 94 |
+
self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0])
|
| 95 |
+
self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1])
|
| 96 |
+
self.conv2_2 = VGGBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], nb_filter[2])
|
| 97 |
+
|
| 98 |
+
self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0])
|
| 99 |
+
self.conv1_3 = VGGBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], nb_filter[1])
|
| 100 |
+
|
| 101 |
+
self.conv0_4 = VGGBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0], nb_filter[0])
|
| 102 |
+
|
| 103 |
+
if self.deep_supervision:
|
| 104 |
+
self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
|
| 105 |
+
self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
|
| 106 |
+
self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
|
| 107 |
+
self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
|
| 108 |
+
else:
|
| 109 |
+
self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def forward(self, input):
|
| 113 |
+
x0_0 = self.conv0_0(input)
|
| 114 |
+
x1_0 = self.conv1_0(self.pool(x0_0))
|
| 115 |
+
x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))
|
| 116 |
+
|
| 117 |
+
x2_0 = self.conv2_0(self.pool(x1_0))
|
| 118 |
+
x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
|
| 119 |
+
x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))
|
| 120 |
+
|
| 121 |
+
x3_0 = self.conv3_0(self.pool(x2_0))
|
| 122 |
+
x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
|
| 123 |
+
x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
|
| 124 |
+
x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))
|
| 125 |
+
|
| 126 |
+
x4_0 = self.conv4_0(self.pool(x3_0))
|
| 127 |
+
x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
|
| 128 |
+
x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
|
| 129 |
+
x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
|
| 130 |
+
x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))
|
| 131 |
+
|
| 132 |
+
if self.deep_supervision:
|
| 133 |
+
output1 = self.final1(x0_1)
|
| 134 |
+
output2 = self.final2(x0_2)
|
| 135 |
+
output3 = self.final3(x0_3)
|
| 136 |
+
output4 = self.final4(x0_4)
|
| 137 |
+
return [output1, output2, output3, output4]
|
| 138 |
+
|
| 139 |
+
else:
|
| 140 |
+
output = self.final(x0_4)
|
| 141 |
+
return output
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
gradio
|
| 4 |
+
numpy
|
| 5 |
+
Pillow
|
| 6 |
+
pyyaml
|
saved_models/isic2018_unet/best_model_state_dict.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8559876e4b85070ffbe84748e02926a9e8690b09bf101a12f7e5c5e590decbf0
|
| 3 |
+
size 7799041
|
saved_models/segpc2021_unet/best_model_state_dict.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a182b4f7a415d056ae7e5293aed483804494580eb3f1a3b27d04e77c55468e76
|
| 3 |
+
size 7800193
|