Spaces:
Sleeping
Sleeping
Yao Zhang
commited on
Commit
·
59ecb50
1
Parent(s):
45ecef6
init
Browse files- README.md +3 -3
- __init__.py +0 -0
- app.py +144 -0
- best_Dice_model.pth +3 -0
- packages.txt +0 -0
- requirements.txt +6 -0
- unetr2d.py +202 -0
README.md
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
---
|
| 2 |
title: NeurIPS CellSeg
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 3.4.1
|
| 8 |
app_file: app.py
|
|
|
|
| 1 |
---
|
| 2 |
title: NeurIPS CellSeg
|
| 3 |
+
emoji: 🔥
|
| 4 |
+
colorFrom: yellow
|
| 5 |
+
colorTo: red
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 3.4.1
|
| 8 |
app_file: app.py
|
__init__.py
ADDED
|
File without changes
|
app.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
# Author: Yao
|
| 4 |
+
# Mail: zhangyao215@mails.ucas.ac.cn
|
| 5 |
+
|
| 6 |
+
import gradio as gr
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
join = os.path.join
|
| 10 |
+
import time
|
| 11 |
+
import numpy as np
|
| 12 |
+
# from skimage.filters import threshold_otsu
|
| 13 |
+
# from skimage.measure import label
|
| 14 |
+
import torch
|
| 15 |
+
import monai
|
| 16 |
+
from monai.inferers import sliding_window_inference
|
| 17 |
+
from unetr2d import UNETR2D
|
| 18 |
+
import time
|
| 19 |
+
from skimage import io, segmentation, morphology, measure, exposure
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def visualize_instance_seg_mask(mask):
|
| 23 |
+
image = np.zeros((mask.shape[0], mask.shape[1], 3))
|
| 24 |
+
labels = np.unique(mask)
|
| 25 |
+
label2color = {label: (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels}
|
| 26 |
+
for i in range(image.shape[0]):
|
| 27 |
+
for j in range(image.shape[1]):
|
| 28 |
+
image[i, j, :] = label2color[mask[i, j]]
|
| 29 |
+
image = image / 255
|
| 30 |
+
return image
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def load_model(model_name, custom_model_path):
|
| 34 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 35 |
+
|
| 36 |
+
if model_name == 'unet':
|
| 37 |
+
model = monai.networks.nets.UNet(
|
| 38 |
+
spatial_dims=2,
|
| 39 |
+
in_channels=3,
|
| 40 |
+
out_channels=3,
|
| 41 |
+
channels=(16, 32, 64, 128, 256),
|
| 42 |
+
strides=(2, 2, 2, 2),
|
| 43 |
+
num_res_units=2,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
elif model_name == 'unetr':
|
| 47 |
+
model = UNETR2D(
|
| 48 |
+
in_channels=3,
|
| 49 |
+
out_channels=3,
|
| 50 |
+
img_size=(256, 256),
|
| 51 |
+
feature_size=16,
|
| 52 |
+
hidden_size=768,
|
| 53 |
+
mlp_dim=3072,
|
| 54 |
+
num_heads=12,
|
| 55 |
+
pos_embed="perceptron",
|
| 56 |
+
norm_name="instance",
|
| 57 |
+
res_block=True,
|
| 58 |
+
dropout_rate=0.0,
|
| 59 |
+
)
|
| 60 |
+
elif model_name == 'swinunetr':
|
| 61 |
+
model = monai.networks.nets.SwinUNETR(
|
| 62 |
+
img_size=(256, 256),
|
| 63 |
+
in_channels=3,
|
| 64 |
+
out_channels=3,
|
| 65 |
+
feature_size=24, # should be divisible by 12
|
| 66 |
+
spatial_dims=2
|
| 67 |
+
)
|
| 68 |
+
if os.path.isfile(custom_model_path):
|
| 69 |
+
checkpoint = torch.load(custom_model_path.resolve(), map_location=torch.device(device))
|
| 70 |
+
elif os.path.isfile(join(os.path.dirname(__file__), 'best_Dice_model.pth')):
|
| 71 |
+
checkpoint = torch.load(join(os.path.dirname(__file__), 'best_Dice_model.pth'), map_location=torch.device(device))
|
| 72 |
+
else:
|
| 73 |
+
torch.hub.download_url_to_file('https://zenodo.org/record/6792177/files/best_Dice_model.pth?download=1', join(os.path.dirname(__file__), 'work_dir/swinunetr/best_Dice_model.pth'))
|
| 74 |
+
checkpoint = torch.load(join(os.path.dirname(__file__), 'best_Dice_model.pth'), map_location=torch.device(device))
|
| 75 |
+
|
| 76 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 77 |
+
|
| 78 |
+
model = model.to(device)
|
| 79 |
+
model.eval()
|
| 80 |
+
|
| 81 |
+
return model
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def normalize_channel(img, lower=1, upper=99):
|
| 85 |
+
non_zero_vals = img[np.nonzero(img)]
|
| 86 |
+
percentiles = np.percentile(non_zero_vals, [lower, upper])
|
| 87 |
+
if percentiles[1] - percentiles[0] > 0.001:
|
| 88 |
+
img_norm = exposure.rescale_intensity(img, in_range=(percentiles[0], percentiles[1]), out_range='uint8')
|
| 89 |
+
else:
|
| 90 |
+
img_norm = img
|
| 91 |
+
return img_norm.astype(np.uint8)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def preprocess(img_data):
|
| 95 |
+
if len(img_data.shape) == 2:
|
| 96 |
+
img_data = np.repeat(np.expand_dims(img_data, axis=-1), 3, axis=-1)
|
| 97 |
+
elif len(img_data.shape) == 3 and img_data.shape[-1] > 3:
|
| 98 |
+
img_data = img_data[:,:, :3]
|
| 99 |
+
else:
|
| 100 |
+
pass
|
| 101 |
+
pre_img_data = np.zeros(img_data.shape, dtype=np.uint8)
|
| 102 |
+
for i in range(3):
|
| 103 |
+
img_channel_i = img_data[:,:,i]
|
| 104 |
+
if len(img_channel_i[np.nonzero(img_channel_i)])>0:
|
| 105 |
+
pre_img_data[:,:,i] = normalize_channel(img_channel_i, lower=1, upper=99)
|
| 106 |
+
return pre_img_data
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def get_seg(pre_img_data, model_name, custom_model_path, threshold):
|
| 110 |
+
model = load_model(model_name, custom_model_path)
|
| 111 |
+
#%%
|
| 112 |
+
roi_size = (256, 256)
|
| 113 |
+
sw_batch_size = 4
|
| 114 |
+
with torch.no_grad():
|
| 115 |
+
t0 = time.time()
|
| 116 |
+
test_npy01 = pre_img_data/np.max(pre_img_data)
|
| 117 |
+
# test_tensor = torch.from_numpy(np.expand_dims(test_npy01, 0)).permute(0,3,1,2).type(torch.FloatTensor).to(device)
|
| 118 |
+
test_tensor = torch.from_numpy(np.expand_dims(test_npy01, 0)).permute(0,3,1,2).type(torch.FloatTensor)
|
| 119 |
+
test_pred_out = sliding_window_inference(test_tensor, roi_size, sw_batch_size, model)
|
| 120 |
+
test_pred_out = torch.nn.functional.softmax(test_pred_out, dim=1) # (B, C, H, W)
|
| 121 |
+
test_pred_npy = test_pred_out[0,1].cpu().numpy()
|
| 122 |
+
# convert probability map to binary mask and apply morphological postprocessing
|
| 123 |
+
test_pred_mask = measure.label(morphology.remove_small_objects(morphology.remove_small_holes(test_pred_npy>threshold),16))
|
| 124 |
+
# tif.imwrite(join(output_path, img_name.split('.')[0]+'_label.tiff'), test_pred_mask, compression='zlib')
|
| 125 |
+
t1 = time.time()
|
| 126 |
+
# print(f'Prediction finished: {img_layer.name}; img size = {pre_img_data.shape}; costing: {t1-t0:.2f}s')
|
| 127 |
+
return test_pred_mask
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def predict(img):
|
| 131 |
+
seg_labels = get_seg(preprocess(img), 'swinunetr', './best_Dice_model.pth', 0.5)
|
| 132 |
+
return seg_labels
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
demo = gr.Interface(
|
| 136 |
+
predict,
|
| 137 |
+
inputs=[gr.Image()],
|
| 138 |
+
outputs="image",
|
| 139 |
+
title="NeurIPS CellSeg Demo",
|
| 140 |
+
examples=[["cell_00225.png"]]
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
demo.launch()
|
| 144 |
+
|
best_Dice_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:764db0c53184da5bf743db84d8837b1f9b2c7f3b0236b43c75ff747e47c75e5a
|
| 3 |
+
size 75949863
|
packages.txt
ADDED
|
File without changes
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy
|
| 2 |
+
scikit-image
|
| 3 |
+
numpy
|
| 4 |
+
torch
|
| 5 |
+
monai
|
| 6 |
+
einops
|
unetr2d.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
Created on Sun Mar 20 14:23:19 2022
|
| 5 |
+
Author: MONAI
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Tuple, Union
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
|
| 12 |
+
from monai.networks.blocks.dynunet_block import UnetOutBlock
|
| 13 |
+
from monai.networks.blocks import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock
|
| 14 |
+
from monai.networks.nets import ViT
|
| 15 |
+
|
| 16 |
+
class UNETR2D(nn.Module):
|
| 17 |
+
"""
|
| 18 |
+
UNETR based on: "Hatamizadeh et al.,
|
| 19 |
+
UNETR: Transformers for 3D Medical Image Segmentation <https://arxiv.org/abs/2103.10504>"
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
in_channels: int,
|
| 25 |
+
out_channels: int,
|
| 26 |
+
img_size: Tuple[int, int],
|
| 27 |
+
feature_size: int = 16,
|
| 28 |
+
hidden_size: int = 768,
|
| 29 |
+
mlp_dim: int = 3072,
|
| 30 |
+
num_heads: int = 12,
|
| 31 |
+
pos_embed: str = "perceptron",
|
| 32 |
+
norm_name: Union[Tuple, str] = "instance",
|
| 33 |
+
conv_block: bool = False,
|
| 34 |
+
res_block: bool = True,
|
| 35 |
+
dropout_rate: float = 0.0,
|
| 36 |
+
debug: bool = False
|
| 37 |
+
) -> None:
|
| 38 |
+
|
| 39 |
+
super().__init__()
|
| 40 |
+
|
| 41 |
+
if not (0 <= dropout_rate <= 1):
|
| 42 |
+
raise AssertionError("dropout_rate should be between 0 and 1.")
|
| 43 |
+
|
| 44 |
+
if hidden_size % num_heads != 0:
|
| 45 |
+
raise AssertionError("hidden size should be divisible by num_heads.")
|
| 46 |
+
|
| 47 |
+
if pos_embed not in ["conv", "perceptron"]:
|
| 48 |
+
raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.")
|
| 49 |
+
|
| 50 |
+
self.num_layers = 12
|
| 51 |
+
self.patch_size = (16, 16)
|
| 52 |
+
self.feat_size = tuple(img_d // p_d for img_d, p_d in zip(img_size, self.patch_size))
|
| 53 |
+
self.hidden_size = hidden_size
|
| 54 |
+
self.classification = False
|
| 55 |
+
self.debug = debug
|
| 56 |
+
self.vit = ViT(
|
| 57 |
+
in_channels=in_channels,
|
| 58 |
+
img_size=img_size,
|
| 59 |
+
patch_size=self.patch_size,
|
| 60 |
+
hidden_size=hidden_size,
|
| 61 |
+
mlp_dim=mlp_dim,
|
| 62 |
+
num_layers=self.num_layers,
|
| 63 |
+
num_heads=num_heads,
|
| 64 |
+
pos_embed=pos_embed,
|
| 65 |
+
classification=self.classification,
|
| 66 |
+
dropout_rate=dropout_rate,
|
| 67 |
+
spatial_dims=2
|
| 68 |
+
)
|
| 69 |
+
self.encoder1 = UnetrBasicBlock(
|
| 70 |
+
spatial_dims=2,
|
| 71 |
+
in_channels=in_channels,
|
| 72 |
+
out_channels=feature_size,
|
| 73 |
+
kernel_size=3,
|
| 74 |
+
stride=1,
|
| 75 |
+
norm_name=norm_name,
|
| 76 |
+
res_block=res_block,
|
| 77 |
+
)
|
| 78 |
+
self.encoder2 = UnetrPrUpBlock(
|
| 79 |
+
spatial_dims=2,
|
| 80 |
+
in_channels=hidden_size,
|
| 81 |
+
out_channels=feature_size * 2,
|
| 82 |
+
num_layer=2,
|
| 83 |
+
kernel_size=3,
|
| 84 |
+
stride=1,
|
| 85 |
+
upsample_kernel_size=2,
|
| 86 |
+
norm_name=norm_name,
|
| 87 |
+
conv_block=conv_block,
|
| 88 |
+
res_block=res_block,
|
| 89 |
+
)
|
| 90 |
+
self.encoder3 = UnetrPrUpBlock(
|
| 91 |
+
spatial_dims=2,
|
| 92 |
+
in_channels=hidden_size,
|
| 93 |
+
out_channels=feature_size * 4,
|
| 94 |
+
num_layer=1,
|
| 95 |
+
kernel_size=3,
|
| 96 |
+
stride=1,
|
| 97 |
+
upsample_kernel_size=2,
|
| 98 |
+
norm_name=norm_name,
|
| 99 |
+
conv_block=conv_block,
|
| 100 |
+
res_block=res_block,
|
| 101 |
+
)
|
| 102 |
+
self.encoder4 = UnetrPrUpBlock(
|
| 103 |
+
spatial_dims=2,
|
| 104 |
+
in_channels=hidden_size,
|
| 105 |
+
out_channels=feature_size * 8,
|
| 106 |
+
num_layer=0,
|
| 107 |
+
kernel_size=3,
|
| 108 |
+
stride=1,
|
| 109 |
+
upsample_kernel_size=2,
|
| 110 |
+
norm_name=norm_name,
|
| 111 |
+
conv_block=conv_block,
|
| 112 |
+
res_block=res_block,
|
| 113 |
+
)
|
| 114 |
+
self.decoder5 = UnetrUpBlock(
|
| 115 |
+
spatial_dims=2,
|
| 116 |
+
in_channels=hidden_size,
|
| 117 |
+
out_channels=feature_size * 8,
|
| 118 |
+
kernel_size=3,
|
| 119 |
+
upsample_kernel_size=2,
|
| 120 |
+
norm_name=norm_name,
|
| 121 |
+
res_block=res_block,
|
| 122 |
+
)
|
| 123 |
+
self.decoder4 = UnetrUpBlock(
|
| 124 |
+
spatial_dims=2,
|
| 125 |
+
in_channels=feature_size * 8,
|
| 126 |
+
out_channels=feature_size * 4,
|
| 127 |
+
kernel_size=3,
|
| 128 |
+
upsample_kernel_size=2,
|
| 129 |
+
norm_name=norm_name,
|
| 130 |
+
res_block=res_block,
|
| 131 |
+
)
|
| 132 |
+
self.decoder3 = UnetrUpBlock(
|
| 133 |
+
spatial_dims=2,
|
| 134 |
+
in_channels=feature_size * 4,
|
| 135 |
+
out_channels=feature_size * 2,
|
| 136 |
+
kernel_size=3,
|
| 137 |
+
upsample_kernel_size=2,
|
| 138 |
+
norm_name=norm_name,
|
| 139 |
+
res_block=res_block,
|
| 140 |
+
)
|
| 141 |
+
self.decoder2 = UnetrUpBlock(
|
| 142 |
+
spatial_dims=2,
|
| 143 |
+
in_channels=feature_size * 2,
|
| 144 |
+
out_channels=feature_size,
|
| 145 |
+
kernel_size=3,
|
| 146 |
+
upsample_kernel_size=2,
|
| 147 |
+
norm_name=norm_name,
|
| 148 |
+
res_block=res_block,
|
| 149 |
+
)
|
| 150 |
+
self.out = UnetOutBlock(spatial_dims=2, in_channels=feature_size, out_channels=out_channels) # type: ignore
|
| 151 |
+
|
| 152 |
+
def proj_feat(self, x, hidden_size, feat_size): # x: (B, 256, 768)
|
| 153 |
+
x = x.view(x.size(0), feat_size[0], feat_size[1], hidden_size) # (B, 16, 16, 768)
|
| 154 |
+
x = x.permute(0, 3, 1, 2).contiguous() # (B, 768, 16, 16)
|
| 155 |
+
return x
|
| 156 |
+
|
| 157 |
+
def forward(self, x_in):
|
| 158 |
+
x, hidden_states_out = self.vit(x_in) # x: (B, 256,768), hidden_states_out: list, 12 elements, (B,256,768)
|
| 159 |
+
enc1 = self.encoder1(x_in) # (1, 16, 256, 256)
|
| 160 |
+
x2 = hidden_states_out[3] # (B, 256, 768)
|
| 161 |
+
# self.proj_feat(x2, self.hidden_size, self.feat_size): (B, 768, 16,16) -> enc2: (B,32,128,128)
|
| 162 |
+
enc2 = self.encoder2(self.proj_feat(x2, self.hidden_size, self.feat_size)) # hidden_size=768, self.feat_size=16
|
| 163 |
+
x3 = hidden_states_out[6] # (B, 256, 768)
|
| 164 |
+
enc3 = self.encoder3(self.proj_feat(x3, self.hidden_size, self.feat_size)) #(B, 768, 16,16) -> (B, 64, 64, 64)
|
| 165 |
+
x4 = hidden_states_out[9] # (B, 256, 768)
|
| 166 |
+
enc4 = self.encoder4(self.proj_feat(x4, self.hidden_size, self.feat_size)) # (B, 768, 16, 16) -> (B, 128, 32, 32)
|
| 167 |
+
dec4 = self.proj_feat(x, self.hidden_size, self.feat_size) # (B, 768, 16, 16)
|
| 168 |
+
dec3 = self.decoder5(dec4, enc4) # up -> cat -> ResConv; (B, 128, 32, 32)
|
| 169 |
+
dec2 = self.decoder4(dec3, enc3) # (B, 64, 64, 64)
|
| 170 |
+
dec1 = self.decoder3(dec2, enc2) # (B, 32, 128, 128)
|
| 171 |
+
out = self.decoder2(dec1, enc1) # (B, 16, 256, 256)
|
| 172 |
+
logits = self.out(out)
|
| 173 |
+
|
| 174 |
+
if self.debug:
|
| 175 |
+
return x, x2, x3,x4, hidden_states_out, enc1, enc2, enc3, enc4, dec4, dec3, dec2, dec1, logits
|
| 176 |
+
else:
|
| 177 |
+
return logits
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
# model = UNETR2D(
|
| 181 |
+
# in_channels=3, # 3 channels, R,G,B
|
| 182 |
+
# out_channels=3,
|
| 183 |
+
# img_size=(256, 256),
|
| 184 |
+
# feature_size=16,
|
| 185 |
+
# hidden_size=768,
|
| 186 |
+
# mlp_dim=3072,
|
| 187 |
+
# num_heads=12,
|
| 188 |
+
# pos_embed="perceptron",
|
| 189 |
+
# norm_name="instance",
|
| 190 |
+
# res_block=True,
|
| 191 |
+
# dropout_rate=0.0,
|
| 192 |
+
# debug=True
|
| 193 |
+
# ).cuda()
|
| 194 |
+
|
| 195 |
+
# from torchinfo import summary
|
| 196 |
+
|
| 197 |
+
# batch_size = 1
|
| 198 |
+
# summary(model, input_size=(batch_size, 3, 256, 256))
|
| 199 |
+
|
| 200 |
+
# x = torch.rand((1,3,256,256)).cuda()
|
| 201 |
+
# x, x2, x3,x4, hidden_states_out, enc1, enc2, enc3, enc4, dec4, dec3, dec2, dec1, logits = model(x)
|
| 202 |
+
# print(logits.shape) # torch.Size([1, 3, 256, 256])
|