FengheTan9 commited on
Commit
6da2a44
·
verified ·
1 Parent(s): f94d70c

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ img/TOKI.png filter=lfs diff=lfs merge=lfs -text
37
+ img/masking_consistency.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,202 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ ## [MIA'25] MambaMIM: Pre-training Mamba with State Space Token Interpolation and its Application to Medical Image Segmentation
4
+
5
+ <p align="center" width="100%">
6
+ <!---->
7
+ </p>
8
+
9
+ ![MambaMIM](img/TOKI.png)
10
+
11
+
12
+
13
+ <div align="center">
14
+ <span class="author-block">
15
+ <a href="https://scholar.google.com/citations?user=x1pODsMAAAAJ&hl=en" target="_blank">Fenghe Tang</a><sup>1,2</sup>,</span>
16
+ <span class="author-block">
17
+ <a target="_blank">Bingkun Nian</a><sup>3</sup>,</span>
18
+ <span class="author-block">
19
+ <a href="https://scholar.google.com/citations?user=ocAtNkkAAAAJ&hl=en" target="_blank">Yingtai Li</a><sup>1,2</sup>,</span>
20
+ <span class="author-block">
21
+ <a href="https://scholar.google.com/citations?user=Wo8tMSMAAAAJ&hl=en" target="_blank">Zihang Jiang</a><sup>1,2</sup>,</span>
22
+ <span class="author-block">
23
+ <a href="https://scholar.google.com/citations?user=tmx7tu8AAAAJ&hl=en" target="_blank">Jie Yang</a><sup>3</sup>,</span>
24
+ <span class="author-block">
25
+ <a href="https://scholar.google.com/citations?user=Vbb5EGIAAAAJ&hl=en" target="_blank"> Liu Wei</a><sup>3</sup>,</span>
26
+ <span class="author-block">
27
+ <a href="https://scholar.google.com/citations?user=8eNm2GMAAAAJ&hl=en" target="_blank">S.Kevin Zhou</a><sup>1,2</sup>
28
+ </span>
29
+ </div>
30
+
31
+
32
+ <br>
33
+
34
+ <div align="center">
35
+ <sup>1</sup>
36
+ <a href='https://en.ustc.edu.cn/' target='_blank'>School of Biomedical Engineering, University of Science and Technology of China</a>&emsp;
37
+ <br>
38
+ <sup>2</sup> <a href='http://english.ict.cas.cn/' target='_blank'>Suzhou Institute for Advanced Research, University of Science and Technology of China</a>&emsp;
39
+ <br>
40
+ <sup>3</sup> <a href='http://www.pami.sjtu.edu.cn/En/Home' target='_blank'>Department of Automation, Institute of Image Processing and Pattern Recognition, Shanghai Jiao Tong University</a>
41
+ <br>
42
+ </div>
43
+
44
+ <br>
45
+ <br>
46
+
47
+ ​ [![arXiv](https://img.shields.io/badge/arxiv-2408.08070-b31b1b)](https://arxiv.org/pdf/2408.08070.pdf) [![github](https://img.shields.io/badge/github-MambaMIM-purple)](https://github.com/FengheTan9/MambaMIM) <a href="#LICENSE--citation"><img alt="License: Apache2.0" src="https://img.shields.io/badge/LICENSE-Apache%202.0-blue.svg"/></a>
48
+
49
+
50
+
51
+ ## News
52
+
53
+ - **MambaMIM accepted by Medical Image Analyses (MIA'25) ! 🥰**
54
+ - **Weights released ! 😎**
55
+ - **Code released !** 😘
56
+ - **Code and weights will be released soon !** 😘
57
+ - **[2024/08/16] Paper released !**
58
+
59
+
60
+
61
+ ## TODOs
62
+
63
+ - [x] Paper released
64
+ - [x] Code released
65
+ - [x] Weight released
66
+
67
+
68
+
69
+ ## Getting Started
70
+
71
+ ### Download weights
72
+
73
+ | Name | Resolution | Intensities | Spacing | Weights |
74
+ | :------: | :----------: | :-----------: | :----------------: | :----------------------------------------------------------: |
75
+ | MambaMIM | 96 x 96 x 96 | [-175, - 250] | 1.5 x 1.5 x 1.5 mm | [Google Drive (87MB)](https://drive.google.com/file/d/1B3j5aRPxkDJqf8UPGKDiAjg2X85a3Kwx/view?usp=sharing) |
76
+
77
+
78
+
79
+ ### Prepare Environments
80
+
81
+ ```
82
+ conda create -n mambamim python=3.9
83
+ conda activate mambamim
84
+ pip install torch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu117
85
+ pip install packaging timm==0.5.4
86
+ pip install transformers==4.34.1 typed-argument-parser
87
+ pip install numpy==1.21.2 opencv-python==4.5.5.64 opencv-python-headless==4.5.5.64
88
+ pip install 'monai[all]'
89
+ pip install monai==1.2.0
90
+ pip install causal_conv1d-1.2.0.post2+cu118torch1.13cxx11abiTRUE-cp38-cp38-linux_x86_64.whl
91
+ pip install mamba_ssm-1.2.0.post1+cu118torch1.13cxx11abiFALSE-cp38-cp38-linux_x86_64.whl
92
+ ```
93
+
94
+
95
+
96
+ ### Prepare Datasets
97
+
98
+ We recommend that you convert the dataset into the [nnUNet](https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/dataset_format.md) format.
99
+
100
+ ```
101
+ └── MambaMIM
102
+ ├── data
103
+ ├── Dataset060_TotalSegmentator
104
+ └── imagesTr
105
+ ├── xxx_0000.nii.gz
106
+ ├── ...
107
+ ├── Dataset006_FLARE2022
108
+ └── imagesTr
109
+ ├── xxx_0000.nii.gz
110
+ ├── ...
111
+ └── Other_dataset
112
+ └── imagesTr
113
+ ├── xxx_0000.nii.gz
114
+ ├── ...
115
+ ```
116
+
117
+ An example ```dataset.json``` will be generated in ```./data```
118
+
119
+ The content should be like below:
120
+
121
+ ```json
122
+ {
123
+ "training": [
124
+ {
125
+ "image": "./Dataset060_TotalSegmentator/imagesTr/xxx_0000.nii.gz"
126
+ },
127
+ {
128
+ "image": "./Dataset006_FLARE2022/imagesTr/xxx_0000.nii.gz"
129
+ },
130
+ ]
131
+ }
132
+
133
+ ```
134
+
135
+
136
+
137
+ ## Start Training
138
+
139
+ ![MambaMIM](img/masking_consistency.png)
140
+
141
+
142
+
143
+ Run training on multi-GPU :
144
+
145
+ ```sh
146
+ # An example of training on 4 GPUs with DDP
147
+ torchrun --nproc_per_node=4 --nnodes=1 --node_rank=0 --master_addr=localhost --master_port=12351 main.py --exp_name=debug --data_path=./data --model=mambamim --bs=16 --exp_dir=debug_mambamim_ddp_4
148
+ ```
149
+
150
+ Run training on the single-GPU :
151
+
152
+ ```sh
153
+ # An example of training on the single GPU
154
+ python main.py --exp_name=debug --data_path=./data --model=mambamim --bs=4 --exp_dir=debug_mambamim
155
+ ```
156
+
157
+
158
+
159
+ ## Fine-tuning
160
+
161
+ Load pre-training weights :
162
+
163
+ ```python
164
+ # An example of Fine-tuning on BTCV (num_classes=14)
165
+ from models.network.hymamba import build_hybird
166
+
167
+ model = build_hybird(in_channel=1, n_classes=14, img_size=96).cuda()
168
+
169
+ model_dict = torch.load("mambamim_mask75.pth")
170
+
171
+ if model.load_state_dict(model_dict, strict=False):
172
+ print("MambaMIM use pretrained weights successfully !")
173
+ ```
174
+
175
+ Downstream pipeline can be referred to [UNETR]([research-contributions/UNETR/BTCV at main · Project-MONAI/research-contributions (github.com)](https://github.com/Project-MONAI/research-contributions/tree/main/UNETR/BTCV)).
176
+
177
+
178
+
179
+ ## Acknowledgements:
180
+
181
+ This code uses helper functions from [SparK](https://github.com/keyu-tian/SparK) and [HySparK](https://github.com/FengheTan9/HySparK).
182
+
183
+
184
+
185
+ ## Citation
186
+
187
+ If the code, paper and weights help your research, please cite:
188
+
189
+ ```
190
+ @article{tang2024mambamim,
191
+ title={MambaMIM: Pre-training Mamba with State Space Token-interpolation},
192
+ author={Tang, Fenghe and Nian, Bingkun and Li, Yingtai and Yang, Jie and Wei, Liu and Zhou, S Kevin},
193
+ journal={arXiv preprint arXiv:2408.08070},
194
+ year={2024}
195
+ }
196
+ ```
197
+
198
+
199
+
200
+ ## License
201
+
202
+ This project is released under the Apache 2.0 license. Please see the [LICENSE](LICENSE) file for more information.
dist.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+ from typing import Union
4
+
5
+ import sys
6
+ import torch
7
+ import torch.distributed as tdist
8
+ import torch.multiprocessing as mp
9
+
10
+ __rank, __local_rank, __world_size, __device = 0, 0, 1, 'cpu'
11
+ __initialized = False
12
+
13
+
14
+ def initialized():
15
+ return __initialized
16
+
17
+
18
+ def initialize(backend='nccl'):
19
+ global __device
20
+ if not torch.cuda.is_available():
21
+ print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr)
22
+ return
23
+ elif 'RANK' not in os.environ:
24
+ __device = torch.empty(1).cuda().device
25
+ print(f'[dist initialize] RANK is not set, use 1 GPU instead', file=sys.stderr)
26
+ return
27
+
28
+ # ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29
29
+ if mp.get_start_method(allow_none=True) is None:
30
+ mp.set_start_method('spawn')
31
+ global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count()
32
+ local_rank = global_rank % num_gpus
33
+ torch.cuda.set_device(local_rank)
34
+ tdist.init_process_group(backend=backend)
35
+
36
+ global __rank, __local_rank, __world_size, __initialized
37
+ __local_rank = local_rank
38
+ __rank, __world_size = tdist.get_rank(), tdist.get_world_size()
39
+ __device = torch.empty(1).cuda().device
40
+ __initialized = True
41
+
42
+ assert tdist.is_initialized(), 'torch.distributed is not initialized!'
43
+
44
+
45
+ def get_rank():
46
+ return __rank
47
+
48
+
49
+ def get_local_rank():
50
+ return __local_rank
51
+
52
+
53
+ def get_world_size():
54
+ return __world_size
55
+
56
+
57
+ def get_device():
58
+ return __device
59
+
60
+
61
+ def is_master():
62
+ return __rank == 0
63
+
64
+
65
+ def is_local_master():
66
+ return __local_rank == 0
67
+
68
+
69
+ def barrier():
70
+ if __initialized:
71
+ tdist.barrier()
72
+
73
+
74
+ def parallelize(net, syncbn=False):
75
+ if syncbn:
76
+ net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
77
+ net = net.cuda()
78
+ net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[get_local_rank()], find_unused_parameters=False, broadcast_buffers=False)
79
+ return net
80
+
81
+
82
+ def allreduce(t: torch.Tensor) -> None:
83
+ if __initialized:
84
+ if not t.is_cuda:
85
+ cu = t.detach().cuda()
86
+ tdist.all_reduce(cu)
87
+ t.copy_(cu.cpu())
88
+ else:
89
+ tdist.all_reduce(t)
90
+
91
+
92
+ def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
93
+ if __initialized:
94
+ if not t.is_cuda:
95
+ t = t.cuda()
96
+ ls = [torch.empty_like(t) for _ in range(__world_size)]
97
+ tdist.all_gather(ls, t)
98
+ else:
99
+ ls = [t]
100
+ if cat:
101
+ ls = torch.cat(ls, dim=0)
102
+ return ls
103
+
104
+
105
+ def broadcast(t: torch.Tensor, src_rank) -> None:
106
+ if __initialized:
107
+ if not t.is_cuda:
108
+ cu = t.detach().cuda()
109
+ tdist.broadcast(cu, src=src_rank)
110
+ t.copy_(cu.cpu())
111
+ else:
112
+ tdist.broadcast(t, src=src_rank)
img/TOKI.png ADDED

Git LFS Details

  • SHA256: 4a3efce3120f63be89e2e17a1dbc80b794e34bb7f5277e0983961356c2fe91e1
  • Pointer size: 132 Bytes
  • Size of remote file: 1.69 MB
img/masking_consistency.png ADDED

Git LFS Details

  • SHA256: a4abcc4d8218afdab7cec2d353b3f1caba6e6ab85aa899b3f5a79ff668edf0e6
  • Pointer size: 132 Bytes
  • Size of remote file: 1.62 MB
main.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import math
3
+ import sys
4
+ import time
5
+ import logging
6
+ import os
7
+ import torch
8
+ from torch.nn.parallel import DistributedDataParallel
9
+ from torch.utils.data import DataLoader
10
+
11
+ import dist
12
+ from models.encoder import SparseEncoder
13
+ from models.decoder import LightDecoder
14
+ from models.MambaMIM import MambaMIM
15
+ from models import build_sparse_encoder
16
+ from utils.sampler import DistInfiniteBatchSampler, worker_init_fn
17
+ from utils import arg_util, misc
18
+ from utils.med_dataset import get_loader
19
+ from utils.lr_control import lr_wd_annealing
20
+
21
+
22
+ cpu_num = 1
23
+ os.environ['OMP_NUM_THREADS'] = str(cpu_num)
24
+ os.environ['OPENBLAS_NUM_THREADS'] = str(cpu_num)
25
+ os.environ['MKL_NUM_THREADS'] = str(cpu_num)
26
+ os.environ['VECLIB_MAXIMUM_THREADS'] = str(cpu_num)
27
+ os.environ['NUMEXPR_NUM_THREADS'] = str(cpu_num)
28
+ torch.set_num_threads(cpu_num)
29
+ torch.multiprocessing.set_sharing_strategy('file_system')
30
+
31
+
32
+ class LocalDDP(torch.nn.Module):
33
+ def __init__(self, module):
34
+ super(LocalDDP, self).__init__()
35
+ self.module = module
36
+
37
+ def forward(self, *args, **kwargs):
38
+ return self.module(*args, **kwargs)
39
+
40
+
41
+ def main_pt():
42
+ args: arg_util.Args = arg_util.init_dist_and_get_args()
43
+ print(f'initial args:\n{str(args)}')
44
+ args.log_epoch()
45
+
46
+ # build data
47
+ print(f'[build data for pre-training] ...\n')
48
+ dataset_train = get_loader(args.data_path, args.input_size)
49
+ data_loader_train = DataLoader(
50
+ dataset=dataset_train, num_workers=args.dataloader_workers, pin_memory=True,
51
+ batch_sampler=DistInfiniteBatchSampler(
52
+ dataset_len=len(dataset_train), glb_batch_size=args.glb_batch_size,
53
+ shuffle=True, filling=True, rank=dist.get_rank(), world_size=dist.get_world_size(),
54
+ ), worker_init_fn=worker_init_fn
55
+ )
56
+
57
+ itrt_train, iters_train = iter(data_loader_train), len(data_loader_train)
58
+ print(f'[dataloader] gbs={args.glb_batch_size}, lbs={args.batch_size_per_gpu}, iters_train={iters_train}')
59
+
60
+ # build encoder and decoder
61
+ enc: SparseEncoder = build_sparse_encoder(args.model, input_size=args.input_size, sbn=args.sbn, drop_path_rate=args.dp, verbose=False)
62
+ dec = LightDecoder(enc.downsample_raito, sbn=args.sbn)
63
+ model_without_ddp = MambaMIM(
64
+ sparse_encoder=enc, dense_decoder=dec, mask_ratio=args.mask,
65
+ densify_norm=args.densify_norm, sbn=args.sbn,
66
+ ).to(args.device)
67
+ print(f'[PT model] model = {model_without_ddp}\n')
68
+
69
+ # the model has been randomly initialized in their construction time
70
+ # now try to load some checkpoint as model weight initialization; this ONLY loads the model weights
71
+
72
+ model = LocalDDP(model_without_ddp)
73
+
74
+ # build optimizer and lr_scheduler
75
+ optimizer = torch.optim.AdamW(params=model_without_ddp.parameters(), lr=args.lr, weight_decay=1e-5)
76
+
77
+ # try to resume the experiment from some checkpoint.pth; this will load model weights, optimizer states, and last epoch (ep_start)
78
+ # if loaded, ep_start will be greater than 0
79
+ ep_start, performance_desc = misc.load_checkpoint(args.resume_from, model_without_ddp, optimizer)
80
+ if ep_start >= args.ep: # load from a complete checkpoint file
81
+ print(f' [*] [PT already done] Min/Last Recon Loss: {performance_desc}')
82
+ else: # perform pre-training
83
+ tb_lg = misc.TensorboardLogger(args.tb_lg_dir, is_master=dist.is_master(), prefix='pt')
84
+ min_loss = 1e9
85
+ print(f'[PT start] from ep{ep_start}')
86
+
87
+ pt_start_time = time.time()
88
+ for ep in range(ep_start, args.ep):
89
+ ep_start_time = time.time()
90
+ tb_lg.set_step(ep * iters_train)
91
+ if hasattr(itrt_train, 'set_epoch'):
92
+ itrt_train.set_epoch(ep)
93
+
94
+ stats = pre_train_one_ep(ep, args, tb_lg, itrt_train, iters_train, model, optimizer)
95
+ last_loss = stats['last_loss']
96
+ min_loss = min(min_loss, last_loss)
97
+ performance_desc = f'{min_loss:.4f} {last_loss:.4f}'
98
+ misc.save_checkpoint_with_meta_info_and_opt_state(f'{args.model}_withdecoder_ct_pretrained.pth', args, ep, performance_desc, model_without_ddp.state_dict(with_config=True), optimizer.state_dict())
99
+ misc.save_checkpoint_model_weights_only(f'{args.model}_ct_pretrained_mambamim_timm_style.pth', args, model_without_ddp.sparse_encoder.state_dict())
100
+
101
+ ep_cost = round(time.time() - ep_start_time, 2) + 1 # +1s: approximate the following logging cost
102
+ remain_secs = (args.ep-1 - ep) * ep_cost
103
+ remain_time = datetime.timedelta(seconds=round(remain_secs))
104
+ finish_time = time.strftime("%m-%d %H:%M", time.localtime(time.time() + remain_secs))
105
+ print(f' [*] [ep{ep}/{args.ep}] Min/Last Recon Loss: {performance_desc}, Cost: {ep_cost}s, Remain: {remain_time}, Finish @ {finish_time}')
106
+
107
+ args.cur_ep = f'{ep + 1}/{args.ep}'
108
+ args.remain_time, args.finish_time = str(remain_time), str(finish_time)
109
+ args.last_loss = last_loss
110
+ args.log_epoch()
111
+
112
+ tb_lg.update(min_loss=min_loss, head='train', step=ep)
113
+ tb_lg.update(rest_hours=round(remain_secs/60/60, 2), head='z_burnout', step=ep)
114
+ tb_lg.flush()
115
+
116
+ # finish pre-training
117
+ tb_lg.update(min_loss=min_loss, head='result', step=ep_start)
118
+ tb_lg.update(min_loss=min_loss, head='result', step=args.ep)
119
+ tb_lg.flush()
120
+ print(f'final args:\n{str(args)}')
121
+ print('\n\n')
122
+ print(f' [*] [PT finished] Min/Last Recon Loss: {performance_desc}, Total Cost: {(time.time() - pt_start_time) / 60 / 60:.1f}h\n')
123
+ print('\n\n')
124
+ tb_lg.close()
125
+ time.sleep(10)
126
+
127
+ args.remain_time, args.finish_time = '-', time.strftime("%m-%d %H:%M", time.localtime(time.time()))
128
+ args.log_epoch()
129
+
130
+
131
+ def pre_train_one_ep(ep, args: arg_util.Args, tb_lg: misc.TensorboardLogger, itrt_train, iters_train, model: DistributedDataParallel, optimizer):
132
+ model.train()
133
+ me = misc.MetricLogger(delimiter=' ')
134
+ me.add_meter('max_lr', misc.SmoothedValue(window_size=1, fmt='{value:.5f}'))
135
+ header = f'[PT] Epoch {ep}:'
136
+
137
+ optimizer.zero_grad()
138
+ early_clipping = args.clip > 0 and not hasattr(optimizer, 'global_grad_norm')
139
+ late_clipping = hasattr(optimizer, 'global_grad_norm')
140
+ if early_clipping:
141
+ params_req_grad = [p for p in model.parameters() if p.requires_grad]
142
+
143
+ for it, inp in enumerate(me.log_every(iters_train, itrt_train, 3, header)):
144
+ # adjust lr and wd
145
+ min_lr, max_lr, min_wd, max_wd = lr_wd_annealing(optimizer, args.lr, args.wd, args.wde, it + ep * iters_train, args.wp_ep * iters_train, args.ep * iters_train)
146
+
147
+ # forward and backward
148
+ # print(inp)
149
+
150
+ temp = []
151
+ for crop_per_batch in inp:
152
+ temp.append(crop_per_batch["image"])
153
+ inp = torch.cat(temp, dim=0)
154
+
155
+
156
+ inp = inp.to(args.device, non_blocking=True)
157
+ MambaSparK.forward
158
+ loss = model(inp, active_b1fff=None, vis=False)
159
+ optimizer.zero_grad()
160
+ loss.backward()
161
+ loss = loss.item()
162
+ if not math.isfinite(loss):
163
+ print(f'[rk{dist.get_rank():02d}] Loss is {loss}, stopping training!', force=True, flush=True)
164
+ sys.exit(-1)
165
+
166
+ # optimize
167
+ grad_norm = None
168
+ if early_clipping: grad_norm = torch.nn.utils.clip_grad_norm_(params_req_grad, args.clip).item()
169
+ optimizer.step()
170
+ if late_clipping: grad_norm = optimizer.global_grad_norm
171
+ torch.cuda.synchronize()
172
+
173
+ # log
174
+ me.update(last_loss=loss)
175
+ me.update(max_lr=max_lr)
176
+ tb_lg.update(loss=me.meters['last_loss'].global_avg, head='train_loss')
177
+ tb_lg.update(sche_lr=max_lr, head='train_hp/lr_max')
178
+ tb_lg.update(sche_lr=min_lr, head='train_hp/lr_min')
179
+ tb_lg.update(sche_wd=max_wd, head='train_hp/wd_max')
180
+ tb_lg.update(sche_wd=min_wd, head='train_hp/wd_min')
181
+
182
+ if grad_norm is not None:
183
+ me.update(orig_norm=grad_norm)
184
+ tb_lg.update(orig_norm=grad_norm, head='train_hp')
185
+ tb_lg.set_step()
186
+
187
+ me.synchronize_between_processes()
188
+ return {k: meter.global_avg for k, meter in me.meters.items()}
189
+
190
+
191
+ if __name__ == '__main__':
192
+ main_pt()
mambamim_mask75.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fd92f6cfdd2aff93f8942536f333bca7eb612b4238153c9b5accbacd9e4e1989
3
+ size 90976893
models/MambaMIM.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pprint import pformat
2
+ from typing import List
3
+
4
+ import sys
5
+ import torch
6
+ import torch.nn as nn
7
+ from timm.models.layers import trunc_normal_
8
+
9
+ import models.encoder as encoder
10
+ from models.decoder import LightDecoder
11
+ from itertools import accumulate
12
+
13
+ class MambaMIM(nn.Module):
14
+ def __init__(
15
+ self, sparse_encoder: encoder.SparseEncoder, dense_decoder: LightDecoder,
16
+ mask_ratio=0.6, densify_norm='ln', sbn=True,
17
+ ):
18
+ super().__init__()
19
+ input_size, downsample_raito = sparse_encoder.input_size, sparse_encoder.downsample_raito
20
+ self.downsample_raito = downsample_raito
21
+ self.fmap_h, self.fmap_w, self.fmap_d = input_size // downsample_raito, input_size // downsample_raito, input_size // downsample_raito
22
+ self.mask_ratio = mask_ratio
23
+ self.len_keep = round(self.fmap_h * self.fmap_w * self.fmap_d * (1 - mask_ratio))
24
+
25
+ self.sparse_encoder = sparse_encoder
26
+ self.dense_decoder = dense_decoder
27
+
28
+ self.sbn = sbn
29
+ self.hierarchy = len(sparse_encoder.enc_feat_map_chs)
30
+ self.densify_norm_str = densify_norm.lower()
31
+ self.densify_norms = nn.ModuleList()
32
+ self.densify_projs = nn.ModuleList()
33
+ self.mask_tokens = nn.ParameterList()
34
+
35
+ # build the `densify` layers
36
+ e_widths, d_width = self.sparse_encoder.enc_feat_map_chs, self.dense_decoder.width
37
+ e_widths: List[int]
38
+ self.A_interpolation = nn.Parameter(torch.zeros(1, self.sparse_encoder.enc_feat_map_chs[-1], self.sparse_encoder.enc_feat_map_chs[-1]))
39
+ print("self.A_interpolation: ", self.A_interpolation.shape)
40
+ for i in range(
41
+ self.hierarchy): # from the smallest feat map to the largest; i=0: the last feat map; i=1: the second last feat map ...
42
+ e_width = e_widths.pop()
43
+ # create mask token
44
+ p = nn.Parameter(torch.zeros(1, e_width, 1, 1, 1))
45
+ trunc_normal_(p, mean=0, std=.02, a=-.02, b=.02)
46
+ self.mask_tokens.append(p)
47
+
48
+ # create densify norm
49
+ densify_norm = nn.Identity()
50
+ self.densify_norms.append(densify_norm)
51
+
52
+ # create densify proj
53
+ if i == 0 and e_width == d_width:
54
+ densify_proj = nn.Identity() # todo: NOTE THAT CONVNEXT-S WOULD USE THIS, because it has a width of 768 that equals to the decoder's width 768
55
+ print(f'[MambaMIM.__init__, densify {i + 1}/{self.hierarchy}]: use nn.Identity() as densify_proj')
56
+ else:
57
+ kernel_size = 1 if i <= 0 else 3
58
+ densify_proj = nn.Conv3d(e_width, d_width, kernel_size=kernel_size, stride=1, padding=kernel_size // 2,
59
+ bias=True)
60
+ print(
61
+ f'[MambaMIM.__init__, densify {i + 1}/{self.hierarchy}]: densify_proj(ksz={kernel_size}, #para={sum(x.numel() for x in densify_proj.parameters()) / 1e6:.2f}M)')
62
+ self.densify_projs.append(densify_proj)
63
+
64
+ # todo: the decoder's width follows a simple halfing rule; you can change it to any other rule
65
+ d_width //= 2
66
+
67
+ print(f'[MambaMIM.__init__] dims of mask_tokens={tuple(p.numel() for p in self.mask_tokens)}')
68
+
69
+ def mask(self, B: int, device, generator=None):
70
+ h, w, d = self.fmap_h, self.fmap_w, self.fmap_d
71
+ idx = torch.rand(B, h * w * d, generator=generator).argsort(dim=1)
72
+ idx = idx[:, :self.len_keep].to(device) # (B, len_keep)
73
+ return torch.zeros(B, h * w * d, dtype=torch.bool, device=device)\
74
+ .scatter_(dim=1, index=idx, value=True).view(B, 1, h, w, d)
75
+
76
+ def mask_token_every_batch(self, bcfff,cur_active):
77
+ #A_token#
78
+ flag = cur_active.flatten(2).clone()
79
+ flag[0][0][0] = True
80
+ flag[0][0][-1] = True
81
+
82
+ indices = torch.nonzero(flag.squeeze()).squeeze()
83
+ #A_token#
84
+ B,N,H,L,W = bcfff.shape
85
+
86
+ A_token =[]
87
+
88
+ for i in range(0,len(indices)-1):
89
+ A_power = [torch.linalg.matrix_power(self.A_interpolation, i) for i in range(indices[i+1]-indices[i])]
90
+ max_power = indices[i+1]-indices[i]-1
91
+ for j in range(0,indices[i+1]-indices[i]):
92
+ A_token.append(A_power[max_power-j])
93
+ A_token.append(self.A_interpolation)
94
+ A_token = torch.cat(A_token, dim=0)
95
+
96
+
97
+ X_token = []
98
+ X_unmask = bcfff.flatten(2).transpose(1, 2).squeeze().unsqueeze(-1)
99
+ for i in range(0,len(indices)-1):
100
+ alpha = torch.linspace(0, 1, indices[i + 1] - indices[i], dtype=X_unmask.dtype, device=X_unmask.device)
101
+ alpha = alpha.view(-1, 1) # alpha
102
+ X_interpolation = (1 - alpha) * X_unmask[indices[i]].transpose(0, 1) + alpha * X_unmask[indices[i + 1]].transpose(0, 1)
103
+ X_token.append(X_interpolation.unsqueeze(-1))
104
+ X_last_token = X_unmask[indices[-1]].unsqueeze(0)
105
+ X_token.append(X_last_token)
106
+ X_token = torch.cat(X_token,dim = 0)
107
+
108
+ AX = A_token.cuda() @ X_token
109
+
110
+ mask_token = AX
111
+ for i in range(0,len(indices)-1):
112
+ current_sum = list(accumulate(AX[indices[i]:indices[i+1]]))
113
+ mask_token[indices[i]:indices[i+1]] = torch.stack(current_sum,dim = 0)
114
+ mask_token = AX.reshape(B,N,H,L,W)
115
+
116
+ return mask_token
117
+
118
+
119
+ def manba_mask(self,bcfff,cur_active):
120
+ '''
121
+ S6T
122
+ '''
123
+ B,N,H,W,L = cur_active.shape
124
+ cur_active_list = torch.chunk(cur_active,B,dim = 0)
125
+ bcfff_list = torch.chunk(bcfff,B,dim = 0)
126
+ mask_token_list=[]
127
+ for i in range(B):
128
+ mask_token_list.append(self.mask_token_every_batch(bcfff_list[i],cur_active_list[i]))
129
+ mask_token = torch.cat(mask_token_list, dim=0)
130
+
131
+ return mask_token
132
+
133
+ def forward(self, inp_bchwd: torch.Tensor, active_b1fff=None, vis=False):
134
+ # step1. Mask
135
+ if active_b1fff is None: # rand mask
136
+ active_b1fff: torch.BoolTensor = self.mask(inp_bchwd.shape[0], inp_bchwd.device) # (B, 1, f, f, f)
137
+ encoder._cur_active = active_b1fff # (B, 1, f, f)
138
+ active_b1hwd = active_b1fff.repeat_interleave(self.downsample_raito, 2).repeat_interleave(self.downsample_raito,
139
+ 3).repeat_interleave(
140
+ self.downsample_raito, 4) # (B, 1, H, W, D)
141
+ masked_bchwd = inp_bchwd * active_b1hwd
142
+
143
+ # step2. Encode: get hierarchical encoded sparse features (a list containing 4 feature maps at 4 scales)
144
+ fea_bcfffs: List[torch.Tensor] = self.sparse_encoder(masked_bchwd, active_b1fff)
145
+ fea_bcfffs.reverse() # after reversion: from the smallest feature map to the largest
146
+
147
+ # step3. Densify: get hierarchical dense features for decoding (need to modified !!!!!!!!!!!)
148
+ cur_active = active_b1fff # (B, 1, f, f, f)
149
+ to_dec = []
150
+ for i, bcfff in enumerate(fea_bcfffs): # from the smallest feature map to the largest
151
+ if bcfff is not None:
152
+ bcfff = self.densify_norms[i](bcfff)
153
+
154
+ mask_tokens = self.manba_mask(bcfff,cur_active) if i==0 else self.mask_tokens[i].expand_as(bcfff)
155
+
156
+ # mask_tokens = self.mask_tokens[i].expand_as(bcfff)
157
+ bcfff = torch.where(cur_active.expand_as(bcfff), bcfff,
158
+ mask_tokens) # fill in empty (non-active) positions with [mask] tokens
159
+ bcfff: torch.Tensor = self.densify_projs[i](bcfff)
160
+ to_dec.append(bcfff)
161
+ cur_active = cur_active.repeat_interleave(2, dim=2).repeat_interleave(2, dim=3).repeat_interleave(2,
162
+ dim=4) # dilate the mask map, from (B, 1, f, f) to (B, 1, H, W)
163
+ # step4. Decode and reconstruct
164
+ rec_bchwd = self.dense_decoder(to_dec)
165
+ inp, rec = self.patchify(inp_bchwd), self.patchify(
166
+ rec_bchwd) # inp and rec: (B, L = f*f*f, N = C*downsample_raito**2)
167
+ mean = inp.mean(dim=-1, keepdim=True)
168
+ var = (inp.var(dim=-1, keepdim=True) + 1e-6) ** .5
169
+ inp = (inp - mean) / var
170
+ l2_loss = ((rec - inp) ** 2).mean(dim=2, keepdim=False) # (B, L, C) ==mean==> (B, L)
171
+
172
+ non_active = active_b1fff.logical_not().int().view(active_b1fff.shape[0], -1) # (B, 1, f, f, f) => (B, L)
173
+ recon_loss = l2_loss.mul_(non_active).sum() / (
174
+ non_active.sum() + 1e-8) # loss only on masked (non-active) patches
175
+
176
+ if vis:
177
+ masked_bchwd = inp_bchwd * active_b1hwd
178
+ rec_bchwd = self.unpatchify(rec * var + mean)
179
+ rec_or_inp = torch.where(active_b1hwd, inp_bchwd, rec_bchwd)
180
+ return inp_bchwd, masked_bchwd, rec_or_inp
181
+ else:
182
+ return recon_loss
183
+
184
+ def patchify(self, bchwd):
185
+ p = self.downsample_raito
186
+ h, w, d = self.fmap_h, self.fmap_w, self.fmap_d
187
+ B, C = bchwd.shape[:2]
188
+ bchwd = bchwd.reshape(shape=(B, C, h, p, w, p, d, p))
189
+ bchwd = torch.einsum('bchpwqds->bhwdpqsc', bchwd)
190
+ bln = bchwd.reshape(shape=(B, h * w * d, C * p ** 3)) # (B, f*f, 3*downsample_raito**2)
191
+ return bln
192
+
193
+ def unpatchify(self, bln):
194
+ p = self.downsample_raito
195
+ h, w, d = self.fmap_h, self.fmap_w, self.fmap_d
196
+ B, C = bln.shape[0], bln.shape[-1] // p ** 3
197
+ bln = bln.reshape(shape=(B, h, w, d, p, p, p, C))
198
+ bln = torch.einsum('bhwdpqsc->bchpwqds', bln)
199
+ bchwd = bln.reshape(shape=(B, C, h * p, w * p, d * p))
200
+ return bchwd
201
+
202
+ def __repr__(self):
203
+ return (
204
+ f'\n'
205
+ f'[MambaMIM.config]: {pformat(self.get_config(), indent=2, width=250)}\n'
206
+ f'[MambaMIM.structure]: {super(MambaMIM, self).__repr__().replace(MambaMIM.__name__, "")}'
207
+ )
208
+
209
+ def get_config(self):
210
+ return {
211
+ # self
212
+ 'mask_ratio': self.mask_ratio,
213
+ 'densify_norm_str': self.densify_norm_str,
214
+ 'sbn': self.sbn, 'hierarchy': self.hierarchy,
215
+
216
+ # enc
217
+ 'sparse_encoder.input_size': self.sparse_encoder.input_size,
218
+ # dec
219
+ 'dense_decoder.width': self.dense_decoder.width,
220
+ }
221
+
222
+ def state_dict(self, destination=None, prefix='', keep_vars=False, with_config=False):
223
+ state = super(MambaMIM, self).state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
224
+ if with_config:
225
+ state['config'] = self.get_config()
226
+ return state
227
+
228
+ def load_state_dict(self, state_dict, strict=True):
229
+ config: dict = state_dict.pop('config', None)
230
+ incompatible_keys = super(MambaMIM, self).load_state_dict(state_dict, strict=strict)
231
+ if config is not None:
232
+ for k, v in self.get_config().items():
233
+ ckpt_v = config.get(k, None)
234
+ if ckpt_v != v:
235
+ err = f'[SparseMIM.load_state_dict] config mismatch: this.{k}={v} (ckpt.{k}={ckpt_v})'
236
+ if strict:
237
+ raise AttributeError(err)
238
+ else:
239
+ print(err, file=sys.stderr)
240
+ return incompatible_keys
models/__init__.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from timm.loss import SoftTargetCrossEntropy
3
+ from timm.models.layers import drop
4
+
5
+
6
+ from models.network.hymamba import Encoder
7
+
8
+
9
+
10
+
11
+ # log more
12
+ def _ex_repr(self):
13
+ return ', '.join(
14
+ f'{k}=' + (f'{v:g}' if isinstance(v, float) else str(v))
15
+ for k, v in vars(self).items()
16
+ if not k.startswith('_') and k != 'training'
17
+ and not isinstance(v, (torch.nn.Module, torch.Tensor))
18
+ )
19
+ for clz in (torch.nn.CrossEntropyLoss, SoftTargetCrossEntropy, drop.DropPath):
20
+ if hasattr(clz, 'extra_repr'):
21
+ clz.extra_repr = _ex_repr
22
+ else:
23
+ clz.__repr__ = lambda self: f'{type(self).__name__}({_ex_repr(self)})'
24
+
25
+
26
+ pretrain_default_model_kwargs = {
27
+ 'mambamim': dict(sparse=True, drop_path_rate=0.1),
28
+ }
29
+ for kw in pretrain_default_model_kwargs.values():
30
+ kw['pretrained'] = False
31
+ kw['num_classes'] = 0
32
+ kw['global_pool'] = ''
33
+
34
+
35
+
36
+ def build_sparse_encoder(name: str, input_size: int, sbn=False, drop_path_rate=0.0, verbose=False):
37
+ from models.encoder import SparseEncoder
38
+ kwargs = pretrain_default_model_kwargs[name]
39
+ if drop_path_rate != 0:
40
+ kwargs['drop_path_rate'] = drop_path_rate
41
+ print(f'[build_sparse_encoder] model kwargs={kwargs}')
42
+ encoder = Encoder(
43
+ in_channel=1,
44
+ channels=(32, 64, 128, 192, 384),
45
+ depths=(1, 2, 2, 2, 1),
46
+ kernels=(3, 3, 3, 3, 3),
47
+ exp_r=(2, 2, 4, 4, 4),
48
+ img_size=96,
49
+ depth=4,
50
+ sparse=True)
51
+ return SparseEncoder(encoder=encoder, input_size=input_size, sbn=sbn, verbose=verbose)
models/decoder.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from timm.models.layers import trunc_normal_
7
+
8
+
9
+ class UNetBlock(nn.Module):
10
+ def __init__(self, cin, cout, bn3d):
11
+ """
12
+ a UNet block with 2x up sampling
13
+ """
14
+ super().__init__()
15
+ self.up_sample = nn.ConvTranspose3d(cin, cin, kernel_size=2, stride=2, padding=0, bias=True)
16
+ self.conv = nn.Sequential(
17
+ nn.Conv3d(cin, cout, kernel_size=3, stride=1, padding=1, bias=True), bn3d(cout), nn.ReLU(inplace=True),
18
+ nn.Conv3d(cout, cout, kernel_size=3, stride=1, padding=1, bias=True), bn3d(cout), nn.ReLU(inplace=True),
19
+ )
20
+
21
+ def forward(self, x):
22
+ x = self.up_sample(x)
23
+ return self.conv(x)
24
+
25
+
26
+ class FusionBlock(nn.Module):
27
+ def __init__(self, cin, cout, bn3d):
28
+ """
29
+ a fusionBlock block with 2x up sampling
30
+ """
31
+ super().__init__()
32
+ self.conv = nn.Sequential(
33
+ nn.Conv3d(cin, cout, kernel_size=3, stride=1, padding=1, bias=True), bn3d(cout), nn.ReLU(inplace=True),
34
+ nn.Conv3d(cout, cout, kernel_size=3, stride=1, padding=1, bias=True), bn3d(cout), nn.ReLU(inplace=True),
35
+ )
36
+
37
+ def forward(self, x):
38
+ return self.conv(x)
39
+
40
+
41
+ class LightDecoder(nn.Module):
42
+ def __init__(self, up_sample_ratio, width=768,
43
+ sbn=True): # todo: the decoder's width follows a simple halfing rule; you can change it to any other rule
44
+ super().__init__()
45
+ self.width = width
46
+ n = round(math.log2(up_sample_ratio))
47
+ channels = [self.width // 2 ** i for i in range(
48
+ n + 1)] # todo: the decoder's width follows a simple halfing rule; you can change it to any other rule
49
+ bn3d = nn.BatchNorm3d
50
+ self.dec = nn.ModuleList([UNetBlock(cin, cout, bn3d) for (cin, cout) in zip(channels[:-1], channels[1:])])
51
+ self.fuse = nn.ModuleList([FusionBlock(cin * 2, cin, bn3d) for (cin, cout) in zip(channels[:-1], channels[1:])])
52
+ self.proj = nn.Conv3d(channels[-1], 1, kernel_size=1, stride=1, bias=True)
53
+
54
+ self.initialize()
55
+
56
+ def forward(self, to_dec: List[torch.Tensor]):
57
+ x = 0
58
+ for i, d in enumerate(self.dec):
59
+ if i < len(to_dec) and to_dec[i] is not None:
60
+ if isinstance(x, int):
61
+ x = x + to_dec[i]
62
+ else:
63
+ x = torch.cat((x, to_dec[i]), dim=1)
64
+ x = self.fuse[i](x)
65
+ x = self.dec[i](x)
66
+ return self.proj(x)
67
+
68
+ def extra_repr(self) -> str:
69
+ return f'width={self.width}'
70
+
71
+ def initialize(self):
72
+ for m in self.modules():
73
+ if isinstance(m, nn.Linear):
74
+ trunc_normal_(m.weight, std=.02)
75
+ if m.bias is not None:
76
+ nn.init.constant_(m.bias, 0)
77
+ elif isinstance(m, nn.Conv3d):
78
+ trunc_normal_(m.weight, std=.02)
79
+ if m.bias is not None:
80
+ nn.init.constant_(m.bias, 0)
81
+ elif isinstance(m, (nn.Conv3d, nn.ConvTranspose3d)):
82
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
83
+ if m.bias is not None:
84
+ nn.init.constant_(m.bias, 0.)
85
+ elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)):
86
+ nn.init.constant_(m.bias, 0)
87
+ nn.init.constant_(m.weight, 1.0)
models/encoder.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from timm.models.layers import DropPath
4
+
5
+ _cur_active: torch.Tensor = None # B1fff
6
+
7
+
8
+ # todo: try to use `gather` for speed?
9
+ def _get_active_ex_or_ii(H, W, D, returning_active_ex=True):
10
+ h_repeat, w_repeat, d_repeat = H // _cur_active.shape[-3], W // _cur_active.shape[-2], D // _cur_active.shape[-1]
11
+ active_ex = _cur_active.repeat_interleave(h_repeat, dim=2).repeat_interleave(w_repeat, dim=3).repeat_interleave(d_repeat, dim=4)
12
+ return active_ex if returning_active_ex else active_ex.squeeze(1).nonzero(as_tuple=True) # ii: bi, hi, wi
13
+
14
+
15
+ def sp_conv_forward(self, x: torch.Tensor):
16
+ x = super(type(self), self).forward(x)
17
+ x *= _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], D=x.shape[4], returning_active_ex=True) # (BCHW) *= (B1HW), mask the output of conv
18
+ return x
19
+
20
+
21
+ def sp_bn_forward(self, x: torch.Tensor):
22
+ ii = _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], D=x.shape[4], returning_active_ex=False)
23
+
24
+ bhwdc = x.permute(0, 2, 3, 4, 1)
25
+ nc = bhwdc[ii] # select the features on non-masked positions to form a flatten feature `nc`
26
+ nc = super(type(self), self).forward(nc) # use BN1d to normalize this flatten feature `nc`
27
+
28
+ bchwd = torch.zeros_like(bhwdc)
29
+ bchwd[ii] = nc
30
+ bchwd = bchwd.permute(0, 4, 1, 2, 3)
31
+ return bchwd
32
+
33
+
34
+ def sp_in_forward(self, x: torch.Tensor):
35
+ ii = _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], D=x.shape[4], returning_active_ex=False)
36
+ bhwdc = x.permute(0, 2, 3, 4, 1)
37
+ cn = bhwdc[ii].permute(1,
38
+ 0) # select the features on non-masked positions to form a flatten feature `nc` [17787, 3]
39
+ C, N = cn.shape
40
+ bcl = cn.reshape(C, -1, x.shape[0]).permute(2, 0, 1)
41
+ bcl = super(type(self), self).forward(bcl) # use BN1d to normalize this flatten feature `nc`
42
+ nc = bcl.permute(1, 2, 0).reshape(C, -1).permute(1, 0)
43
+ bchwd = torch.zeros_like(bhwdc)
44
+ bchwd[ii] = nc
45
+ bchwd = bchwd.permute(0, 4, 1, 2, 3)
46
+ return bchwd
47
+
48
+
49
+ class SparseConv3d(nn.Conv3d):
50
+ forward = sp_conv_forward # hack: override the forward function; see `sp_conv_forward` above for more details
51
+
52
+
53
+ class SparseMaxPooling(nn.MaxPool3d):
54
+ forward = sp_conv_forward # hack: override the forward function; see `sp_conv_forward` above for more details
55
+
56
+
57
+ class SparseAvgPooling(nn.AvgPool3d):
58
+ forward = sp_conv_forward # hack: override the forward function; see `sp_conv_forward` above for more details
59
+
60
+
61
+ class SparseBatchNorm3d(nn.BatchNorm1d):
62
+ forward = sp_bn_forward # hack: override the forward function; see `sp_bn_forward` above for more details
63
+
64
+
65
+ class SparseSyncBatchNorm3d(nn.SyncBatchNorm):
66
+ forward = sp_bn_forward # hack: override the forward function; see `sp_bn_forward` above for more details
67
+
68
+
69
+ class SparseInstanceNorm3d(nn.InstanceNorm1d):
70
+ forward = sp_in_forward # hack: override the forward function; see `sp_bn_forward` above for more details
71
+
72
+
73
+ class SparseConvNeXtLayerNorm(nn.LayerNorm):
74
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
75
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
76
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
77
+ with shape (batch_size, channels, height, width).
78
+ """
79
+
80
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", sparse=True):
81
+ if data_format not in ["channels_last", "channels_first"]:
82
+ raise NotImplementedError
83
+ super().__init__(normalized_shape, eps, elementwise_affine=True)
84
+ self.data_format = data_format
85
+ self.sparse = sparse
86
+
87
+ def forward(self, x):
88
+ if x.ndim == 5: # BHWDC or BCHWD
89
+ if self.data_format == "channels_last": # BHWDC
90
+ if self.sparse:
91
+ ii = _get_active_ex_or_ii(H=x.shape[1], W=x.shape[2], D=x.shape[3], returning_active_ex=False)
92
+ nc = x[ii]
93
+ nc = super(SparseConvNeXtLayerNorm, self).forward(nc)
94
+
95
+ x = torch.zeros_like(x)
96
+ x[ii] = nc
97
+ return x
98
+ else:
99
+ return super(SparseConvNeXtLayerNorm, self).forward(x)
100
+ else: # channels_first, BCHWD
101
+ if self.sparse:
102
+ ii = _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], D=x.shape[4], returning_active_ex=False)
103
+ bhwc = x.permute(0, 2, 3, 4, 1)
104
+ nc = bhwc[ii]
105
+ nc = super(SparseConvNeXtLayerNorm, self).forward(nc)
106
+
107
+ x = torch.zeros_like(bhwc)
108
+ x[ii] = nc
109
+ return x.permute(0, 4, 1, 2, 3)
110
+ else:
111
+ u = x.mean(1, keepdim=True)
112
+ s = (x - u).pow(2).mean(1, keepdim=True)
113
+ x = (x - u) / torch.sqrt(s + self.eps)
114
+ x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None]
115
+ return x
116
+ else: # BLC or BC
117
+ if self.sparse:
118
+ raise NotImplementedError
119
+ else:
120
+ return super(SparseConvNeXtLayerNorm, self).forward(x)
121
+
122
+ def __repr__(self):
123
+ return super(SparseConvNeXtLayerNorm, self).__repr__()[
124
+ :-1] + f', ch={self.data_format.split("_")[-1]}, sp={self.sparse})'
125
+
126
+
127
+ class SparseConvNeXtBlock(nn.Module):
128
+ r""" ConvNeXt Block. There are two equivalent implementations:
129
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
130
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
131
+ We use (2) as we find it slightly faster in PyTorch
132
+
133
+ Args:
134
+ dim (int): Number of input channels.
135
+ drop_path (float): Stochastic depth rate. Default: 0.0
136
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
137
+ """
138
+
139
+ def __init__(self, in_channels, out_channels, kernel_size=7, exp_r=4, do_res=False, drop_path=0.,
140
+ layer_scale_init_value=1e-6, sparse=True):
141
+ super().__init__()
142
+
143
+ self.do_res = do_res
144
+ self.dwconv = nn.Conv3d(in_channels, in_channels, kernel_size=kernel_size, padding=kernel_size // 2,
145
+ groups=in_channels) # depthwise conv
146
+ self.norm = SparseConvNeXtLayerNorm(in_channels, eps=1e-6, sparse=sparse)
147
+ self.pwconv1 = nn.Linear(in_channels,
148
+ exp_r * in_channels) # pointwise/1x1 convs, implemented with linear layers
149
+ self.act = nn.GELU()
150
+ self.pwconv2 = nn.Linear(exp_r * in_channels, out_channels)
151
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((out_channels)),
152
+ requires_grad=True) if layer_scale_init_value > 0 else None
153
+ self.drop_path: nn.Module = DropPath(drop_path) if drop_path > 0. else nn.Identity()
154
+ self.sparse = sparse
155
+
156
+ def forward(self, x):
157
+ input = x
158
+ x = self.dwconv(x)
159
+ x = x.permute(0, 2, 3, 4, 1) # (N, C, H, W, D) -> (N, H, W, D, C)
160
+ x = self.norm(x)
161
+ x = self.pwconv1(x)
162
+ x = self.act(x) # GELU(0) == (0), so there is no need to mask x (no need to `x *= _get_active_ex_or_ii`)
163
+ x = self.pwconv2(x)
164
+ if self.gamma is not None:
165
+ x = self.gamma * x
166
+ x = x.permute(0, 4, 1, 2, 3) # (N, H, W, C) -> (N, C, H, W)
167
+
168
+ if self.sparse:
169
+ x *= _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], D=x.shape[4], returning_active_ex=True)
170
+ if self.do_res:
171
+ x = input + self.drop_path(x)
172
+ return x
173
+
174
+ def __repr__(self):
175
+ return super(SparseConvNeXtBlock, self).__repr__()[:-1] + f', sp={self.sparse})'
176
+
177
+
178
+ class SparseEncoder(nn.Module):
179
+ def __init__(self, encoder, input_size, sbn=False, verbose=False):
180
+ super(SparseEncoder, self).__init__()
181
+ self.embeddings = SparseEncoder.dense_model_to_sparse(m=encoder.embeddings, verbose=verbose, sbn=sbn)
182
+ self.mae = encoder.mae
183
+
184
+ # self.encoder = SparseEncoder.dense_model_to_sparse(m=encoder, verbose=verbose, sbn=sbn)
185
+ self.input_size, self.downsample_raito, self.enc_feat_map_chs = input_size, encoder.get_downsample_ratio(), encoder.get_feature_map_channels()
186
+
187
+ @staticmethod
188
+ def dense_model_to_sparse(m: nn.Module, verbose=False, sbn=False):
189
+ oup = m
190
+ if isinstance(m, nn.Conv3d):
191
+ m: nn.Conv3d
192
+ bias = m.bias is not None
193
+ oup = SparseConv3d(
194
+ m.in_channels, m.out_channels,
195
+ kernel_size=m.kernel_size, stride=m.stride, padding=m.padding,
196
+ dilation=m.dilation, groups=m.groups, bias=bias, padding_mode=m.padding_mode,
197
+ )
198
+ oup.weight.data.copy_(m.weight.data)
199
+ if bias:
200
+ oup.bias.data.copy_(m.bias.data)
201
+ elif isinstance(m, nn.MaxPool3d):
202
+ m: nn.MaxPool3d
203
+ oup = SparseMaxPooling(m.kernel_size, stride=m.stride, padding=m.padding, dilation=m.dilation,
204
+ return_indices=m.return_indices, ceil_mode=m.ceil_mode)
205
+ elif isinstance(m, nn.AvgPool3d):
206
+ m: nn.AvgPool3d
207
+ oup = SparseAvgPooling(m.kernel_size, m.stride, m.padding, ceil_mode=m.ceil_mode,
208
+ count_include_pad=m.count_include_pad, divisor_override=m.divisor_override)
209
+ elif isinstance(m, (nn.BatchNorm3d, nn.SyncBatchNorm)):
210
+ m: nn.BatchNorm3d
211
+ oup = (SparseSyncBatchNorm3d if sbn else SparseBatchNorm3d)(m.weight.shape[0], eps=m.eps,
212
+ momentum=m.momentum, affine=m.affine,
213
+ track_running_stats=m.track_running_stats)
214
+ oup.weight.data.copy_(m.weight.data)
215
+ oup.bias.data.copy_(m.bias.data)
216
+ oup.running_mean.data.copy_(m.running_mean.data)
217
+ oup.running_var.data.copy_(m.running_var.data)
218
+ oup.num_batches_tracked.data.copy_(m.num_batches_tracked.data)
219
+ if hasattr(m, "qconfig"):
220
+ oup.qconfig = m.qconfig
221
+ elif isinstance(m, nn.InstanceNorm3d):
222
+ m: nn.InstanceNorm3d
223
+ oup = SparseInstanceNorm3d(m.num_features, eps=m.eps, momentum=m.momentum, affine=m.affine,
224
+ track_running_stats=m.track_running_stats)
225
+ if hasattr(m, "qconfig"):
226
+ oup.qconfig = m.qconfig
227
+ elif isinstance(m, nn.LayerNorm) and not isinstance(m, SparseConvNeXtLayerNorm):
228
+ m: nn.LayerNorm
229
+ oup = SparseConvNeXtLayerNorm(m.weight.shape[0], eps=m.eps)
230
+ oup.weight.data.copy_(m.weight.data)
231
+ oup.bias.data.copy_(m.bias.data)
232
+ elif isinstance(m, (nn.Conv1d,)):
233
+ m: nn.Conv1d
234
+ bias = m.bias is not None
235
+ oup = nn.Conv1d(
236
+ m.in_channels, m.out_channels,
237
+ kernel_size=m.kernel_size, stride=m.stride, padding=m.padding,
238
+ dilation=m.dilation, groups=m.groups, bias=bias, padding_mode=m.padding_mode)
239
+ oup.weight.data.copy_(m.weight.data)
240
+ if bias:
241
+ oup.bias.data.copy_(m.bias.data)
242
+ for name, child in m.named_children():
243
+ oup.add_module(name, SparseEncoder.dense_model_to_sparse(child, verbose=verbose, sbn=sbn))
244
+ del m
245
+ return oup
246
+
247
+ def forward(self, x, active_b1fff):
248
+ x1, x2, x3, x4, x5 = self.embeddings(x)
249
+ _x5 = self.mae(x5, active_b1fff)
250
+ return [x1, x2, x3, x4, _x5]
251
+
252
+
253
+ if __name__ == '__main__':
254
+ x = torch.randn([1, 96, 24, 24, 24])
255
+ _cur_active = torch.randn([1, 1, 96 // 16, 96 // 16, 96 // 16])
256
+ print(x.shape)
257
+ print(_get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], D=x.shape[4], returning_active_ex=True).shape)
258
+ print(x.shape)
models/mamba/bi_vision_mamba.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao, Albert Gu.
2
+
3
+ import math
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch import Tensor
10
+
11
+ from einops import rearrange, repeat
12
+
13
+ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
14
+
15
+ try:
16
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
17
+ except ImportError:
18
+ causal_conv1d_fn, causal_conv1d_update = None, None
19
+
20
+ try:
21
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
22
+ except ImportError:
23
+ selective_state_update = None
24
+
25
+ try:
26
+ from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
27
+ except ImportError:
28
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
29
+
30
+
31
+ class Mamba(nn.Module):
32
+ def __init__(
33
+ self,
34
+ d_model,
35
+ d_state=16,
36
+ d_conv=4,
37
+ expand=2,
38
+ dt_rank="auto",
39
+ dt_min=0.001,
40
+ dt_max=0.1,
41
+ dt_init="random",
42
+ dt_scale=1.0,
43
+ dt_init_floor=1e-4,
44
+ conv_bias=True,
45
+ bias=False,
46
+ use_fast_path=True, # Fused kernel options
47
+ layer_idx=None,
48
+ device=None,
49
+ dtype=None,
50
+ bimamba_type="none"
51
+ ):
52
+ factory_kwargs = {"device": device, "dtype": dtype}
53
+ super().__init__()
54
+ self.d_model = d_model
55
+ self.d_state = d_state
56
+ self.d_conv = d_conv
57
+ self.expand = expand
58
+ self.d_inner = int(self.expand * self.d_model)
59
+ self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
60
+ self.use_fast_path = use_fast_path
61
+ self.layer_idx = layer_idx
62
+ self.bimamba_type = bimamba_type
63
+
64
+ self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
65
+
66
+ self.conv1d = nn.Conv1d(
67
+ in_channels=self.d_inner,
68
+ out_channels=self.d_inner,
69
+ bias=conv_bias,
70
+ kernel_size=d_conv,
71
+ groups=self.d_inner,
72
+ padding=d_conv - 1,
73
+ **factory_kwargs,
74
+ )
75
+
76
+ self.activation = "silu"
77
+ self.act = nn.SiLU()
78
+
79
+ self.x_proj = nn.Linear(
80
+ self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
81
+ )
82
+ self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
83
+
84
+ # Initialize special dt projection to preserve variance at initialization
85
+ dt_init_std = self.dt_rank ** -0.5 * dt_scale
86
+ if dt_init == "constant":
87
+ nn.init.constant_(self.dt_proj.weight, dt_init_std)
88
+ elif dt_init == "random":
89
+ nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
90
+ else:
91
+ raise NotImplementedError
92
+
93
+ # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
94
+ dt = torch.exp(
95
+ torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
96
+ + math.log(dt_min)
97
+ ).clamp(min=dt_init_floor)
98
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
99
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
100
+ with torch.no_grad():
101
+ self.dt_proj.bias.copy_(inv_dt)
102
+ # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
103
+ self.dt_proj.bias._no_reinit = True
104
+
105
+ # S4D real initialization
106
+ A = repeat(
107
+ torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
108
+ "n -> d n",
109
+ d=self.d_inner,
110
+ ).contiguous()
111
+ A_log = torch.log(A) # Keep A_log in fp32
112
+ self.A_log = nn.Parameter(A_log)
113
+ self.A_log._no_weight_decay = True
114
+
115
+ # D "skip" parameter
116
+ self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
117
+ self.D._no_weight_decay = True
118
+
119
+ # assert bimamba_type == "v2"
120
+
121
+ A_b = repeat(
122
+ torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
123
+ "n -> d n",
124
+ d=self.d_inner,
125
+ ).contiguous()
126
+ A_b_log = torch.log(A_b) # Keep A_b_log in fp32
127
+ self.A_b_log = nn.Parameter(A_b_log)
128
+ self.A_b_log._no_weight_decay = True
129
+
130
+ self.conv1d_b = nn.Conv1d(
131
+ in_channels=self.d_inner,
132
+ out_channels=self.d_inner,
133
+ bias=conv_bias,
134
+ kernel_size=d_conv,
135
+ groups=self.d_inner,
136
+ padding=d_conv - 1,
137
+ **factory_kwargs,
138
+ )
139
+
140
+ self.x_proj_b = nn.Linear(
141
+ self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
142
+ )
143
+ self.dt_proj_b = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
144
+
145
+ self.D_b = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
146
+ self.D_b._no_weight_decay = True
147
+
148
+ self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
149
+
150
+ def forward(self, hidden_states, inference_params=None):
151
+ """
152
+ hidden_states: (B, L, D)
153
+ Returns: same shape as hidden_states
154
+ """
155
+ batch, seqlen, dim = hidden_states.shape
156
+
157
+ conv_state, ssm_state = None, None
158
+ if inference_params is not None:
159
+ conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
160
+ if inference_params.seqlen_offset > 0:
161
+ # The states are updated inplace
162
+ out, _, _ = self.step(hidden_states, conv_state, ssm_state)
163
+ return out
164
+
165
+ # We do matmul and transpose BLH -> HBL at the same time
166
+ xz = rearrange(
167
+ self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
168
+ "d (b l) -> b d l",
169
+ l=seqlen,
170
+ )
171
+ if self.in_proj.bias is not None:
172
+ xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
173
+
174
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
175
+ # In the backward pass we write dx and dz next to each other to avoid torch.cat
176
+ if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None: # Doesn't support outputting the states
177
+ if self.bimamba_type == "v2":
178
+ A_b = -torch.exp(self.A_b_log.float())
179
+ out = mamba_inner_fn_no_out_proj(
180
+ xz,
181
+ self.conv1d.weight,
182
+ self.conv1d.bias,
183
+ self.x_proj.weight,
184
+ self.dt_proj.weight,
185
+ A,
186
+ None, # input-dependent B
187
+ None, # input-dependent C
188
+ self.D.float(),
189
+ delta_bias=self.dt_proj.bias.float(),
190
+ delta_softplus=True,
191
+ )
192
+ out_b = mamba_inner_fn_no_out_proj(
193
+ xz.flip([-1]),
194
+ self.conv1d_b.weight,
195
+ self.conv1d_b.bias,
196
+ self.x_proj_b.weight,
197
+ self.dt_proj_b.weight,
198
+ A_b,
199
+ None,
200
+ None,
201
+ self.D_b.float(),
202
+ delta_bias=self.dt_proj_b.bias.float(),
203
+ delta_softplus=True,
204
+ )
205
+ # F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
206
+ out = F.linear(rearrange(out + out_b.flip([-1]), "b d l -> b l d"), self.out_proj.weight,
207
+ self.out_proj.bias)
208
+ else:
209
+ out = mamba_inner_fn(
210
+ xz,
211
+ self.conv1d.weight,
212
+ self.conv1d.bias,
213
+ self.x_proj.weight,
214
+ self.dt_proj.weight,
215
+ self.out_proj.weight,
216
+ self.out_proj.bias,
217
+ A,
218
+ None, # input-dependent B
219
+ None, # input-dependent C
220
+ self.D.float(),
221
+ delta_bias=self.dt_proj.bias.float(),
222
+ delta_softplus=True,
223
+ )
224
+ else:
225
+ x, z = xz.chunk(2, dim=1)
226
+ # Compute short convolution
227
+ if conv_state is not None:
228
+ # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
229
+ # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
230
+ conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
231
+ if causal_conv1d_fn is None:
232
+ x = self.act(self.conv1d(x)[..., :seqlen])
233
+ else:
234
+ assert self.activation in ["silu", "swish"]
235
+ x = causal_conv1d_fn(
236
+ x=x,
237
+ weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
238
+ bias=self.conv1d.bias,
239
+ activation=self.activation,
240
+ )
241
+
242
+ # We're careful here about the layout, to avoid extra transposes.
243
+ # We want dt to have d as the slowest moving dimension
244
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
245
+ x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
246
+ dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
247
+ dt = self.dt_proj.weight @ dt.t()
248
+ dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
249
+ B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
250
+ C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
251
+ assert self.activation in ["silu", "swish"]
252
+ y = selective_scan_fn(
253
+ x,
254
+ dt,
255
+ A,
256
+ B,
257
+ C,
258
+ self.D.float(),
259
+ z=z,
260
+ delta_bias=self.dt_proj.bias.float(),
261
+ delta_softplus=True,
262
+ return_last_state=ssm_state is not None,
263
+ )
264
+ if ssm_state is not None:
265
+ y, last_state = y
266
+ ssm_state.copy_(last_state)
267
+ y = rearrange(y, "b d l -> b l d")
268
+ out = self.out_proj(y)
269
+ return out
270
+
271
+ def step(self, hidden_states, conv_state, ssm_state):
272
+ dtype = hidden_states.dtype
273
+ assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
274
+ xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
275
+ x, z = xz.chunk(2, dim=-1) # (B D)
276
+
277
+ # Conv step
278
+ if causal_conv1d_update is None:
279
+ conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
280
+ conv_state[:, :, -1] = x
281
+ x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
282
+ if self.conv1d.bias is not None:
283
+ x = x + self.conv1d.bias
284
+ x = self.act(x).to(dtype=dtype)
285
+ else:
286
+ x = causal_conv1d_update(
287
+ x,
288
+ conv_state,
289
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
290
+ self.conv1d.bias,
291
+ self.activation,
292
+ )
293
+
294
+ x_db = self.x_proj(x) # (B dt_rank+2*d_state)
295
+ dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
296
+ # Don't add dt_bias here
297
+ dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
298
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
299
+
300
+ # SSM step
301
+ if selective_state_update is None:
302
+ # Discretize A and B
303
+ dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
304
+ dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
305
+ dB = torch.einsum("bd,bn->bdn", dt, B)
306
+ ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
307
+ y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
308
+ y = y + self.D.to(dtype) * x
309
+ y = y * self.act(z) # (B D)
310
+ else:
311
+ y = selective_state_update(
312
+ ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
313
+ )
314
+
315
+ out = self.out_proj(y)
316
+ return out.unsqueeze(1), conv_state, ssm_state
317
+
318
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
319
+ device = self.out_proj.weight.device
320
+ conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
321
+ conv_state = torch.zeros(
322
+ batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
323
+ )
324
+ ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
325
+ # ssm_dtype = torch.float32
326
+ ssm_state = torch.zeros(
327
+ batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
328
+ )
329
+ return conv_state, ssm_state
330
+
331
+ def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
332
+ assert self.layer_idx is not None
333
+ if self.layer_idx not in inference_params.key_value_memory_dict:
334
+ batch_shape = (batch_size,)
335
+ conv_state = torch.zeros(
336
+ batch_size,
337
+ self.d_model * self.expand,
338
+ self.d_conv,
339
+ device=self.conv1d.weight.device,
340
+ dtype=self.conv1d.weight.dtype,
341
+ )
342
+ ssm_state = torch.zeros(
343
+ batch_size,
344
+ self.d_model * self.expand,
345
+ self.d_state,
346
+ device=self.dt_proj.weight.device,
347
+ dtype=self.dt_proj.weight.dtype,
348
+ # dtype=torch.float32,
349
+ )
350
+ inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
351
+ else:
352
+ conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
353
+ # TODO: What if batch size changes between generation, and we reuse the same states?
354
+ if initialize_states:
355
+ conv_state.zero_()
356
+ ssm_state.zero_()
357
+ return conv_state, ssm_state
358
+
359
+
360
+ class Block(nn.Module):
361
+ def __init__(
362
+ self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
363
+ ):
364
+ """
365
+ Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
366
+
367
+ This Block has a slightly different structure compared to a regular
368
+ prenorm Transformer block.
369
+ The standard block is: LN -> MHA/MLP -> Add.
370
+ [Ref: https://arxiv.org/abs/2002.04745]
371
+ Here we have: Add -> LN -> Mixer, returning both
372
+ the hidden_states (output of the mixer) and the residual.
373
+ This is purely for performance reasons, as we can fuse add and LayerNorm.
374
+ The residual needs to be provided (except for the very first block).
375
+ """
376
+ super().__init__()
377
+ self.residual_in_fp32 = residual_in_fp32
378
+ self.fused_add_norm = fused_add_norm
379
+ self.mixer = mixer_cls(dim)
380
+ self.norm = norm_cls(dim)
381
+ if self.fused_add_norm:
382
+ assert RMSNorm is not None, "RMSNorm import fails"
383
+ assert isinstance(
384
+ self.norm, (nn.LayerNorm, RMSNorm)
385
+ ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
386
+
387
+ def forward(
388
+ self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
389
+ ):
390
+ r"""Pass the input through the encoder layer.
391
+
392
+ Args:
393
+ hidden_states: the sequence to the encoder layer (required).
394
+ residual: hidden_states = Mixer(LN(residual))
395
+ """
396
+ if not self.fused_add_norm:
397
+ residual = (hidden_states + residual) if residual is not None else hidden_states
398
+ hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
399
+ if self.residual_in_fp32:
400
+ residual = residual.to(torch.float32)
401
+ else:
402
+ fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
403
+ hidden_states, residual = fused_add_norm_fn(
404
+ hidden_states,
405
+ self.norm.weight,
406
+ self.norm.bias,
407
+ residual=residual,
408
+ prenorm=True,
409
+ residual_in_fp32=self.residual_in_fp32,
410
+ eps=self.norm.eps,
411
+ )
412
+ hidden_states = self.mixer(hidden_states, inference_params=inference_params)
413
+ return hidden_states, residual
414
+
415
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
416
+ return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
models/network/hymamba.py ADDED
@@ -0,0 +1,681 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from models.encoder import SparseConvNeXtLayerNorm, _get_active_ex_or_ii
4
+ from typing import Optional, Sequence, Tuple, Union, List
5
+ import numpy as np
6
+ from models.mamba.bi_vision_mamba import Mamba
7
+ from monai.networks.blocks.unetr_block import UnetrUpBlock
8
+
9
+ def build_3d_sincos_position_embedding(grid_size, embed_dim, num_tokens=0, temperature=10000.):
10
+ grid_size = (grid_size, grid_size, grid_size)
11
+ h, w, d = grid_size
12
+ grid_h = torch.arange(h, dtype=torch.float32)
13
+ grid_w = torch.arange(w, dtype=torch.float32)
14
+ grid_d = torch.arange(d, dtype=torch.float32)
15
+
16
+ grid_h, grid_w, grid_d = torch.meshgrid(grid_h, grid_w, grid_d)
17
+ assert embed_dim % 6 == 0, 'Embed dimension must be divisible by 6 for 3D sin-cos position embedding'
18
+ pos_dim = embed_dim // 6
19
+ omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
20
+ omega = 1. / (temperature ** omega)
21
+ out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
22
+ out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])
23
+ out_d = torch.einsum('m,d->md', [grid_d.flatten(), omega])
24
+ pos_emb = torch.cat(
25
+ [torch.sin(out_h), torch.cos(out_h), torch.sin(out_w), torch.cos(out_w), torch.sin(out_d), torch.cos(out_d)],
26
+ dim=1)[None, :, :]
27
+
28
+ assert num_tokens == 1 or num_tokens == 0, "Number of tokens must be of 0 or 1"
29
+ if num_tokens == 1:
30
+ pe_token = torch.zeros([1, 1, embed_dim], dtype=torch.float32)
31
+ pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1))
32
+ else:
33
+ pos_embed = nn.Parameter(pos_emb)
34
+ pos_embed.requires_grad = False
35
+ return pos_embed
36
+
37
+
38
+ class MlpChannel(nn.Module):
39
+ def __init__(self, hidden_size, mlp_dim):
40
+ super().__init__()
41
+ self.fc1 = nn.Linear(hidden_size, mlp_dim)
42
+ self.act = nn.GELU()
43
+ self.fc2 = nn.Linear(mlp_dim, hidden_size)
44
+
45
+ def forward(self, x):
46
+ x = self.fc1(x)
47
+ x = self.act(x)
48
+ x = self.fc2(x)
49
+ return x
50
+
51
+
52
+ class MambaLayer(nn.Module):
53
+ def __init__(self, dim, d_state=16, d_conv=4, expand=2):
54
+ super().__init__()
55
+ self.dim = dim
56
+ self.norm1 = nn.LayerNorm(dim)
57
+ self.mamba = Mamba(
58
+ d_model=dim, # Model dimension d_model
59
+ d_state=d_state, # SSM state expansion factor
60
+ d_conv=d_conv, # Local convolution width
61
+ expand=expand, # Block expansion factor
62
+ bimamba_type="v1",
63
+ )
64
+ self.mlp = MlpChannel(hidden_size=dim, mlp_dim=2 * dim)
65
+ self.norm2 = nn.LayerNorm(dim)
66
+ def forward(self, x):
67
+ x = self.mamba(self.norm1(x)) + x
68
+ x = self.mlp(self.norm2(x)) + x
69
+ return x
70
+
71
+
72
+ class MaskedAutoencoderMamba(nn.Module):
73
+ """ Masked Autoencoder with VisionTransformer backbone
74
+ """
75
+
76
+ def __init__(self, img_size=96, downsample_rato=16, embed_dim=384, depth=8, norm_layer=nn.LayerNorm, sparse=True):
77
+ super().__init__()
78
+ print("mamba sparse: ", sparse)
79
+ # --------------------------------------------------------------------------
80
+ # MAE encoder specifics
81
+ self.grid_size = img_size // downsample_rato
82
+ self.num_patches = (self.grid_size) ** 3
83
+ self.embed_dim = embed_dim
84
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim),
85
+ requires_grad=False) # fixed sin-cos embedding
86
+
87
+ self.blocks = nn.ModuleList([
88
+ MambaLayer(dim=embed_dim)
89
+ for i in range(depth)])
90
+ # self.gsc = GSC(in_channels=embed_dim, sparse=sparse)
91
+
92
+ self.sparse = sparse
93
+ if self.sparse:
94
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
95
+ # --------------------------------------------------------------------------
96
+ self.initialize_weights()
97
+
98
+ def initialize_weights(self):
99
+ # initialization
100
+ # initialize (and freeze) pos_embed by sin-cos embedding
101
+ pos_embed = build_3d_sincos_position_embedding(self.grid_size, self.embed_dim)
102
+ self.pos_embed.data.copy_(pos_embed)
103
+ if self.sparse:
104
+ torch.nn.init.normal_(self.mask_token, std=.02)
105
+ # initialize nn.Linear and nn.LayerNorm
106
+ self.apply(self._init_weights)
107
+
108
+ def _init_weights(self, m):
109
+ if isinstance(m, nn.Linear):
110
+ # we use xavier_uniform following official JAX ViT:
111
+ torch.nn.init.xavier_uniform_(m.weight)
112
+ if isinstance(m, nn.Linear) and m.bias is not None:
113
+ nn.init.constant_(m.bias, 0)
114
+ elif isinstance(m, nn.LayerNorm):
115
+ nn.init.constant_(m.bias, 0)
116
+ nn.init.constant_(m.weight, 1.0)
117
+
118
+ def random_masking(self, enc, active_b1fff):
119
+ """
120
+ Perform per-sample random masking by per-sample shuffling.
121
+ Per-sample shuffling is done by argsort random noise.
122
+ x: [N, L, D], sequence
123
+ """
124
+ N, L, D = enc.shape # batch, length, dim
125
+ mask = torch.tensor(active_b1fff, dtype=torch.int).flatten(2).transpose(1, 2)
126
+ # sort noise for each sample
127
+ noise = 1 - mask
128
+ len_keep = torch.sum(mask)
129
+ ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
130
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
131
+
132
+ # keep the first subset
133
+ ids_keep = ids_shuffle[:, :len_keep]
134
+ x_masked = torch.gather(enc, dim=1, index=ids_keep.repeat(1, 1, D))
135
+
136
+ # generate the binary mask: 0 is keep, 1 is remove
137
+ return x_masked, mask, ids_restore
138
+
139
+ def unmasking(self, x, ids_restore):
140
+ mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1)
141
+ x_ = torch.cat([x, mask_tokens], dim=1) # no cls token
142
+ x = torch.gather(x_, dim=1, index=ids_restore.repeat(1, 1, x.shape[2])) # unshuffle
143
+ return x
144
+
145
+ def forward_encoder(self, enc, active_b1fff=None):
146
+ # enc = self.gsc(enc)
147
+ B, C, H, W, D = enc.shape
148
+ x = enc.flatten(2).transpose(1, 2)
149
+ # add pos embed w/o cls token
150
+ x = x + self.pos_embed
151
+ if self.sparse:
152
+ # masking: length -> length * mask_ratio
153
+ x, mask, ids_restore = self.random_masking(x, active_b1fff)
154
+ # apply Mamba blocks
155
+ for blk in self.blocks:
156
+ x = blk(x)
157
+ x = self.unmasking(x, ids_restore)
158
+ else:
159
+ for blk in self.blocks:
160
+ x = blk(x)
161
+ x = x.transpose(1, 2).reshape(B, C, H, W, D)
162
+ return x
163
+
164
+ def forward(self, imgs, active_b1fff=None):
165
+ return self.forward_encoder(imgs, active_b1fff)
166
+
167
+
168
+ class MedNeXtBlock(nn.Module):
169
+ def __init__(self,
170
+ in_channels: int,
171
+ out_channels: int,
172
+ exp_r: int = 4,
173
+ kernel_size: int = 7,
174
+ do_res: int = True,
175
+ n_groups: int or None = None,
176
+ sparse=False):
177
+
178
+ super().__init__()
179
+
180
+ self.do_res = do_res
181
+ self.sparse = sparse
182
+ conv = nn.Conv3d
183
+
184
+ # First convolution layer with DepthWise Convolutions
185
+ self.conv1 = conv(
186
+ in_channels=in_channels,
187
+ out_channels=in_channels,
188
+ kernel_size=kernel_size,
189
+ stride=1,
190
+ padding=kernel_size // 2,
191
+ groups=in_channels if n_groups is None else n_groups,
192
+ )
193
+
194
+ # Normalization Layer. GroupNorm is used by default.
195
+
196
+ self.norm = SparseConvNeXtLayerNorm(normalized_shape=in_channels, data_format='channels_first', sparse=sparse)
197
+
198
+ # Second convolution (Expansion) layer with Conv3D 1x1x1
199
+ self.conv2 = conv(
200
+ in_channels=in_channels,
201
+ out_channels=exp_r * in_channels,
202
+ kernel_size=1,
203
+ stride=1,
204
+ padding=0
205
+ )
206
+
207
+ # GeLU activations
208
+ self.act = nn.GELU()
209
+
210
+ # Third convolution (Compression) layer with Conv3D 1x1x1
211
+ self.conv3 = conv(
212
+ in_channels=exp_r * in_channels,
213
+ out_channels=out_channels,
214
+ kernel_size=1,
215
+ stride=1,
216
+ padding=0
217
+ )
218
+
219
+ def forward(self, x, dummy_tensor=None):
220
+
221
+ x1 = x
222
+ x1 = self.conv1(x1)
223
+ x1 = self.act(self.conv2(self.norm(x1)))
224
+ x1 = self.conv3(x1)
225
+ if self.sparse:
226
+ x1 *= _get_active_ex_or_ii(H=x1.shape[2], W=x1.shape[3], D=x1.shape[4], returning_active_ex=True)
227
+ if self.do_res:
228
+ x1 = x + x1
229
+ return x1
230
+
231
+
232
+ class MedNeXtDownBlock(MedNeXtBlock):
233
+
234
+ def __init__(self, in_channels, out_channels, exp_r=4, kernel_size=7,
235
+ do_res=False, sparse=False):
236
+
237
+ super().__init__(in_channels, out_channels, exp_r, kernel_size,
238
+ do_res=False, sparse=sparse)
239
+
240
+ self.resample_do_res = do_res
241
+ if do_res:
242
+ self.res_conv = nn.Conv3d(
243
+ in_channels=in_channels,
244
+ out_channels=out_channels,
245
+ kernel_size=1,
246
+ stride=2
247
+ )
248
+
249
+ self.conv1 = nn.Conv3d(
250
+ in_channels=in_channels,
251
+ out_channels=in_channels,
252
+ kernel_size=kernel_size,
253
+ stride=2,
254
+ padding=kernel_size // 2,
255
+ groups=in_channels,
256
+ )
257
+
258
+ def forward(self, x, dummy_tensor=None):
259
+
260
+ x1 = super().forward(x)
261
+ if self.resample_do_res:
262
+ res = self.res_conv(x)
263
+ x1 = x1 + res
264
+
265
+ return x1
266
+
267
+
268
+ class UnetResBlock(nn.Module):
269
+ """
270
+ A skip-connection based module that can be used for DynUNet, based on:
271
+ `Automated Design of Deep Learning Methods for Biomedical Image Segmentation <https://arxiv.org/abs/1904.08128>`_.
272
+ `nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation <https://arxiv.org/abs/1809.10486>`_.
273
+
274
+ Args:
275
+ spatial_dims: number of spatial dimensions.
276
+ in_channels: number of input channels.
277
+ out_channels: number of output channels.
278
+ kernel_size: convolution kernel size.
279
+ stride: convolution stride.
280
+ norm_name: feature normalization type and arguments.
281
+ act_name: activation layer type and arguments.
282
+ dropout: dropout probability.
283
+
284
+ """
285
+
286
+ def __init__(
287
+ self,
288
+ sparse: bool,
289
+ in_channels: int,
290
+ out_channels: int,
291
+ kernel_size: Union[Sequence[int], int],
292
+ stride: Union[Sequence[int], int],
293
+ ):
294
+ super().__init__()
295
+ self.conv1 = nn.Conv3d(
296
+ in_channels,
297
+ out_channels,
298
+ kernel_size=kernel_size,
299
+ stride=stride,
300
+ padding=kernel_size // 2)
301
+ self.conv2 = nn.Conv3d(
302
+ out_channels,
303
+ out_channels,
304
+ kernel_size=kernel_size,
305
+ stride=1,
306
+ padding=kernel_size // 2,
307
+ )
308
+ self.lrelu = nn.LeakyReLU(inplace=True, negative_slope=0.01)
309
+ self.norm1 = SparseConvNeXtLayerNorm(normalized_shape=out_channels, data_format='channels_first', sparse=sparse)
310
+ self.norm2 = SparseConvNeXtLayerNorm(normalized_shape=out_channels, data_format='channels_first', sparse=sparse)
311
+ self.downsample = in_channels != out_channels
312
+ stride_np = np.atleast_1d(stride)
313
+ if not np.all(stride_np == 1):
314
+ self.downsample = True
315
+ if self.downsample:
316
+ self.conv3 = nn.Conv3d(
317
+ in_channels,
318
+ out_channels,
319
+ kernel_size=1,
320
+ stride=stride)
321
+ self.norm3 = SparseConvNeXtLayerNorm(normalized_shape=out_channels, data_format='channels_first', sparse=sparse)
322
+
323
+ def forward(self, inp):
324
+ residual = inp
325
+ out = self.conv1(inp)
326
+ out = self.norm1(out)
327
+ out = self.lrelu(out)
328
+ out = self.conv2(out)
329
+ out = self.norm2(out)
330
+ if hasattr(self, "conv3"):
331
+ residual = self.conv3(residual)
332
+ if hasattr(self, "norm3"):
333
+ residual = self.norm3(residual)
334
+ out += residual
335
+ out = self.lrelu(out)
336
+ return out
337
+
338
+
339
+ class MedNeXtUpBlock(MedNeXtBlock):
340
+
341
+ def __init__(self, in_channels, out_channels, exp_r=4, kernel_size=3,
342
+ do_res=True, sparse=False):
343
+ super().__init__(in_channels, out_channels, exp_r, kernel_size,
344
+ do_res=False, sparse=sparse)
345
+
346
+ self.resample_do_res = do_res
347
+
348
+ conv = nn.ConvTranspose3d
349
+ if do_res:
350
+ self.res_conv = conv(
351
+ in_channels=in_channels,
352
+ out_channels=out_channels,
353
+ kernel_size=1,
354
+ stride=2
355
+ )
356
+
357
+ self.conv1 = conv(
358
+ in_channels=in_channels,
359
+ out_channels=in_channels,
360
+ kernel_size=kernel_size,
361
+ stride=2,
362
+ padding=kernel_size // 2,
363
+ groups=in_channels,
364
+ )
365
+
366
+ def forward(self, x, dummy_tensor=None):
367
+
368
+ x1 = super().forward(x)
369
+ # Asymmetry but necessary to match shape
370
+ x1 = torch.nn.functional.pad(x1, (1, 0, 1, 0, 1, 0))
371
+
372
+ if self.resample_do_res:
373
+ res = self.res_conv(x)
374
+ res = torch.nn.functional.pad(res, (1, 0, 1, 0, 1, 0))
375
+ x1 = x1 + res
376
+ return x1
377
+
378
+
379
+ class UnetOutBlock(nn.Module):
380
+ def __init__(self, in_channels: int, n_classes: int):
381
+ super().__init__()
382
+ self.conv = nn.Conv3d(
383
+ in_channels,
384
+ n_classes,
385
+ kernel_size=1,
386
+ stride=1,
387
+ bias=True,
388
+ )
389
+
390
+ def forward(self, inp):
391
+ return self.conv(inp)
392
+
393
+
394
+ class Embeddings(nn.Module):
395
+ def __init__(self,
396
+ in_channel: int = 3,
397
+ channels: Tuple = (32, 64, 96, 128, 192),
398
+ depths: Tuple = (1, 1, 3, 1, 1),
399
+ kernels: Tuple = (3, 3, 3, 3, 3),
400
+ exp_r: Tuple = (2, 4, 4, 4, 2),
401
+ sparse=True):
402
+ super(Embeddings, self).__init__()
403
+ self.dim = [channels[1], channels[2], channels[3], channels[4], channels[4]]
404
+ self.stem = nn.Conv3d(in_channels=in_channel, out_channels=channels[0], kernel_size=3, stride=1, padding=1)
405
+
406
+
407
+ self.layer2 = nn.Sequential(*[
408
+ MedNeXtBlock(
409
+ in_channels=channels[1],
410
+ out_channels=channels[1],
411
+ exp_r=exp_r[1],
412
+ kernel_size=kernels[1],
413
+ do_res=True,
414
+ sparse=sparse
415
+ )
416
+ for i in range(depths[1])])
417
+
418
+ self.layer3 = nn.Sequential(*[
419
+ MedNeXtBlock(
420
+ in_channels=channels[2],
421
+ out_channels=channels[2],
422
+ exp_r=exp_r[2],
423
+ kernel_size=kernels[2],
424
+ do_res=True,
425
+ sparse=sparse
426
+ )
427
+ for i in range(depths[2])])
428
+
429
+ self.layer4 = nn.Sequential(*[
430
+ MedNeXtBlock(
431
+ in_channels=channels[3],
432
+ out_channels=channels[3],
433
+ exp_r=exp_r[3],
434
+ kernel_size=kernels[3],
435
+ do_res=True,
436
+ sparse=sparse
437
+ )
438
+ for i in range(depths[3])])
439
+
440
+ self.layer5 = nn.Sequential(*[
441
+ MedNeXtBlock(
442
+ in_channels=channels[4],
443
+ out_channels=channels[4],
444
+ exp_r=exp_r[4],
445
+ kernel_size=kernels[4],
446
+ do_res=True,
447
+ sparse=sparse
448
+ )
449
+ for i in range(depths[4])])
450
+
451
+ self.down = nn.MaxPool3d((2, 2, 2))
452
+ self.expend1 = nn.Conv3d(in_channels=channels[0], out_channels=channels[1], kernel_size=3, stride=1, padding=1)
453
+ self.expend2 = nn.Conv3d(in_channels=channels[1], out_channels=channels[2], kernel_size=3, stride=1, padding=1)
454
+ self.expend3 = nn.Conv3d(in_channels=channels[2], out_channels=channels[3], kernel_size=3, stride=1, padding=1)
455
+ self.expend4 = nn.Conv3d(in_channels=channels[3], out_channels=channels[4], kernel_size=3, stride=1, padding=1)
456
+
457
+ self.encoder1 = UnetResBlock(
458
+ in_channels=channels[1],
459
+ out_channels=channels[1],
460
+ kernel_size=3,
461
+ stride=1,
462
+ sparse=sparse
463
+ )
464
+ self.encoder2 = UnetResBlock(
465
+ in_channels=channels[2],
466
+ out_channels=channels[2],
467
+ kernel_size=3,
468
+ stride=1,
469
+ sparse=sparse
470
+ )
471
+ self.encoder3 = UnetResBlock(
472
+ in_channels=channels[3],
473
+ out_channels=channels[3],
474
+ kernel_size=3,
475
+ stride=1,
476
+ sparse=sparse
477
+ )
478
+ self.encoder4 = UnetResBlock(
479
+ in_channels=channels[4],
480
+ out_channels=channels[4],
481
+ kernel_size=3,
482
+ stride=1,
483
+ sparse=sparse
484
+ )
485
+
486
+
487
+
488
+ def forward(self, x):
489
+ x = self.stem(x)
490
+
491
+ x1 = self.expend1(x)
492
+
493
+ x = self.down(x1)
494
+ x = self.layer2(x)
495
+ x2 = self.expend2(x)
496
+
497
+ x = self.down(x2)
498
+ x = self.layer3(x)
499
+ x3 = self.expend3(x)
500
+
501
+ x = self.down(x3)
502
+ x = self.layer4(x)
503
+ x4 = self.expend4(x)
504
+
505
+ x = self.down(x4)
506
+ x5 = self.layer5(x)
507
+
508
+ return self.encoder1(x1), self.encoder2(x2), self.encoder3(x3), self.encoder4(x4), x5
509
+
510
+
511
+ class Encoder(nn.Module):
512
+
513
+ def __init__(self,
514
+ in_channel: int = 1,
515
+ channels=(32, 64, 128, 192, 384),
516
+ depths=(1, 2, 2, 2, 1),
517
+ kernels=(3, 3, 3, 3, 3),
518
+ exp_r=(2, 2, 4, 4, 4),
519
+ img_size=96,
520
+ depth=4,
521
+ norm_layer=nn.LayerNorm,
522
+ sparse=False):
523
+ super(Encoder, self).__init__()
524
+ self.dim = [channels[1], channels[2], channels[3], channels[4], channels[4]]
525
+
526
+ self.embeddings = Embeddings(in_channel=in_channel,
527
+ channels=channels,
528
+ depths=depths,
529
+ kernels=kernels,
530
+ exp_r=exp_r,
531
+ sparse=sparse)
532
+
533
+ self.mae = MaskedAutoencoderMamba(
534
+ img_size=img_size,
535
+ downsample_rato=self.get_downsample_ratio(),
536
+ embed_dim=channels[-1],
537
+ depth=depth,
538
+ norm_layer=norm_layer,
539
+ sparse=sparse)
540
+
541
+ def get_downsample_ratio(self) -> int:
542
+ """
543
+ This func would ONLY be used in `SparseEncoder's __init__` (see `pretrain/encoder.py`).
544
+
545
+ :return: the TOTAL downsample ratio of the ConvNet.
546
+ E.g., for a ResNet-50, this should return 32.
547
+ """
548
+ return 16
549
+
550
+ def get_feature_map_channels(self) -> List[int]:
551
+ """
552
+ This func would ONLY be used in `SparseEncoder's __init__` (see `pretrain/encoder.py`).
553
+
554
+ :return: a list of the number of channels of each feature map.
555
+ E.g., for a ResNet-50, this should return [256, 512, 1024, 2048].
556
+ """
557
+ return self.dim
558
+
559
+ def forward(self, x, active_b1fff=None):
560
+ x1, x2, x3, x4, x5 = self.embeddings(x)
561
+ _x5 = self.mae(x5, active_b1fff)
562
+ return x1, x2, x3, x4, _x5
563
+
564
+
565
+ class Decoder(nn.Module):
566
+ def __init__(self,
567
+ n_classes: int = 3,
568
+ channels: Tuple = (32, 64, 128, 196, 384),
569
+ norm_name = "instance",
570
+ res_block: bool = True):
571
+ super(Decoder, self).__init__()
572
+
573
+ self.decoder5 = UnetrUpBlock(
574
+ spatial_dims=3,
575
+ in_channels=channels[4],
576
+ out_channels=channels[4],
577
+ kernel_size=3,
578
+ upsample_kernel_size=2,
579
+ norm_name=norm_name,
580
+ res_block=res_block,
581
+ )
582
+ self.decoder4 = UnetrUpBlock(
583
+ spatial_dims=3,
584
+ in_channels=channels[4],
585
+ out_channels=channels[3],
586
+ kernel_size=3,
587
+ upsample_kernel_size=2,
588
+ norm_name=norm_name,
589
+ res_block=res_block,
590
+ )
591
+ self.decoder3 = UnetrUpBlock(
592
+ spatial_dims=3,
593
+ in_channels=channels[3],
594
+ out_channels=channels[2],
595
+ kernel_size=3,
596
+ upsample_kernel_size=2,
597
+ norm_name=norm_name,
598
+ res_block=res_block,
599
+ )
600
+ self.decoder2 = UnetrUpBlock(
601
+ spatial_dims=3,
602
+ in_channels=channels[2],
603
+ out_channels=channels[1],
604
+ kernel_size=3,
605
+ upsample_kernel_size=2,
606
+ norm_name=norm_name,
607
+ res_block=res_block,
608
+ )
609
+ self.decoder1 = UnetResBlock(
610
+ in_channels=channels[1],
611
+ out_channels=channels[0],
612
+ kernel_size=3,
613
+ stride=1,
614
+ sparse=False
615
+ )
616
+ self.out = UnetOutBlock(in_channels=channels[0], n_classes=n_classes)
617
+
618
+ def forward(self, x1, x2, x3, x4, x5):
619
+ d4 = self.decoder5(x5, x4)
620
+ d3 = self.decoder4(d4, x3)
621
+ d2 = self.decoder3(d3, x2)
622
+ d1 = self.decoder2(d2, x1)
623
+ d0 = self.decoder1(d1)
624
+ return self.out(d0)
625
+
626
+
627
+ class Hybird(nn.Module):
628
+ def __init__(self,
629
+ in_channel: int = 3,
630
+ n_classes: int = 3,
631
+ channels: Tuple = (32, 64, 96, 128, 192),
632
+ depths: Tuple = (1, 1, 3, 3, 1),
633
+ kernels: Tuple = (3, 3, 3, 3, 3),
634
+ exp_r: Tuple = (2, 4, 4, 4, 2),
635
+ img_size=96,
636
+ depth=3,
637
+ norm_layer=nn.LayerNorm, ):
638
+ super().__init__()
639
+ self.embeddings = Embeddings(in_channel=in_channel,
640
+ channels=channels,
641
+ depths=depths,
642
+ kernels=kernels,
643
+ exp_r=exp_r,
644
+ sparse=False)
645
+
646
+ self.mae = MaskedAutoencoderMamba(
647
+ img_size=img_size,
648
+ downsample_rato=16,
649
+ embed_dim=channels[-1],
650
+ depth=depth,
651
+ norm_layer=norm_layer,
652
+ sparse=False)
653
+
654
+ self.decoder = Decoder(
655
+ n_classes=n_classes,
656
+ channels=channels,
657
+ )
658
+
659
+ def forward(self, x):
660
+ x1, x2, x3, x4, x5 = self.embeddings(x)
661
+ x5 = self.mae(x5, None)
662
+ return self.decoder(x1, x2, x3, x4, x5)
663
+
664
+
665
+ def build_hybird(in_channel=1, n_classes=14, img_size=96):
666
+ return Hybird(in_channel=in_channel,
667
+ n_classes=n_classes,
668
+ channels=(32, 64, 128, 192, 384),
669
+ depths=(1, 2, 2, 2, 1),
670
+ kernels=(3, 3, 3, 3, 3),
671
+ exp_r=(2, 2, 4, 4, 4),
672
+ img_size=img_size,
673
+ depth=4)
674
+
675
+
676
+ if __name__ == '__main__':
677
+ x = torch.rand((1, 1, 96, 96, 96))
678
+ network = build_hybird()
679
+ print(network(x).shape)
680
+
681
+
utils/arg_util.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import sys
4
+
5
+ from tap import Tap
6
+
7
+ import dist
8
+
9
+
10
+ class Args(Tap):
11
+ # environment
12
+ exp_name: str = 'mamba'
13
+ exp_dir: str = '' # will be created if not exists
14
+ data_path: str = ''
15
+ init_weight: str = '' # use some checkpoint as model weight initialization; ONLY load model weights
16
+ resume_from: str = '' # resume the experiment from some checkpoint.pth; load model weights, optimizer states, and last epoch
17
+
18
+ # MambaMIM hyperparameters
19
+ mask: float = 0.75 # mask ratio, should be in (0, 1)
20
+
21
+ # encoder hyperparameters
22
+ model: str = 'mambamim'
23
+ input_size: int = 96
24
+ sbn: bool = True
25
+
26
+ # data hyperparameters
27
+ bs: int = 1
28
+ dataloader_workers: int = 8
29
+
30
+ # pre-training hyperparameters
31
+ dp: float = 0.0
32
+ base_lr: float = 1e-4
33
+ wd: float = 0.04
34
+ wde: float = 0.2
35
+ ep: int = 100
36
+ wp_ep: int = 40
37
+ clip: int = 5.
38
+ opt: str = 'adamw'
39
+ ada: float = 0.
40
+
41
+ # NO NEED TO SPECIFIED; each of these args would be updated in runtime automatically
42
+ lr: float = 1e-4
43
+ batch_size_per_gpu: int = 0
44
+ glb_batch_size: int = 0
45
+ densify_norm: str = ''
46
+ device: str = 'gpu'
47
+ local_rank: int = 0
48
+ cmd: str = ' '.join(sys.argv[1:])
49
+ commit_id: str = os.popen(f'git rev-parse HEAD').read().strip() or '[unknown]'
50
+ commit_msg: str = (os.popen(f'git log -1').read().strip().splitlines() or ['[unknown]'])[-1].strip()
51
+ last_loss: float = 0.
52
+ cur_ep: str = ''
53
+ remain_time: str = ''
54
+ finish_time: str = ''
55
+ first_logging: bool = True
56
+ log_txt_name: str = '{args.exp_dir}/pretrain_log.txt'
57
+ tb_lg_dir: str = '' # tensorboard log directory
58
+
59
+ @property
60
+ def is_convnext(self):
61
+ return 'convnext' in self.model or 'cnx' in self.model
62
+
63
+ @property
64
+ def is_resnet(self):
65
+ return 'resnet' in self.model
66
+
67
+ def log_epoch(self):
68
+ if not dist.is_local_master():
69
+ return
70
+
71
+ if self.first_logging:
72
+ self.first_logging = False
73
+ with open(self.log_txt_name, 'w') as fp:
74
+ json.dump({
75
+ 'name': self.exp_name, 'cmd': self.cmd, 'git_commit_id': self.commit_id, 'git_commit_msg': self.commit_msg,
76
+ 'model': self.model,
77
+ }, fp)
78
+ fp.write('\n\n')
79
+
80
+ with open(self.log_txt_name, 'a') as fp:
81
+ json.dump({
82
+ 'cur_ep': self.cur_ep,
83
+ 'last_L': self.last_loss,
84
+ 'rema': self.remain_time, 'fini': self.finish_time,
85
+ }, fp)
86
+ fp.write('\n')
87
+
88
+
89
+ def init_dist_and_get_args():
90
+ from utils import misc
91
+
92
+ # initialize
93
+ args = Args(explicit_bool=True).parse_args()
94
+ e = os.path.abspath(args.exp_dir)
95
+ d, e = os.path.dirname(e), os.path.basename(e)
96
+ e = ''.join(ch if (ch.isalnum() or ch == '-') else '_' for ch in e)
97
+ args.exp_dir = os.path.join(d, e)
98
+
99
+ os.makedirs(args.exp_dir, exist_ok=True)
100
+ args.log_txt_name = os.path.join(args.exp_dir, 'pretrain_log.txt')
101
+ args.tb_lg_dir = args.tb_lg_dir or os.path.join(args.exp_dir, 'tensorboard_log')
102
+ try:
103
+ os.makedirs(args.tb_lg_dir, exist_ok=True)
104
+ except:
105
+ pass
106
+
107
+ misc.init_distributed_environ(exp_dir=args.exp_dir)
108
+
109
+ # update args
110
+ if not dist.initialized():
111
+ args.sbn = False
112
+ args.first_logging = True
113
+ args.device = dist.get_device()
114
+ args.batch_size_per_gpu = args.bs // dist.get_world_size()
115
+ args.glb_batch_size = args.batch_size_per_gpu * dist.get_world_size()
116
+
117
+
118
+ args.ada = args.ada or 0.999
119
+ args.densify_norm = 'ln'
120
+
121
+ args.opt = args.opt.lower()
122
+ args.lr = args.base_lr
123
+ args.wde = args.wde or args.wd
124
+
125
+ return args
utils/lamb.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ PyTorch Lamb optimizer w/ behaviour similar to NVIDIA FusedLamb
2
+ This optimizer code was adapted from the following (starting with latest)
3
+ * https://github.com/HabanaAI/Model-References/blob/2b435114fe8e31f159b1d3063b8280ae37af7423/PyTorch/nlp/bert/pretraining/lamb.py
4
+ * https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py
5
+ * https://github.com/cybertronai/pytorch-lamb
6
+ Use FusedLamb if you can (GPU). The reason for including this variant of Lamb is to have a version that is
7
+ similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install/use APEX.
8
+ In addition to some cleanup, this Lamb impl has been modified to support PyTorch XLA and has been tested on TPU.
9
+ Original copyrights for above sources are below.
10
+ Modifications Copyright 2021 Ross Wightman
11
+ """
12
+ import math
13
+
14
+ import torch
15
+ from torch.optim.optimizer import Optimizer
16
+
17
+
18
+ class TheSameAsTimmLAMB(Optimizer):
19
+ """Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB
20
+ reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py
21
+
22
+ LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
23
+
24
+ Arguments:
25
+ params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
26
+ lr (float, optional): learning rate. (default: 1e-3)
27
+ betas (Tuple[float, float], optional): coefficients used for computing
28
+ running averages of gradient and its norm. (default: (0.9, 0.999))
29
+ eps (float, optional): term added to the denominator to improve
30
+ numerical stability. (default: 1e-8)
31
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
32
+ grad_averaging (bool, optional): whether apply (1-beta2) to grad when
33
+ calculating running averages of gradient. (default: True)
34
+ max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0)
35
+ trust_clip (bool): enable LAMBC trust ratio clipping (default: False)
36
+ always_adapt (boolean, optional): Apply adaptive learning rate to 0.0
37
+ weight decay parameter (default: False)
38
+
39
+ .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
40
+ https://arxiv.org/abs/1904.00962
41
+ .. _On the Convergence of Adam and Beyond:
42
+ https://openreview.net/forum?id=ryQu7f-RZ
43
+ """
44
+
45
+ def __init__(
46
+ self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-6,
47
+ weight_decay=0.01, grad_averaging=True, max_grad_norm=2.0, trust_clip=False, always_adapt=False):
48
+ defaults = dict(
49
+ lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay,
50
+ grad_averaging=grad_averaging, max_grad_norm=max_grad_norm,
51
+ trust_clip=trust_clip, always_adapt=always_adapt)
52
+ super().__init__(params, defaults)
53
+ print(f'[lamb1] max_grad_norm={max_grad_norm}')
54
+ self.global_grad_norm = 0
55
+
56
+ @torch.no_grad()
57
+ def step(self, closure=None):
58
+ """Performs a single optimization step.
59
+ Arguments:
60
+ closure (callable, optional): A closure that reevaluates the model
61
+ and returns the loss.
62
+ """
63
+ loss = None
64
+ if closure is not None:
65
+ with torch.enable_grad():
66
+ loss = closure()
67
+
68
+ device = self.param_groups[0]['params'][0].device
69
+ one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
70
+ global_grad_norm = torch.zeros(1, device=device)
71
+ for group in self.param_groups:
72
+ for p in group['params']:
73
+ if p.grad is None:
74
+ continue
75
+ grad = p.grad
76
+ if grad.is_sparse:
77
+ raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')
78
+ global_grad_norm.add_(grad.pow(2).sum())
79
+
80
+ global_grad_norm = torch.sqrt(global_grad_norm)
81
+ self.global_grad_norm = global_grad_norm.item()
82
+ max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], device=device)
83
+ clip_global_grad_norm = 1 / torch.where(
84
+ global_grad_norm > max_grad_norm,
85
+ global_grad_norm / max_grad_norm,
86
+ one_tensor)
87
+
88
+ for group in self.param_groups:
89
+ bias_correction = 1 if group['bias_correction'] else 0
90
+ beta1, beta2 = group['betas']
91
+ grad_averaging = 1 if group['grad_averaging'] else 0
92
+ beta3 = 1 - beta1 if grad_averaging else 1.0
93
+
94
+ # assume same step across group now to simplify things
95
+ # per parameter step can be easily support by making it tensor, or pass list into kernel
96
+ if 'step' in group:
97
+ group['step'] += 1
98
+ else:
99
+ group['step'] = 1
100
+
101
+ if bias_correction:
102
+ bias_correction1 = 1 - beta1 ** group['step']
103
+ bias_correction2 = 1 - beta2 ** group['step']
104
+ else:
105
+ bias_correction1, bias_correction2 = 1.0, 1.0
106
+
107
+ for p in group['params']:
108
+ if p.grad is None:
109
+ continue
110
+ grad = p.grad.mul_(clip_global_grad_norm)
111
+ state = self.state[p]
112
+
113
+ # State initialization
114
+ if len(state) == 0:
115
+ # Exponential moving average of gradient valuesa
116
+ state['exp_avg'] = torch.zeros_like(p)
117
+ # Exponential moving average of squared gradient values
118
+ state['exp_avg_sq'] = torch.zeros_like(p)
119
+
120
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
121
+
122
+ # Decay the first and second moment running average coefficient
123
+ exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t
124
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # v_t
125
+
126
+ denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
127
+ update = (exp_avg / bias_correction1).div_(denom)
128
+
129
+ weight_decay = group['weight_decay']
130
+ if weight_decay != 0:
131
+ update.add_(p, alpha=weight_decay)
132
+
133
+ if weight_decay != 0 or group['always_adapt']:
134
+ # Layer-wise LR adaptation. By default, skip adaptation on parameters that are
135
+ # excluded from weight decay, unless always_adapt == True, then always enabled.
136
+ w_norm = p.norm(2.0)
137
+ g_norm = update.norm(2.0)
138
+ # FIXME nested where required since logical and/or not working in PT XLA
139
+ trust_ratio = torch.where(
140
+ w_norm > 0,
141
+ torch.where(g_norm > 0, w_norm / g_norm, one_tensor),
142
+ one_tensor,
143
+ )
144
+ if group['trust_clip']:
145
+ # LAMBC trust clipping, upper bound fixed at one
146
+ trust_ratio = torch.minimum(trust_ratio, one_tensor)
147
+ update.mul_(trust_ratio)
148
+
149
+ p.add_(update, alpha=-group['lr'])
150
+
151
+ return loss
utils/lr_control.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from pprint import pformat
3
+
4
+
5
+ def lr_wd_annealing(optimizer, peak_lr, wd, wd_end, cur_it, wp_it, max_it):
6
+ wp_it = round(wp_it)
7
+ if cur_it < wp_it:
8
+ cur_lr = 0.005 * peak_lr + 0.995 * peak_lr * cur_it / wp_it
9
+ else:
10
+ ratio = (cur_it - wp_it) / (max_it - 1 - wp_it)
11
+ cur_lr = 0.001 * peak_lr + 0.999 * peak_lr * (0.5 + 0.5 * math.cos(math.pi * ratio))
12
+
13
+ ratio = cur_it / (max_it - 1)
14
+ cur_wd = wd_end + (wd - wd_end) * (0.5 + 0.5 * math.cos(math.pi * ratio))
15
+
16
+ min_lr, max_lr = cur_lr, cur_lr
17
+ min_wd, max_wd = cur_wd, cur_wd
18
+ for param_group in optimizer.param_groups:
19
+ scaled_lr = param_group['lr'] = cur_lr * param_group.get('lr_scale', 1) # 'lr_scale' could be assigned
20
+ min_lr, max_lr = min(min_lr, scaled_lr), max(max_lr, scaled_lr)
21
+ scaled_wd = param_group['weight_decay'] = cur_wd * param_group.get('weight_decay_scale', 1) # 'weight_decay_scale' could be assigned
22
+ min_wd, max_wd = min(min_wd, scaled_wd), max(max_wd, scaled_wd)
23
+ return min_lr, max_lr, min_wd, max_wd
24
+
25
+
26
+ def get_param_groups(model, nowd_keys=()):
27
+ para_groups, para_groups_dbg = {}, {}
28
+
29
+ for name, para in model.named_parameters():
30
+ if not para.requires_grad:
31
+ continue # frozen weights
32
+ if len(para.shape) == 1 or name.endswith('.bias') or any(k in name for k in nowd_keys):
33
+ wd_scale, group_name = 0., 'no_decay'
34
+ else:
35
+ wd_scale, group_name = 1., 'decay'
36
+
37
+ if group_name not in para_groups:
38
+ para_groups[group_name] = {'params': [], 'weight_decay_scale': wd_scale, 'lr_scale': 1.}
39
+ para_groups_dbg[group_name] = {'params': [], 'weight_decay_scale': wd_scale, 'lr_scale': 1.}
40
+ para_groups[group_name]['params'].append(para)
41
+ para_groups_dbg[group_name]['params'].append(name)
42
+
43
+ for g in para_groups_dbg.values():
44
+ g['params'] = pformat(', '.join(g['params']), width=200)
45
+
46
+ print(f'[get_ft_param_groups] param groups = \n{pformat(para_groups_dbg, indent=2, width=250)}\n')
47
+ return list(para_groups.values())
utils/med_dataset.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from typing import Any, Callable, Optional, Tuple
4
+ from monai import data, transforms as med
5
+ from monai.data import load_decathlon_datalist
6
+ import PIL.Image as PImage
7
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
8
+ from torchvision.datasets.folder import DatasetFolder, IMG_EXTENSIONS
9
+ from torchvision.transforms import transforms
10
+ from torch.utils.data import Dataset
11
+ import torch
12
+ import numpy as np
13
+ import cv2
14
+ try:
15
+ from torchvision.transforms import InterpolationMode
16
+ interpolation = InterpolationMode.BICUBIC
17
+ except:
18
+ import PIL
19
+ interpolation = PIL.Image.BICUBIC
20
+ from monai.transforms.transform import LazyTransform, MapTransform, RandomizableTransform
21
+ import random
22
+
23
+
24
+ def pil_loader(path):
25
+ # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
26
+ with open(path, 'rb') as f: img: PImage.Image = PImage.open(f).convert('RGB')
27
+ return img
28
+
29
+
30
+ class ImageNetDataset(DatasetFolder):
31
+ def __init__(
32
+ self,
33
+ imagenet_folder: str,
34
+ train: bool,
35
+ transform: Callable,
36
+ is_valid_file: Optional[Callable[[str], bool]] = None,
37
+ ):
38
+ imagenet_folder = os.path.join(imagenet_folder, 'train' if train else 'val')
39
+ super(ImageNetDataset, self).__init__(
40
+ imagenet_folder,
41
+ loader=pil_loader,
42
+ extensions=IMG_EXTENSIONS if is_valid_file is None else None,
43
+ transform=transform,
44
+ target_transform=None, is_valid_file=is_valid_file
45
+ )
46
+
47
+ self.samples = tuple(img for (img, label) in self.samples)
48
+ self.targets = None # this is self-supervised learning so we don't need labels
49
+
50
+ def __getitem__(self, index: int) -> Any:
51
+ img_file_path = self.samples[index]
52
+ return self.transform(self.loader(img_file_path))
53
+
54
+
55
+ def build_dataset_to_pretrain(dataset_path, input_size) -> Dataset:
56
+ """
57
+ You may need to modify this function to return your own dataset.
58
+ Define a new class, a subclass of `Dataset`, to replace our ImageNetDataset.
59
+ Use dataset_path to build your image file path list.
60
+ Use input_size to create the transformation function for your images, can refer to the `trans_train` blow.
61
+
62
+ :param dataset_path: the folder of dataset
63
+ :param input_size: the input size (image resolution)
64
+ :return: the dataset used for pretraining
65
+ """
66
+ trans_train = transforms.Compose([
67
+ transforms.RandomResizedCrop(input_size, scale=(0.67, 1.0), interpolation=interpolation),
68
+ transforms.RandomHorizontalFlip(),
69
+ transforms.ToTensor(),
70
+ transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
71
+ ])
72
+
73
+ dataset_path = os.path.abspath(dataset_path)
74
+ for postfix in ('train', 'val'):
75
+ if dataset_path.endswith(postfix):
76
+ dataset_path = dataset_path[:-len(postfix)]
77
+
78
+ dataset_train = ImageNetDataset(imagenet_folder=dataset_path, transform=trans_train, train=True)
79
+ print_transform(trans_train, '[pre-train]')
80
+ return dataset_train
81
+
82
+
83
+ def build_meddataset_to_pretrain(dataset_path, input_size) -> Dataset:
84
+ """
85
+ You may need to modify this function to return your own dataset.
86
+ Define a new class, a subclass of `Dataset`, to replace our ImageNetDataset.
87
+ Use dataset_path to build your image file path list.
88
+ Use input_size to create the transformation function for your images, can refer to the `trans_train` blow.
89
+
90
+ :param dataset_path: the folder of dataset
91
+ :param input_size: the input size (image resolution)
92
+ :return: the dataset used for pretraining
93
+ """
94
+ trans_train = transforms.Compose([
95
+ transforms.RandomResizedCrop(input_size, scale=(0.67, 1.0), interpolation=interpolation),
96
+ transforms.RandomHorizontalFlip(),
97
+ transforms.ToTensor(),
98
+ transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
99
+ ])
100
+
101
+ dataset_path = os.path.abspath(dataset_path)
102
+
103
+
104
+ dataset_train = MedicalDataSets(base_dir=dataset_path, transform=trans_train)
105
+ print_transform(trans_train, '[pre-train]')
106
+ return dataset_train
107
+
108
+
109
+
110
+ class MedicalDataSets(Dataset):
111
+ def __init__(
112
+ self,
113
+ base_dir=None,
114
+ transform=None,
115
+ ):
116
+ self._base_dir = base_dir
117
+ self.sample_list = []
118
+ self.sample_list = os.listdir(self._base_dir)
119
+ self.transform = transform
120
+ print("total {}".format(len(self.sample_list)))
121
+
122
+ def __len__(self):
123
+ return len(self.sample_list)
124
+
125
+ def __getitem__(self, idx):
126
+ case = self.sample_list[idx]
127
+ img = PImage.open(os.path.join(self._base_dir, case)).convert('RGB')
128
+ aug = self.transform(img)
129
+ return aug
130
+
131
+ def print_transform(transform, s):
132
+ print(f'Transform {s} = ')
133
+ for t in transform.transforms:
134
+ print(t)
135
+ print('---------------------------\n')
136
+
137
+
138
+ class Sampler(torch.utils.data.Sampler):
139
+ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, make_even=True):
140
+ if num_replicas is None:
141
+ if not torch.distributed.is_available():
142
+ raise RuntimeError("Requires distributed package to be available")
143
+ num_replicas = torch.distributed.get_world_size()
144
+ if rank is None:
145
+ if not torch.distributed.is_available():
146
+ raise RuntimeError("Requires distributed package to be available")
147
+ rank = torch.distributed.get_rank()
148
+ self.shuffle = shuffle
149
+ self.make_even = make_even
150
+ self.dataset = dataset
151
+ self.num_replicas = num_replicas
152
+ self.rank = rank
153
+ self.epoch = 0
154
+ self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
155
+ self.total_size = self.num_samples * self.num_replicas
156
+ indices = list(range(len(self.dataset)))
157
+ self.valid_length = len(indices[self.rank : self.total_size : self.num_replicas])
158
+
159
+ def __iter__(self):
160
+ if self.shuffle:
161
+ g = torch.Generator()
162
+ g.manual_seed(self.epoch)
163
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
164
+ else:
165
+ indices = list(range(len(self.dataset)))
166
+ if self.make_even:
167
+ if len(indices) < self.total_size:
168
+ if self.total_size - len(indices) < len(indices):
169
+ indices += indices[: (self.total_size - len(indices))]
170
+ else:
171
+ extra_ids = np.random.randint(low=0, high=len(indices), size=self.total_size - len(indices))
172
+ indices += [indices[ids] for ids in extra_ids]
173
+ assert len(indices) == self.total_size
174
+ indices = indices[self.rank : self.total_size : self.num_replicas]
175
+ self.num_samples = len(indices)
176
+ return iter(indices)
177
+
178
+ def __len__(self):
179
+ return self.num_samples
180
+
181
+ def set_epoch(self, epoch):
182
+ self.epoch = epoch
183
+
184
+ class RandScaleCropdPlusScaleByMidDimSampled(MapTransform):
185
+ def __init__(self, keys, mode='area', max_size=128,allow_missing_keys=False,num_samples=4,max_radio=0.8,min_radio=0.5):
186
+ self.keys = keys
187
+ self.mode = mode
188
+ self.allow_missing_keys = allow_missing_keys
189
+ self.max_size=max_size
190
+ self.num_samples = num_samples
191
+ self.max_radio=max_radio
192
+ self.min_radio=min_radio
193
+
194
+ def __call__(self, data):
195
+ outputs = []
196
+ for i in range(self.num_samples):
197
+ random_number = round(random.uniform(self.min_radio, self.max_radio), 2)
198
+ _data = dict(data)
199
+ for key in self.keys:
200
+ cropper= med.RandScaleCropd(keys=[key],roi_scale=random_number)
201
+ _data[key] = cropper(_data)[key]
202
+ ct_tensor = _data[key]
203
+ sorted_numbers = sorted(ct_tensor.shape[1:])
204
+ scale_factor = self.max_size / sorted_numbers[1]
205
+ new_size = [int(d * scale_factor)
206
+ for d in ct_tensor.shape[1:]]
207
+
208
+ resizer = med.Resized(keys=[key],
209
+ spatial_size=new_size,
210
+ mode=self.mode,
211
+ allow_missing_keys=self.allow_missing_keys)
212
+ _data[key] = resizer(_data)[key]
213
+
214
+ outputs.append(_data)
215
+
216
+ return outputs
217
+
218
+
219
+
220
+
221
+ def get_loader(data_dir, size):
222
+ datalist_json = os.path.join(data_dir, "dataset.json")
223
+ train_transform = med.Compose(
224
+ [
225
+ med.LoadImaged(keys=["image"], allow_missing_keys=True),
226
+ med.AddChanneld(keys=["image"], allow_missing_keys=True),
227
+ med.Orientationd(keys=["image"], axcodes="RAS", allow_missing_keys=True),
228
+ med.Spacingd(keys=["image"], pixdim=(1.5, 1.5, 1.5), mode="bilinear", allow_missing_keys=True),
229
+ med.ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
230
+ med.CropForegroundd(keys=["image"], source_key="image", allow_missing_keys=True),
231
+ med.SpatialPadd(keys=["image"], spatial_size=(size, size, size), mode='constant'),
232
+ med.RandCropByPosNegLabeld(
233
+ spatial_size=(size, size, size),
234
+ keys=["image"],
235
+ label_key="image",
236
+ pos=1,
237
+ neg=0,
238
+ num_samples=4,
239
+ ),
240
+ med.RandFlipd(keys=["image"],
241
+ prob=0.2,
242
+ spatial_axis=0),
243
+ med.RandFlipd(keys=["image"],
244
+ prob=0.2,
245
+ spatial_axis=1),
246
+ med.RandFlipd(keys=["image"],
247
+ prob=0.1,
248
+ spatial_axis=2),
249
+ med.ToTensord(keys=["image"]),
250
+ ])
251
+ # val_transform = transforms.Compose(
252
+ # [
253
+ # transforms.LoadImaged(keys=["image", "label"]),
254
+ # transforms.AddChanneld(keys=["image", "label"]),
255
+ # transforms.Orientationd(keys=["image", "label"], axcodes="RAS"),
256
+ # transforms.Spacingd(
257
+ # keys=["image", "label"], pixdim=(1, 1, 1), mode=("bilinear", "nearest")
258
+ # ),
259
+ # transforms.ScaleIntensityRanged(
260
+ # keys=["image"], a_min=-175.0, a_max=250.0, b_min=0.0, b_max=1.0, clip=True
261
+ # ),
262
+ # transforms.CropForegroundd(keys=["image", "label"], source_key="image"),
263
+ # transforms.ToTensord(keys=["image", "label"]),
264
+ # ]
265
+ # )
266
+
267
+ datalist = load_decathlon_datalist(datalist_json, True, "training", base_dir=data_dir)
268
+ # train_ds = data.Dataset(data=datalist, transform=train_transform)
269
+ # train_ds = data.CacheDataset(data=datalist, transform=train_transform)
270
+ # train_ds = data.SmartCacheDataset(data=datalist, transform=train_transform, replace_rate=0.7, cache_num=256, num_init_workers=4, num_replace_workers=4)
271
+ train_ds= data.CacheNTransDataset(data=datalist, transform=train_transform, cache_n_trans=6, cache_dir="/fenghetang/3d/pretrain/MM/cache_dataset")
272
+ return train_ds
273
+
274
+
275
+
utils/misc.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import functools
3
+ import os
4
+ import subprocess
5
+ import sys
6
+ import time
7
+ from collections import defaultdict, deque
8
+ from typing import Iterator
9
+
10
+ import numpy as np
11
+ import pytz
12
+ import torch
13
+ from torch.utils.tensorboard import SummaryWriter
14
+
15
+ import dist
16
+
17
+ os_system = functools.partial(subprocess.call, shell=True)
18
+ os_system_get_stdout = lambda cmd: subprocess.run(cmd, shell=True, stdout=subprocess.PIPE).stdout.decode('utf-8')
19
+
20
+
21
+ def os_system_get_stdout_stderr(cmd):
22
+ sp = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
23
+ return sp.stdout.decode('utf-8'), sp.stderr.decode('utf-8')
24
+
25
+
26
+ def is_pow2n(x):
27
+ return x > 0 and ((x - 1) & x == 0)
28
+
29
+
30
+ def time_str(for_dirname=False):
31
+ return datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime(
32
+ '%m-%d_%H-%M-%S' if for_dirname else '[%m-%d %H:%M:%S]')
33
+
34
+
35
+ def init_distributed_environ(exp_dir):
36
+ dist.initialize()
37
+ dist.barrier()
38
+
39
+ import torch.backends.cudnn as cudnn
40
+ cudnn.benchmark = True
41
+ cudnn.deterministic = False
42
+
43
+ _set_print_only_on_master_proc(is_master=dist.is_local_master())
44
+ if dist.is_local_master() and len(exp_dir):
45
+ sys.stdout, sys.stderr = _SyncPrintToFile(exp_dir, stdout=True), _SyncPrintToFile(exp_dir, stdout=False)
46
+
47
+
48
+ def _set_print_only_on_master_proc(is_master):
49
+ import builtins as __builtin__
50
+
51
+ builtin_print = __builtin__.print
52
+
53
+ def prt(msg, *args, **kwargs):
54
+ force = kwargs.pop('force', False)
55
+ clean = kwargs.pop('clean', False)
56
+ deeper = kwargs.pop('deeper', False)
57
+ if is_master or force:
58
+ if not clean:
59
+ f_back = sys._getframe().f_back
60
+ if deeper and f_back.f_back is not None:
61
+ f_back = f_back.f_back
62
+ file_desc = f'{f_back.f_code.co_filename:24s}'[-24:]
63
+ msg = f'{time_str()} ({file_desc}, line{f_back.f_lineno:-4d})=> {msg}'
64
+ builtin_print(msg, *args, **kwargs)
65
+
66
+ __builtin__.print = prt
67
+
68
+
69
+ class _SyncPrintToFile(object):
70
+ def __init__(self, exp_dir, stdout=True):
71
+ self.terminal = sys.stdout if stdout else sys.stderr
72
+ fname = os.path.join(exp_dir, 'stdout_backup.txt' if stdout else 'stderr_backup.txt')
73
+ self.log = open(fname, 'w')
74
+ self.log.flush()
75
+
76
+ def write(self, message):
77
+ self.terminal.write(message)
78
+ self.log.write(message)
79
+ self.log.flush()
80
+
81
+ def flush(self):
82
+ self.terminal.flush()
83
+ self.log.flush()
84
+
85
+
86
+ class TensorboardLogger(object):
87
+ def __init__(self, log_dir, is_master, prefix='pt'):
88
+ self.is_master = is_master
89
+ self.writer = SummaryWriter(log_dir=log_dir) if self.is_master else None
90
+ self.step = 0
91
+ self.prefix = prefix
92
+ self.log_freq = 300
93
+
94
+ def set_step(self, step=None):
95
+ if step is not None:
96
+ self.step = step
97
+ else:
98
+ self.step += 1
99
+
100
+ def get_loggable(self, step=None):
101
+ if step is None: # iter wise
102
+ step = self.step
103
+ loggable = step % self.log_freq == 0
104
+ else: # epoch wise
105
+ loggable = True
106
+ return step, (loggable and self.is_master)
107
+
108
+ def update(self, head='scalar', step=None, **kwargs):
109
+ step, loggable = self.get_loggable(step)
110
+ if loggable:
111
+ head = f'{self.prefix}_{head}'
112
+ for k, v in kwargs.items():
113
+ if v is None:
114
+ continue
115
+ if isinstance(v, torch.Tensor):
116
+ v = v.item()
117
+ assert isinstance(v, (float, int))
118
+ self.writer.add_scalar(head + "/" + k, v, step)
119
+
120
+ def log_distribution(self, tag, values, step=None):
121
+ step, loggable = self.get_loggable(step)
122
+ if loggable:
123
+ if not isinstance(values, torch.Tensor):
124
+ values = torch.tensor(values)
125
+ self.writer.add_histogram(tag=tag, values=values, global_step=step)
126
+
127
+ def log_image(self, tag, img, step=None, dataformats='NCHW'):
128
+ step, loggable = self.get_loggable(step)
129
+ if loggable:
130
+ # img = img.cpu().numpy()
131
+ self.writer.add_image(tag, img, step, dataformats=dataformats)
132
+
133
+ def flush(self):
134
+ if self.is_master: self.writer.flush()
135
+
136
+ def close(self):
137
+ if self.is_master: self.writer.close()
138
+
139
+
140
+ def save_checkpoint_with_meta_info_and_opt_state(save_to, args, epoch, performance_desc, model_without_ddp_state,
141
+ optimizer_state):
142
+ checkpoint_path = os.path.join(args.exp_dir, save_to)
143
+ if dist.is_local_master():
144
+ to_save = {
145
+ 'args': str(args),
146
+ 'input_size': args.input_size,
147
+ 'arch': args.model,
148
+ 'epoch': epoch,
149
+ 'performance_desc': performance_desc,
150
+ 'module': model_without_ddp_state,
151
+ 'optimizer': optimizer_state,
152
+ 'is_pretrain': True,
153
+ }
154
+ torch.save(to_save, checkpoint_path)
155
+
156
+
157
+ def save_checkpoint_model_weights_only(save_to, args, sp_cnn_state):
158
+ checkpoint_path = os.path.join(args.exp_dir, save_to)
159
+ if dist.is_local_master():
160
+ torch.save(sp_cnn_state, checkpoint_path)
161
+
162
+
163
+ def initialize_weight(init_weight: str, model_without_ddp):
164
+ # use some checkpoint as model weight initialization; ONLY load model weights
165
+ if len(init_weight):
166
+ checkpoint = torch.load(init_weight, 'cpu')
167
+ missing, unexpected = model_without_ddp.load_state_dict(checkpoint.get('module', checkpoint), strict=False)
168
+ print(f'[initialize_weight] missing_keys={missing}')
169
+ print(f'[initialize_weight] unexpected_keys={unexpected}')
170
+
171
+
172
+ def load_checkpoint(resume_from: str, model_without_ddp, optimizer):
173
+ # resume the experiment from some checkpoint.pth; load model weights, optimizer states, and last epoch
174
+ if len(resume_from) == 0:
175
+ return 0, '[no performance_desc]'
176
+ print(f'[try to resume from file `{resume_from}`]')
177
+ checkpoint = torch.load(resume_from, map_location='cpu')
178
+
179
+ ep_start, performance_desc = checkpoint.get('epoch', -1) + 1, checkpoint.get('performance_desc',
180
+ '[no performance_desc]')
181
+ missing, unexpected = model_without_ddp.load_state_dict(checkpoint.get('module', checkpoint), strict=False)
182
+ print(f'[load_checkpoint] missing_keys={missing}')
183
+ print(f'[load_checkpoint] unexpected_keys={unexpected}')
184
+ print(f'[load_checkpoint] ep_start={ep_start}, performance_desc={performance_desc}')
185
+
186
+ if 'optimizer' in checkpoint:
187
+ optimizer.load_state_dict(checkpoint['optimizer'])
188
+ return ep_start, performance_desc
189
+
190
+
191
+ class SmoothedValue(object):
192
+ """Track a series of values and provide access to smoothed values over a
193
+ window or the global series average.
194
+ """
195
+
196
+ def __init__(self, window_size=20, fmt=None):
197
+ if fmt is None:
198
+ fmt = "{median:.4f} ({global_avg:.4f})"
199
+ self.deque = deque(maxlen=window_size)
200
+ self.total = 0.0
201
+ self.count = 0
202
+ self.fmt = fmt
203
+
204
+ def update(self, value, n=1):
205
+ self.deque.append(value)
206
+ self.count += n
207
+ self.total += value * n
208
+
209
+ def synchronize_between_processes(self):
210
+ """
211
+ Warning: does not synchronize the deque!
212
+ """
213
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
214
+ dist.barrier()
215
+ dist.allreduce(t)
216
+ t = t.tolist()
217
+ self.count = int(t[0])
218
+ self.total = t[1]
219
+
220
+ @property
221
+ def median(self):
222
+ d = torch.tensor(list(self.deque))
223
+ return d.median().item()
224
+
225
+ @property
226
+ def avg(self):
227
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
228
+ return d.mean().item()
229
+
230
+ @property
231
+ def global_avg(self):
232
+ return self.total / self.count
233
+
234
+ @property
235
+ def max(self):
236
+ return max(self.deque)
237
+
238
+ @property
239
+ def value(self):
240
+ return self.deque[-1]
241
+
242
+ def __str__(self):
243
+ return self.fmt.format(
244
+ median=self.median,
245
+ avg=self.avg,
246
+ global_avg=self.global_avg,
247
+ max=self.max,
248
+ value=self.value)
249
+
250
+
251
+ class MetricLogger(object):
252
+ def __init__(self, delimiter="\t"):
253
+ self.meters = defaultdict(SmoothedValue)
254
+ self.delimiter = delimiter
255
+
256
+ def update(self, **kwargs):
257
+ for k, v in kwargs.items():
258
+ if v is None:
259
+ continue
260
+ if isinstance(v, torch.Tensor):
261
+ v = v.item()
262
+ assert isinstance(v, (float, int))
263
+ self.meters[k].update(v)
264
+
265
+ def __getattr__(self, attr):
266
+ if attr in self.meters:
267
+ return self.meters[attr]
268
+ if attr in self.__dict__:
269
+ return self.__dict__[attr]
270
+ raise AttributeError("'{}' object has no attribute '{}'".format(
271
+ type(self).__name__, attr))
272
+
273
+ def __str__(self):
274
+ loss_str = []
275
+ for name, meter in self.meters.items():
276
+ loss_str.append(
277
+ "{}: {}".format(name, str(meter))
278
+ )
279
+ return self.delimiter.join(loss_str)
280
+
281
+ def synchronize_between_processes(self):
282
+ for meter in self.meters.values():
283
+ meter.synchronize_between_processes()
284
+
285
+ def add_meter(self, name, meter):
286
+ self.meters[name] = meter
287
+
288
+ def log_every(self, max_iters, itrt, print_freq, header=None):
289
+ print_iters = set(np.linspace(0, max_iters - 1, print_freq, dtype=int).tolist())
290
+ if not header:
291
+ header = ''
292
+ start_time = time.time()
293
+ end = time.time()
294
+ self.iter_time = SmoothedValue(fmt='{avg:.4f}')
295
+ self.data_time = SmoothedValue(fmt='{avg:.4f}')
296
+ space_fmt = ':' + str(len(str(max_iters))) + 'd'
297
+ log_msg = [
298
+ header,
299
+ '[{0' + space_fmt + '}/{1}]',
300
+ 'eta: {eta}',
301
+ '{meters}',
302
+ 'iter: {time}s',
303
+ 'data: {data}s'
304
+ ]
305
+ log_msg = self.delimiter.join(log_msg)
306
+
307
+ if isinstance(itrt, Iterator) and not hasattr(itrt, 'preload') and not hasattr(itrt, 'set_epoch'):
308
+ for i in range(max_iters):
309
+ obj = next(itrt)
310
+ self.data_time.update(time.time() - end)
311
+ yield obj
312
+ self.iter_time.update(time.time() - end)
313
+ if i in print_iters:
314
+ eta_seconds = self.iter_time.global_avg * (max_iters - i)
315
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
316
+ print(log_msg.format(
317
+ i, max_iters, eta=eta_string,
318
+ meters=str(self),
319
+ time=str(self.iter_time), data=str(self.data_time)))
320
+ end = time.time()
321
+ else:
322
+ for i, obj in enumerate(itrt):
323
+ self.data_time.update(time.time() - end)
324
+ yield obj
325
+ self.iter_time.update(time.time() - end)
326
+ if i in print_iters:
327
+ eta_seconds = self.iter_time.global_avg * (max_iters - i)
328
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
329
+ print(log_msg.format(
330
+ i, max_iters, eta=eta_string,
331
+ meters=str(self),
332
+ time=str(self.iter_time), data=str(self.data_time)))
333
+ end = time.time()
334
+
335
+ total_time = time.time() - start_time
336
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
337
+ print('{} Total time: {} ({:.3f} s / it)'.format(
338
+ header, total_time_str, total_time / max_iters))
utils/sampler.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch.utils.data.sampler import Sampler
6
+
7
+
8
+ def worker_init_fn(worker_id):
9
+ # https://pytorch.org/docs/stable/notes/randomness.html#dataloader
10
+ worker_seed = torch.initial_seed() % 2 ** 32
11
+ np.random.seed(worker_seed)
12
+ random.seed(worker_seed)
13
+
14
+
15
+ class DistInfiniteBatchSampler(Sampler):
16
+ def __init__(self, world_size, rank, dataset_len, glb_batch_size, seed=1, filling=False, shuffle=True):
17
+ assert glb_batch_size % world_size == 0
18
+ self.world_size, self.rank = world_size, rank
19
+ self.dataset_len = dataset_len
20
+ self.glb_batch_size = glb_batch_size
21
+ self.batch_size = glb_batch_size // world_size
22
+
23
+ self.iters_per_ep = (dataset_len + glb_batch_size - 1) // glb_batch_size
24
+ self.filling = filling
25
+ self.shuffle = shuffle
26
+ self.epoch = 0
27
+ self.seed = seed
28
+ self.indices = self.gener_indices()
29
+
30
+ def gener_indices(self):
31
+ global_max_p = self.iters_per_ep * self.glb_batch_size # global_max_p % world_size must be 0 cuz glb_batch_size % world_size == 0
32
+ if self.shuffle:
33
+ g = torch.Generator()
34
+ g.manual_seed(self.epoch + self.seed)
35
+ global_indices = torch.randperm(self.dataset_len, generator=g)
36
+ else:
37
+ global_indices = torch.arange(self.dataset_len)
38
+ filling = global_max_p - global_indices.shape[0]
39
+ if filling > 0 and self.filling:
40
+ global_indices = torch.cat((global_indices, global_indices[:filling]))
41
+ global_indices = tuple(global_indices.numpy().tolist())
42
+
43
+ seps = torch.linspace(0, len(global_indices), self.world_size + 1, dtype=torch.int)
44
+ local_indices = global_indices[seps[self.rank]:seps[self.rank + 1]]
45
+ self.max_p = len(local_indices)
46
+ return local_indices
47
+
48
+ def __iter__(self):
49
+ self.epoch = 0
50
+ while True:
51
+ self.epoch += 1
52
+ p, q = 0, 0
53
+ while p < self.max_p:
54
+ q = p + self.batch_size
55
+ yield self.indices[p:q]
56
+ p = q
57
+ if self.shuffle:
58
+ self.indices = self.gener_indices()
59
+
60
+ def __len__(self):
61
+ return self.iters_per_ep
62
+
63
+
64
+ if __name__ == '__main__':
65
+ W = 16
66
+ for rk in range(W):
67
+ ind = DistInfiniteBatchSampler(W, rk, 5024, 5024).gener_indices()
68
+ print(rk, len(ind))