Upload folder using huggingface_hub
Browse files- .gitattributes +2 -0
- README.md +202 -3
- dist.py +112 -0
- img/TOKI.png +3 -0
- img/masking_consistency.png +3 -0
- main.py +192 -0
- mambamim_mask75.pth +3 -0
- models/MambaMIM.py +240 -0
- models/__init__.py +51 -0
- models/decoder.py +87 -0
- models/encoder.py +258 -0
- models/mamba/bi_vision_mamba.py +416 -0
- models/network/hymamba.py +681 -0
- utils/arg_util.py +125 -0
- utils/lamb.py +151 -0
- utils/lr_control.py +47 -0
- utils/med_dataset.py +275 -0
- utils/misc.py +338 -0
- utils/sampler.py +68 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
img/TOKI.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
img/masking_consistency.png filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,3 +1,202 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
## [MIA'25] MambaMIM: Pre-training Mamba with State Space Token Interpolation and its Application to Medical Image Segmentation
|
| 4 |
+
|
| 5 |
+
<p align="center" width="100%">
|
| 6 |
+
<!---->
|
| 7 |
+
</p>
|
| 8 |
+
|
| 9 |
+

|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
<div align="center">
|
| 14 |
+
<span class="author-block">
|
| 15 |
+
<a href="https://scholar.google.com/citations?user=x1pODsMAAAAJ&hl=en" target="_blank">Fenghe Tang</a><sup>1,2</sup>,</span>
|
| 16 |
+
<span class="author-block">
|
| 17 |
+
<a target="_blank">Bingkun Nian</a><sup>3</sup>,</span>
|
| 18 |
+
<span class="author-block">
|
| 19 |
+
<a href="https://scholar.google.com/citations?user=ocAtNkkAAAAJ&hl=en" target="_blank">Yingtai Li</a><sup>1,2</sup>,</span>
|
| 20 |
+
<span class="author-block">
|
| 21 |
+
<a href="https://scholar.google.com/citations?user=Wo8tMSMAAAAJ&hl=en" target="_blank">Zihang Jiang</a><sup>1,2</sup>,</span>
|
| 22 |
+
<span class="author-block">
|
| 23 |
+
<a href="https://scholar.google.com/citations?user=tmx7tu8AAAAJ&hl=en" target="_blank">Jie Yang</a><sup>3</sup>,</span>
|
| 24 |
+
<span class="author-block">
|
| 25 |
+
<a href="https://scholar.google.com/citations?user=Vbb5EGIAAAAJ&hl=en" target="_blank"> Liu Wei</a><sup>3</sup>,</span>
|
| 26 |
+
<span class="author-block">
|
| 27 |
+
<a href="https://scholar.google.com/citations?user=8eNm2GMAAAAJ&hl=en" target="_blank">S.Kevin Zhou</a><sup>1,2</sup>
|
| 28 |
+
</span>
|
| 29 |
+
</div>
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
<br>
|
| 33 |
+
|
| 34 |
+
<div align="center">
|
| 35 |
+
<sup>1</sup>
|
| 36 |
+
<a href='https://en.ustc.edu.cn/' target='_blank'>School of Biomedical Engineering, University of Science and Technology of China</a> 
|
| 37 |
+
<br>
|
| 38 |
+
<sup>2</sup> <a href='http://english.ict.cas.cn/' target='_blank'>Suzhou Institute for Advanced Research, University of Science and Technology of China</a> 
|
| 39 |
+
<br>
|
| 40 |
+
<sup>3</sup> <a href='http://www.pami.sjtu.edu.cn/En/Home' target='_blank'>Department of Automation, Institute of Image Processing and Pattern Recognition, Shanghai Jiao Tong University</a>
|
| 41 |
+
<br>
|
| 42 |
+
</div>
|
| 43 |
+
|
| 44 |
+
<br>
|
| 45 |
+
<br>
|
| 46 |
+
|
| 47 |
+
[](https://arxiv.org/pdf/2408.08070.pdf) [](https://github.com/FengheTan9/MambaMIM) <a href="#LICENSE--citation"><img alt="License: Apache2.0" src="https://img.shields.io/badge/LICENSE-Apache%202.0-blue.svg"/></a>
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
## News
|
| 52 |
+
|
| 53 |
+
- **MambaMIM accepted by Medical Image Analyses (MIA'25) ! 🥰**
|
| 54 |
+
- **Weights released ! 😎**
|
| 55 |
+
- **Code released !** 😘
|
| 56 |
+
- **Code and weights will be released soon !** 😘
|
| 57 |
+
- **[2024/08/16] Paper released !**
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
## TODOs
|
| 62 |
+
|
| 63 |
+
- [x] Paper released
|
| 64 |
+
- [x] Code released
|
| 65 |
+
- [x] Weight released
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
## Getting Started
|
| 70 |
+
|
| 71 |
+
### Download weights
|
| 72 |
+
|
| 73 |
+
| Name | Resolution | Intensities | Spacing | Weights |
|
| 74 |
+
| :------: | :----------: | :-----------: | :----------------: | :----------------------------------------------------------: |
|
| 75 |
+
| MambaMIM | 96 x 96 x 96 | [-175, - 250] | 1.5 x 1.5 x 1.5 mm | [Google Drive (87MB)](https://drive.google.com/file/d/1B3j5aRPxkDJqf8UPGKDiAjg2X85a3Kwx/view?usp=sharing) |
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
### Prepare Environments
|
| 80 |
+
|
| 81 |
+
```
|
| 82 |
+
conda create -n mambamim python=3.9
|
| 83 |
+
conda activate mambamim
|
| 84 |
+
pip install torch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu117
|
| 85 |
+
pip install packaging timm==0.5.4
|
| 86 |
+
pip install transformers==4.34.1 typed-argument-parser
|
| 87 |
+
pip install numpy==1.21.2 opencv-python==4.5.5.64 opencv-python-headless==4.5.5.64
|
| 88 |
+
pip install 'monai[all]'
|
| 89 |
+
pip install monai==1.2.0
|
| 90 |
+
pip install causal_conv1d-1.2.0.post2+cu118torch1.13cxx11abiTRUE-cp38-cp38-linux_x86_64.whl
|
| 91 |
+
pip install mamba_ssm-1.2.0.post1+cu118torch1.13cxx11abiFALSE-cp38-cp38-linux_x86_64.whl
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
### Prepare Datasets
|
| 97 |
+
|
| 98 |
+
We recommend that you convert the dataset into the [nnUNet](https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/dataset_format.md) format.
|
| 99 |
+
|
| 100 |
+
```
|
| 101 |
+
└── MambaMIM
|
| 102 |
+
├── data
|
| 103 |
+
├── Dataset060_TotalSegmentator
|
| 104 |
+
└── imagesTr
|
| 105 |
+
├── xxx_0000.nii.gz
|
| 106 |
+
├── ...
|
| 107 |
+
├── Dataset006_FLARE2022
|
| 108 |
+
└── imagesTr
|
| 109 |
+
├── xxx_0000.nii.gz
|
| 110 |
+
├── ...
|
| 111 |
+
└── Other_dataset
|
| 112 |
+
└── imagesTr
|
| 113 |
+
├── xxx_0000.nii.gz
|
| 114 |
+
├── ...
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
An example ```dataset.json``` will be generated in ```./data```
|
| 118 |
+
|
| 119 |
+
The content should be like below:
|
| 120 |
+
|
| 121 |
+
```json
|
| 122 |
+
{
|
| 123 |
+
"training": [
|
| 124 |
+
{
|
| 125 |
+
"image": "./Dataset060_TotalSegmentator/imagesTr/xxx_0000.nii.gz"
|
| 126 |
+
},
|
| 127 |
+
{
|
| 128 |
+
"image": "./Dataset006_FLARE2022/imagesTr/xxx_0000.nii.gz"
|
| 129 |
+
},
|
| 130 |
+
]
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
## Start Training
|
| 138 |
+
|
| 139 |
+

|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
Run training on multi-GPU :
|
| 144 |
+
|
| 145 |
+
```sh
|
| 146 |
+
# An example of training on 4 GPUs with DDP
|
| 147 |
+
torchrun --nproc_per_node=4 --nnodes=1 --node_rank=0 --master_addr=localhost --master_port=12351 main.py --exp_name=debug --data_path=./data --model=mambamim --bs=16 --exp_dir=debug_mambamim_ddp_4
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
Run training on the single-GPU :
|
| 151 |
+
|
| 152 |
+
```sh
|
| 153 |
+
# An example of training on the single GPU
|
| 154 |
+
python main.py --exp_name=debug --data_path=./data --model=mambamim --bs=4 --exp_dir=debug_mambamim
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
## Fine-tuning
|
| 160 |
+
|
| 161 |
+
Load pre-training weights :
|
| 162 |
+
|
| 163 |
+
```python
|
| 164 |
+
# An example of Fine-tuning on BTCV (num_classes=14)
|
| 165 |
+
from models.network.hymamba import build_hybird
|
| 166 |
+
|
| 167 |
+
model = build_hybird(in_channel=1, n_classes=14, img_size=96).cuda()
|
| 168 |
+
|
| 169 |
+
model_dict = torch.load("mambamim_mask75.pth")
|
| 170 |
+
|
| 171 |
+
if model.load_state_dict(model_dict, strict=False):
|
| 172 |
+
print("MambaMIM use pretrained weights successfully !")
|
| 173 |
+
```
|
| 174 |
+
|
| 175 |
+
Downstream pipeline can be referred to [UNETR]([research-contributions/UNETR/BTCV at main · Project-MONAI/research-contributions (github.com)](https://github.com/Project-MONAI/research-contributions/tree/main/UNETR/BTCV)).
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
## Acknowledgements:
|
| 180 |
+
|
| 181 |
+
This code uses helper functions from [SparK](https://github.com/keyu-tian/SparK) and [HySparK](https://github.com/FengheTan9/HySparK).
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
## Citation
|
| 186 |
+
|
| 187 |
+
If the code, paper and weights help your research, please cite:
|
| 188 |
+
|
| 189 |
+
```
|
| 190 |
+
@article{tang2024mambamim,
|
| 191 |
+
title={MambaMIM: Pre-training Mamba with State Space Token-interpolation},
|
| 192 |
+
author={Tang, Fenghe and Nian, Bingkun and Li, Yingtai and Yang, Jie and Wei, Liu and Zhou, S Kevin},
|
| 193 |
+
journal={arXiv preprint arXiv:2408.08070},
|
| 194 |
+
year={2024}
|
| 195 |
+
}
|
| 196 |
+
```
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
## License
|
| 201 |
+
|
| 202 |
+
This project is released under the Apache 2.0 license. Please see the [LICENSE](LICENSE) file for more information.
|
dist.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List
|
| 3 |
+
from typing import Union
|
| 4 |
+
|
| 5 |
+
import sys
|
| 6 |
+
import torch
|
| 7 |
+
import torch.distributed as tdist
|
| 8 |
+
import torch.multiprocessing as mp
|
| 9 |
+
|
| 10 |
+
__rank, __local_rank, __world_size, __device = 0, 0, 1, 'cpu'
|
| 11 |
+
__initialized = False
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def initialized():
|
| 15 |
+
return __initialized
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def initialize(backend='nccl'):
|
| 19 |
+
global __device
|
| 20 |
+
if not torch.cuda.is_available():
|
| 21 |
+
print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr)
|
| 22 |
+
return
|
| 23 |
+
elif 'RANK' not in os.environ:
|
| 24 |
+
__device = torch.empty(1).cuda().device
|
| 25 |
+
print(f'[dist initialize] RANK is not set, use 1 GPU instead', file=sys.stderr)
|
| 26 |
+
return
|
| 27 |
+
|
| 28 |
+
# ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29
|
| 29 |
+
if mp.get_start_method(allow_none=True) is None:
|
| 30 |
+
mp.set_start_method('spawn')
|
| 31 |
+
global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count()
|
| 32 |
+
local_rank = global_rank % num_gpus
|
| 33 |
+
torch.cuda.set_device(local_rank)
|
| 34 |
+
tdist.init_process_group(backend=backend)
|
| 35 |
+
|
| 36 |
+
global __rank, __local_rank, __world_size, __initialized
|
| 37 |
+
__local_rank = local_rank
|
| 38 |
+
__rank, __world_size = tdist.get_rank(), tdist.get_world_size()
|
| 39 |
+
__device = torch.empty(1).cuda().device
|
| 40 |
+
__initialized = True
|
| 41 |
+
|
| 42 |
+
assert tdist.is_initialized(), 'torch.distributed is not initialized!'
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def get_rank():
|
| 46 |
+
return __rank
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def get_local_rank():
|
| 50 |
+
return __local_rank
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def get_world_size():
|
| 54 |
+
return __world_size
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def get_device():
|
| 58 |
+
return __device
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def is_master():
|
| 62 |
+
return __rank == 0
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def is_local_master():
|
| 66 |
+
return __local_rank == 0
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def barrier():
|
| 70 |
+
if __initialized:
|
| 71 |
+
tdist.barrier()
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def parallelize(net, syncbn=False):
|
| 75 |
+
if syncbn:
|
| 76 |
+
net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
|
| 77 |
+
net = net.cuda()
|
| 78 |
+
net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[get_local_rank()], find_unused_parameters=False, broadcast_buffers=False)
|
| 79 |
+
return net
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def allreduce(t: torch.Tensor) -> None:
|
| 83 |
+
if __initialized:
|
| 84 |
+
if not t.is_cuda:
|
| 85 |
+
cu = t.detach().cuda()
|
| 86 |
+
tdist.all_reduce(cu)
|
| 87 |
+
t.copy_(cu.cpu())
|
| 88 |
+
else:
|
| 89 |
+
tdist.all_reduce(t)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
|
| 93 |
+
if __initialized:
|
| 94 |
+
if not t.is_cuda:
|
| 95 |
+
t = t.cuda()
|
| 96 |
+
ls = [torch.empty_like(t) for _ in range(__world_size)]
|
| 97 |
+
tdist.all_gather(ls, t)
|
| 98 |
+
else:
|
| 99 |
+
ls = [t]
|
| 100 |
+
if cat:
|
| 101 |
+
ls = torch.cat(ls, dim=0)
|
| 102 |
+
return ls
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def broadcast(t: torch.Tensor, src_rank) -> None:
|
| 106 |
+
if __initialized:
|
| 107 |
+
if not t.is_cuda:
|
| 108 |
+
cu = t.detach().cuda()
|
| 109 |
+
tdist.broadcast(cu, src=src_rank)
|
| 110 |
+
t.copy_(cu.cpu())
|
| 111 |
+
else:
|
| 112 |
+
tdist.broadcast(t, src=src_rank)
|
img/TOKI.png
ADDED
|
Git LFS Details
|
img/masking_consistency.png
ADDED
|
Git LFS Details
|
main.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datetime
|
| 2 |
+
import math
|
| 3 |
+
import sys
|
| 4 |
+
import time
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
import torch
|
| 8 |
+
from torch.nn.parallel import DistributedDataParallel
|
| 9 |
+
from torch.utils.data import DataLoader
|
| 10 |
+
|
| 11 |
+
import dist
|
| 12 |
+
from models.encoder import SparseEncoder
|
| 13 |
+
from models.decoder import LightDecoder
|
| 14 |
+
from models.MambaMIM import MambaMIM
|
| 15 |
+
from models import build_sparse_encoder
|
| 16 |
+
from utils.sampler import DistInfiniteBatchSampler, worker_init_fn
|
| 17 |
+
from utils import arg_util, misc
|
| 18 |
+
from utils.med_dataset import get_loader
|
| 19 |
+
from utils.lr_control import lr_wd_annealing
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
cpu_num = 1
|
| 23 |
+
os.environ['OMP_NUM_THREADS'] = str(cpu_num)
|
| 24 |
+
os.environ['OPENBLAS_NUM_THREADS'] = str(cpu_num)
|
| 25 |
+
os.environ['MKL_NUM_THREADS'] = str(cpu_num)
|
| 26 |
+
os.environ['VECLIB_MAXIMUM_THREADS'] = str(cpu_num)
|
| 27 |
+
os.environ['NUMEXPR_NUM_THREADS'] = str(cpu_num)
|
| 28 |
+
torch.set_num_threads(cpu_num)
|
| 29 |
+
torch.multiprocessing.set_sharing_strategy('file_system')
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class LocalDDP(torch.nn.Module):
|
| 33 |
+
def __init__(self, module):
|
| 34 |
+
super(LocalDDP, self).__init__()
|
| 35 |
+
self.module = module
|
| 36 |
+
|
| 37 |
+
def forward(self, *args, **kwargs):
|
| 38 |
+
return self.module(*args, **kwargs)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def main_pt():
|
| 42 |
+
args: arg_util.Args = arg_util.init_dist_and_get_args()
|
| 43 |
+
print(f'initial args:\n{str(args)}')
|
| 44 |
+
args.log_epoch()
|
| 45 |
+
|
| 46 |
+
# build data
|
| 47 |
+
print(f'[build data for pre-training] ...\n')
|
| 48 |
+
dataset_train = get_loader(args.data_path, args.input_size)
|
| 49 |
+
data_loader_train = DataLoader(
|
| 50 |
+
dataset=dataset_train, num_workers=args.dataloader_workers, pin_memory=True,
|
| 51 |
+
batch_sampler=DistInfiniteBatchSampler(
|
| 52 |
+
dataset_len=len(dataset_train), glb_batch_size=args.glb_batch_size,
|
| 53 |
+
shuffle=True, filling=True, rank=dist.get_rank(), world_size=dist.get_world_size(),
|
| 54 |
+
), worker_init_fn=worker_init_fn
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
itrt_train, iters_train = iter(data_loader_train), len(data_loader_train)
|
| 58 |
+
print(f'[dataloader] gbs={args.glb_batch_size}, lbs={args.batch_size_per_gpu}, iters_train={iters_train}')
|
| 59 |
+
|
| 60 |
+
# build encoder and decoder
|
| 61 |
+
enc: SparseEncoder = build_sparse_encoder(args.model, input_size=args.input_size, sbn=args.sbn, drop_path_rate=args.dp, verbose=False)
|
| 62 |
+
dec = LightDecoder(enc.downsample_raito, sbn=args.sbn)
|
| 63 |
+
model_without_ddp = MambaMIM(
|
| 64 |
+
sparse_encoder=enc, dense_decoder=dec, mask_ratio=args.mask,
|
| 65 |
+
densify_norm=args.densify_norm, sbn=args.sbn,
|
| 66 |
+
).to(args.device)
|
| 67 |
+
print(f'[PT model] model = {model_without_ddp}\n')
|
| 68 |
+
|
| 69 |
+
# the model has been randomly initialized in their construction time
|
| 70 |
+
# now try to load some checkpoint as model weight initialization; this ONLY loads the model weights
|
| 71 |
+
|
| 72 |
+
model = LocalDDP(model_without_ddp)
|
| 73 |
+
|
| 74 |
+
# build optimizer and lr_scheduler
|
| 75 |
+
optimizer = torch.optim.AdamW(params=model_without_ddp.parameters(), lr=args.lr, weight_decay=1e-5)
|
| 76 |
+
|
| 77 |
+
# try to resume the experiment from some checkpoint.pth; this will load model weights, optimizer states, and last epoch (ep_start)
|
| 78 |
+
# if loaded, ep_start will be greater than 0
|
| 79 |
+
ep_start, performance_desc = misc.load_checkpoint(args.resume_from, model_without_ddp, optimizer)
|
| 80 |
+
if ep_start >= args.ep: # load from a complete checkpoint file
|
| 81 |
+
print(f' [*] [PT already done] Min/Last Recon Loss: {performance_desc}')
|
| 82 |
+
else: # perform pre-training
|
| 83 |
+
tb_lg = misc.TensorboardLogger(args.tb_lg_dir, is_master=dist.is_master(), prefix='pt')
|
| 84 |
+
min_loss = 1e9
|
| 85 |
+
print(f'[PT start] from ep{ep_start}')
|
| 86 |
+
|
| 87 |
+
pt_start_time = time.time()
|
| 88 |
+
for ep in range(ep_start, args.ep):
|
| 89 |
+
ep_start_time = time.time()
|
| 90 |
+
tb_lg.set_step(ep * iters_train)
|
| 91 |
+
if hasattr(itrt_train, 'set_epoch'):
|
| 92 |
+
itrt_train.set_epoch(ep)
|
| 93 |
+
|
| 94 |
+
stats = pre_train_one_ep(ep, args, tb_lg, itrt_train, iters_train, model, optimizer)
|
| 95 |
+
last_loss = stats['last_loss']
|
| 96 |
+
min_loss = min(min_loss, last_loss)
|
| 97 |
+
performance_desc = f'{min_loss:.4f} {last_loss:.4f}'
|
| 98 |
+
misc.save_checkpoint_with_meta_info_and_opt_state(f'{args.model}_withdecoder_ct_pretrained.pth', args, ep, performance_desc, model_without_ddp.state_dict(with_config=True), optimizer.state_dict())
|
| 99 |
+
misc.save_checkpoint_model_weights_only(f'{args.model}_ct_pretrained_mambamim_timm_style.pth', args, model_without_ddp.sparse_encoder.state_dict())
|
| 100 |
+
|
| 101 |
+
ep_cost = round(time.time() - ep_start_time, 2) + 1 # +1s: approximate the following logging cost
|
| 102 |
+
remain_secs = (args.ep-1 - ep) * ep_cost
|
| 103 |
+
remain_time = datetime.timedelta(seconds=round(remain_secs))
|
| 104 |
+
finish_time = time.strftime("%m-%d %H:%M", time.localtime(time.time() + remain_secs))
|
| 105 |
+
print(f' [*] [ep{ep}/{args.ep}] Min/Last Recon Loss: {performance_desc}, Cost: {ep_cost}s, Remain: {remain_time}, Finish @ {finish_time}')
|
| 106 |
+
|
| 107 |
+
args.cur_ep = f'{ep + 1}/{args.ep}'
|
| 108 |
+
args.remain_time, args.finish_time = str(remain_time), str(finish_time)
|
| 109 |
+
args.last_loss = last_loss
|
| 110 |
+
args.log_epoch()
|
| 111 |
+
|
| 112 |
+
tb_lg.update(min_loss=min_loss, head='train', step=ep)
|
| 113 |
+
tb_lg.update(rest_hours=round(remain_secs/60/60, 2), head='z_burnout', step=ep)
|
| 114 |
+
tb_lg.flush()
|
| 115 |
+
|
| 116 |
+
# finish pre-training
|
| 117 |
+
tb_lg.update(min_loss=min_loss, head='result', step=ep_start)
|
| 118 |
+
tb_lg.update(min_loss=min_loss, head='result', step=args.ep)
|
| 119 |
+
tb_lg.flush()
|
| 120 |
+
print(f'final args:\n{str(args)}')
|
| 121 |
+
print('\n\n')
|
| 122 |
+
print(f' [*] [PT finished] Min/Last Recon Loss: {performance_desc}, Total Cost: {(time.time() - pt_start_time) / 60 / 60:.1f}h\n')
|
| 123 |
+
print('\n\n')
|
| 124 |
+
tb_lg.close()
|
| 125 |
+
time.sleep(10)
|
| 126 |
+
|
| 127 |
+
args.remain_time, args.finish_time = '-', time.strftime("%m-%d %H:%M", time.localtime(time.time()))
|
| 128 |
+
args.log_epoch()
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def pre_train_one_ep(ep, args: arg_util.Args, tb_lg: misc.TensorboardLogger, itrt_train, iters_train, model: DistributedDataParallel, optimizer):
|
| 132 |
+
model.train()
|
| 133 |
+
me = misc.MetricLogger(delimiter=' ')
|
| 134 |
+
me.add_meter('max_lr', misc.SmoothedValue(window_size=1, fmt='{value:.5f}'))
|
| 135 |
+
header = f'[PT] Epoch {ep}:'
|
| 136 |
+
|
| 137 |
+
optimizer.zero_grad()
|
| 138 |
+
early_clipping = args.clip > 0 and not hasattr(optimizer, 'global_grad_norm')
|
| 139 |
+
late_clipping = hasattr(optimizer, 'global_grad_norm')
|
| 140 |
+
if early_clipping:
|
| 141 |
+
params_req_grad = [p for p in model.parameters() if p.requires_grad]
|
| 142 |
+
|
| 143 |
+
for it, inp in enumerate(me.log_every(iters_train, itrt_train, 3, header)):
|
| 144 |
+
# adjust lr and wd
|
| 145 |
+
min_lr, max_lr, min_wd, max_wd = lr_wd_annealing(optimizer, args.lr, args.wd, args.wde, it + ep * iters_train, args.wp_ep * iters_train, args.ep * iters_train)
|
| 146 |
+
|
| 147 |
+
# forward and backward
|
| 148 |
+
# print(inp)
|
| 149 |
+
|
| 150 |
+
temp = []
|
| 151 |
+
for crop_per_batch in inp:
|
| 152 |
+
temp.append(crop_per_batch["image"])
|
| 153 |
+
inp = torch.cat(temp, dim=0)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
inp = inp.to(args.device, non_blocking=True)
|
| 157 |
+
MambaSparK.forward
|
| 158 |
+
loss = model(inp, active_b1fff=None, vis=False)
|
| 159 |
+
optimizer.zero_grad()
|
| 160 |
+
loss.backward()
|
| 161 |
+
loss = loss.item()
|
| 162 |
+
if not math.isfinite(loss):
|
| 163 |
+
print(f'[rk{dist.get_rank():02d}] Loss is {loss}, stopping training!', force=True, flush=True)
|
| 164 |
+
sys.exit(-1)
|
| 165 |
+
|
| 166 |
+
# optimize
|
| 167 |
+
grad_norm = None
|
| 168 |
+
if early_clipping: grad_norm = torch.nn.utils.clip_grad_norm_(params_req_grad, args.clip).item()
|
| 169 |
+
optimizer.step()
|
| 170 |
+
if late_clipping: grad_norm = optimizer.global_grad_norm
|
| 171 |
+
torch.cuda.synchronize()
|
| 172 |
+
|
| 173 |
+
# log
|
| 174 |
+
me.update(last_loss=loss)
|
| 175 |
+
me.update(max_lr=max_lr)
|
| 176 |
+
tb_lg.update(loss=me.meters['last_loss'].global_avg, head='train_loss')
|
| 177 |
+
tb_lg.update(sche_lr=max_lr, head='train_hp/lr_max')
|
| 178 |
+
tb_lg.update(sche_lr=min_lr, head='train_hp/lr_min')
|
| 179 |
+
tb_lg.update(sche_wd=max_wd, head='train_hp/wd_max')
|
| 180 |
+
tb_lg.update(sche_wd=min_wd, head='train_hp/wd_min')
|
| 181 |
+
|
| 182 |
+
if grad_norm is not None:
|
| 183 |
+
me.update(orig_norm=grad_norm)
|
| 184 |
+
tb_lg.update(orig_norm=grad_norm, head='train_hp')
|
| 185 |
+
tb_lg.set_step()
|
| 186 |
+
|
| 187 |
+
me.synchronize_between_processes()
|
| 188 |
+
return {k: meter.global_avg for k, meter in me.meters.items()}
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
if __name__ == '__main__':
|
| 192 |
+
main_pt()
|
mambamim_mask75.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fd92f6cfdd2aff93f8942536f333bca7eb612b4238153c9b5accbacd9e4e1989
|
| 3 |
+
size 90976893
|
models/MambaMIM.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pprint import pformat
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
import sys
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from timm.models.layers import trunc_normal_
|
| 8 |
+
|
| 9 |
+
import models.encoder as encoder
|
| 10 |
+
from models.decoder import LightDecoder
|
| 11 |
+
from itertools import accumulate
|
| 12 |
+
|
| 13 |
+
class MambaMIM(nn.Module):
|
| 14 |
+
def __init__(
|
| 15 |
+
self, sparse_encoder: encoder.SparseEncoder, dense_decoder: LightDecoder,
|
| 16 |
+
mask_ratio=0.6, densify_norm='ln', sbn=True,
|
| 17 |
+
):
|
| 18 |
+
super().__init__()
|
| 19 |
+
input_size, downsample_raito = sparse_encoder.input_size, sparse_encoder.downsample_raito
|
| 20 |
+
self.downsample_raito = downsample_raito
|
| 21 |
+
self.fmap_h, self.fmap_w, self.fmap_d = input_size // downsample_raito, input_size // downsample_raito, input_size // downsample_raito
|
| 22 |
+
self.mask_ratio = mask_ratio
|
| 23 |
+
self.len_keep = round(self.fmap_h * self.fmap_w * self.fmap_d * (1 - mask_ratio))
|
| 24 |
+
|
| 25 |
+
self.sparse_encoder = sparse_encoder
|
| 26 |
+
self.dense_decoder = dense_decoder
|
| 27 |
+
|
| 28 |
+
self.sbn = sbn
|
| 29 |
+
self.hierarchy = len(sparse_encoder.enc_feat_map_chs)
|
| 30 |
+
self.densify_norm_str = densify_norm.lower()
|
| 31 |
+
self.densify_norms = nn.ModuleList()
|
| 32 |
+
self.densify_projs = nn.ModuleList()
|
| 33 |
+
self.mask_tokens = nn.ParameterList()
|
| 34 |
+
|
| 35 |
+
# build the `densify` layers
|
| 36 |
+
e_widths, d_width = self.sparse_encoder.enc_feat_map_chs, self.dense_decoder.width
|
| 37 |
+
e_widths: List[int]
|
| 38 |
+
self.A_interpolation = nn.Parameter(torch.zeros(1, self.sparse_encoder.enc_feat_map_chs[-1], self.sparse_encoder.enc_feat_map_chs[-1]))
|
| 39 |
+
print("self.A_interpolation: ", self.A_interpolation.shape)
|
| 40 |
+
for i in range(
|
| 41 |
+
self.hierarchy): # from the smallest feat map to the largest; i=0: the last feat map; i=1: the second last feat map ...
|
| 42 |
+
e_width = e_widths.pop()
|
| 43 |
+
# create mask token
|
| 44 |
+
p = nn.Parameter(torch.zeros(1, e_width, 1, 1, 1))
|
| 45 |
+
trunc_normal_(p, mean=0, std=.02, a=-.02, b=.02)
|
| 46 |
+
self.mask_tokens.append(p)
|
| 47 |
+
|
| 48 |
+
# create densify norm
|
| 49 |
+
densify_norm = nn.Identity()
|
| 50 |
+
self.densify_norms.append(densify_norm)
|
| 51 |
+
|
| 52 |
+
# create densify proj
|
| 53 |
+
if i == 0 and e_width == d_width:
|
| 54 |
+
densify_proj = nn.Identity() # todo: NOTE THAT CONVNEXT-S WOULD USE THIS, because it has a width of 768 that equals to the decoder's width 768
|
| 55 |
+
print(f'[MambaMIM.__init__, densify {i + 1}/{self.hierarchy}]: use nn.Identity() as densify_proj')
|
| 56 |
+
else:
|
| 57 |
+
kernel_size = 1 if i <= 0 else 3
|
| 58 |
+
densify_proj = nn.Conv3d(e_width, d_width, kernel_size=kernel_size, stride=1, padding=kernel_size // 2,
|
| 59 |
+
bias=True)
|
| 60 |
+
print(
|
| 61 |
+
f'[MambaMIM.__init__, densify {i + 1}/{self.hierarchy}]: densify_proj(ksz={kernel_size}, #para={sum(x.numel() for x in densify_proj.parameters()) / 1e6:.2f}M)')
|
| 62 |
+
self.densify_projs.append(densify_proj)
|
| 63 |
+
|
| 64 |
+
# todo: the decoder's width follows a simple halfing rule; you can change it to any other rule
|
| 65 |
+
d_width //= 2
|
| 66 |
+
|
| 67 |
+
print(f'[MambaMIM.__init__] dims of mask_tokens={tuple(p.numel() for p in self.mask_tokens)}')
|
| 68 |
+
|
| 69 |
+
def mask(self, B: int, device, generator=None):
|
| 70 |
+
h, w, d = self.fmap_h, self.fmap_w, self.fmap_d
|
| 71 |
+
idx = torch.rand(B, h * w * d, generator=generator).argsort(dim=1)
|
| 72 |
+
idx = idx[:, :self.len_keep].to(device) # (B, len_keep)
|
| 73 |
+
return torch.zeros(B, h * w * d, dtype=torch.bool, device=device)\
|
| 74 |
+
.scatter_(dim=1, index=idx, value=True).view(B, 1, h, w, d)
|
| 75 |
+
|
| 76 |
+
def mask_token_every_batch(self, bcfff,cur_active):
|
| 77 |
+
#A_token#
|
| 78 |
+
flag = cur_active.flatten(2).clone()
|
| 79 |
+
flag[0][0][0] = True
|
| 80 |
+
flag[0][0][-1] = True
|
| 81 |
+
|
| 82 |
+
indices = torch.nonzero(flag.squeeze()).squeeze()
|
| 83 |
+
#A_token#
|
| 84 |
+
B,N,H,L,W = bcfff.shape
|
| 85 |
+
|
| 86 |
+
A_token =[]
|
| 87 |
+
|
| 88 |
+
for i in range(0,len(indices)-1):
|
| 89 |
+
A_power = [torch.linalg.matrix_power(self.A_interpolation, i) for i in range(indices[i+1]-indices[i])]
|
| 90 |
+
max_power = indices[i+1]-indices[i]-1
|
| 91 |
+
for j in range(0,indices[i+1]-indices[i]):
|
| 92 |
+
A_token.append(A_power[max_power-j])
|
| 93 |
+
A_token.append(self.A_interpolation)
|
| 94 |
+
A_token = torch.cat(A_token, dim=0)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
X_token = []
|
| 98 |
+
X_unmask = bcfff.flatten(2).transpose(1, 2).squeeze().unsqueeze(-1)
|
| 99 |
+
for i in range(0,len(indices)-1):
|
| 100 |
+
alpha = torch.linspace(0, 1, indices[i + 1] - indices[i], dtype=X_unmask.dtype, device=X_unmask.device)
|
| 101 |
+
alpha = alpha.view(-1, 1) # alpha
|
| 102 |
+
X_interpolation = (1 - alpha) * X_unmask[indices[i]].transpose(0, 1) + alpha * X_unmask[indices[i + 1]].transpose(0, 1)
|
| 103 |
+
X_token.append(X_interpolation.unsqueeze(-1))
|
| 104 |
+
X_last_token = X_unmask[indices[-1]].unsqueeze(0)
|
| 105 |
+
X_token.append(X_last_token)
|
| 106 |
+
X_token = torch.cat(X_token,dim = 0)
|
| 107 |
+
|
| 108 |
+
AX = A_token.cuda() @ X_token
|
| 109 |
+
|
| 110 |
+
mask_token = AX
|
| 111 |
+
for i in range(0,len(indices)-1):
|
| 112 |
+
current_sum = list(accumulate(AX[indices[i]:indices[i+1]]))
|
| 113 |
+
mask_token[indices[i]:indices[i+1]] = torch.stack(current_sum,dim = 0)
|
| 114 |
+
mask_token = AX.reshape(B,N,H,L,W)
|
| 115 |
+
|
| 116 |
+
return mask_token
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def manba_mask(self,bcfff,cur_active):
|
| 120 |
+
'''
|
| 121 |
+
S6T
|
| 122 |
+
'''
|
| 123 |
+
B,N,H,W,L = cur_active.shape
|
| 124 |
+
cur_active_list = torch.chunk(cur_active,B,dim = 0)
|
| 125 |
+
bcfff_list = torch.chunk(bcfff,B,dim = 0)
|
| 126 |
+
mask_token_list=[]
|
| 127 |
+
for i in range(B):
|
| 128 |
+
mask_token_list.append(self.mask_token_every_batch(bcfff_list[i],cur_active_list[i]))
|
| 129 |
+
mask_token = torch.cat(mask_token_list, dim=0)
|
| 130 |
+
|
| 131 |
+
return mask_token
|
| 132 |
+
|
| 133 |
+
def forward(self, inp_bchwd: torch.Tensor, active_b1fff=None, vis=False):
|
| 134 |
+
# step1. Mask
|
| 135 |
+
if active_b1fff is None: # rand mask
|
| 136 |
+
active_b1fff: torch.BoolTensor = self.mask(inp_bchwd.shape[0], inp_bchwd.device) # (B, 1, f, f, f)
|
| 137 |
+
encoder._cur_active = active_b1fff # (B, 1, f, f)
|
| 138 |
+
active_b1hwd = active_b1fff.repeat_interleave(self.downsample_raito, 2).repeat_interleave(self.downsample_raito,
|
| 139 |
+
3).repeat_interleave(
|
| 140 |
+
self.downsample_raito, 4) # (B, 1, H, W, D)
|
| 141 |
+
masked_bchwd = inp_bchwd * active_b1hwd
|
| 142 |
+
|
| 143 |
+
# step2. Encode: get hierarchical encoded sparse features (a list containing 4 feature maps at 4 scales)
|
| 144 |
+
fea_bcfffs: List[torch.Tensor] = self.sparse_encoder(masked_bchwd, active_b1fff)
|
| 145 |
+
fea_bcfffs.reverse() # after reversion: from the smallest feature map to the largest
|
| 146 |
+
|
| 147 |
+
# step3. Densify: get hierarchical dense features for decoding (need to modified !!!!!!!!!!!)
|
| 148 |
+
cur_active = active_b1fff # (B, 1, f, f, f)
|
| 149 |
+
to_dec = []
|
| 150 |
+
for i, bcfff in enumerate(fea_bcfffs): # from the smallest feature map to the largest
|
| 151 |
+
if bcfff is not None:
|
| 152 |
+
bcfff = self.densify_norms[i](bcfff)
|
| 153 |
+
|
| 154 |
+
mask_tokens = self.manba_mask(bcfff,cur_active) if i==0 else self.mask_tokens[i].expand_as(bcfff)
|
| 155 |
+
|
| 156 |
+
# mask_tokens = self.mask_tokens[i].expand_as(bcfff)
|
| 157 |
+
bcfff = torch.where(cur_active.expand_as(bcfff), bcfff,
|
| 158 |
+
mask_tokens) # fill in empty (non-active) positions with [mask] tokens
|
| 159 |
+
bcfff: torch.Tensor = self.densify_projs[i](bcfff)
|
| 160 |
+
to_dec.append(bcfff)
|
| 161 |
+
cur_active = cur_active.repeat_interleave(2, dim=2).repeat_interleave(2, dim=3).repeat_interleave(2,
|
| 162 |
+
dim=4) # dilate the mask map, from (B, 1, f, f) to (B, 1, H, W)
|
| 163 |
+
# step4. Decode and reconstruct
|
| 164 |
+
rec_bchwd = self.dense_decoder(to_dec)
|
| 165 |
+
inp, rec = self.patchify(inp_bchwd), self.patchify(
|
| 166 |
+
rec_bchwd) # inp and rec: (B, L = f*f*f, N = C*downsample_raito**2)
|
| 167 |
+
mean = inp.mean(dim=-1, keepdim=True)
|
| 168 |
+
var = (inp.var(dim=-1, keepdim=True) + 1e-6) ** .5
|
| 169 |
+
inp = (inp - mean) / var
|
| 170 |
+
l2_loss = ((rec - inp) ** 2).mean(dim=2, keepdim=False) # (B, L, C) ==mean==> (B, L)
|
| 171 |
+
|
| 172 |
+
non_active = active_b1fff.logical_not().int().view(active_b1fff.shape[0], -1) # (B, 1, f, f, f) => (B, L)
|
| 173 |
+
recon_loss = l2_loss.mul_(non_active).sum() / (
|
| 174 |
+
non_active.sum() + 1e-8) # loss only on masked (non-active) patches
|
| 175 |
+
|
| 176 |
+
if vis:
|
| 177 |
+
masked_bchwd = inp_bchwd * active_b1hwd
|
| 178 |
+
rec_bchwd = self.unpatchify(rec * var + mean)
|
| 179 |
+
rec_or_inp = torch.where(active_b1hwd, inp_bchwd, rec_bchwd)
|
| 180 |
+
return inp_bchwd, masked_bchwd, rec_or_inp
|
| 181 |
+
else:
|
| 182 |
+
return recon_loss
|
| 183 |
+
|
| 184 |
+
def patchify(self, bchwd):
|
| 185 |
+
p = self.downsample_raito
|
| 186 |
+
h, w, d = self.fmap_h, self.fmap_w, self.fmap_d
|
| 187 |
+
B, C = bchwd.shape[:2]
|
| 188 |
+
bchwd = bchwd.reshape(shape=(B, C, h, p, w, p, d, p))
|
| 189 |
+
bchwd = torch.einsum('bchpwqds->bhwdpqsc', bchwd)
|
| 190 |
+
bln = bchwd.reshape(shape=(B, h * w * d, C * p ** 3)) # (B, f*f, 3*downsample_raito**2)
|
| 191 |
+
return bln
|
| 192 |
+
|
| 193 |
+
def unpatchify(self, bln):
|
| 194 |
+
p = self.downsample_raito
|
| 195 |
+
h, w, d = self.fmap_h, self.fmap_w, self.fmap_d
|
| 196 |
+
B, C = bln.shape[0], bln.shape[-1] // p ** 3
|
| 197 |
+
bln = bln.reshape(shape=(B, h, w, d, p, p, p, C))
|
| 198 |
+
bln = torch.einsum('bhwdpqsc->bchpwqds', bln)
|
| 199 |
+
bchwd = bln.reshape(shape=(B, C, h * p, w * p, d * p))
|
| 200 |
+
return bchwd
|
| 201 |
+
|
| 202 |
+
def __repr__(self):
|
| 203 |
+
return (
|
| 204 |
+
f'\n'
|
| 205 |
+
f'[MambaMIM.config]: {pformat(self.get_config(), indent=2, width=250)}\n'
|
| 206 |
+
f'[MambaMIM.structure]: {super(MambaMIM, self).__repr__().replace(MambaMIM.__name__, "")}'
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
def get_config(self):
|
| 210 |
+
return {
|
| 211 |
+
# self
|
| 212 |
+
'mask_ratio': self.mask_ratio,
|
| 213 |
+
'densify_norm_str': self.densify_norm_str,
|
| 214 |
+
'sbn': self.sbn, 'hierarchy': self.hierarchy,
|
| 215 |
+
|
| 216 |
+
# enc
|
| 217 |
+
'sparse_encoder.input_size': self.sparse_encoder.input_size,
|
| 218 |
+
# dec
|
| 219 |
+
'dense_decoder.width': self.dense_decoder.width,
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
def state_dict(self, destination=None, prefix='', keep_vars=False, with_config=False):
|
| 223 |
+
state = super(MambaMIM, self).state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
|
| 224 |
+
if with_config:
|
| 225 |
+
state['config'] = self.get_config()
|
| 226 |
+
return state
|
| 227 |
+
|
| 228 |
+
def load_state_dict(self, state_dict, strict=True):
|
| 229 |
+
config: dict = state_dict.pop('config', None)
|
| 230 |
+
incompatible_keys = super(MambaMIM, self).load_state_dict(state_dict, strict=strict)
|
| 231 |
+
if config is not None:
|
| 232 |
+
for k, v in self.get_config().items():
|
| 233 |
+
ckpt_v = config.get(k, None)
|
| 234 |
+
if ckpt_v != v:
|
| 235 |
+
err = f'[SparseMIM.load_state_dict] config mismatch: this.{k}={v} (ckpt.{k}={ckpt_v})'
|
| 236 |
+
if strict:
|
| 237 |
+
raise AttributeError(err)
|
| 238 |
+
else:
|
| 239 |
+
print(err, file=sys.stderr)
|
| 240 |
+
return incompatible_keys
|
models/__init__.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from timm.loss import SoftTargetCrossEntropy
|
| 3 |
+
from timm.models.layers import drop
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
from models.network.hymamba import Encoder
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# log more
|
| 12 |
+
def _ex_repr(self):
|
| 13 |
+
return ', '.join(
|
| 14 |
+
f'{k}=' + (f'{v:g}' if isinstance(v, float) else str(v))
|
| 15 |
+
for k, v in vars(self).items()
|
| 16 |
+
if not k.startswith('_') and k != 'training'
|
| 17 |
+
and not isinstance(v, (torch.nn.Module, torch.Tensor))
|
| 18 |
+
)
|
| 19 |
+
for clz in (torch.nn.CrossEntropyLoss, SoftTargetCrossEntropy, drop.DropPath):
|
| 20 |
+
if hasattr(clz, 'extra_repr'):
|
| 21 |
+
clz.extra_repr = _ex_repr
|
| 22 |
+
else:
|
| 23 |
+
clz.__repr__ = lambda self: f'{type(self).__name__}({_ex_repr(self)})'
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
pretrain_default_model_kwargs = {
|
| 27 |
+
'mambamim': dict(sparse=True, drop_path_rate=0.1),
|
| 28 |
+
}
|
| 29 |
+
for kw in pretrain_default_model_kwargs.values():
|
| 30 |
+
kw['pretrained'] = False
|
| 31 |
+
kw['num_classes'] = 0
|
| 32 |
+
kw['global_pool'] = ''
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def build_sparse_encoder(name: str, input_size: int, sbn=False, drop_path_rate=0.0, verbose=False):
|
| 37 |
+
from models.encoder import SparseEncoder
|
| 38 |
+
kwargs = pretrain_default_model_kwargs[name]
|
| 39 |
+
if drop_path_rate != 0:
|
| 40 |
+
kwargs['drop_path_rate'] = drop_path_rate
|
| 41 |
+
print(f'[build_sparse_encoder] model kwargs={kwargs}')
|
| 42 |
+
encoder = Encoder(
|
| 43 |
+
in_channel=1,
|
| 44 |
+
channels=(32, 64, 128, 192, 384),
|
| 45 |
+
depths=(1, 2, 2, 2, 1),
|
| 46 |
+
kernels=(3, 3, 3, 3, 3),
|
| 47 |
+
exp_r=(2, 2, 4, 4, 4),
|
| 48 |
+
img_size=96,
|
| 49 |
+
depth=4,
|
| 50 |
+
sparse=True)
|
| 51 |
+
return SparseEncoder(encoder=encoder, input_size=input_size, sbn=sbn, verbose=verbose)
|
models/decoder.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from timm.models.layers import trunc_normal_
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class UNetBlock(nn.Module):
|
| 10 |
+
def __init__(self, cin, cout, bn3d):
|
| 11 |
+
"""
|
| 12 |
+
a UNet block with 2x up sampling
|
| 13 |
+
"""
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.up_sample = nn.ConvTranspose3d(cin, cin, kernel_size=2, stride=2, padding=0, bias=True)
|
| 16 |
+
self.conv = nn.Sequential(
|
| 17 |
+
nn.Conv3d(cin, cout, kernel_size=3, stride=1, padding=1, bias=True), bn3d(cout), nn.ReLU(inplace=True),
|
| 18 |
+
nn.Conv3d(cout, cout, kernel_size=3, stride=1, padding=1, bias=True), bn3d(cout), nn.ReLU(inplace=True),
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
x = self.up_sample(x)
|
| 23 |
+
return self.conv(x)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class FusionBlock(nn.Module):
|
| 27 |
+
def __init__(self, cin, cout, bn3d):
|
| 28 |
+
"""
|
| 29 |
+
a fusionBlock block with 2x up sampling
|
| 30 |
+
"""
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.conv = nn.Sequential(
|
| 33 |
+
nn.Conv3d(cin, cout, kernel_size=3, stride=1, padding=1, bias=True), bn3d(cout), nn.ReLU(inplace=True),
|
| 34 |
+
nn.Conv3d(cout, cout, kernel_size=3, stride=1, padding=1, bias=True), bn3d(cout), nn.ReLU(inplace=True),
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
def forward(self, x):
|
| 38 |
+
return self.conv(x)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class LightDecoder(nn.Module):
|
| 42 |
+
def __init__(self, up_sample_ratio, width=768,
|
| 43 |
+
sbn=True): # todo: the decoder's width follows a simple halfing rule; you can change it to any other rule
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.width = width
|
| 46 |
+
n = round(math.log2(up_sample_ratio))
|
| 47 |
+
channels = [self.width // 2 ** i for i in range(
|
| 48 |
+
n + 1)] # todo: the decoder's width follows a simple halfing rule; you can change it to any other rule
|
| 49 |
+
bn3d = nn.BatchNorm3d
|
| 50 |
+
self.dec = nn.ModuleList([UNetBlock(cin, cout, bn3d) for (cin, cout) in zip(channels[:-1], channels[1:])])
|
| 51 |
+
self.fuse = nn.ModuleList([FusionBlock(cin * 2, cin, bn3d) for (cin, cout) in zip(channels[:-1], channels[1:])])
|
| 52 |
+
self.proj = nn.Conv3d(channels[-1], 1, kernel_size=1, stride=1, bias=True)
|
| 53 |
+
|
| 54 |
+
self.initialize()
|
| 55 |
+
|
| 56 |
+
def forward(self, to_dec: List[torch.Tensor]):
|
| 57 |
+
x = 0
|
| 58 |
+
for i, d in enumerate(self.dec):
|
| 59 |
+
if i < len(to_dec) and to_dec[i] is not None:
|
| 60 |
+
if isinstance(x, int):
|
| 61 |
+
x = x + to_dec[i]
|
| 62 |
+
else:
|
| 63 |
+
x = torch.cat((x, to_dec[i]), dim=1)
|
| 64 |
+
x = self.fuse[i](x)
|
| 65 |
+
x = self.dec[i](x)
|
| 66 |
+
return self.proj(x)
|
| 67 |
+
|
| 68 |
+
def extra_repr(self) -> str:
|
| 69 |
+
return f'width={self.width}'
|
| 70 |
+
|
| 71 |
+
def initialize(self):
|
| 72 |
+
for m in self.modules():
|
| 73 |
+
if isinstance(m, nn.Linear):
|
| 74 |
+
trunc_normal_(m.weight, std=.02)
|
| 75 |
+
if m.bias is not None:
|
| 76 |
+
nn.init.constant_(m.bias, 0)
|
| 77 |
+
elif isinstance(m, nn.Conv3d):
|
| 78 |
+
trunc_normal_(m.weight, std=.02)
|
| 79 |
+
if m.bias is not None:
|
| 80 |
+
nn.init.constant_(m.bias, 0)
|
| 81 |
+
elif isinstance(m, (nn.Conv3d, nn.ConvTranspose3d)):
|
| 82 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 83 |
+
if m.bias is not None:
|
| 84 |
+
nn.init.constant_(m.bias, 0.)
|
| 85 |
+
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)):
|
| 86 |
+
nn.init.constant_(m.bias, 0)
|
| 87 |
+
nn.init.constant_(m.weight, 1.0)
|
models/encoder.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from timm.models.layers import DropPath
|
| 4 |
+
|
| 5 |
+
_cur_active: torch.Tensor = None # B1fff
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# todo: try to use `gather` for speed?
|
| 9 |
+
def _get_active_ex_or_ii(H, W, D, returning_active_ex=True):
|
| 10 |
+
h_repeat, w_repeat, d_repeat = H // _cur_active.shape[-3], W // _cur_active.shape[-2], D // _cur_active.shape[-1]
|
| 11 |
+
active_ex = _cur_active.repeat_interleave(h_repeat, dim=2).repeat_interleave(w_repeat, dim=3).repeat_interleave(d_repeat, dim=4)
|
| 12 |
+
return active_ex if returning_active_ex else active_ex.squeeze(1).nonzero(as_tuple=True) # ii: bi, hi, wi
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def sp_conv_forward(self, x: torch.Tensor):
|
| 16 |
+
x = super(type(self), self).forward(x)
|
| 17 |
+
x *= _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], D=x.shape[4], returning_active_ex=True) # (BCHW) *= (B1HW), mask the output of conv
|
| 18 |
+
return x
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def sp_bn_forward(self, x: torch.Tensor):
|
| 22 |
+
ii = _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], D=x.shape[4], returning_active_ex=False)
|
| 23 |
+
|
| 24 |
+
bhwdc = x.permute(0, 2, 3, 4, 1)
|
| 25 |
+
nc = bhwdc[ii] # select the features on non-masked positions to form a flatten feature `nc`
|
| 26 |
+
nc = super(type(self), self).forward(nc) # use BN1d to normalize this flatten feature `nc`
|
| 27 |
+
|
| 28 |
+
bchwd = torch.zeros_like(bhwdc)
|
| 29 |
+
bchwd[ii] = nc
|
| 30 |
+
bchwd = bchwd.permute(0, 4, 1, 2, 3)
|
| 31 |
+
return bchwd
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def sp_in_forward(self, x: torch.Tensor):
|
| 35 |
+
ii = _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], D=x.shape[4], returning_active_ex=False)
|
| 36 |
+
bhwdc = x.permute(0, 2, 3, 4, 1)
|
| 37 |
+
cn = bhwdc[ii].permute(1,
|
| 38 |
+
0) # select the features on non-masked positions to form a flatten feature `nc` [17787, 3]
|
| 39 |
+
C, N = cn.shape
|
| 40 |
+
bcl = cn.reshape(C, -1, x.shape[0]).permute(2, 0, 1)
|
| 41 |
+
bcl = super(type(self), self).forward(bcl) # use BN1d to normalize this flatten feature `nc`
|
| 42 |
+
nc = bcl.permute(1, 2, 0).reshape(C, -1).permute(1, 0)
|
| 43 |
+
bchwd = torch.zeros_like(bhwdc)
|
| 44 |
+
bchwd[ii] = nc
|
| 45 |
+
bchwd = bchwd.permute(0, 4, 1, 2, 3)
|
| 46 |
+
return bchwd
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class SparseConv3d(nn.Conv3d):
|
| 50 |
+
forward = sp_conv_forward # hack: override the forward function; see `sp_conv_forward` above for more details
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class SparseMaxPooling(nn.MaxPool3d):
|
| 54 |
+
forward = sp_conv_forward # hack: override the forward function; see `sp_conv_forward` above for more details
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class SparseAvgPooling(nn.AvgPool3d):
|
| 58 |
+
forward = sp_conv_forward # hack: override the forward function; see `sp_conv_forward` above for more details
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class SparseBatchNorm3d(nn.BatchNorm1d):
|
| 62 |
+
forward = sp_bn_forward # hack: override the forward function; see `sp_bn_forward` above for more details
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class SparseSyncBatchNorm3d(nn.SyncBatchNorm):
|
| 66 |
+
forward = sp_bn_forward # hack: override the forward function; see `sp_bn_forward` above for more details
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class SparseInstanceNorm3d(nn.InstanceNorm1d):
|
| 70 |
+
forward = sp_in_forward # hack: override the forward function; see `sp_bn_forward` above for more details
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class SparseConvNeXtLayerNorm(nn.LayerNorm):
|
| 74 |
+
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
| 75 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
| 76 |
+
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
| 77 |
+
with shape (batch_size, channels, height, width).
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", sparse=True):
|
| 81 |
+
if data_format not in ["channels_last", "channels_first"]:
|
| 82 |
+
raise NotImplementedError
|
| 83 |
+
super().__init__(normalized_shape, eps, elementwise_affine=True)
|
| 84 |
+
self.data_format = data_format
|
| 85 |
+
self.sparse = sparse
|
| 86 |
+
|
| 87 |
+
def forward(self, x):
|
| 88 |
+
if x.ndim == 5: # BHWDC or BCHWD
|
| 89 |
+
if self.data_format == "channels_last": # BHWDC
|
| 90 |
+
if self.sparse:
|
| 91 |
+
ii = _get_active_ex_or_ii(H=x.shape[1], W=x.shape[2], D=x.shape[3], returning_active_ex=False)
|
| 92 |
+
nc = x[ii]
|
| 93 |
+
nc = super(SparseConvNeXtLayerNorm, self).forward(nc)
|
| 94 |
+
|
| 95 |
+
x = torch.zeros_like(x)
|
| 96 |
+
x[ii] = nc
|
| 97 |
+
return x
|
| 98 |
+
else:
|
| 99 |
+
return super(SparseConvNeXtLayerNorm, self).forward(x)
|
| 100 |
+
else: # channels_first, BCHWD
|
| 101 |
+
if self.sparse:
|
| 102 |
+
ii = _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], D=x.shape[4], returning_active_ex=False)
|
| 103 |
+
bhwc = x.permute(0, 2, 3, 4, 1)
|
| 104 |
+
nc = bhwc[ii]
|
| 105 |
+
nc = super(SparseConvNeXtLayerNorm, self).forward(nc)
|
| 106 |
+
|
| 107 |
+
x = torch.zeros_like(bhwc)
|
| 108 |
+
x[ii] = nc
|
| 109 |
+
return x.permute(0, 4, 1, 2, 3)
|
| 110 |
+
else:
|
| 111 |
+
u = x.mean(1, keepdim=True)
|
| 112 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
| 113 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
| 114 |
+
x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None]
|
| 115 |
+
return x
|
| 116 |
+
else: # BLC or BC
|
| 117 |
+
if self.sparse:
|
| 118 |
+
raise NotImplementedError
|
| 119 |
+
else:
|
| 120 |
+
return super(SparseConvNeXtLayerNorm, self).forward(x)
|
| 121 |
+
|
| 122 |
+
def __repr__(self):
|
| 123 |
+
return super(SparseConvNeXtLayerNorm, self).__repr__()[
|
| 124 |
+
:-1] + f', ch={self.data_format.split("_")[-1]}, sp={self.sparse})'
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class SparseConvNeXtBlock(nn.Module):
|
| 128 |
+
r""" ConvNeXt Block. There are two equivalent implementations:
|
| 129 |
+
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
| 130 |
+
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
| 131 |
+
We use (2) as we find it slightly faster in PyTorch
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
dim (int): Number of input channels.
|
| 135 |
+
drop_path (float): Stochastic depth rate. Default: 0.0
|
| 136 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
def __init__(self, in_channels, out_channels, kernel_size=7, exp_r=4, do_res=False, drop_path=0.,
|
| 140 |
+
layer_scale_init_value=1e-6, sparse=True):
|
| 141 |
+
super().__init__()
|
| 142 |
+
|
| 143 |
+
self.do_res = do_res
|
| 144 |
+
self.dwconv = nn.Conv3d(in_channels, in_channels, kernel_size=kernel_size, padding=kernel_size // 2,
|
| 145 |
+
groups=in_channels) # depthwise conv
|
| 146 |
+
self.norm = SparseConvNeXtLayerNorm(in_channels, eps=1e-6, sparse=sparse)
|
| 147 |
+
self.pwconv1 = nn.Linear(in_channels,
|
| 148 |
+
exp_r * in_channels) # pointwise/1x1 convs, implemented with linear layers
|
| 149 |
+
self.act = nn.GELU()
|
| 150 |
+
self.pwconv2 = nn.Linear(exp_r * in_channels, out_channels)
|
| 151 |
+
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((out_channels)),
|
| 152 |
+
requires_grad=True) if layer_scale_init_value > 0 else None
|
| 153 |
+
self.drop_path: nn.Module = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 154 |
+
self.sparse = sparse
|
| 155 |
+
|
| 156 |
+
def forward(self, x):
|
| 157 |
+
input = x
|
| 158 |
+
x = self.dwconv(x)
|
| 159 |
+
x = x.permute(0, 2, 3, 4, 1) # (N, C, H, W, D) -> (N, H, W, D, C)
|
| 160 |
+
x = self.norm(x)
|
| 161 |
+
x = self.pwconv1(x)
|
| 162 |
+
x = self.act(x) # GELU(0) == (0), so there is no need to mask x (no need to `x *= _get_active_ex_or_ii`)
|
| 163 |
+
x = self.pwconv2(x)
|
| 164 |
+
if self.gamma is not None:
|
| 165 |
+
x = self.gamma * x
|
| 166 |
+
x = x.permute(0, 4, 1, 2, 3) # (N, H, W, C) -> (N, C, H, W)
|
| 167 |
+
|
| 168 |
+
if self.sparse:
|
| 169 |
+
x *= _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], D=x.shape[4], returning_active_ex=True)
|
| 170 |
+
if self.do_res:
|
| 171 |
+
x = input + self.drop_path(x)
|
| 172 |
+
return x
|
| 173 |
+
|
| 174 |
+
def __repr__(self):
|
| 175 |
+
return super(SparseConvNeXtBlock, self).__repr__()[:-1] + f', sp={self.sparse})'
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class SparseEncoder(nn.Module):
|
| 179 |
+
def __init__(self, encoder, input_size, sbn=False, verbose=False):
|
| 180 |
+
super(SparseEncoder, self).__init__()
|
| 181 |
+
self.embeddings = SparseEncoder.dense_model_to_sparse(m=encoder.embeddings, verbose=verbose, sbn=sbn)
|
| 182 |
+
self.mae = encoder.mae
|
| 183 |
+
|
| 184 |
+
# self.encoder = SparseEncoder.dense_model_to_sparse(m=encoder, verbose=verbose, sbn=sbn)
|
| 185 |
+
self.input_size, self.downsample_raito, self.enc_feat_map_chs = input_size, encoder.get_downsample_ratio(), encoder.get_feature_map_channels()
|
| 186 |
+
|
| 187 |
+
@staticmethod
|
| 188 |
+
def dense_model_to_sparse(m: nn.Module, verbose=False, sbn=False):
|
| 189 |
+
oup = m
|
| 190 |
+
if isinstance(m, nn.Conv3d):
|
| 191 |
+
m: nn.Conv3d
|
| 192 |
+
bias = m.bias is not None
|
| 193 |
+
oup = SparseConv3d(
|
| 194 |
+
m.in_channels, m.out_channels,
|
| 195 |
+
kernel_size=m.kernel_size, stride=m.stride, padding=m.padding,
|
| 196 |
+
dilation=m.dilation, groups=m.groups, bias=bias, padding_mode=m.padding_mode,
|
| 197 |
+
)
|
| 198 |
+
oup.weight.data.copy_(m.weight.data)
|
| 199 |
+
if bias:
|
| 200 |
+
oup.bias.data.copy_(m.bias.data)
|
| 201 |
+
elif isinstance(m, nn.MaxPool3d):
|
| 202 |
+
m: nn.MaxPool3d
|
| 203 |
+
oup = SparseMaxPooling(m.kernel_size, stride=m.stride, padding=m.padding, dilation=m.dilation,
|
| 204 |
+
return_indices=m.return_indices, ceil_mode=m.ceil_mode)
|
| 205 |
+
elif isinstance(m, nn.AvgPool3d):
|
| 206 |
+
m: nn.AvgPool3d
|
| 207 |
+
oup = SparseAvgPooling(m.kernel_size, m.stride, m.padding, ceil_mode=m.ceil_mode,
|
| 208 |
+
count_include_pad=m.count_include_pad, divisor_override=m.divisor_override)
|
| 209 |
+
elif isinstance(m, (nn.BatchNorm3d, nn.SyncBatchNorm)):
|
| 210 |
+
m: nn.BatchNorm3d
|
| 211 |
+
oup = (SparseSyncBatchNorm3d if sbn else SparseBatchNorm3d)(m.weight.shape[0], eps=m.eps,
|
| 212 |
+
momentum=m.momentum, affine=m.affine,
|
| 213 |
+
track_running_stats=m.track_running_stats)
|
| 214 |
+
oup.weight.data.copy_(m.weight.data)
|
| 215 |
+
oup.bias.data.copy_(m.bias.data)
|
| 216 |
+
oup.running_mean.data.copy_(m.running_mean.data)
|
| 217 |
+
oup.running_var.data.copy_(m.running_var.data)
|
| 218 |
+
oup.num_batches_tracked.data.copy_(m.num_batches_tracked.data)
|
| 219 |
+
if hasattr(m, "qconfig"):
|
| 220 |
+
oup.qconfig = m.qconfig
|
| 221 |
+
elif isinstance(m, nn.InstanceNorm3d):
|
| 222 |
+
m: nn.InstanceNorm3d
|
| 223 |
+
oup = SparseInstanceNorm3d(m.num_features, eps=m.eps, momentum=m.momentum, affine=m.affine,
|
| 224 |
+
track_running_stats=m.track_running_stats)
|
| 225 |
+
if hasattr(m, "qconfig"):
|
| 226 |
+
oup.qconfig = m.qconfig
|
| 227 |
+
elif isinstance(m, nn.LayerNorm) and not isinstance(m, SparseConvNeXtLayerNorm):
|
| 228 |
+
m: nn.LayerNorm
|
| 229 |
+
oup = SparseConvNeXtLayerNorm(m.weight.shape[0], eps=m.eps)
|
| 230 |
+
oup.weight.data.copy_(m.weight.data)
|
| 231 |
+
oup.bias.data.copy_(m.bias.data)
|
| 232 |
+
elif isinstance(m, (nn.Conv1d,)):
|
| 233 |
+
m: nn.Conv1d
|
| 234 |
+
bias = m.bias is not None
|
| 235 |
+
oup = nn.Conv1d(
|
| 236 |
+
m.in_channels, m.out_channels,
|
| 237 |
+
kernel_size=m.kernel_size, stride=m.stride, padding=m.padding,
|
| 238 |
+
dilation=m.dilation, groups=m.groups, bias=bias, padding_mode=m.padding_mode)
|
| 239 |
+
oup.weight.data.copy_(m.weight.data)
|
| 240 |
+
if bias:
|
| 241 |
+
oup.bias.data.copy_(m.bias.data)
|
| 242 |
+
for name, child in m.named_children():
|
| 243 |
+
oup.add_module(name, SparseEncoder.dense_model_to_sparse(child, verbose=verbose, sbn=sbn))
|
| 244 |
+
del m
|
| 245 |
+
return oup
|
| 246 |
+
|
| 247 |
+
def forward(self, x, active_b1fff):
|
| 248 |
+
x1, x2, x3, x4, x5 = self.embeddings(x)
|
| 249 |
+
_x5 = self.mae(x5, active_b1fff)
|
| 250 |
+
return [x1, x2, x3, x4, _x5]
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
if __name__ == '__main__':
|
| 254 |
+
x = torch.randn([1, 96, 24, 24, 24])
|
| 255 |
+
_cur_active = torch.randn([1, 1, 96 // 16, 96 // 16, 96 // 16])
|
| 256 |
+
print(x.shape)
|
| 257 |
+
print(_get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], D=x.shape[4], returning_active_ex=True).shape)
|
| 258 |
+
print(x.shape)
|
models/mamba/bi_vision_mamba.py
ADDED
|
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023, Tri Dao, Albert Gu.
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
|
| 11 |
+
from einops import rearrange, repeat
|
| 12 |
+
|
| 13 |
+
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
| 17 |
+
except ImportError:
|
| 18 |
+
causal_conv1d_fn, causal_conv1d_update = None, None
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
| 22 |
+
except ImportError:
|
| 23 |
+
selective_state_update = None
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
|
| 27 |
+
except ImportError:
|
| 28 |
+
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class Mamba(nn.Module):
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
d_model,
|
| 35 |
+
d_state=16,
|
| 36 |
+
d_conv=4,
|
| 37 |
+
expand=2,
|
| 38 |
+
dt_rank="auto",
|
| 39 |
+
dt_min=0.001,
|
| 40 |
+
dt_max=0.1,
|
| 41 |
+
dt_init="random",
|
| 42 |
+
dt_scale=1.0,
|
| 43 |
+
dt_init_floor=1e-4,
|
| 44 |
+
conv_bias=True,
|
| 45 |
+
bias=False,
|
| 46 |
+
use_fast_path=True, # Fused kernel options
|
| 47 |
+
layer_idx=None,
|
| 48 |
+
device=None,
|
| 49 |
+
dtype=None,
|
| 50 |
+
bimamba_type="none"
|
| 51 |
+
):
|
| 52 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.d_model = d_model
|
| 55 |
+
self.d_state = d_state
|
| 56 |
+
self.d_conv = d_conv
|
| 57 |
+
self.expand = expand
|
| 58 |
+
self.d_inner = int(self.expand * self.d_model)
|
| 59 |
+
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
|
| 60 |
+
self.use_fast_path = use_fast_path
|
| 61 |
+
self.layer_idx = layer_idx
|
| 62 |
+
self.bimamba_type = bimamba_type
|
| 63 |
+
|
| 64 |
+
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
|
| 65 |
+
|
| 66 |
+
self.conv1d = nn.Conv1d(
|
| 67 |
+
in_channels=self.d_inner,
|
| 68 |
+
out_channels=self.d_inner,
|
| 69 |
+
bias=conv_bias,
|
| 70 |
+
kernel_size=d_conv,
|
| 71 |
+
groups=self.d_inner,
|
| 72 |
+
padding=d_conv - 1,
|
| 73 |
+
**factory_kwargs,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
self.activation = "silu"
|
| 77 |
+
self.act = nn.SiLU()
|
| 78 |
+
|
| 79 |
+
self.x_proj = nn.Linear(
|
| 80 |
+
self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
|
| 81 |
+
)
|
| 82 |
+
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
|
| 83 |
+
|
| 84 |
+
# Initialize special dt projection to preserve variance at initialization
|
| 85 |
+
dt_init_std = self.dt_rank ** -0.5 * dt_scale
|
| 86 |
+
if dt_init == "constant":
|
| 87 |
+
nn.init.constant_(self.dt_proj.weight, dt_init_std)
|
| 88 |
+
elif dt_init == "random":
|
| 89 |
+
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
|
| 90 |
+
else:
|
| 91 |
+
raise NotImplementedError
|
| 92 |
+
|
| 93 |
+
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
|
| 94 |
+
dt = torch.exp(
|
| 95 |
+
torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
|
| 96 |
+
+ math.log(dt_min)
|
| 97 |
+
).clamp(min=dt_init_floor)
|
| 98 |
+
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
| 99 |
+
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
| 100 |
+
with torch.no_grad():
|
| 101 |
+
self.dt_proj.bias.copy_(inv_dt)
|
| 102 |
+
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
|
| 103 |
+
self.dt_proj.bias._no_reinit = True
|
| 104 |
+
|
| 105 |
+
# S4D real initialization
|
| 106 |
+
A = repeat(
|
| 107 |
+
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
|
| 108 |
+
"n -> d n",
|
| 109 |
+
d=self.d_inner,
|
| 110 |
+
).contiguous()
|
| 111 |
+
A_log = torch.log(A) # Keep A_log in fp32
|
| 112 |
+
self.A_log = nn.Parameter(A_log)
|
| 113 |
+
self.A_log._no_weight_decay = True
|
| 114 |
+
|
| 115 |
+
# D "skip" parameter
|
| 116 |
+
self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
|
| 117 |
+
self.D._no_weight_decay = True
|
| 118 |
+
|
| 119 |
+
# assert bimamba_type == "v2"
|
| 120 |
+
|
| 121 |
+
A_b = repeat(
|
| 122 |
+
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
|
| 123 |
+
"n -> d n",
|
| 124 |
+
d=self.d_inner,
|
| 125 |
+
).contiguous()
|
| 126 |
+
A_b_log = torch.log(A_b) # Keep A_b_log in fp32
|
| 127 |
+
self.A_b_log = nn.Parameter(A_b_log)
|
| 128 |
+
self.A_b_log._no_weight_decay = True
|
| 129 |
+
|
| 130 |
+
self.conv1d_b = nn.Conv1d(
|
| 131 |
+
in_channels=self.d_inner,
|
| 132 |
+
out_channels=self.d_inner,
|
| 133 |
+
bias=conv_bias,
|
| 134 |
+
kernel_size=d_conv,
|
| 135 |
+
groups=self.d_inner,
|
| 136 |
+
padding=d_conv - 1,
|
| 137 |
+
**factory_kwargs,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
self.x_proj_b = nn.Linear(
|
| 141 |
+
self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
|
| 142 |
+
)
|
| 143 |
+
self.dt_proj_b = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
|
| 144 |
+
|
| 145 |
+
self.D_b = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
|
| 146 |
+
self.D_b._no_weight_decay = True
|
| 147 |
+
|
| 148 |
+
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
|
| 149 |
+
|
| 150 |
+
def forward(self, hidden_states, inference_params=None):
|
| 151 |
+
"""
|
| 152 |
+
hidden_states: (B, L, D)
|
| 153 |
+
Returns: same shape as hidden_states
|
| 154 |
+
"""
|
| 155 |
+
batch, seqlen, dim = hidden_states.shape
|
| 156 |
+
|
| 157 |
+
conv_state, ssm_state = None, None
|
| 158 |
+
if inference_params is not None:
|
| 159 |
+
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
|
| 160 |
+
if inference_params.seqlen_offset > 0:
|
| 161 |
+
# The states are updated inplace
|
| 162 |
+
out, _, _ = self.step(hidden_states, conv_state, ssm_state)
|
| 163 |
+
return out
|
| 164 |
+
|
| 165 |
+
# We do matmul and transpose BLH -> HBL at the same time
|
| 166 |
+
xz = rearrange(
|
| 167 |
+
self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
|
| 168 |
+
"d (b l) -> b d l",
|
| 169 |
+
l=seqlen,
|
| 170 |
+
)
|
| 171 |
+
if self.in_proj.bias is not None:
|
| 172 |
+
xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
|
| 173 |
+
|
| 174 |
+
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
| 175 |
+
# In the backward pass we write dx and dz next to each other to avoid torch.cat
|
| 176 |
+
if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None: # Doesn't support outputting the states
|
| 177 |
+
if self.bimamba_type == "v2":
|
| 178 |
+
A_b = -torch.exp(self.A_b_log.float())
|
| 179 |
+
out = mamba_inner_fn_no_out_proj(
|
| 180 |
+
xz,
|
| 181 |
+
self.conv1d.weight,
|
| 182 |
+
self.conv1d.bias,
|
| 183 |
+
self.x_proj.weight,
|
| 184 |
+
self.dt_proj.weight,
|
| 185 |
+
A,
|
| 186 |
+
None, # input-dependent B
|
| 187 |
+
None, # input-dependent C
|
| 188 |
+
self.D.float(),
|
| 189 |
+
delta_bias=self.dt_proj.bias.float(),
|
| 190 |
+
delta_softplus=True,
|
| 191 |
+
)
|
| 192 |
+
out_b = mamba_inner_fn_no_out_proj(
|
| 193 |
+
xz.flip([-1]),
|
| 194 |
+
self.conv1d_b.weight,
|
| 195 |
+
self.conv1d_b.bias,
|
| 196 |
+
self.x_proj_b.weight,
|
| 197 |
+
self.dt_proj_b.weight,
|
| 198 |
+
A_b,
|
| 199 |
+
None,
|
| 200 |
+
None,
|
| 201 |
+
self.D_b.float(),
|
| 202 |
+
delta_bias=self.dt_proj_b.bias.float(),
|
| 203 |
+
delta_softplus=True,
|
| 204 |
+
)
|
| 205 |
+
# F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
|
| 206 |
+
out = F.linear(rearrange(out + out_b.flip([-1]), "b d l -> b l d"), self.out_proj.weight,
|
| 207 |
+
self.out_proj.bias)
|
| 208 |
+
else:
|
| 209 |
+
out = mamba_inner_fn(
|
| 210 |
+
xz,
|
| 211 |
+
self.conv1d.weight,
|
| 212 |
+
self.conv1d.bias,
|
| 213 |
+
self.x_proj.weight,
|
| 214 |
+
self.dt_proj.weight,
|
| 215 |
+
self.out_proj.weight,
|
| 216 |
+
self.out_proj.bias,
|
| 217 |
+
A,
|
| 218 |
+
None, # input-dependent B
|
| 219 |
+
None, # input-dependent C
|
| 220 |
+
self.D.float(),
|
| 221 |
+
delta_bias=self.dt_proj.bias.float(),
|
| 222 |
+
delta_softplus=True,
|
| 223 |
+
)
|
| 224 |
+
else:
|
| 225 |
+
x, z = xz.chunk(2, dim=1)
|
| 226 |
+
# Compute short convolution
|
| 227 |
+
if conv_state is not None:
|
| 228 |
+
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
|
| 229 |
+
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
|
| 230 |
+
conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
|
| 231 |
+
if causal_conv1d_fn is None:
|
| 232 |
+
x = self.act(self.conv1d(x)[..., :seqlen])
|
| 233 |
+
else:
|
| 234 |
+
assert self.activation in ["silu", "swish"]
|
| 235 |
+
x = causal_conv1d_fn(
|
| 236 |
+
x=x,
|
| 237 |
+
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
| 238 |
+
bias=self.conv1d.bias,
|
| 239 |
+
activation=self.activation,
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
# We're careful here about the layout, to avoid extra transposes.
|
| 243 |
+
# We want dt to have d as the slowest moving dimension
|
| 244 |
+
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
| 245 |
+
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
|
| 246 |
+
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
| 247 |
+
dt = self.dt_proj.weight @ dt.t()
|
| 248 |
+
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
|
| 249 |
+
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
| 250 |
+
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
| 251 |
+
assert self.activation in ["silu", "swish"]
|
| 252 |
+
y = selective_scan_fn(
|
| 253 |
+
x,
|
| 254 |
+
dt,
|
| 255 |
+
A,
|
| 256 |
+
B,
|
| 257 |
+
C,
|
| 258 |
+
self.D.float(),
|
| 259 |
+
z=z,
|
| 260 |
+
delta_bias=self.dt_proj.bias.float(),
|
| 261 |
+
delta_softplus=True,
|
| 262 |
+
return_last_state=ssm_state is not None,
|
| 263 |
+
)
|
| 264 |
+
if ssm_state is not None:
|
| 265 |
+
y, last_state = y
|
| 266 |
+
ssm_state.copy_(last_state)
|
| 267 |
+
y = rearrange(y, "b d l -> b l d")
|
| 268 |
+
out = self.out_proj(y)
|
| 269 |
+
return out
|
| 270 |
+
|
| 271 |
+
def step(self, hidden_states, conv_state, ssm_state):
|
| 272 |
+
dtype = hidden_states.dtype
|
| 273 |
+
assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
|
| 274 |
+
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
|
| 275 |
+
x, z = xz.chunk(2, dim=-1) # (B D)
|
| 276 |
+
|
| 277 |
+
# Conv step
|
| 278 |
+
if causal_conv1d_update is None:
|
| 279 |
+
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
|
| 280 |
+
conv_state[:, :, -1] = x
|
| 281 |
+
x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
|
| 282 |
+
if self.conv1d.bias is not None:
|
| 283 |
+
x = x + self.conv1d.bias
|
| 284 |
+
x = self.act(x).to(dtype=dtype)
|
| 285 |
+
else:
|
| 286 |
+
x = causal_conv1d_update(
|
| 287 |
+
x,
|
| 288 |
+
conv_state,
|
| 289 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
| 290 |
+
self.conv1d.bias,
|
| 291 |
+
self.activation,
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
|
| 295 |
+
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
| 296 |
+
# Don't add dt_bias here
|
| 297 |
+
dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
|
| 298 |
+
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
| 299 |
+
|
| 300 |
+
# SSM step
|
| 301 |
+
if selective_state_update is None:
|
| 302 |
+
# Discretize A and B
|
| 303 |
+
dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
|
| 304 |
+
dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
|
| 305 |
+
dB = torch.einsum("bd,bn->bdn", dt, B)
|
| 306 |
+
ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
|
| 307 |
+
y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
|
| 308 |
+
y = y + self.D.to(dtype) * x
|
| 309 |
+
y = y * self.act(z) # (B D)
|
| 310 |
+
else:
|
| 311 |
+
y = selective_state_update(
|
| 312 |
+
ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
out = self.out_proj(y)
|
| 316 |
+
return out.unsqueeze(1), conv_state, ssm_state
|
| 317 |
+
|
| 318 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 319 |
+
device = self.out_proj.weight.device
|
| 320 |
+
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
|
| 321 |
+
conv_state = torch.zeros(
|
| 322 |
+
batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
|
| 323 |
+
)
|
| 324 |
+
ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
|
| 325 |
+
# ssm_dtype = torch.float32
|
| 326 |
+
ssm_state = torch.zeros(
|
| 327 |
+
batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
|
| 328 |
+
)
|
| 329 |
+
return conv_state, ssm_state
|
| 330 |
+
|
| 331 |
+
def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
|
| 332 |
+
assert self.layer_idx is not None
|
| 333 |
+
if self.layer_idx not in inference_params.key_value_memory_dict:
|
| 334 |
+
batch_shape = (batch_size,)
|
| 335 |
+
conv_state = torch.zeros(
|
| 336 |
+
batch_size,
|
| 337 |
+
self.d_model * self.expand,
|
| 338 |
+
self.d_conv,
|
| 339 |
+
device=self.conv1d.weight.device,
|
| 340 |
+
dtype=self.conv1d.weight.dtype,
|
| 341 |
+
)
|
| 342 |
+
ssm_state = torch.zeros(
|
| 343 |
+
batch_size,
|
| 344 |
+
self.d_model * self.expand,
|
| 345 |
+
self.d_state,
|
| 346 |
+
device=self.dt_proj.weight.device,
|
| 347 |
+
dtype=self.dt_proj.weight.dtype,
|
| 348 |
+
# dtype=torch.float32,
|
| 349 |
+
)
|
| 350 |
+
inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
|
| 351 |
+
else:
|
| 352 |
+
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
|
| 353 |
+
# TODO: What if batch size changes between generation, and we reuse the same states?
|
| 354 |
+
if initialize_states:
|
| 355 |
+
conv_state.zero_()
|
| 356 |
+
ssm_state.zero_()
|
| 357 |
+
return conv_state, ssm_state
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
class Block(nn.Module):
|
| 361 |
+
def __init__(
|
| 362 |
+
self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
|
| 363 |
+
):
|
| 364 |
+
"""
|
| 365 |
+
Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
|
| 366 |
+
|
| 367 |
+
This Block has a slightly different structure compared to a regular
|
| 368 |
+
prenorm Transformer block.
|
| 369 |
+
The standard block is: LN -> MHA/MLP -> Add.
|
| 370 |
+
[Ref: https://arxiv.org/abs/2002.04745]
|
| 371 |
+
Here we have: Add -> LN -> Mixer, returning both
|
| 372 |
+
the hidden_states (output of the mixer) and the residual.
|
| 373 |
+
This is purely for performance reasons, as we can fuse add and LayerNorm.
|
| 374 |
+
The residual needs to be provided (except for the very first block).
|
| 375 |
+
"""
|
| 376 |
+
super().__init__()
|
| 377 |
+
self.residual_in_fp32 = residual_in_fp32
|
| 378 |
+
self.fused_add_norm = fused_add_norm
|
| 379 |
+
self.mixer = mixer_cls(dim)
|
| 380 |
+
self.norm = norm_cls(dim)
|
| 381 |
+
if self.fused_add_norm:
|
| 382 |
+
assert RMSNorm is not None, "RMSNorm import fails"
|
| 383 |
+
assert isinstance(
|
| 384 |
+
self.norm, (nn.LayerNorm, RMSNorm)
|
| 385 |
+
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
|
| 386 |
+
|
| 387 |
+
def forward(
|
| 388 |
+
self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
|
| 389 |
+
):
|
| 390 |
+
r"""Pass the input through the encoder layer.
|
| 391 |
+
|
| 392 |
+
Args:
|
| 393 |
+
hidden_states: the sequence to the encoder layer (required).
|
| 394 |
+
residual: hidden_states = Mixer(LN(residual))
|
| 395 |
+
"""
|
| 396 |
+
if not self.fused_add_norm:
|
| 397 |
+
residual = (hidden_states + residual) if residual is not None else hidden_states
|
| 398 |
+
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
|
| 399 |
+
if self.residual_in_fp32:
|
| 400 |
+
residual = residual.to(torch.float32)
|
| 401 |
+
else:
|
| 402 |
+
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
|
| 403 |
+
hidden_states, residual = fused_add_norm_fn(
|
| 404 |
+
hidden_states,
|
| 405 |
+
self.norm.weight,
|
| 406 |
+
self.norm.bias,
|
| 407 |
+
residual=residual,
|
| 408 |
+
prenorm=True,
|
| 409 |
+
residual_in_fp32=self.residual_in_fp32,
|
| 410 |
+
eps=self.norm.eps,
|
| 411 |
+
)
|
| 412 |
+
hidden_states = self.mixer(hidden_states, inference_params=inference_params)
|
| 413 |
+
return hidden_states, residual
|
| 414 |
+
|
| 415 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 416 |
+
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
models/network/hymamba.py
ADDED
|
@@ -0,0 +1,681 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from models.encoder import SparseConvNeXtLayerNorm, _get_active_ex_or_ii
|
| 4 |
+
from typing import Optional, Sequence, Tuple, Union, List
|
| 5 |
+
import numpy as np
|
| 6 |
+
from models.mamba.bi_vision_mamba import Mamba
|
| 7 |
+
from monai.networks.blocks.unetr_block import UnetrUpBlock
|
| 8 |
+
|
| 9 |
+
def build_3d_sincos_position_embedding(grid_size, embed_dim, num_tokens=0, temperature=10000.):
|
| 10 |
+
grid_size = (grid_size, grid_size, grid_size)
|
| 11 |
+
h, w, d = grid_size
|
| 12 |
+
grid_h = torch.arange(h, dtype=torch.float32)
|
| 13 |
+
grid_w = torch.arange(w, dtype=torch.float32)
|
| 14 |
+
grid_d = torch.arange(d, dtype=torch.float32)
|
| 15 |
+
|
| 16 |
+
grid_h, grid_w, grid_d = torch.meshgrid(grid_h, grid_w, grid_d)
|
| 17 |
+
assert embed_dim % 6 == 0, 'Embed dimension must be divisible by 6 for 3D sin-cos position embedding'
|
| 18 |
+
pos_dim = embed_dim // 6
|
| 19 |
+
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
|
| 20 |
+
omega = 1. / (temperature ** omega)
|
| 21 |
+
out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
|
| 22 |
+
out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])
|
| 23 |
+
out_d = torch.einsum('m,d->md', [grid_d.flatten(), omega])
|
| 24 |
+
pos_emb = torch.cat(
|
| 25 |
+
[torch.sin(out_h), torch.cos(out_h), torch.sin(out_w), torch.cos(out_w), torch.sin(out_d), torch.cos(out_d)],
|
| 26 |
+
dim=1)[None, :, :]
|
| 27 |
+
|
| 28 |
+
assert num_tokens == 1 or num_tokens == 0, "Number of tokens must be of 0 or 1"
|
| 29 |
+
if num_tokens == 1:
|
| 30 |
+
pe_token = torch.zeros([1, 1, embed_dim], dtype=torch.float32)
|
| 31 |
+
pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1))
|
| 32 |
+
else:
|
| 33 |
+
pos_embed = nn.Parameter(pos_emb)
|
| 34 |
+
pos_embed.requires_grad = False
|
| 35 |
+
return pos_embed
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class MlpChannel(nn.Module):
|
| 39 |
+
def __init__(self, hidden_size, mlp_dim):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.fc1 = nn.Linear(hidden_size, mlp_dim)
|
| 42 |
+
self.act = nn.GELU()
|
| 43 |
+
self.fc2 = nn.Linear(mlp_dim, hidden_size)
|
| 44 |
+
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
x = self.fc1(x)
|
| 47 |
+
x = self.act(x)
|
| 48 |
+
x = self.fc2(x)
|
| 49 |
+
return x
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class MambaLayer(nn.Module):
|
| 53 |
+
def __init__(self, dim, d_state=16, d_conv=4, expand=2):
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.dim = dim
|
| 56 |
+
self.norm1 = nn.LayerNorm(dim)
|
| 57 |
+
self.mamba = Mamba(
|
| 58 |
+
d_model=dim, # Model dimension d_model
|
| 59 |
+
d_state=d_state, # SSM state expansion factor
|
| 60 |
+
d_conv=d_conv, # Local convolution width
|
| 61 |
+
expand=expand, # Block expansion factor
|
| 62 |
+
bimamba_type="v1",
|
| 63 |
+
)
|
| 64 |
+
self.mlp = MlpChannel(hidden_size=dim, mlp_dim=2 * dim)
|
| 65 |
+
self.norm2 = nn.LayerNorm(dim)
|
| 66 |
+
def forward(self, x):
|
| 67 |
+
x = self.mamba(self.norm1(x)) + x
|
| 68 |
+
x = self.mlp(self.norm2(x)) + x
|
| 69 |
+
return x
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class MaskedAutoencoderMamba(nn.Module):
|
| 73 |
+
""" Masked Autoencoder with VisionTransformer backbone
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
def __init__(self, img_size=96, downsample_rato=16, embed_dim=384, depth=8, norm_layer=nn.LayerNorm, sparse=True):
|
| 77 |
+
super().__init__()
|
| 78 |
+
print("mamba sparse: ", sparse)
|
| 79 |
+
# --------------------------------------------------------------------------
|
| 80 |
+
# MAE encoder specifics
|
| 81 |
+
self.grid_size = img_size // downsample_rato
|
| 82 |
+
self.num_patches = (self.grid_size) ** 3
|
| 83 |
+
self.embed_dim = embed_dim
|
| 84 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim),
|
| 85 |
+
requires_grad=False) # fixed sin-cos embedding
|
| 86 |
+
|
| 87 |
+
self.blocks = nn.ModuleList([
|
| 88 |
+
MambaLayer(dim=embed_dim)
|
| 89 |
+
for i in range(depth)])
|
| 90 |
+
# self.gsc = GSC(in_channels=embed_dim, sparse=sparse)
|
| 91 |
+
|
| 92 |
+
self.sparse = sparse
|
| 93 |
+
if self.sparse:
|
| 94 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 95 |
+
# --------------------------------------------------------------------------
|
| 96 |
+
self.initialize_weights()
|
| 97 |
+
|
| 98 |
+
def initialize_weights(self):
|
| 99 |
+
# initialization
|
| 100 |
+
# initialize (and freeze) pos_embed by sin-cos embedding
|
| 101 |
+
pos_embed = build_3d_sincos_position_embedding(self.grid_size, self.embed_dim)
|
| 102 |
+
self.pos_embed.data.copy_(pos_embed)
|
| 103 |
+
if self.sparse:
|
| 104 |
+
torch.nn.init.normal_(self.mask_token, std=.02)
|
| 105 |
+
# initialize nn.Linear and nn.LayerNorm
|
| 106 |
+
self.apply(self._init_weights)
|
| 107 |
+
|
| 108 |
+
def _init_weights(self, m):
|
| 109 |
+
if isinstance(m, nn.Linear):
|
| 110 |
+
# we use xavier_uniform following official JAX ViT:
|
| 111 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
| 112 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 113 |
+
nn.init.constant_(m.bias, 0)
|
| 114 |
+
elif isinstance(m, nn.LayerNorm):
|
| 115 |
+
nn.init.constant_(m.bias, 0)
|
| 116 |
+
nn.init.constant_(m.weight, 1.0)
|
| 117 |
+
|
| 118 |
+
def random_masking(self, enc, active_b1fff):
|
| 119 |
+
"""
|
| 120 |
+
Perform per-sample random masking by per-sample shuffling.
|
| 121 |
+
Per-sample shuffling is done by argsort random noise.
|
| 122 |
+
x: [N, L, D], sequence
|
| 123 |
+
"""
|
| 124 |
+
N, L, D = enc.shape # batch, length, dim
|
| 125 |
+
mask = torch.tensor(active_b1fff, dtype=torch.int).flatten(2).transpose(1, 2)
|
| 126 |
+
# sort noise for each sample
|
| 127 |
+
noise = 1 - mask
|
| 128 |
+
len_keep = torch.sum(mask)
|
| 129 |
+
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
|
| 130 |
+
ids_restore = torch.argsort(ids_shuffle, dim=1)
|
| 131 |
+
|
| 132 |
+
# keep the first subset
|
| 133 |
+
ids_keep = ids_shuffle[:, :len_keep]
|
| 134 |
+
x_masked = torch.gather(enc, dim=1, index=ids_keep.repeat(1, 1, D))
|
| 135 |
+
|
| 136 |
+
# generate the binary mask: 0 is keep, 1 is remove
|
| 137 |
+
return x_masked, mask, ids_restore
|
| 138 |
+
|
| 139 |
+
def unmasking(self, x, ids_restore):
|
| 140 |
+
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1)
|
| 141 |
+
x_ = torch.cat([x, mask_tokens], dim=1) # no cls token
|
| 142 |
+
x = torch.gather(x_, dim=1, index=ids_restore.repeat(1, 1, x.shape[2])) # unshuffle
|
| 143 |
+
return x
|
| 144 |
+
|
| 145 |
+
def forward_encoder(self, enc, active_b1fff=None):
|
| 146 |
+
# enc = self.gsc(enc)
|
| 147 |
+
B, C, H, W, D = enc.shape
|
| 148 |
+
x = enc.flatten(2).transpose(1, 2)
|
| 149 |
+
# add pos embed w/o cls token
|
| 150 |
+
x = x + self.pos_embed
|
| 151 |
+
if self.sparse:
|
| 152 |
+
# masking: length -> length * mask_ratio
|
| 153 |
+
x, mask, ids_restore = self.random_masking(x, active_b1fff)
|
| 154 |
+
# apply Mamba blocks
|
| 155 |
+
for blk in self.blocks:
|
| 156 |
+
x = blk(x)
|
| 157 |
+
x = self.unmasking(x, ids_restore)
|
| 158 |
+
else:
|
| 159 |
+
for blk in self.blocks:
|
| 160 |
+
x = blk(x)
|
| 161 |
+
x = x.transpose(1, 2).reshape(B, C, H, W, D)
|
| 162 |
+
return x
|
| 163 |
+
|
| 164 |
+
def forward(self, imgs, active_b1fff=None):
|
| 165 |
+
return self.forward_encoder(imgs, active_b1fff)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class MedNeXtBlock(nn.Module):
|
| 169 |
+
def __init__(self,
|
| 170 |
+
in_channels: int,
|
| 171 |
+
out_channels: int,
|
| 172 |
+
exp_r: int = 4,
|
| 173 |
+
kernel_size: int = 7,
|
| 174 |
+
do_res: int = True,
|
| 175 |
+
n_groups: int or None = None,
|
| 176 |
+
sparse=False):
|
| 177 |
+
|
| 178 |
+
super().__init__()
|
| 179 |
+
|
| 180 |
+
self.do_res = do_res
|
| 181 |
+
self.sparse = sparse
|
| 182 |
+
conv = nn.Conv3d
|
| 183 |
+
|
| 184 |
+
# First convolution layer with DepthWise Convolutions
|
| 185 |
+
self.conv1 = conv(
|
| 186 |
+
in_channels=in_channels,
|
| 187 |
+
out_channels=in_channels,
|
| 188 |
+
kernel_size=kernel_size,
|
| 189 |
+
stride=1,
|
| 190 |
+
padding=kernel_size // 2,
|
| 191 |
+
groups=in_channels if n_groups is None else n_groups,
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# Normalization Layer. GroupNorm is used by default.
|
| 195 |
+
|
| 196 |
+
self.norm = SparseConvNeXtLayerNorm(normalized_shape=in_channels, data_format='channels_first', sparse=sparse)
|
| 197 |
+
|
| 198 |
+
# Second convolution (Expansion) layer with Conv3D 1x1x1
|
| 199 |
+
self.conv2 = conv(
|
| 200 |
+
in_channels=in_channels,
|
| 201 |
+
out_channels=exp_r * in_channels,
|
| 202 |
+
kernel_size=1,
|
| 203 |
+
stride=1,
|
| 204 |
+
padding=0
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
# GeLU activations
|
| 208 |
+
self.act = nn.GELU()
|
| 209 |
+
|
| 210 |
+
# Third convolution (Compression) layer with Conv3D 1x1x1
|
| 211 |
+
self.conv3 = conv(
|
| 212 |
+
in_channels=exp_r * in_channels,
|
| 213 |
+
out_channels=out_channels,
|
| 214 |
+
kernel_size=1,
|
| 215 |
+
stride=1,
|
| 216 |
+
padding=0
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
def forward(self, x, dummy_tensor=None):
|
| 220 |
+
|
| 221 |
+
x1 = x
|
| 222 |
+
x1 = self.conv1(x1)
|
| 223 |
+
x1 = self.act(self.conv2(self.norm(x1)))
|
| 224 |
+
x1 = self.conv3(x1)
|
| 225 |
+
if self.sparse:
|
| 226 |
+
x1 *= _get_active_ex_or_ii(H=x1.shape[2], W=x1.shape[3], D=x1.shape[4], returning_active_ex=True)
|
| 227 |
+
if self.do_res:
|
| 228 |
+
x1 = x + x1
|
| 229 |
+
return x1
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class MedNeXtDownBlock(MedNeXtBlock):
|
| 233 |
+
|
| 234 |
+
def __init__(self, in_channels, out_channels, exp_r=4, kernel_size=7,
|
| 235 |
+
do_res=False, sparse=False):
|
| 236 |
+
|
| 237 |
+
super().__init__(in_channels, out_channels, exp_r, kernel_size,
|
| 238 |
+
do_res=False, sparse=sparse)
|
| 239 |
+
|
| 240 |
+
self.resample_do_res = do_res
|
| 241 |
+
if do_res:
|
| 242 |
+
self.res_conv = nn.Conv3d(
|
| 243 |
+
in_channels=in_channels,
|
| 244 |
+
out_channels=out_channels,
|
| 245 |
+
kernel_size=1,
|
| 246 |
+
stride=2
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
self.conv1 = nn.Conv3d(
|
| 250 |
+
in_channels=in_channels,
|
| 251 |
+
out_channels=in_channels,
|
| 252 |
+
kernel_size=kernel_size,
|
| 253 |
+
stride=2,
|
| 254 |
+
padding=kernel_size // 2,
|
| 255 |
+
groups=in_channels,
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
def forward(self, x, dummy_tensor=None):
|
| 259 |
+
|
| 260 |
+
x1 = super().forward(x)
|
| 261 |
+
if self.resample_do_res:
|
| 262 |
+
res = self.res_conv(x)
|
| 263 |
+
x1 = x1 + res
|
| 264 |
+
|
| 265 |
+
return x1
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class UnetResBlock(nn.Module):
|
| 269 |
+
"""
|
| 270 |
+
A skip-connection based module that can be used for DynUNet, based on:
|
| 271 |
+
`Automated Design of Deep Learning Methods for Biomedical Image Segmentation <https://arxiv.org/abs/1904.08128>`_.
|
| 272 |
+
`nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation <https://arxiv.org/abs/1809.10486>`_.
|
| 273 |
+
|
| 274 |
+
Args:
|
| 275 |
+
spatial_dims: number of spatial dimensions.
|
| 276 |
+
in_channels: number of input channels.
|
| 277 |
+
out_channels: number of output channels.
|
| 278 |
+
kernel_size: convolution kernel size.
|
| 279 |
+
stride: convolution stride.
|
| 280 |
+
norm_name: feature normalization type and arguments.
|
| 281 |
+
act_name: activation layer type and arguments.
|
| 282 |
+
dropout: dropout probability.
|
| 283 |
+
|
| 284 |
+
"""
|
| 285 |
+
|
| 286 |
+
def __init__(
|
| 287 |
+
self,
|
| 288 |
+
sparse: bool,
|
| 289 |
+
in_channels: int,
|
| 290 |
+
out_channels: int,
|
| 291 |
+
kernel_size: Union[Sequence[int], int],
|
| 292 |
+
stride: Union[Sequence[int], int],
|
| 293 |
+
):
|
| 294 |
+
super().__init__()
|
| 295 |
+
self.conv1 = nn.Conv3d(
|
| 296 |
+
in_channels,
|
| 297 |
+
out_channels,
|
| 298 |
+
kernel_size=kernel_size,
|
| 299 |
+
stride=stride,
|
| 300 |
+
padding=kernel_size // 2)
|
| 301 |
+
self.conv2 = nn.Conv3d(
|
| 302 |
+
out_channels,
|
| 303 |
+
out_channels,
|
| 304 |
+
kernel_size=kernel_size,
|
| 305 |
+
stride=1,
|
| 306 |
+
padding=kernel_size // 2,
|
| 307 |
+
)
|
| 308 |
+
self.lrelu = nn.LeakyReLU(inplace=True, negative_slope=0.01)
|
| 309 |
+
self.norm1 = SparseConvNeXtLayerNorm(normalized_shape=out_channels, data_format='channels_first', sparse=sparse)
|
| 310 |
+
self.norm2 = SparseConvNeXtLayerNorm(normalized_shape=out_channels, data_format='channels_first', sparse=sparse)
|
| 311 |
+
self.downsample = in_channels != out_channels
|
| 312 |
+
stride_np = np.atleast_1d(stride)
|
| 313 |
+
if not np.all(stride_np == 1):
|
| 314 |
+
self.downsample = True
|
| 315 |
+
if self.downsample:
|
| 316 |
+
self.conv3 = nn.Conv3d(
|
| 317 |
+
in_channels,
|
| 318 |
+
out_channels,
|
| 319 |
+
kernel_size=1,
|
| 320 |
+
stride=stride)
|
| 321 |
+
self.norm3 = SparseConvNeXtLayerNorm(normalized_shape=out_channels, data_format='channels_first', sparse=sparse)
|
| 322 |
+
|
| 323 |
+
def forward(self, inp):
|
| 324 |
+
residual = inp
|
| 325 |
+
out = self.conv1(inp)
|
| 326 |
+
out = self.norm1(out)
|
| 327 |
+
out = self.lrelu(out)
|
| 328 |
+
out = self.conv2(out)
|
| 329 |
+
out = self.norm2(out)
|
| 330 |
+
if hasattr(self, "conv3"):
|
| 331 |
+
residual = self.conv3(residual)
|
| 332 |
+
if hasattr(self, "norm3"):
|
| 333 |
+
residual = self.norm3(residual)
|
| 334 |
+
out += residual
|
| 335 |
+
out = self.lrelu(out)
|
| 336 |
+
return out
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
class MedNeXtUpBlock(MedNeXtBlock):
|
| 340 |
+
|
| 341 |
+
def __init__(self, in_channels, out_channels, exp_r=4, kernel_size=3,
|
| 342 |
+
do_res=True, sparse=False):
|
| 343 |
+
super().__init__(in_channels, out_channels, exp_r, kernel_size,
|
| 344 |
+
do_res=False, sparse=sparse)
|
| 345 |
+
|
| 346 |
+
self.resample_do_res = do_res
|
| 347 |
+
|
| 348 |
+
conv = nn.ConvTranspose3d
|
| 349 |
+
if do_res:
|
| 350 |
+
self.res_conv = conv(
|
| 351 |
+
in_channels=in_channels,
|
| 352 |
+
out_channels=out_channels,
|
| 353 |
+
kernel_size=1,
|
| 354 |
+
stride=2
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
self.conv1 = conv(
|
| 358 |
+
in_channels=in_channels,
|
| 359 |
+
out_channels=in_channels,
|
| 360 |
+
kernel_size=kernel_size,
|
| 361 |
+
stride=2,
|
| 362 |
+
padding=kernel_size // 2,
|
| 363 |
+
groups=in_channels,
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
def forward(self, x, dummy_tensor=None):
|
| 367 |
+
|
| 368 |
+
x1 = super().forward(x)
|
| 369 |
+
# Asymmetry but necessary to match shape
|
| 370 |
+
x1 = torch.nn.functional.pad(x1, (1, 0, 1, 0, 1, 0))
|
| 371 |
+
|
| 372 |
+
if self.resample_do_res:
|
| 373 |
+
res = self.res_conv(x)
|
| 374 |
+
res = torch.nn.functional.pad(res, (1, 0, 1, 0, 1, 0))
|
| 375 |
+
x1 = x1 + res
|
| 376 |
+
return x1
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
class UnetOutBlock(nn.Module):
|
| 380 |
+
def __init__(self, in_channels: int, n_classes: int):
|
| 381 |
+
super().__init__()
|
| 382 |
+
self.conv = nn.Conv3d(
|
| 383 |
+
in_channels,
|
| 384 |
+
n_classes,
|
| 385 |
+
kernel_size=1,
|
| 386 |
+
stride=1,
|
| 387 |
+
bias=True,
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
def forward(self, inp):
|
| 391 |
+
return self.conv(inp)
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
class Embeddings(nn.Module):
|
| 395 |
+
def __init__(self,
|
| 396 |
+
in_channel: int = 3,
|
| 397 |
+
channels: Tuple = (32, 64, 96, 128, 192),
|
| 398 |
+
depths: Tuple = (1, 1, 3, 1, 1),
|
| 399 |
+
kernels: Tuple = (3, 3, 3, 3, 3),
|
| 400 |
+
exp_r: Tuple = (2, 4, 4, 4, 2),
|
| 401 |
+
sparse=True):
|
| 402 |
+
super(Embeddings, self).__init__()
|
| 403 |
+
self.dim = [channels[1], channels[2], channels[3], channels[4], channels[4]]
|
| 404 |
+
self.stem = nn.Conv3d(in_channels=in_channel, out_channels=channels[0], kernel_size=3, stride=1, padding=1)
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
self.layer2 = nn.Sequential(*[
|
| 408 |
+
MedNeXtBlock(
|
| 409 |
+
in_channels=channels[1],
|
| 410 |
+
out_channels=channels[1],
|
| 411 |
+
exp_r=exp_r[1],
|
| 412 |
+
kernel_size=kernels[1],
|
| 413 |
+
do_res=True,
|
| 414 |
+
sparse=sparse
|
| 415 |
+
)
|
| 416 |
+
for i in range(depths[1])])
|
| 417 |
+
|
| 418 |
+
self.layer3 = nn.Sequential(*[
|
| 419 |
+
MedNeXtBlock(
|
| 420 |
+
in_channels=channels[2],
|
| 421 |
+
out_channels=channels[2],
|
| 422 |
+
exp_r=exp_r[2],
|
| 423 |
+
kernel_size=kernels[2],
|
| 424 |
+
do_res=True,
|
| 425 |
+
sparse=sparse
|
| 426 |
+
)
|
| 427 |
+
for i in range(depths[2])])
|
| 428 |
+
|
| 429 |
+
self.layer4 = nn.Sequential(*[
|
| 430 |
+
MedNeXtBlock(
|
| 431 |
+
in_channels=channels[3],
|
| 432 |
+
out_channels=channels[3],
|
| 433 |
+
exp_r=exp_r[3],
|
| 434 |
+
kernel_size=kernels[3],
|
| 435 |
+
do_res=True,
|
| 436 |
+
sparse=sparse
|
| 437 |
+
)
|
| 438 |
+
for i in range(depths[3])])
|
| 439 |
+
|
| 440 |
+
self.layer5 = nn.Sequential(*[
|
| 441 |
+
MedNeXtBlock(
|
| 442 |
+
in_channels=channels[4],
|
| 443 |
+
out_channels=channels[4],
|
| 444 |
+
exp_r=exp_r[4],
|
| 445 |
+
kernel_size=kernels[4],
|
| 446 |
+
do_res=True,
|
| 447 |
+
sparse=sparse
|
| 448 |
+
)
|
| 449 |
+
for i in range(depths[4])])
|
| 450 |
+
|
| 451 |
+
self.down = nn.MaxPool3d((2, 2, 2))
|
| 452 |
+
self.expend1 = nn.Conv3d(in_channels=channels[0], out_channels=channels[1], kernel_size=3, stride=1, padding=1)
|
| 453 |
+
self.expend2 = nn.Conv3d(in_channels=channels[1], out_channels=channels[2], kernel_size=3, stride=1, padding=1)
|
| 454 |
+
self.expend3 = nn.Conv3d(in_channels=channels[2], out_channels=channels[3], kernel_size=3, stride=1, padding=1)
|
| 455 |
+
self.expend4 = nn.Conv3d(in_channels=channels[3], out_channels=channels[4], kernel_size=3, stride=1, padding=1)
|
| 456 |
+
|
| 457 |
+
self.encoder1 = UnetResBlock(
|
| 458 |
+
in_channels=channels[1],
|
| 459 |
+
out_channels=channels[1],
|
| 460 |
+
kernel_size=3,
|
| 461 |
+
stride=1,
|
| 462 |
+
sparse=sparse
|
| 463 |
+
)
|
| 464 |
+
self.encoder2 = UnetResBlock(
|
| 465 |
+
in_channels=channels[2],
|
| 466 |
+
out_channels=channels[2],
|
| 467 |
+
kernel_size=3,
|
| 468 |
+
stride=1,
|
| 469 |
+
sparse=sparse
|
| 470 |
+
)
|
| 471 |
+
self.encoder3 = UnetResBlock(
|
| 472 |
+
in_channels=channels[3],
|
| 473 |
+
out_channels=channels[3],
|
| 474 |
+
kernel_size=3,
|
| 475 |
+
stride=1,
|
| 476 |
+
sparse=sparse
|
| 477 |
+
)
|
| 478 |
+
self.encoder4 = UnetResBlock(
|
| 479 |
+
in_channels=channels[4],
|
| 480 |
+
out_channels=channels[4],
|
| 481 |
+
kernel_size=3,
|
| 482 |
+
stride=1,
|
| 483 |
+
sparse=sparse
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
def forward(self, x):
|
| 489 |
+
x = self.stem(x)
|
| 490 |
+
|
| 491 |
+
x1 = self.expend1(x)
|
| 492 |
+
|
| 493 |
+
x = self.down(x1)
|
| 494 |
+
x = self.layer2(x)
|
| 495 |
+
x2 = self.expend2(x)
|
| 496 |
+
|
| 497 |
+
x = self.down(x2)
|
| 498 |
+
x = self.layer3(x)
|
| 499 |
+
x3 = self.expend3(x)
|
| 500 |
+
|
| 501 |
+
x = self.down(x3)
|
| 502 |
+
x = self.layer4(x)
|
| 503 |
+
x4 = self.expend4(x)
|
| 504 |
+
|
| 505 |
+
x = self.down(x4)
|
| 506 |
+
x5 = self.layer5(x)
|
| 507 |
+
|
| 508 |
+
return self.encoder1(x1), self.encoder2(x2), self.encoder3(x3), self.encoder4(x4), x5
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
class Encoder(nn.Module):
|
| 512 |
+
|
| 513 |
+
def __init__(self,
|
| 514 |
+
in_channel: int = 1,
|
| 515 |
+
channels=(32, 64, 128, 192, 384),
|
| 516 |
+
depths=(1, 2, 2, 2, 1),
|
| 517 |
+
kernels=(3, 3, 3, 3, 3),
|
| 518 |
+
exp_r=(2, 2, 4, 4, 4),
|
| 519 |
+
img_size=96,
|
| 520 |
+
depth=4,
|
| 521 |
+
norm_layer=nn.LayerNorm,
|
| 522 |
+
sparse=False):
|
| 523 |
+
super(Encoder, self).__init__()
|
| 524 |
+
self.dim = [channels[1], channels[2], channels[3], channels[4], channels[4]]
|
| 525 |
+
|
| 526 |
+
self.embeddings = Embeddings(in_channel=in_channel,
|
| 527 |
+
channels=channels,
|
| 528 |
+
depths=depths,
|
| 529 |
+
kernels=kernels,
|
| 530 |
+
exp_r=exp_r,
|
| 531 |
+
sparse=sparse)
|
| 532 |
+
|
| 533 |
+
self.mae = MaskedAutoencoderMamba(
|
| 534 |
+
img_size=img_size,
|
| 535 |
+
downsample_rato=self.get_downsample_ratio(),
|
| 536 |
+
embed_dim=channels[-1],
|
| 537 |
+
depth=depth,
|
| 538 |
+
norm_layer=norm_layer,
|
| 539 |
+
sparse=sparse)
|
| 540 |
+
|
| 541 |
+
def get_downsample_ratio(self) -> int:
|
| 542 |
+
"""
|
| 543 |
+
This func would ONLY be used in `SparseEncoder's __init__` (see `pretrain/encoder.py`).
|
| 544 |
+
|
| 545 |
+
:return: the TOTAL downsample ratio of the ConvNet.
|
| 546 |
+
E.g., for a ResNet-50, this should return 32.
|
| 547 |
+
"""
|
| 548 |
+
return 16
|
| 549 |
+
|
| 550 |
+
def get_feature_map_channels(self) -> List[int]:
|
| 551 |
+
"""
|
| 552 |
+
This func would ONLY be used in `SparseEncoder's __init__` (see `pretrain/encoder.py`).
|
| 553 |
+
|
| 554 |
+
:return: a list of the number of channels of each feature map.
|
| 555 |
+
E.g., for a ResNet-50, this should return [256, 512, 1024, 2048].
|
| 556 |
+
"""
|
| 557 |
+
return self.dim
|
| 558 |
+
|
| 559 |
+
def forward(self, x, active_b1fff=None):
|
| 560 |
+
x1, x2, x3, x4, x5 = self.embeddings(x)
|
| 561 |
+
_x5 = self.mae(x5, active_b1fff)
|
| 562 |
+
return x1, x2, x3, x4, _x5
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
class Decoder(nn.Module):
|
| 566 |
+
def __init__(self,
|
| 567 |
+
n_classes: int = 3,
|
| 568 |
+
channels: Tuple = (32, 64, 128, 196, 384),
|
| 569 |
+
norm_name = "instance",
|
| 570 |
+
res_block: bool = True):
|
| 571 |
+
super(Decoder, self).__init__()
|
| 572 |
+
|
| 573 |
+
self.decoder5 = UnetrUpBlock(
|
| 574 |
+
spatial_dims=3,
|
| 575 |
+
in_channels=channels[4],
|
| 576 |
+
out_channels=channels[4],
|
| 577 |
+
kernel_size=3,
|
| 578 |
+
upsample_kernel_size=2,
|
| 579 |
+
norm_name=norm_name,
|
| 580 |
+
res_block=res_block,
|
| 581 |
+
)
|
| 582 |
+
self.decoder4 = UnetrUpBlock(
|
| 583 |
+
spatial_dims=3,
|
| 584 |
+
in_channels=channels[4],
|
| 585 |
+
out_channels=channels[3],
|
| 586 |
+
kernel_size=3,
|
| 587 |
+
upsample_kernel_size=2,
|
| 588 |
+
norm_name=norm_name,
|
| 589 |
+
res_block=res_block,
|
| 590 |
+
)
|
| 591 |
+
self.decoder3 = UnetrUpBlock(
|
| 592 |
+
spatial_dims=3,
|
| 593 |
+
in_channels=channels[3],
|
| 594 |
+
out_channels=channels[2],
|
| 595 |
+
kernel_size=3,
|
| 596 |
+
upsample_kernel_size=2,
|
| 597 |
+
norm_name=norm_name,
|
| 598 |
+
res_block=res_block,
|
| 599 |
+
)
|
| 600 |
+
self.decoder2 = UnetrUpBlock(
|
| 601 |
+
spatial_dims=3,
|
| 602 |
+
in_channels=channels[2],
|
| 603 |
+
out_channels=channels[1],
|
| 604 |
+
kernel_size=3,
|
| 605 |
+
upsample_kernel_size=2,
|
| 606 |
+
norm_name=norm_name,
|
| 607 |
+
res_block=res_block,
|
| 608 |
+
)
|
| 609 |
+
self.decoder1 = UnetResBlock(
|
| 610 |
+
in_channels=channels[1],
|
| 611 |
+
out_channels=channels[0],
|
| 612 |
+
kernel_size=3,
|
| 613 |
+
stride=1,
|
| 614 |
+
sparse=False
|
| 615 |
+
)
|
| 616 |
+
self.out = UnetOutBlock(in_channels=channels[0], n_classes=n_classes)
|
| 617 |
+
|
| 618 |
+
def forward(self, x1, x2, x3, x4, x5):
|
| 619 |
+
d4 = self.decoder5(x5, x4)
|
| 620 |
+
d3 = self.decoder4(d4, x3)
|
| 621 |
+
d2 = self.decoder3(d3, x2)
|
| 622 |
+
d1 = self.decoder2(d2, x1)
|
| 623 |
+
d0 = self.decoder1(d1)
|
| 624 |
+
return self.out(d0)
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
class Hybird(nn.Module):
|
| 628 |
+
def __init__(self,
|
| 629 |
+
in_channel: int = 3,
|
| 630 |
+
n_classes: int = 3,
|
| 631 |
+
channels: Tuple = (32, 64, 96, 128, 192),
|
| 632 |
+
depths: Tuple = (1, 1, 3, 3, 1),
|
| 633 |
+
kernels: Tuple = (3, 3, 3, 3, 3),
|
| 634 |
+
exp_r: Tuple = (2, 4, 4, 4, 2),
|
| 635 |
+
img_size=96,
|
| 636 |
+
depth=3,
|
| 637 |
+
norm_layer=nn.LayerNorm, ):
|
| 638 |
+
super().__init__()
|
| 639 |
+
self.embeddings = Embeddings(in_channel=in_channel,
|
| 640 |
+
channels=channels,
|
| 641 |
+
depths=depths,
|
| 642 |
+
kernels=kernels,
|
| 643 |
+
exp_r=exp_r,
|
| 644 |
+
sparse=False)
|
| 645 |
+
|
| 646 |
+
self.mae = MaskedAutoencoderMamba(
|
| 647 |
+
img_size=img_size,
|
| 648 |
+
downsample_rato=16,
|
| 649 |
+
embed_dim=channels[-1],
|
| 650 |
+
depth=depth,
|
| 651 |
+
norm_layer=norm_layer,
|
| 652 |
+
sparse=False)
|
| 653 |
+
|
| 654 |
+
self.decoder = Decoder(
|
| 655 |
+
n_classes=n_classes,
|
| 656 |
+
channels=channels,
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
def forward(self, x):
|
| 660 |
+
x1, x2, x3, x4, x5 = self.embeddings(x)
|
| 661 |
+
x5 = self.mae(x5, None)
|
| 662 |
+
return self.decoder(x1, x2, x3, x4, x5)
|
| 663 |
+
|
| 664 |
+
|
| 665 |
+
def build_hybird(in_channel=1, n_classes=14, img_size=96):
|
| 666 |
+
return Hybird(in_channel=in_channel,
|
| 667 |
+
n_classes=n_classes,
|
| 668 |
+
channels=(32, 64, 128, 192, 384),
|
| 669 |
+
depths=(1, 2, 2, 2, 1),
|
| 670 |
+
kernels=(3, 3, 3, 3, 3),
|
| 671 |
+
exp_r=(2, 2, 4, 4, 4),
|
| 672 |
+
img_size=img_size,
|
| 673 |
+
depth=4)
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
if __name__ == '__main__':
|
| 677 |
+
x = torch.rand((1, 1, 96, 96, 96))
|
| 678 |
+
network = build_hybird()
|
| 679 |
+
print(network(x).shape)
|
| 680 |
+
|
| 681 |
+
|
utils/arg_util.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
from tap import Tap
|
| 6 |
+
|
| 7 |
+
import dist
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Args(Tap):
|
| 11 |
+
# environment
|
| 12 |
+
exp_name: str = 'mamba'
|
| 13 |
+
exp_dir: str = '' # will be created if not exists
|
| 14 |
+
data_path: str = ''
|
| 15 |
+
init_weight: str = '' # use some checkpoint as model weight initialization; ONLY load model weights
|
| 16 |
+
resume_from: str = '' # resume the experiment from some checkpoint.pth; load model weights, optimizer states, and last epoch
|
| 17 |
+
|
| 18 |
+
# MambaMIM hyperparameters
|
| 19 |
+
mask: float = 0.75 # mask ratio, should be in (0, 1)
|
| 20 |
+
|
| 21 |
+
# encoder hyperparameters
|
| 22 |
+
model: str = 'mambamim'
|
| 23 |
+
input_size: int = 96
|
| 24 |
+
sbn: bool = True
|
| 25 |
+
|
| 26 |
+
# data hyperparameters
|
| 27 |
+
bs: int = 1
|
| 28 |
+
dataloader_workers: int = 8
|
| 29 |
+
|
| 30 |
+
# pre-training hyperparameters
|
| 31 |
+
dp: float = 0.0
|
| 32 |
+
base_lr: float = 1e-4
|
| 33 |
+
wd: float = 0.04
|
| 34 |
+
wde: float = 0.2
|
| 35 |
+
ep: int = 100
|
| 36 |
+
wp_ep: int = 40
|
| 37 |
+
clip: int = 5.
|
| 38 |
+
opt: str = 'adamw'
|
| 39 |
+
ada: float = 0.
|
| 40 |
+
|
| 41 |
+
# NO NEED TO SPECIFIED; each of these args would be updated in runtime automatically
|
| 42 |
+
lr: float = 1e-4
|
| 43 |
+
batch_size_per_gpu: int = 0
|
| 44 |
+
glb_batch_size: int = 0
|
| 45 |
+
densify_norm: str = ''
|
| 46 |
+
device: str = 'gpu'
|
| 47 |
+
local_rank: int = 0
|
| 48 |
+
cmd: str = ' '.join(sys.argv[1:])
|
| 49 |
+
commit_id: str = os.popen(f'git rev-parse HEAD').read().strip() or '[unknown]'
|
| 50 |
+
commit_msg: str = (os.popen(f'git log -1').read().strip().splitlines() or ['[unknown]'])[-1].strip()
|
| 51 |
+
last_loss: float = 0.
|
| 52 |
+
cur_ep: str = ''
|
| 53 |
+
remain_time: str = ''
|
| 54 |
+
finish_time: str = ''
|
| 55 |
+
first_logging: bool = True
|
| 56 |
+
log_txt_name: str = '{args.exp_dir}/pretrain_log.txt'
|
| 57 |
+
tb_lg_dir: str = '' # tensorboard log directory
|
| 58 |
+
|
| 59 |
+
@property
|
| 60 |
+
def is_convnext(self):
|
| 61 |
+
return 'convnext' in self.model or 'cnx' in self.model
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def is_resnet(self):
|
| 65 |
+
return 'resnet' in self.model
|
| 66 |
+
|
| 67 |
+
def log_epoch(self):
|
| 68 |
+
if not dist.is_local_master():
|
| 69 |
+
return
|
| 70 |
+
|
| 71 |
+
if self.first_logging:
|
| 72 |
+
self.first_logging = False
|
| 73 |
+
with open(self.log_txt_name, 'w') as fp:
|
| 74 |
+
json.dump({
|
| 75 |
+
'name': self.exp_name, 'cmd': self.cmd, 'git_commit_id': self.commit_id, 'git_commit_msg': self.commit_msg,
|
| 76 |
+
'model': self.model,
|
| 77 |
+
}, fp)
|
| 78 |
+
fp.write('\n\n')
|
| 79 |
+
|
| 80 |
+
with open(self.log_txt_name, 'a') as fp:
|
| 81 |
+
json.dump({
|
| 82 |
+
'cur_ep': self.cur_ep,
|
| 83 |
+
'last_L': self.last_loss,
|
| 84 |
+
'rema': self.remain_time, 'fini': self.finish_time,
|
| 85 |
+
}, fp)
|
| 86 |
+
fp.write('\n')
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def init_dist_and_get_args():
|
| 90 |
+
from utils import misc
|
| 91 |
+
|
| 92 |
+
# initialize
|
| 93 |
+
args = Args(explicit_bool=True).parse_args()
|
| 94 |
+
e = os.path.abspath(args.exp_dir)
|
| 95 |
+
d, e = os.path.dirname(e), os.path.basename(e)
|
| 96 |
+
e = ''.join(ch if (ch.isalnum() or ch == '-') else '_' for ch in e)
|
| 97 |
+
args.exp_dir = os.path.join(d, e)
|
| 98 |
+
|
| 99 |
+
os.makedirs(args.exp_dir, exist_ok=True)
|
| 100 |
+
args.log_txt_name = os.path.join(args.exp_dir, 'pretrain_log.txt')
|
| 101 |
+
args.tb_lg_dir = args.tb_lg_dir or os.path.join(args.exp_dir, 'tensorboard_log')
|
| 102 |
+
try:
|
| 103 |
+
os.makedirs(args.tb_lg_dir, exist_ok=True)
|
| 104 |
+
except:
|
| 105 |
+
pass
|
| 106 |
+
|
| 107 |
+
misc.init_distributed_environ(exp_dir=args.exp_dir)
|
| 108 |
+
|
| 109 |
+
# update args
|
| 110 |
+
if not dist.initialized():
|
| 111 |
+
args.sbn = False
|
| 112 |
+
args.first_logging = True
|
| 113 |
+
args.device = dist.get_device()
|
| 114 |
+
args.batch_size_per_gpu = args.bs // dist.get_world_size()
|
| 115 |
+
args.glb_batch_size = args.batch_size_per_gpu * dist.get_world_size()
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
args.ada = args.ada or 0.999
|
| 119 |
+
args.densify_norm = 'ln'
|
| 120 |
+
|
| 121 |
+
args.opt = args.opt.lower()
|
| 122 |
+
args.lr = args.base_lr
|
| 123 |
+
args.wde = args.wde or args.wd
|
| 124 |
+
|
| 125 |
+
return args
|
utils/lamb.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" PyTorch Lamb optimizer w/ behaviour similar to NVIDIA FusedLamb
|
| 2 |
+
This optimizer code was adapted from the following (starting with latest)
|
| 3 |
+
* https://github.com/HabanaAI/Model-References/blob/2b435114fe8e31f159b1d3063b8280ae37af7423/PyTorch/nlp/bert/pretraining/lamb.py
|
| 4 |
+
* https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py
|
| 5 |
+
* https://github.com/cybertronai/pytorch-lamb
|
| 6 |
+
Use FusedLamb if you can (GPU). The reason for including this variant of Lamb is to have a version that is
|
| 7 |
+
similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install/use APEX.
|
| 8 |
+
In addition to some cleanup, this Lamb impl has been modified to support PyTorch XLA and has been tested on TPU.
|
| 9 |
+
Original copyrights for above sources are below.
|
| 10 |
+
Modifications Copyright 2021 Ross Wightman
|
| 11 |
+
"""
|
| 12 |
+
import math
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from torch.optim.optimizer import Optimizer
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class TheSameAsTimmLAMB(Optimizer):
|
| 19 |
+
"""Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB
|
| 20 |
+
reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py
|
| 21 |
+
|
| 22 |
+
LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
|
| 23 |
+
|
| 24 |
+
Arguments:
|
| 25 |
+
params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
|
| 26 |
+
lr (float, optional): learning rate. (default: 1e-3)
|
| 27 |
+
betas (Tuple[float, float], optional): coefficients used for computing
|
| 28 |
+
running averages of gradient and its norm. (default: (0.9, 0.999))
|
| 29 |
+
eps (float, optional): term added to the denominator to improve
|
| 30 |
+
numerical stability. (default: 1e-8)
|
| 31 |
+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
| 32 |
+
grad_averaging (bool, optional): whether apply (1-beta2) to grad when
|
| 33 |
+
calculating running averages of gradient. (default: True)
|
| 34 |
+
max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0)
|
| 35 |
+
trust_clip (bool): enable LAMBC trust ratio clipping (default: False)
|
| 36 |
+
always_adapt (boolean, optional): Apply adaptive learning rate to 0.0
|
| 37 |
+
weight decay parameter (default: False)
|
| 38 |
+
|
| 39 |
+
.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
|
| 40 |
+
https://arxiv.org/abs/1904.00962
|
| 41 |
+
.. _On the Convergence of Adam and Beyond:
|
| 42 |
+
https://openreview.net/forum?id=ryQu7f-RZ
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(
|
| 46 |
+
self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-6,
|
| 47 |
+
weight_decay=0.01, grad_averaging=True, max_grad_norm=2.0, trust_clip=False, always_adapt=False):
|
| 48 |
+
defaults = dict(
|
| 49 |
+
lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay,
|
| 50 |
+
grad_averaging=grad_averaging, max_grad_norm=max_grad_norm,
|
| 51 |
+
trust_clip=trust_clip, always_adapt=always_adapt)
|
| 52 |
+
super().__init__(params, defaults)
|
| 53 |
+
print(f'[lamb1] max_grad_norm={max_grad_norm}')
|
| 54 |
+
self.global_grad_norm = 0
|
| 55 |
+
|
| 56 |
+
@torch.no_grad()
|
| 57 |
+
def step(self, closure=None):
|
| 58 |
+
"""Performs a single optimization step.
|
| 59 |
+
Arguments:
|
| 60 |
+
closure (callable, optional): A closure that reevaluates the model
|
| 61 |
+
and returns the loss.
|
| 62 |
+
"""
|
| 63 |
+
loss = None
|
| 64 |
+
if closure is not None:
|
| 65 |
+
with torch.enable_grad():
|
| 66 |
+
loss = closure()
|
| 67 |
+
|
| 68 |
+
device = self.param_groups[0]['params'][0].device
|
| 69 |
+
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
|
| 70 |
+
global_grad_norm = torch.zeros(1, device=device)
|
| 71 |
+
for group in self.param_groups:
|
| 72 |
+
for p in group['params']:
|
| 73 |
+
if p.grad is None:
|
| 74 |
+
continue
|
| 75 |
+
grad = p.grad
|
| 76 |
+
if grad.is_sparse:
|
| 77 |
+
raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')
|
| 78 |
+
global_grad_norm.add_(grad.pow(2).sum())
|
| 79 |
+
|
| 80 |
+
global_grad_norm = torch.sqrt(global_grad_norm)
|
| 81 |
+
self.global_grad_norm = global_grad_norm.item()
|
| 82 |
+
max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], device=device)
|
| 83 |
+
clip_global_grad_norm = 1 / torch.where(
|
| 84 |
+
global_grad_norm > max_grad_norm,
|
| 85 |
+
global_grad_norm / max_grad_norm,
|
| 86 |
+
one_tensor)
|
| 87 |
+
|
| 88 |
+
for group in self.param_groups:
|
| 89 |
+
bias_correction = 1 if group['bias_correction'] else 0
|
| 90 |
+
beta1, beta2 = group['betas']
|
| 91 |
+
grad_averaging = 1 if group['grad_averaging'] else 0
|
| 92 |
+
beta3 = 1 - beta1 if grad_averaging else 1.0
|
| 93 |
+
|
| 94 |
+
# assume same step across group now to simplify things
|
| 95 |
+
# per parameter step can be easily support by making it tensor, or pass list into kernel
|
| 96 |
+
if 'step' in group:
|
| 97 |
+
group['step'] += 1
|
| 98 |
+
else:
|
| 99 |
+
group['step'] = 1
|
| 100 |
+
|
| 101 |
+
if bias_correction:
|
| 102 |
+
bias_correction1 = 1 - beta1 ** group['step']
|
| 103 |
+
bias_correction2 = 1 - beta2 ** group['step']
|
| 104 |
+
else:
|
| 105 |
+
bias_correction1, bias_correction2 = 1.0, 1.0
|
| 106 |
+
|
| 107 |
+
for p in group['params']:
|
| 108 |
+
if p.grad is None:
|
| 109 |
+
continue
|
| 110 |
+
grad = p.grad.mul_(clip_global_grad_norm)
|
| 111 |
+
state = self.state[p]
|
| 112 |
+
|
| 113 |
+
# State initialization
|
| 114 |
+
if len(state) == 0:
|
| 115 |
+
# Exponential moving average of gradient valuesa
|
| 116 |
+
state['exp_avg'] = torch.zeros_like(p)
|
| 117 |
+
# Exponential moving average of squared gradient values
|
| 118 |
+
state['exp_avg_sq'] = torch.zeros_like(p)
|
| 119 |
+
|
| 120 |
+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
| 121 |
+
|
| 122 |
+
# Decay the first and second moment running average coefficient
|
| 123 |
+
exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t
|
| 124 |
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # v_t
|
| 125 |
+
|
| 126 |
+
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
|
| 127 |
+
update = (exp_avg / bias_correction1).div_(denom)
|
| 128 |
+
|
| 129 |
+
weight_decay = group['weight_decay']
|
| 130 |
+
if weight_decay != 0:
|
| 131 |
+
update.add_(p, alpha=weight_decay)
|
| 132 |
+
|
| 133 |
+
if weight_decay != 0 or group['always_adapt']:
|
| 134 |
+
# Layer-wise LR adaptation. By default, skip adaptation on parameters that are
|
| 135 |
+
# excluded from weight decay, unless always_adapt == True, then always enabled.
|
| 136 |
+
w_norm = p.norm(2.0)
|
| 137 |
+
g_norm = update.norm(2.0)
|
| 138 |
+
# FIXME nested where required since logical and/or not working in PT XLA
|
| 139 |
+
trust_ratio = torch.where(
|
| 140 |
+
w_norm > 0,
|
| 141 |
+
torch.where(g_norm > 0, w_norm / g_norm, one_tensor),
|
| 142 |
+
one_tensor,
|
| 143 |
+
)
|
| 144 |
+
if group['trust_clip']:
|
| 145 |
+
# LAMBC trust clipping, upper bound fixed at one
|
| 146 |
+
trust_ratio = torch.minimum(trust_ratio, one_tensor)
|
| 147 |
+
update.mul_(trust_ratio)
|
| 148 |
+
|
| 149 |
+
p.add_(update, alpha=-group['lr'])
|
| 150 |
+
|
| 151 |
+
return loss
|
utils/lr_control.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from pprint import pformat
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def lr_wd_annealing(optimizer, peak_lr, wd, wd_end, cur_it, wp_it, max_it):
|
| 6 |
+
wp_it = round(wp_it)
|
| 7 |
+
if cur_it < wp_it:
|
| 8 |
+
cur_lr = 0.005 * peak_lr + 0.995 * peak_lr * cur_it / wp_it
|
| 9 |
+
else:
|
| 10 |
+
ratio = (cur_it - wp_it) / (max_it - 1 - wp_it)
|
| 11 |
+
cur_lr = 0.001 * peak_lr + 0.999 * peak_lr * (0.5 + 0.5 * math.cos(math.pi * ratio))
|
| 12 |
+
|
| 13 |
+
ratio = cur_it / (max_it - 1)
|
| 14 |
+
cur_wd = wd_end + (wd - wd_end) * (0.5 + 0.5 * math.cos(math.pi * ratio))
|
| 15 |
+
|
| 16 |
+
min_lr, max_lr = cur_lr, cur_lr
|
| 17 |
+
min_wd, max_wd = cur_wd, cur_wd
|
| 18 |
+
for param_group in optimizer.param_groups:
|
| 19 |
+
scaled_lr = param_group['lr'] = cur_lr * param_group.get('lr_scale', 1) # 'lr_scale' could be assigned
|
| 20 |
+
min_lr, max_lr = min(min_lr, scaled_lr), max(max_lr, scaled_lr)
|
| 21 |
+
scaled_wd = param_group['weight_decay'] = cur_wd * param_group.get('weight_decay_scale', 1) # 'weight_decay_scale' could be assigned
|
| 22 |
+
min_wd, max_wd = min(min_wd, scaled_wd), max(max_wd, scaled_wd)
|
| 23 |
+
return min_lr, max_lr, min_wd, max_wd
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_param_groups(model, nowd_keys=()):
|
| 27 |
+
para_groups, para_groups_dbg = {}, {}
|
| 28 |
+
|
| 29 |
+
for name, para in model.named_parameters():
|
| 30 |
+
if not para.requires_grad:
|
| 31 |
+
continue # frozen weights
|
| 32 |
+
if len(para.shape) == 1 or name.endswith('.bias') or any(k in name for k in nowd_keys):
|
| 33 |
+
wd_scale, group_name = 0., 'no_decay'
|
| 34 |
+
else:
|
| 35 |
+
wd_scale, group_name = 1., 'decay'
|
| 36 |
+
|
| 37 |
+
if group_name not in para_groups:
|
| 38 |
+
para_groups[group_name] = {'params': [], 'weight_decay_scale': wd_scale, 'lr_scale': 1.}
|
| 39 |
+
para_groups_dbg[group_name] = {'params': [], 'weight_decay_scale': wd_scale, 'lr_scale': 1.}
|
| 40 |
+
para_groups[group_name]['params'].append(para)
|
| 41 |
+
para_groups_dbg[group_name]['params'].append(name)
|
| 42 |
+
|
| 43 |
+
for g in para_groups_dbg.values():
|
| 44 |
+
g['params'] = pformat(', '.join(g['params']), width=200)
|
| 45 |
+
|
| 46 |
+
print(f'[get_ft_param_groups] param groups = \n{pformat(para_groups_dbg, indent=2, width=250)}\n')
|
| 47 |
+
return list(para_groups.values())
|
utils/med_dataset.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import os
|
| 3 |
+
from typing import Any, Callable, Optional, Tuple
|
| 4 |
+
from monai import data, transforms as med
|
| 5 |
+
from monai.data import load_decathlon_datalist
|
| 6 |
+
import PIL.Image as PImage
|
| 7 |
+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 8 |
+
from torchvision.datasets.folder import DatasetFolder, IMG_EXTENSIONS
|
| 9 |
+
from torchvision.transforms import transforms
|
| 10 |
+
from torch.utils.data import Dataset
|
| 11 |
+
import torch
|
| 12 |
+
import numpy as np
|
| 13 |
+
import cv2
|
| 14 |
+
try:
|
| 15 |
+
from torchvision.transforms import InterpolationMode
|
| 16 |
+
interpolation = InterpolationMode.BICUBIC
|
| 17 |
+
except:
|
| 18 |
+
import PIL
|
| 19 |
+
interpolation = PIL.Image.BICUBIC
|
| 20 |
+
from monai.transforms.transform import LazyTransform, MapTransform, RandomizableTransform
|
| 21 |
+
import random
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def pil_loader(path):
|
| 25 |
+
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
|
| 26 |
+
with open(path, 'rb') as f: img: PImage.Image = PImage.open(f).convert('RGB')
|
| 27 |
+
return img
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class ImageNetDataset(DatasetFolder):
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
imagenet_folder: str,
|
| 34 |
+
train: bool,
|
| 35 |
+
transform: Callable,
|
| 36 |
+
is_valid_file: Optional[Callable[[str], bool]] = None,
|
| 37 |
+
):
|
| 38 |
+
imagenet_folder = os.path.join(imagenet_folder, 'train' if train else 'val')
|
| 39 |
+
super(ImageNetDataset, self).__init__(
|
| 40 |
+
imagenet_folder,
|
| 41 |
+
loader=pil_loader,
|
| 42 |
+
extensions=IMG_EXTENSIONS if is_valid_file is None else None,
|
| 43 |
+
transform=transform,
|
| 44 |
+
target_transform=None, is_valid_file=is_valid_file
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
self.samples = tuple(img for (img, label) in self.samples)
|
| 48 |
+
self.targets = None # this is self-supervised learning so we don't need labels
|
| 49 |
+
|
| 50 |
+
def __getitem__(self, index: int) -> Any:
|
| 51 |
+
img_file_path = self.samples[index]
|
| 52 |
+
return self.transform(self.loader(img_file_path))
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def build_dataset_to_pretrain(dataset_path, input_size) -> Dataset:
|
| 56 |
+
"""
|
| 57 |
+
You may need to modify this function to return your own dataset.
|
| 58 |
+
Define a new class, a subclass of `Dataset`, to replace our ImageNetDataset.
|
| 59 |
+
Use dataset_path to build your image file path list.
|
| 60 |
+
Use input_size to create the transformation function for your images, can refer to the `trans_train` blow.
|
| 61 |
+
|
| 62 |
+
:param dataset_path: the folder of dataset
|
| 63 |
+
:param input_size: the input size (image resolution)
|
| 64 |
+
:return: the dataset used for pretraining
|
| 65 |
+
"""
|
| 66 |
+
trans_train = transforms.Compose([
|
| 67 |
+
transforms.RandomResizedCrop(input_size, scale=(0.67, 1.0), interpolation=interpolation),
|
| 68 |
+
transforms.RandomHorizontalFlip(),
|
| 69 |
+
transforms.ToTensor(),
|
| 70 |
+
transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
| 71 |
+
])
|
| 72 |
+
|
| 73 |
+
dataset_path = os.path.abspath(dataset_path)
|
| 74 |
+
for postfix in ('train', 'val'):
|
| 75 |
+
if dataset_path.endswith(postfix):
|
| 76 |
+
dataset_path = dataset_path[:-len(postfix)]
|
| 77 |
+
|
| 78 |
+
dataset_train = ImageNetDataset(imagenet_folder=dataset_path, transform=trans_train, train=True)
|
| 79 |
+
print_transform(trans_train, '[pre-train]')
|
| 80 |
+
return dataset_train
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def build_meddataset_to_pretrain(dataset_path, input_size) -> Dataset:
|
| 84 |
+
"""
|
| 85 |
+
You may need to modify this function to return your own dataset.
|
| 86 |
+
Define a new class, a subclass of `Dataset`, to replace our ImageNetDataset.
|
| 87 |
+
Use dataset_path to build your image file path list.
|
| 88 |
+
Use input_size to create the transformation function for your images, can refer to the `trans_train` blow.
|
| 89 |
+
|
| 90 |
+
:param dataset_path: the folder of dataset
|
| 91 |
+
:param input_size: the input size (image resolution)
|
| 92 |
+
:return: the dataset used for pretraining
|
| 93 |
+
"""
|
| 94 |
+
trans_train = transforms.Compose([
|
| 95 |
+
transforms.RandomResizedCrop(input_size, scale=(0.67, 1.0), interpolation=interpolation),
|
| 96 |
+
transforms.RandomHorizontalFlip(),
|
| 97 |
+
transforms.ToTensor(),
|
| 98 |
+
transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
| 99 |
+
])
|
| 100 |
+
|
| 101 |
+
dataset_path = os.path.abspath(dataset_path)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
dataset_train = MedicalDataSets(base_dir=dataset_path, transform=trans_train)
|
| 105 |
+
print_transform(trans_train, '[pre-train]')
|
| 106 |
+
return dataset_train
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class MedicalDataSets(Dataset):
|
| 111 |
+
def __init__(
|
| 112 |
+
self,
|
| 113 |
+
base_dir=None,
|
| 114 |
+
transform=None,
|
| 115 |
+
):
|
| 116 |
+
self._base_dir = base_dir
|
| 117 |
+
self.sample_list = []
|
| 118 |
+
self.sample_list = os.listdir(self._base_dir)
|
| 119 |
+
self.transform = transform
|
| 120 |
+
print("total {}".format(len(self.sample_list)))
|
| 121 |
+
|
| 122 |
+
def __len__(self):
|
| 123 |
+
return len(self.sample_list)
|
| 124 |
+
|
| 125 |
+
def __getitem__(self, idx):
|
| 126 |
+
case = self.sample_list[idx]
|
| 127 |
+
img = PImage.open(os.path.join(self._base_dir, case)).convert('RGB')
|
| 128 |
+
aug = self.transform(img)
|
| 129 |
+
return aug
|
| 130 |
+
|
| 131 |
+
def print_transform(transform, s):
|
| 132 |
+
print(f'Transform {s} = ')
|
| 133 |
+
for t in transform.transforms:
|
| 134 |
+
print(t)
|
| 135 |
+
print('---------------------------\n')
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class Sampler(torch.utils.data.Sampler):
|
| 139 |
+
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, make_even=True):
|
| 140 |
+
if num_replicas is None:
|
| 141 |
+
if not torch.distributed.is_available():
|
| 142 |
+
raise RuntimeError("Requires distributed package to be available")
|
| 143 |
+
num_replicas = torch.distributed.get_world_size()
|
| 144 |
+
if rank is None:
|
| 145 |
+
if not torch.distributed.is_available():
|
| 146 |
+
raise RuntimeError("Requires distributed package to be available")
|
| 147 |
+
rank = torch.distributed.get_rank()
|
| 148 |
+
self.shuffle = shuffle
|
| 149 |
+
self.make_even = make_even
|
| 150 |
+
self.dataset = dataset
|
| 151 |
+
self.num_replicas = num_replicas
|
| 152 |
+
self.rank = rank
|
| 153 |
+
self.epoch = 0
|
| 154 |
+
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
|
| 155 |
+
self.total_size = self.num_samples * self.num_replicas
|
| 156 |
+
indices = list(range(len(self.dataset)))
|
| 157 |
+
self.valid_length = len(indices[self.rank : self.total_size : self.num_replicas])
|
| 158 |
+
|
| 159 |
+
def __iter__(self):
|
| 160 |
+
if self.shuffle:
|
| 161 |
+
g = torch.Generator()
|
| 162 |
+
g.manual_seed(self.epoch)
|
| 163 |
+
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
| 164 |
+
else:
|
| 165 |
+
indices = list(range(len(self.dataset)))
|
| 166 |
+
if self.make_even:
|
| 167 |
+
if len(indices) < self.total_size:
|
| 168 |
+
if self.total_size - len(indices) < len(indices):
|
| 169 |
+
indices += indices[: (self.total_size - len(indices))]
|
| 170 |
+
else:
|
| 171 |
+
extra_ids = np.random.randint(low=0, high=len(indices), size=self.total_size - len(indices))
|
| 172 |
+
indices += [indices[ids] for ids in extra_ids]
|
| 173 |
+
assert len(indices) == self.total_size
|
| 174 |
+
indices = indices[self.rank : self.total_size : self.num_replicas]
|
| 175 |
+
self.num_samples = len(indices)
|
| 176 |
+
return iter(indices)
|
| 177 |
+
|
| 178 |
+
def __len__(self):
|
| 179 |
+
return self.num_samples
|
| 180 |
+
|
| 181 |
+
def set_epoch(self, epoch):
|
| 182 |
+
self.epoch = epoch
|
| 183 |
+
|
| 184 |
+
class RandScaleCropdPlusScaleByMidDimSampled(MapTransform):
|
| 185 |
+
def __init__(self, keys, mode='area', max_size=128,allow_missing_keys=False,num_samples=4,max_radio=0.8,min_radio=0.5):
|
| 186 |
+
self.keys = keys
|
| 187 |
+
self.mode = mode
|
| 188 |
+
self.allow_missing_keys = allow_missing_keys
|
| 189 |
+
self.max_size=max_size
|
| 190 |
+
self.num_samples = num_samples
|
| 191 |
+
self.max_radio=max_radio
|
| 192 |
+
self.min_radio=min_radio
|
| 193 |
+
|
| 194 |
+
def __call__(self, data):
|
| 195 |
+
outputs = []
|
| 196 |
+
for i in range(self.num_samples):
|
| 197 |
+
random_number = round(random.uniform(self.min_radio, self.max_radio), 2)
|
| 198 |
+
_data = dict(data)
|
| 199 |
+
for key in self.keys:
|
| 200 |
+
cropper= med.RandScaleCropd(keys=[key],roi_scale=random_number)
|
| 201 |
+
_data[key] = cropper(_data)[key]
|
| 202 |
+
ct_tensor = _data[key]
|
| 203 |
+
sorted_numbers = sorted(ct_tensor.shape[1:])
|
| 204 |
+
scale_factor = self.max_size / sorted_numbers[1]
|
| 205 |
+
new_size = [int(d * scale_factor)
|
| 206 |
+
for d in ct_tensor.shape[1:]]
|
| 207 |
+
|
| 208 |
+
resizer = med.Resized(keys=[key],
|
| 209 |
+
spatial_size=new_size,
|
| 210 |
+
mode=self.mode,
|
| 211 |
+
allow_missing_keys=self.allow_missing_keys)
|
| 212 |
+
_data[key] = resizer(_data)[key]
|
| 213 |
+
|
| 214 |
+
outputs.append(_data)
|
| 215 |
+
|
| 216 |
+
return outputs
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def get_loader(data_dir, size):
|
| 222 |
+
datalist_json = os.path.join(data_dir, "dataset.json")
|
| 223 |
+
train_transform = med.Compose(
|
| 224 |
+
[
|
| 225 |
+
med.LoadImaged(keys=["image"], allow_missing_keys=True),
|
| 226 |
+
med.AddChanneld(keys=["image"], allow_missing_keys=True),
|
| 227 |
+
med.Orientationd(keys=["image"], axcodes="RAS", allow_missing_keys=True),
|
| 228 |
+
med.Spacingd(keys=["image"], pixdim=(1.5, 1.5, 1.5), mode="bilinear", allow_missing_keys=True),
|
| 229 |
+
med.ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
|
| 230 |
+
med.CropForegroundd(keys=["image"], source_key="image", allow_missing_keys=True),
|
| 231 |
+
med.SpatialPadd(keys=["image"], spatial_size=(size, size, size), mode='constant'),
|
| 232 |
+
med.RandCropByPosNegLabeld(
|
| 233 |
+
spatial_size=(size, size, size),
|
| 234 |
+
keys=["image"],
|
| 235 |
+
label_key="image",
|
| 236 |
+
pos=1,
|
| 237 |
+
neg=0,
|
| 238 |
+
num_samples=4,
|
| 239 |
+
),
|
| 240 |
+
med.RandFlipd(keys=["image"],
|
| 241 |
+
prob=0.2,
|
| 242 |
+
spatial_axis=0),
|
| 243 |
+
med.RandFlipd(keys=["image"],
|
| 244 |
+
prob=0.2,
|
| 245 |
+
spatial_axis=1),
|
| 246 |
+
med.RandFlipd(keys=["image"],
|
| 247 |
+
prob=0.1,
|
| 248 |
+
spatial_axis=2),
|
| 249 |
+
med.ToTensord(keys=["image"]),
|
| 250 |
+
])
|
| 251 |
+
# val_transform = transforms.Compose(
|
| 252 |
+
# [
|
| 253 |
+
# transforms.LoadImaged(keys=["image", "label"]),
|
| 254 |
+
# transforms.AddChanneld(keys=["image", "label"]),
|
| 255 |
+
# transforms.Orientationd(keys=["image", "label"], axcodes="RAS"),
|
| 256 |
+
# transforms.Spacingd(
|
| 257 |
+
# keys=["image", "label"], pixdim=(1, 1, 1), mode=("bilinear", "nearest")
|
| 258 |
+
# ),
|
| 259 |
+
# transforms.ScaleIntensityRanged(
|
| 260 |
+
# keys=["image"], a_min=-175.0, a_max=250.0, b_min=0.0, b_max=1.0, clip=True
|
| 261 |
+
# ),
|
| 262 |
+
# transforms.CropForegroundd(keys=["image", "label"], source_key="image"),
|
| 263 |
+
# transforms.ToTensord(keys=["image", "label"]),
|
| 264 |
+
# ]
|
| 265 |
+
# )
|
| 266 |
+
|
| 267 |
+
datalist = load_decathlon_datalist(datalist_json, True, "training", base_dir=data_dir)
|
| 268 |
+
# train_ds = data.Dataset(data=datalist, transform=train_transform)
|
| 269 |
+
# train_ds = data.CacheDataset(data=datalist, transform=train_transform)
|
| 270 |
+
# train_ds = data.SmartCacheDataset(data=datalist, transform=train_transform, replace_rate=0.7, cache_num=256, num_init_workers=4, num_replace_workers=4)
|
| 271 |
+
train_ds= data.CacheNTransDataset(data=datalist, transform=train_transform, cache_n_trans=6, cache_dir="/fenghetang/3d/pretrain/MM/cache_dataset")
|
| 272 |
+
return train_ds
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
|
utils/misc.py
ADDED
|
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datetime
|
| 2 |
+
import functools
|
| 3 |
+
import os
|
| 4 |
+
import subprocess
|
| 5 |
+
import sys
|
| 6 |
+
import time
|
| 7 |
+
from collections import defaultdict, deque
|
| 8 |
+
from typing import Iterator
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pytz
|
| 12 |
+
import torch
|
| 13 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 14 |
+
|
| 15 |
+
import dist
|
| 16 |
+
|
| 17 |
+
os_system = functools.partial(subprocess.call, shell=True)
|
| 18 |
+
os_system_get_stdout = lambda cmd: subprocess.run(cmd, shell=True, stdout=subprocess.PIPE).stdout.decode('utf-8')
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def os_system_get_stdout_stderr(cmd):
|
| 22 |
+
sp = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
| 23 |
+
return sp.stdout.decode('utf-8'), sp.stderr.decode('utf-8')
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def is_pow2n(x):
|
| 27 |
+
return x > 0 and ((x - 1) & x == 0)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def time_str(for_dirname=False):
|
| 31 |
+
return datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime(
|
| 32 |
+
'%m-%d_%H-%M-%S' if for_dirname else '[%m-%d %H:%M:%S]')
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def init_distributed_environ(exp_dir):
|
| 36 |
+
dist.initialize()
|
| 37 |
+
dist.barrier()
|
| 38 |
+
|
| 39 |
+
import torch.backends.cudnn as cudnn
|
| 40 |
+
cudnn.benchmark = True
|
| 41 |
+
cudnn.deterministic = False
|
| 42 |
+
|
| 43 |
+
_set_print_only_on_master_proc(is_master=dist.is_local_master())
|
| 44 |
+
if dist.is_local_master() and len(exp_dir):
|
| 45 |
+
sys.stdout, sys.stderr = _SyncPrintToFile(exp_dir, stdout=True), _SyncPrintToFile(exp_dir, stdout=False)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _set_print_only_on_master_proc(is_master):
|
| 49 |
+
import builtins as __builtin__
|
| 50 |
+
|
| 51 |
+
builtin_print = __builtin__.print
|
| 52 |
+
|
| 53 |
+
def prt(msg, *args, **kwargs):
|
| 54 |
+
force = kwargs.pop('force', False)
|
| 55 |
+
clean = kwargs.pop('clean', False)
|
| 56 |
+
deeper = kwargs.pop('deeper', False)
|
| 57 |
+
if is_master or force:
|
| 58 |
+
if not clean:
|
| 59 |
+
f_back = sys._getframe().f_back
|
| 60 |
+
if deeper and f_back.f_back is not None:
|
| 61 |
+
f_back = f_back.f_back
|
| 62 |
+
file_desc = f'{f_back.f_code.co_filename:24s}'[-24:]
|
| 63 |
+
msg = f'{time_str()} ({file_desc}, line{f_back.f_lineno:-4d})=> {msg}'
|
| 64 |
+
builtin_print(msg, *args, **kwargs)
|
| 65 |
+
|
| 66 |
+
__builtin__.print = prt
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class _SyncPrintToFile(object):
|
| 70 |
+
def __init__(self, exp_dir, stdout=True):
|
| 71 |
+
self.terminal = sys.stdout if stdout else sys.stderr
|
| 72 |
+
fname = os.path.join(exp_dir, 'stdout_backup.txt' if stdout else 'stderr_backup.txt')
|
| 73 |
+
self.log = open(fname, 'w')
|
| 74 |
+
self.log.flush()
|
| 75 |
+
|
| 76 |
+
def write(self, message):
|
| 77 |
+
self.terminal.write(message)
|
| 78 |
+
self.log.write(message)
|
| 79 |
+
self.log.flush()
|
| 80 |
+
|
| 81 |
+
def flush(self):
|
| 82 |
+
self.terminal.flush()
|
| 83 |
+
self.log.flush()
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class TensorboardLogger(object):
|
| 87 |
+
def __init__(self, log_dir, is_master, prefix='pt'):
|
| 88 |
+
self.is_master = is_master
|
| 89 |
+
self.writer = SummaryWriter(log_dir=log_dir) if self.is_master else None
|
| 90 |
+
self.step = 0
|
| 91 |
+
self.prefix = prefix
|
| 92 |
+
self.log_freq = 300
|
| 93 |
+
|
| 94 |
+
def set_step(self, step=None):
|
| 95 |
+
if step is not None:
|
| 96 |
+
self.step = step
|
| 97 |
+
else:
|
| 98 |
+
self.step += 1
|
| 99 |
+
|
| 100 |
+
def get_loggable(self, step=None):
|
| 101 |
+
if step is None: # iter wise
|
| 102 |
+
step = self.step
|
| 103 |
+
loggable = step % self.log_freq == 0
|
| 104 |
+
else: # epoch wise
|
| 105 |
+
loggable = True
|
| 106 |
+
return step, (loggable and self.is_master)
|
| 107 |
+
|
| 108 |
+
def update(self, head='scalar', step=None, **kwargs):
|
| 109 |
+
step, loggable = self.get_loggable(step)
|
| 110 |
+
if loggable:
|
| 111 |
+
head = f'{self.prefix}_{head}'
|
| 112 |
+
for k, v in kwargs.items():
|
| 113 |
+
if v is None:
|
| 114 |
+
continue
|
| 115 |
+
if isinstance(v, torch.Tensor):
|
| 116 |
+
v = v.item()
|
| 117 |
+
assert isinstance(v, (float, int))
|
| 118 |
+
self.writer.add_scalar(head + "/" + k, v, step)
|
| 119 |
+
|
| 120 |
+
def log_distribution(self, tag, values, step=None):
|
| 121 |
+
step, loggable = self.get_loggable(step)
|
| 122 |
+
if loggable:
|
| 123 |
+
if not isinstance(values, torch.Tensor):
|
| 124 |
+
values = torch.tensor(values)
|
| 125 |
+
self.writer.add_histogram(tag=tag, values=values, global_step=step)
|
| 126 |
+
|
| 127 |
+
def log_image(self, tag, img, step=None, dataformats='NCHW'):
|
| 128 |
+
step, loggable = self.get_loggable(step)
|
| 129 |
+
if loggable:
|
| 130 |
+
# img = img.cpu().numpy()
|
| 131 |
+
self.writer.add_image(tag, img, step, dataformats=dataformats)
|
| 132 |
+
|
| 133 |
+
def flush(self):
|
| 134 |
+
if self.is_master: self.writer.flush()
|
| 135 |
+
|
| 136 |
+
def close(self):
|
| 137 |
+
if self.is_master: self.writer.close()
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def save_checkpoint_with_meta_info_and_opt_state(save_to, args, epoch, performance_desc, model_without_ddp_state,
|
| 141 |
+
optimizer_state):
|
| 142 |
+
checkpoint_path = os.path.join(args.exp_dir, save_to)
|
| 143 |
+
if dist.is_local_master():
|
| 144 |
+
to_save = {
|
| 145 |
+
'args': str(args),
|
| 146 |
+
'input_size': args.input_size,
|
| 147 |
+
'arch': args.model,
|
| 148 |
+
'epoch': epoch,
|
| 149 |
+
'performance_desc': performance_desc,
|
| 150 |
+
'module': model_without_ddp_state,
|
| 151 |
+
'optimizer': optimizer_state,
|
| 152 |
+
'is_pretrain': True,
|
| 153 |
+
}
|
| 154 |
+
torch.save(to_save, checkpoint_path)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def save_checkpoint_model_weights_only(save_to, args, sp_cnn_state):
|
| 158 |
+
checkpoint_path = os.path.join(args.exp_dir, save_to)
|
| 159 |
+
if dist.is_local_master():
|
| 160 |
+
torch.save(sp_cnn_state, checkpoint_path)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def initialize_weight(init_weight: str, model_without_ddp):
|
| 164 |
+
# use some checkpoint as model weight initialization; ONLY load model weights
|
| 165 |
+
if len(init_weight):
|
| 166 |
+
checkpoint = torch.load(init_weight, 'cpu')
|
| 167 |
+
missing, unexpected = model_without_ddp.load_state_dict(checkpoint.get('module', checkpoint), strict=False)
|
| 168 |
+
print(f'[initialize_weight] missing_keys={missing}')
|
| 169 |
+
print(f'[initialize_weight] unexpected_keys={unexpected}')
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def load_checkpoint(resume_from: str, model_without_ddp, optimizer):
|
| 173 |
+
# resume the experiment from some checkpoint.pth; load model weights, optimizer states, and last epoch
|
| 174 |
+
if len(resume_from) == 0:
|
| 175 |
+
return 0, '[no performance_desc]'
|
| 176 |
+
print(f'[try to resume from file `{resume_from}`]')
|
| 177 |
+
checkpoint = torch.load(resume_from, map_location='cpu')
|
| 178 |
+
|
| 179 |
+
ep_start, performance_desc = checkpoint.get('epoch', -1) + 1, checkpoint.get('performance_desc',
|
| 180 |
+
'[no performance_desc]')
|
| 181 |
+
missing, unexpected = model_without_ddp.load_state_dict(checkpoint.get('module', checkpoint), strict=False)
|
| 182 |
+
print(f'[load_checkpoint] missing_keys={missing}')
|
| 183 |
+
print(f'[load_checkpoint] unexpected_keys={unexpected}')
|
| 184 |
+
print(f'[load_checkpoint] ep_start={ep_start}, performance_desc={performance_desc}')
|
| 185 |
+
|
| 186 |
+
if 'optimizer' in checkpoint:
|
| 187 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
| 188 |
+
return ep_start, performance_desc
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class SmoothedValue(object):
|
| 192 |
+
"""Track a series of values and provide access to smoothed values over a
|
| 193 |
+
window or the global series average.
|
| 194 |
+
"""
|
| 195 |
+
|
| 196 |
+
def __init__(self, window_size=20, fmt=None):
|
| 197 |
+
if fmt is None:
|
| 198 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
| 199 |
+
self.deque = deque(maxlen=window_size)
|
| 200 |
+
self.total = 0.0
|
| 201 |
+
self.count = 0
|
| 202 |
+
self.fmt = fmt
|
| 203 |
+
|
| 204 |
+
def update(self, value, n=1):
|
| 205 |
+
self.deque.append(value)
|
| 206 |
+
self.count += n
|
| 207 |
+
self.total += value * n
|
| 208 |
+
|
| 209 |
+
def synchronize_between_processes(self):
|
| 210 |
+
"""
|
| 211 |
+
Warning: does not synchronize the deque!
|
| 212 |
+
"""
|
| 213 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
| 214 |
+
dist.barrier()
|
| 215 |
+
dist.allreduce(t)
|
| 216 |
+
t = t.tolist()
|
| 217 |
+
self.count = int(t[0])
|
| 218 |
+
self.total = t[1]
|
| 219 |
+
|
| 220 |
+
@property
|
| 221 |
+
def median(self):
|
| 222 |
+
d = torch.tensor(list(self.deque))
|
| 223 |
+
return d.median().item()
|
| 224 |
+
|
| 225 |
+
@property
|
| 226 |
+
def avg(self):
|
| 227 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
| 228 |
+
return d.mean().item()
|
| 229 |
+
|
| 230 |
+
@property
|
| 231 |
+
def global_avg(self):
|
| 232 |
+
return self.total / self.count
|
| 233 |
+
|
| 234 |
+
@property
|
| 235 |
+
def max(self):
|
| 236 |
+
return max(self.deque)
|
| 237 |
+
|
| 238 |
+
@property
|
| 239 |
+
def value(self):
|
| 240 |
+
return self.deque[-1]
|
| 241 |
+
|
| 242 |
+
def __str__(self):
|
| 243 |
+
return self.fmt.format(
|
| 244 |
+
median=self.median,
|
| 245 |
+
avg=self.avg,
|
| 246 |
+
global_avg=self.global_avg,
|
| 247 |
+
max=self.max,
|
| 248 |
+
value=self.value)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class MetricLogger(object):
|
| 252 |
+
def __init__(self, delimiter="\t"):
|
| 253 |
+
self.meters = defaultdict(SmoothedValue)
|
| 254 |
+
self.delimiter = delimiter
|
| 255 |
+
|
| 256 |
+
def update(self, **kwargs):
|
| 257 |
+
for k, v in kwargs.items():
|
| 258 |
+
if v is None:
|
| 259 |
+
continue
|
| 260 |
+
if isinstance(v, torch.Tensor):
|
| 261 |
+
v = v.item()
|
| 262 |
+
assert isinstance(v, (float, int))
|
| 263 |
+
self.meters[k].update(v)
|
| 264 |
+
|
| 265 |
+
def __getattr__(self, attr):
|
| 266 |
+
if attr in self.meters:
|
| 267 |
+
return self.meters[attr]
|
| 268 |
+
if attr in self.__dict__:
|
| 269 |
+
return self.__dict__[attr]
|
| 270 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
| 271 |
+
type(self).__name__, attr))
|
| 272 |
+
|
| 273 |
+
def __str__(self):
|
| 274 |
+
loss_str = []
|
| 275 |
+
for name, meter in self.meters.items():
|
| 276 |
+
loss_str.append(
|
| 277 |
+
"{}: {}".format(name, str(meter))
|
| 278 |
+
)
|
| 279 |
+
return self.delimiter.join(loss_str)
|
| 280 |
+
|
| 281 |
+
def synchronize_between_processes(self):
|
| 282 |
+
for meter in self.meters.values():
|
| 283 |
+
meter.synchronize_between_processes()
|
| 284 |
+
|
| 285 |
+
def add_meter(self, name, meter):
|
| 286 |
+
self.meters[name] = meter
|
| 287 |
+
|
| 288 |
+
def log_every(self, max_iters, itrt, print_freq, header=None):
|
| 289 |
+
print_iters = set(np.linspace(0, max_iters - 1, print_freq, dtype=int).tolist())
|
| 290 |
+
if not header:
|
| 291 |
+
header = ''
|
| 292 |
+
start_time = time.time()
|
| 293 |
+
end = time.time()
|
| 294 |
+
self.iter_time = SmoothedValue(fmt='{avg:.4f}')
|
| 295 |
+
self.data_time = SmoothedValue(fmt='{avg:.4f}')
|
| 296 |
+
space_fmt = ':' + str(len(str(max_iters))) + 'd'
|
| 297 |
+
log_msg = [
|
| 298 |
+
header,
|
| 299 |
+
'[{0' + space_fmt + '}/{1}]',
|
| 300 |
+
'eta: {eta}',
|
| 301 |
+
'{meters}',
|
| 302 |
+
'iter: {time}s',
|
| 303 |
+
'data: {data}s'
|
| 304 |
+
]
|
| 305 |
+
log_msg = self.delimiter.join(log_msg)
|
| 306 |
+
|
| 307 |
+
if isinstance(itrt, Iterator) and not hasattr(itrt, 'preload') and not hasattr(itrt, 'set_epoch'):
|
| 308 |
+
for i in range(max_iters):
|
| 309 |
+
obj = next(itrt)
|
| 310 |
+
self.data_time.update(time.time() - end)
|
| 311 |
+
yield obj
|
| 312 |
+
self.iter_time.update(time.time() - end)
|
| 313 |
+
if i in print_iters:
|
| 314 |
+
eta_seconds = self.iter_time.global_avg * (max_iters - i)
|
| 315 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
| 316 |
+
print(log_msg.format(
|
| 317 |
+
i, max_iters, eta=eta_string,
|
| 318 |
+
meters=str(self),
|
| 319 |
+
time=str(self.iter_time), data=str(self.data_time)))
|
| 320 |
+
end = time.time()
|
| 321 |
+
else:
|
| 322 |
+
for i, obj in enumerate(itrt):
|
| 323 |
+
self.data_time.update(time.time() - end)
|
| 324 |
+
yield obj
|
| 325 |
+
self.iter_time.update(time.time() - end)
|
| 326 |
+
if i in print_iters:
|
| 327 |
+
eta_seconds = self.iter_time.global_avg * (max_iters - i)
|
| 328 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
| 329 |
+
print(log_msg.format(
|
| 330 |
+
i, max_iters, eta=eta_string,
|
| 331 |
+
meters=str(self),
|
| 332 |
+
time=str(self.iter_time), data=str(self.data_time)))
|
| 333 |
+
end = time.time()
|
| 334 |
+
|
| 335 |
+
total_time = time.time() - start_time
|
| 336 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 337 |
+
print('{} Total time: {} ({:.3f} s / it)'.format(
|
| 338 |
+
header, total_time_str, total_time / max_iters))
|
utils/sampler.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from torch.utils.data.sampler import Sampler
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def worker_init_fn(worker_id):
|
| 9 |
+
# https://pytorch.org/docs/stable/notes/randomness.html#dataloader
|
| 10 |
+
worker_seed = torch.initial_seed() % 2 ** 32
|
| 11 |
+
np.random.seed(worker_seed)
|
| 12 |
+
random.seed(worker_seed)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class DistInfiniteBatchSampler(Sampler):
|
| 16 |
+
def __init__(self, world_size, rank, dataset_len, glb_batch_size, seed=1, filling=False, shuffle=True):
|
| 17 |
+
assert glb_batch_size % world_size == 0
|
| 18 |
+
self.world_size, self.rank = world_size, rank
|
| 19 |
+
self.dataset_len = dataset_len
|
| 20 |
+
self.glb_batch_size = glb_batch_size
|
| 21 |
+
self.batch_size = glb_batch_size // world_size
|
| 22 |
+
|
| 23 |
+
self.iters_per_ep = (dataset_len + glb_batch_size - 1) // glb_batch_size
|
| 24 |
+
self.filling = filling
|
| 25 |
+
self.shuffle = shuffle
|
| 26 |
+
self.epoch = 0
|
| 27 |
+
self.seed = seed
|
| 28 |
+
self.indices = self.gener_indices()
|
| 29 |
+
|
| 30 |
+
def gener_indices(self):
|
| 31 |
+
global_max_p = self.iters_per_ep * self.glb_batch_size # global_max_p % world_size must be 0 cuz glb_batch_size % world_size == 0
|
| 32 |
+
if self.shuffle:
|
| 33 |
+
g = torch.Generator()
|
| 34 |
+
g.manual_seed(self.epoch + self.seed)
|
| 35 |
+
global_indices = torch.randperm(self.dataset_len, generator=g)
|
| 36 |
+
else:
|
| 37 |
+
global_indices = torch.arange(self.dataset_len)
|
| 38 |
+
filling = global_max_p - global_indices.shape[0]
|
| 39 |
+
if filling > 0 and self.filling:
|
| 40 |
+
global_indices = torch.cat((global_indices, global_indices[:filling]))
|
| 41 |
+
global_indices = tuple(global_indices.numpy().tolist())
|
| 42 |
+
|
| 43 |
+
seps = torch.linspace(0, len(global_indices), self.world_size + 1, dtype=torch.int)
|
| 44 |
+
local_indices = global_indices[seps[self.rank]:seps[self.rank + 1]]
|
| 45 |
+
self.max_p = len(local_indices)
|
| 46 |
+
return local_indices
|
| 47 |
+
|
| 48 |
+
def __iter__(self):
|
| 49 |
+
self.epoch = 0
|
| 50 |
+
while True:
|
| 51 |
+
self.epoch += 1
|
| 52 |
+
p, q = 0, 0
|
| 53 |
+
while p < self.max_p:
|
| 54 |
+
q = p + self.batch_size
|
| 55 |
+
yield self.indices[p:q]
|
| 56 |
+
p = q
|
| 57 |
+
if self.shuffle:
|
| 58 |
+
self.indices = self.gener_indices()
|
| 59 |
+
|
| 60 |
+
def __len__(self):
|
| 61 |
+
return self.iters_per_ep
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
if __name__ == '__main__':
|
| 65 |
+
W = 16
|
| 66 |
+
for rk in range(W):
|
| 67 |
+
ind = DistInfiniteBatchSampler(W, rk, 5024, 5024).gener_indices()
|
| 68 |
+
print(rk, len(ind))
|