Upload 84 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- 2024.09.27/config.yaml +211 -0
- 2024.09.27/language_model/clip_tokenizer_4.16.2/merges.txt +0 -0
- 2024.09.27/language_model/clip_tokenizer_4.16.2/special_tokens_map.json +27 -0
- 2024.09.27/language_model/clip_tokenizer_4.16.2/tokenizer_config.json +38 -0
- 2024.09.27/language_model/clip_tokenizer_4.16.2/vocab.json +0 -0
- 2024.09.27/vision_model/medimageinsigt-v1.0.0.pt +3 -0
- Images/Breast-Cancer-1.jpg +0 -0
- Images/Breast-Cancer-2.jpeg +0 -0
- Images/Peneumonia.png +0 -0
- MedImageInsight/Distributed/Utils.py +344 -0
- MedImageInsight/Distributed/__init__.py +6 -0
- MedImageInsight/Distributed/__pycache__/Utils.cpython-38.pyc +0 -0
- MedImageInsight/Distributed/__pycache__/__init__.cpython-38.pyc +0 -0
- MedImageInsight/ImageDataLoader/__init__.py +8 -0
- MedImageInsight/ImageDataLoader/__pycache__/__init__.cpython-38.pyc +0 -0
- MedImageInsight/ImageDataLoader/__pycache__/blob_storage.cpython-38.pyc +0 -0
- MedImageInsight/ImageDataLoader/__pycache__/build.cpython-38.pyc +0 -0
- MedImageInsight/ImageDataLoader/__pycache__/constants.cpython-38.pyc +0 -0
- MedImageInsight/ImageDataLoader/__pycache__/tsv.cpython-38.pyc +0 -0
- MedImageInsight/ImageDataLoader/__pycache__/tsv_file.cpython-38.pyc +0 -0
- MedImageInsight/ImageDataLoader/__pycache__/zipdata.cpython-38.pyc +0 -0
- MedImageInsight/ImageDataLoader/blob_storage.py +244 -0
- MedImageInsight/ImageDataLoader/build.py +260 -0
- MedImageInsight/ImageDataLoader/constants.py +85 -0
- MedImageInsight/ImageDataLoader/languages/__init__.py +0 -0
- MedImageInsight/ImageDataLoader/languages/__pycache__/__init__.cpython-38.pyc +0 -0
- MedImageInsight/ImageDataLoader/languages/__pycache__/prompt_engineering.cpython-38.pyc +0 -0
- MedImageInsight/ImageDataLoader/languages/prompt_engineering.py +101 -0
- MedImageInsight/ImageDataLoader/transforms/__init__.py +1 -0
- MedImageInsight/ImageDataLoader/transforms/__pycache__/__init__.cpython-38.pyc +0 -0
- MedImageInsight/ImageDataLoader/transforms/__pycache__/autoaugment.cpython-38.pyc +0 -0
- MedImageInsight/ImageDataLoader/transforms/__pycache__/build.cpython-38.pyc +0 -0
- MedImageInsight/ImageDataLoader/transforms/__pycache__/threeaugment.cpython-38.pyc +0 -0
- MedImageInsight/ImageDataLoader/transforms/autoaugment.py +447 -0
- MedImageInsight/ImageDataLoader/transforms/build.py +261 -0
- MedImageInsight/ImageDataLoader/transforms/threeaugment.py +54 -0
- MedImageInsight/ImageDataLoader/tsv.py +351 -0
- MedImageInsight/ImageDataLoader/tsv_file.py +290 -0
- MedImageInsight/ImageDataLoader/zipdata.py +98 -0
- MedImageInsight/ImageEncoder/__init__.py +8 -0
- MedImageInsight/ImageEncoder/__pycache__/__init__.cpython-38.pyc +0 -0
- MedImageInsight/ImageEncoder/__pycache__/build.cpython-38.pyc +0 -0
- MedImageInsight/ImageEncoder/__pycache__/coswin.cpython-38.pyc +0 -0
- MedImageInsight/ImageEncoder/__pycache__/davit_v1.cpython-38.pyc +0 -0
- MedImageInsight/ImageEncoder/__pycache__/registry.cpython-38.pyc +0 -0
- MedImageInsight/ImageEncoder/build.py +13 -0
- MedImageInsight/ImageEncoder/coswin.py +779 -0
- MedImageInsight/ImageEncoder/davit_v1.py +727 -0
- MedImageInsight/ImageEncoder/registry.py +18 -0
- MedImageInsight/LangEncoder/__init__.py +13 -0
2024.09.27/config.yaml
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
##################
|
| 2 |
+
# Trainer settings
|
| 3 |
+
##################
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
TASK: UniCLTask
|
| 7 |
+
|
| 8 |
+
NAME: 'Example Eval Configuration'
|
| 9 |
+
SAVE_TIMER_LOG: true
|
| 10 |
+
|
| 11 |
+
# TUTORIAL STEP 1: CHOOSE SAVE DIR
|
| 12 |
+
SAVE_DIR: ''
|
| 13 |
+
LOG_EVERY: 10
|
| 14 |
+
LOGLEVEL_OVERRIDE: INFO
|
| 15 |
+
LOG_GPU_MEM: true
|
| 16 |
+
RESUME: False
|
| 17 |
+
RESET_DATA_LOADER: false
|
| 18 |
+
|
| 19 |
+
FP16: true
|
| 20 |
+
ZERO_STAGE: 0
|
| 21 |
+
DEEPSPEED: false
|
| 22 |
+
# ZERO_STAGE: 1
|
| 23 |
+
AMP: PYTORCH
|
| 24 |
+
# USE_APEX_DDP: false
|
| 25 |
+
# USE_APEX_AMP: false
|
| 26 |
+
# USE_HIT: false
|
| 27 |
+
|
| 28 |
+
FIND_UNUSED_PARAMETERS: false
|
| 29 |
+
|
| 30 |
+
SAVE_PER_OPTIM_STEPS: 500
|
| 31 |
+
EVAL_PER_OPTIM_STEPS: 250
|
| 32 |
+
EVAL_AT_START: False
|
| 33 |
+
# SAVE_PER_UPDATE_NUM: -1
|
| 34 |
+
# EVAL_PER_UPDATE_NUM: 0 # 0: do evaluation when saving checkpoint, -1: don't do evaluation
|
| 35 |
+
|
| 36 |
+
NO_AUTO_LR_SCALING: true
|
| 37 |
+
GRAD_CLIPPING: 1.0 #0.07
|
| 38 |
+
|
| 39 |
+
SET_SAMPLER_EPOCH: true
|
| 40 |
+
|
| 41 |
+
DONT_LOAD_MODEL: true
|
| 42 |
+
|
| 43 |
+
user_dir: "./MainzVision" # lower case due to it is used in mainz as such
|
| 44 |
+
|
| 45 |
+
##################
|
| 46 |
+
# Task settings
|
| 47 |
+
##################
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
VERBOSE: true
|
| 52 |
+
WORKERS: 6
|
| 53 |
+
PIN_MEMORY: true
|
| 54 |
+
IMAGE_ENCODER:
|
| 55 |
+
NAME: davit_v1
|
| 56 |
+
NUM_CLASSES: 0
|
| 57 |
+
#IMAGE_SIZE: [384, 384]
|
| 58 |
+
IMAGE_SIZE: [480, 480]
|
| 59 |
+
LOAD_PRETRAINED: true
|
| 60 |
+
PRETRAINED: ''
|
| 61 |
+
PRETRAINED_LAYERS: '*'
|
| 62 |
+
IMAGE_MEAN: [0.485, 0.456, 0.406]
|
| 63 |
+
IMAGE_STD: [0.229, 0.224, 0.225]
|
| 64 |
+
SPEC:
|
| 65 |
+
DROP_RATE: 0.1
|
| 66 |
+
DROP_PATH_RATE: 0.2
|
| 67 |
+
PATCH_SIZE: [7, 3, 3, 3]
|
| 68 |
+
PATCH_STRIDE: [4, 2, 2, 2]
|
| 69 |
+
PATCH_PADDING: [3, 1, 1, 1]
|
| 70 |
+
PATCH_PRENORM: [false, true, true, true]
|
| 71 |
+
DIM_EMBED: [256, 512, 1024, 2048]
|
| 72 |
+
NUM_HEADS: [8, 16, 32, 64]
|
| 73 |
+
NUM_GROUPS: [8, 16, 32, 64]
|
| 74 |
+
DEPTHS: [1, 1, 9, 1]
|
| 75 |
+
WINDOW_SIZE: 12
|
| 76 |
+
ENABLE_CHECKPOINT: true
|
| 77 |
+
|
| 78 |
+
LANG_ENCODER:
|
| 79 |
+
NAME: transformer
|
| 80 |
+
LOAD_PRETRAINED: false
|
| 81 |
+
PRETRAINED: ''
|
| 82 |
+
PRETRAINED_LAYERS: '*'
|
| 83 |
+
TOKENIZER: clip
|
| 84 |
+
CONTEXT_LENGTH: 77
|
| 85 |
+
WIDTH: 1024
|
| 86 |
+
HEADS: 16
|
| 87 |
+
LAYERS: 16
|
| 88 |
+
AUTOGRESSIVE: false
|
| 89 |
+
|
| 90 |
+
UNICL_MODEL:
|
| 91 |
+
DIM_PROJECTION: 1024
|
| 92 |
+
GATHER_TENSORS: true
|
| 93 |
+
LOAD_PRETRAINED: true
|
| 94 |
+
|
| 95 |
+
# TUTORIAL STEP 2: CHOOSE MODEL PATH
|
| 96 |
+
PRETRAINED: ''
|
| 97 |
+
|
| 98 |
+
PRETRAINED_LAYERS: '*'
|
| 99 |
+
|
| 100 |
+
AUG:
|
| 101 |
+
MIXUP_PROB: 0.0
|
| 102 |
+
MIXUP: 0.8
|
| 103 |
+
MIXCUT: 1.0
|
| 104 |
+
MIXCUT_MINMAX: []
|
| 105 |
+
MIXUP_SWITCH_PROB: 0.5
|
| 106 |
+
MIXUP_MODE: 'batch'
|
| 107 |
+
SCALE: [0.8, 1.0]
|
| 108 |
+
RATIO: [0.75, 1.3333333]
|
| 109 |
+
INTERPOLATION: 'bicubic'
|
| 110 |
+
TORCHVISION_AUG:
|
| 111 |
+
AUTO_AUGMENT: ta_wide
|
| 112 |
+
RE_PROB: 0.25
|
| 113 |
+
HFLIP: 0.0
|
| 114 |
+
VFLIP: 0.0
|
| 115 |
+
|
| 116 |
+
LOSS:
|
| 117 |
+
LOSS: UniCL
|
| 118 |
+
DATASET:
|
| 119 |
+
DATASET: 'image_text_pairs_v2'
|
| 120 |
+
TEXT_FORMAT: 'json'
|
| 121 |
+
ROOT: ''
|
| 122 |
+
TRAIN_SET: 'mimic_cxr_v2-chestxray14-chexpertv4-irma2009_v2-rsnaboneage-mura-bingmedicalfewshot'
|
| 123 |
+
DATA_FORMAT: 'tsv'
|
| 124 |
+
SAMPLER: 'default'
|
| 125 |
+
LOADER: 'default'
|
| 126 |
+
TOKEN_FILE: ''
|
| 127 |
+
#PROMPT_ENGINEERING: False
|
| 128 |
+
#SAMPLER: 'chunk'
|
| 129 |
+
#LOADER: 'azcopy'
|
| 130 |
+
#TOKEN_FILE: 'cliptrainingpairs.txt'
|
| 131 |
+
#TEST_SET: 'MarsAtrain'
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# TUTORIAL STEP 3: CHOOSE ALL BELOW EVAL PATHS (THESE ARE ALL OPTIONAL EXTRA EVALS)
|
| 135 |
+
# Note how one eval is ZIP format and the other is TSV format.
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
EVALDATASET_LTCXR_S100_N100_TEXT_CLASSIFIER:
|
| 141 |
+
TEXT_FORMAT: json
|
| 142 |
+
FORMAT: 'zip'
|
| 143 |
+
SPLIT: 'NIH-CXR-LT'
|
| 144 |
+
ZIP_FILE: ''
|
| 145 |
+
ZIP_MAP_FILE: ''
|
| 146 |
+
LABEL_FILE: ''
|
| 147 |
+
IMAGE_TSV: ''
|
| 148 |
+
TEXT_TSV: ''
|
| 149 |
+
CWEIGHT_FILE: ''
|
| 150 |
+
ZS_MODE: 2
|
| 151 |
+
ZS_WEIGHT: 1.0
|
| 152 |
+
KNN: 100
|
| 153 |
+
# CLASSIFICATION_SETS: ['NIH-CXR-LT']
|
| 154 |
+
# NUM_CLASSES: [20]
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
# TUTORIAL STEP 4: SET THE DEFAULT ZEROSHOT EVAL (THIS IS THE MANDATORY EVAL)
|
| 160 |
+
|
| 161 |
+
ZEROSHOT_EVAL_DATASET:
|
| 162 |
+
FORMAT: 'zip'
|
| 163 |
+
SPLIT: 'NIH-CXR-LT'
|
| 164 |
+
ZIP_FILE: ''
|
| 165 |
+
ZIP_MAP_FILE: ''
|
| 166 |
+
LABEL_FILE: ''
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
EVALUATION_SPLITS: ['cls-zeroshot-eval']
|
| 171 |
+
TEST:
|
| 172 |
+
BATCH_SIZE_PER_GPU: 8
|
| 173 |
+
MODEL_FILE: ''
|
| 174 |
+
CENTER_CROP: false
|
| 175 |
+
TRAIN:
|
| 176 |
+
BATCH_SIZE_TOTAL: 1024
|
| 177 |
+
BATCH_SIZE_PER_GPU: 16
|
| 178 |
+
|
| 179 |
+
SHUFFLE: true
|
| 180 |
+
|
| 181 |
+
WEIGHT_SMOOTHING:
|
| 182 |
+
decay: 0.999
|
| 183 |
+
use_cpu: False
|
| 184 |
+
eval_smoothed_weight: True
|
| 185 |
+
|
| 186 |
+
START_LEARNING_RATE: 0.00001
|
| 187 |
+
# MAX_NUM_EPOCHS: 2
|
| 188 |
+
MAX_NUM_EPOCHS: 100
|
| 189 |
+
OPTIMIZER: AdamW # adam
|
| 190 |
+
OPTIMIZER_PARAMS:
|
| 191 |
+
weight_decay: 0.2 #0.1
|
| 192 |
+
CUSTOMIZED_PARAMS_CONF:
|
| 193 |
+
NO_WEIGHT_DECAY_MODULES: ['dw', 'norm']
|
| 194 |
+
WEIGHT_DECAY_PATTERNS:
|
| 195 |
+
"\\.bias$": 0.0
|
| 196 |
+
"logit_scale": 0.0
|
| 197 |
+
"positional_embedding": 0.0
|
| 198 |
+
"token_embedding": 0.0
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
LR_SCHEDULER: TimmScheduler
|
| 203 |
+
LR_SCHEDULER_PARAMS:
|
| 204 |
+
sched: cosine
|
| 205 |
+
warmup_steps: 5
|
| 206 |
+
warmup_lr: 0.000000001
|
| 207 |
+
min_lr: 0.000000001
|
| 208 |
+
|
| 209 |
+
# GRADIENT_ACCUMULATE_STEP will be updated by:
|
| 210 |
+
# BATCH_SIZE_TOTAL // (BATCH_SIZE_PER_GPU * world_size)
|
| 211 |
+
GRADIENT_ACCUMULATE_STEP: -1
|
2024.09.27/language_model/clip_tokenizer_4.16.2/merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
2024.09.27/language_model/clip_tokenizer_4.16.2/special_tokens_map.json
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "<|startoftext|>",
|
| 4 |
+
"single_word": false,
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"normalized": true,
|
| 8 |
+
"special": false
|
| 9 |
+
},
|
| 10 |
+
"eos_token": {
|
| 11 |
+
"content": "<|endoftext|>",
|
| 12 |
+
"single_word": false,
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"rstrip": false,
|
| 15 |
+
"normalized": true,
|
| 16 |
+
"special": false
|
| 17 |
+
},
|
| 18 |
+
"unk_token": {
|
| 19 |
+
"content": "<|endoftext|>",
|
| 20 |
+
"single_word": false,
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"rstrip": false,
|
| 23 |
+
"normalized": true,
|
| 24 |
+
"special": false
|
| 25 |
+
},
|
| 26 |
+
"pad_token": "<|endoftext|>"
|
| 27 |
+
}
|
2024.09.27/language_model/clip_tokenizer_4.16.2/tokenizer_config.json
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"errors": "replace",
|
| 3 |
+
"unk_token": {
|
| 4 |
+
"content": "<|endoftext|>",
|
| 5 |
+
"single_word": false,
|
| 6 |
+
"lstrip": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"normalized": true,
|
| 9 |
+
"special": false,
|
| 10 |
+
"__type": "AddedToken"
|
| 11 |
+
},
|
| 12 |
+
"bos_token": {
|
| 13 |
+
"content": "<|startoftext|>",
|
| 14 |
+
"single_word": false,
|
| 15 |
+
"lstrip": false,
|
| 16 |
+
"rstrip": false,
|
| 17 |
+
"normalized": true,
|
| 18 |
+
"special": false,
|
| 19 |
+
"__type": "AddedToken"
|
| 20 |
+
},
|
| 21 |
+
"eos_token": {
|
| 22 |
+
"content": "<|endoftext|>",
|
| 23 |
+
"single_word": false,
|
| 24 |
+
"lstrip": false,
|
| 25 |
+
"rstrip": false,
|
| 26 |
+
"normalized": true,
|
| 27 |
+
"special": false,
|
| 28 |
+
"__type": "AddedToken"
|
| 29 |
+
},
|
| 30 |
+
"pad_token": "<|endoftext|>",
|
| 31 |
+
"add_prefix_space": false,
|
| 32 |
+
"do_lower_case": true,
|
| 33 |
+
"name_or_path": "openai/clip-vit-base-patch32",
|
| 34 |
+
"model_max_length": 77,
|
| 35 |
+
"special_tokens_map_file": "/home/ncodella/.cache/huggingface/transformers/18a566598f286c9139f88160c99f84eec492a26bd22738fa9cb44d5b7e0a5c76.cce1206abbad28826f000510f22f354e53e66a97f7c23745a7dfe27609cc07f5",
|
| 36 |
+
"tokenizer_file": "/home/ncodella/.cache/huggingface/transformers/7811def0c53be25ba790cb67ac785669b508a8d1cf8c912b8ac046c5f08aee68.20428ea8b6821af2719b760af844a371643ff49f255c73285f6ea448e15597fe",
|
| 37 |
+
"tokenizer_class": "CLIPTokenizer"
|
| 38 |
+
}
|
2024.09.27/language_model/clip_tokenizer_4.16.2/vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
2024.09.27/vision_model/medimageinsigt-v1.0.0.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5eeda63bf616a61664bc95b2c09d3b3d7125209e635678bd3f5f324e9bdb1414
|
| 3 |
+
size 2464060700
|
Images/Breast-Cancer-1.jpg
ADDED
|
Images/Breast-Cancer-2.jpeg
ADDED
|
Images/Peneumonia.png
ADDED
|
MedImageInsight/Distributed/Utils.py
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import pickle
|
| 4 |
+
import requests
|
| 5 |
+
import tenacity
|
| 6 |
+
import time
|
| 7 |
+
import shutil
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.distributed as dist
|
| 11 |
+
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from torchvision.utils import make_grid
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
from fvcore.nn import FlopCountAnalysis
|
| 17 |
+
from fvcore.nn import flop_count_table
|
| 18 |
+
from fvcore.nn import flop_count_str
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
NORM_MODULES = [
|
| 23 |
+
torch.nn.BatchNorm1d,
|
| 24 |
+
torch.nn.BatchNorm2d,
|
| 25 |
+
torch.nn.BatchNorm3d,
|
| 26 |
+
torch.nn.SyncBatchNorm,
|
| 27 |
+
# NaiveSyncBatchNorm inherits from BatchNorm2d
|
| 28 |
+
torch.nn.GroupNorm,
|
| 29 |
+
torch.nn.InstanceNorm1d,
|
| 30 |
+
torch.nn.InstanceNorm2d,
|
| 31 |
+
torch.nn.InstanceNorm3d,
|
| 32 |
+
torch.nn.LayerNorm,
|
| 33 |
+
torch.nn.LocalResponseNorm,
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def register_norm_module(cls):
|
| 38 |
+
NORM_MODULES.append(cls)
|
| 39 |
+
|
| 40 |
+
return cls
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def is_main_process():
|
| 44 |
+
rank = 0
|
| 45 |
+
if 'OMPI_COMM_WORLD_SIZE' in os.environ:
|
| 46 |
+
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
|
| 47 |
+
|
| 48 |
+
return rank == 0
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@torch.no_grad()
|
| 52 |
+
def analysis_model(model, dump_input, verbose=False):
|
| 53 |
+
model.eval()
|
| 54 |
+
flops = FlopCountAnalysis(model, dump_input)
|
| 55 |
+
total = flops.total()
|
| 56 |
+
model.train()
|
| 57 |
+
params_total = sum(p.numel() for p in model.parameters())
|
| 58 |
+
params_learned = sum(
|
| 59 |
+
p.numel() for p in model.parameters() if p.requires_grad
|
| 60 |
+
)
|
| 61 |
+
logger.info(f"flop count table:\n {flop_count_table(flops)}")
|
| 62 |
+
if verbose:
|
| 63 |
+
logger.info(f"flop count str:\n {flop_count_str(flops)}")
|
| 64 |
+
logger.info(f" Total flops: {total / 1000 / 1000:.3f}M,")
|
| 65 |
+
logger.info(f" Total params: {params_total / 1000 / 1000:.3f}M,")
|
| 66 |
+
logger.info(f" Learned params: {params_learned / 1000 / 1000:.3f}M")
|
| 67 |
+
|
| 68 |
+
return total, flop_count_table(flops), flop_count_str(flops)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def gather_tensors(tensor):
|
| 72 |
+
"""
|
| 73 |
+
Performs all_gather operation on the provided tensors.
|
| 74 |
+
*** Warning ***: torch.distributed.all_gather has no gradient.
|
| 75 |
+
"""
|
| 76 |
+
tensors_gather = [
|
| 77 |
+
torch.ones_like(tensor)
|
| 78 |
+
for _ in range(int(os.environ['WORLD_SIZE']))
|
| 79 |
+
]
|
| 80 |
+
|
| 81 |
+
dist.all_gather(tensors_gather, tensor, async_op=False)
|
| 82 |
+
# need to do this to restore propagation of the gradients
|
| 83 |
+
tensors_gather[int(os.environ['RANK'])] = tensor
|
| 84 |
+
output = torch.cat(tensors_gather, dim=0)
|
| 85 |
+
return output
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def is_valid_url(url):
|
| 89 |
+
try:
|
| 90 |
+
from urllib import parse
|
| 91 |
+
return parse.urlparse(str(url)).scheme != ''
|
| 92 |
+
except Exception:
|
| 93 |
+
return False
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@tenacity.retry(stop=tenacity.stop_after_attempt(3))
|
| 97 |
+
def download_file(url, filepath):
|
| 98 |
+
logger.info(f'Downloading from {url} to {filepath.absolute()}.')
|
| 99 |
+
with requests.get(url, stream=True, allow_redirects=True, timeout=60) as r:
|
| 100 |
+
if r.status_code > 200:
|
| 101 |
+
raise RuntimeError(f'Failed in downloading from {url}, status code {r.status_code}.')
|
| 102 |
+
|
| 103 |
+
with open(filepath, 'wb') as f:
|
| 104 |
+
shutil.copyfileobj(r.raw, f, length=4194304)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class DistributionGridFactory:
|
| 108 |
+
"""
|
| 109 |
+
DistributionGrid Factory for helping create, cache and share the DistributionGrid based on the usage.
|
| 110 |
+
The DistributionGrid con be shared cross modules only the when this 3 conditions:
|
| 111 |
+
1. expert parallel group size
|
| 112 |
+
2. expert parallel replica group size,
|
| 113 |
+
are the same.
|
| 114 |
+
"""
|
| 115 |
+
distribution_grid_cache = {}
|
| 116 |
+
|
| 117 |
+
@classmethod
|
| 118 |
+
def get_distribution_grid(cls,
|
| 119 |
+
expert_parallel_group_size,
|
| 120 |
+
expert_parallel_replica_group_size,
|
| 121 |
+
ddp_type):
|
| 122 |
+
"""
|
| 123 |
+
Get the DistributionGrid by the conditions.
|
| 124 |
+
Args:
|
| 125 |
+
expert_parallel_group_size: expert parallel group size
|
| 126 |
+
expert_parallel_replica_group_size: expert parallel replica group size
|
| 127 |
+
ddp_type: distributed data parallel type. "DDP" of the recipe, only allow ddp_type is "MAINZ", "OSS" or "ShardedDDP".
|
| 128 |
+
|
| 129 |
+
Returns: new created DistributionGrid or shared DistributionGrid.
|
| 130 |
+
|
| 131 |
+
Notes: Currently get_distribution_grid only support "DDP" is "MAINZ", "OSS" or "ShardedDDP".
|
| 132 |
+
"""
|
| 133 |
+
# TODO: Support cases that "DDP" is "FSDP".
|
| 134 |
+
# For "FSDP", we use the DG of self.opt['fsdp_expert_grid'] which is initialize in DistributedTrainer directly.
|
| 135 |
+
ddp_type = ddp_type.upper()
|
| 136 |
+
assert ddp_type in ["MAINZ", "OSS", "SHARDEDDDP"], f'DistributionGrid Factory only support "DDP" is "MAINZ",' \
|
| 137 |
+
f' "OSS" or "ShardedDDP".' \
|
| 138 |
+
f' But currently "DDP" is {ddp_type}'
|
| 139 |
+
|
| 140 |
+
cached_distributed_grid = cls.distribution_grid_cache.get(
|
| 141 |
+
(expert_parallel_group_size, expert_parallel_replica_group_size), None)
|
| 142 |
+
|
| 143 |
+
if cached_distributed_grid is not None:
|
| 144 |
+
return cached_distributed_grid
|
| 145 |
+
else:
|
| 146 |
+
from ort_moe.grids import DistributionGrid
|
| 147 |
+
distributed_grid = DistributionGrid(expert_parallel_group_size=expert_parallel_group_size,
|
| 148 |
+
expert_parallel_replica_group_size=expert_parallel_replica_group_size)
|
| 149 |
+
|
| 150 |
+
cls.distribution_grid_cache[expert_parallel_group_size,
|
| 151 |
+
expert_parallel_replica_group_size] = distributed_grid
|
| 152 |
+
return distributed_grid
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def get_world_size():
|
| 156 |
+
if not dist.is_available():
|
| 157 |
+
return 1
|
| 158 |
+
if not dist.is_initialized():
|
| 159 |
+
return 1
|
| 160 |
+
return dist.get_world_size()
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def get_rank():
|
| 164 |
+
if not dist.is_available():
|
| 165 |
+
return 0
|
| 166 |
+
if not dist.is_initialized():
|
| 167 |
+
return 0
|
| 168 |
+
return dist.get_rank()
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def synchronize():
|
| 172 |
+
"""
|
| 173 |
+
Helper function to synchronize (barrier) among all processes when
|
| 174 |
+
using distributed training
|
| 175 |
+
"""
|
| 176 |
+
if not dist.is_available():
|
| 177 |
+
return
|
| 178 |
+
if not dist.is_initialized():
|
| 179 |
+
return
|
| 180 |
+
world_size = dist.get_world_size()
|
| 181 |
+
rank = dist.get_rank()
|
| 182 |
+
if world_size == 1:
|
| 183 |
+
return
|
| 184 |
+
|
| 185 |
+
def _send_and_wait(r):
|
| 186 |
+
if rank == r:
|
| 187 |
+
tensor = torch.tensor(0, device="cuda")
|
| 188 |
+
else:
|
| 189 |
+
tensor = torch.tensor(1, device="cuda")
|
| 190 |
+
dist.broadcast(tensor, r)
|
| 191 |
+
while tensor.item() == 1:
|
| 192 |
+
time.sleep(1)
|
| 193 |
+
|
| 194 |
+
_send_and_wait(0)
|
| 195 |
+
# now sync on the main process
|
| 196 |
+
_send_and_wait(1)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def all_gather(data):
|
| 200 |
+
"""
|
| 201 |
+
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
| 202 |
+
Args:
|
| 203 |
+
data: any picklable object
|
| 204 |
+
Returns:
|
| 205 |
+
list[data]: list of data gathered from each rank
|
| 206 |
+
"""
|
| 207 |
+
world_size = get_world_size()
|
| 208 |
+
if world_size == 1:
|
| 209 |
+
return [data]
|
| 210 |
+
|
| 211 |
+
# serialized to a Tensor
|
| 212 |
+
buffer = pickle.dumps(data)
|
| 213 |
+
storage = torch.ByteStorage.from_buffer(buffer)
|
| 214 |
+
tensor = torch.ByteTensor(storage).to("cuda")
|
| 215 |
+
|
| 216 |
+
# obtain Tensor size of each rank
|
| 217 |
+
local_size = torch.LongTensor([tensor.numel()]).to("cuda")
|
| 218 |
+
size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)]
|
| 219 |
+
dist.all_gather(size_list, local_size)
|
| 220 |
+
size_list = [int(size.item()) for size in size_list]
|
| 221 |
+
max_size = max(size_list)
|
| 222 |
+
|
| 223 |
+
# receiving Tensor from all ranks
|
| 224 |
+
# we pad the tensor because torch all_gather does not support
|
| 225 |
+
# gathering tensors of different shapes
|
| 226 |
+
tensor_list = []
|
| 227 |
+
for _ in size_list:
|
| 228 |
+
tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
|
| 229 |
+
if local_size != max_size:
|
| 230 |
+
padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
|
| 231 |
+
tensor = torch.cat((tensor, padding), dim=0)
|
| 232 |
+
dist.all_gather(tensor_list, tensor)
|
| 233 |
+
|
| 234 |
+
data_list = []
|
| 235 |
+
for size, tensor in zip(size_list, tensor_list):
|
| 236 |
+
buffer = tensor.cpu().numpy().tobytes()[:size]
|
| 237 |
+
data_list.append(pickle.loads(buffer))
|
| 238 |
+
|
| 239 |
+
return data_list
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def all_gather_cpu(data):
|
| 243 |
+
"""
|
| 244 |
+
Run all_gather on arbitrary picklable data (not necessarily tensors).
|
| 245 |
+
Args:
|
| 246 |
+
data: any picklable object
|
| 247 |
+
group: a torch process group. By default, will use a group which
|
| 248 |
+
contains all ranks on gloo backend.
|
| 249 |
+
Returns:
|
| 250 |
+
list[data]: list of data gathered from each rank
|
| 251 |
+
"""
|
| 252 |
+
|
| 253 |
+
def _get_global_gloo_group():
|
| 254 |
+
"""
|
| 255 |
+
Return a process group based on gloo backend, containing all the ranks
|
| 256 |
+
The result is cached.
|
| 257 |
+
"""
|
| 258 |
+
if dist.get_backend() == "nccl":
|
| 259 |
+
return dist.new_group(backend="gloo")
|
| 260 |
+
else:
|
| 261 |
+
return dist.group.WORLD
|
| 262 |
+
|
| 263 |
+
if get_world_size() == 1:
|
| 264 |
+
return [data]
|
| 265 |
+
group = _get_global_gloo_group() # use CPU group by default, to reduce GPU RAM usage.
|
| 266 |
+
world_size = dist.get_world_size(group)
|
| 267 |
+
if world_size == 1:
|
| 268 |
+
return [data]
|
| 269 |
+
|
| 270 |
+
output = [None for _ in range(world_size)]
|
| 271 |
+
dist.all_gather_object(output, data, group=group)
|
| 272 |
+
return output
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def reduce_dict(input_dict, average=True):
|
| 276 |
+
"""
|
| 277 |
+
Args:
|
| 278 |
+
input_dict (dict): all the values will be reduced
|
| 279 |
+
average (bool): whether to do average or sum
|
| 280 |
+
Reduce the values in the dictionary from all processes so that process with rank
|
| 281 |
+
0 has the averaged results. Returns a dict with the same fields as
|
| 282 |
+
input_dict, after reduction.
|
| 283 |
+
"""
|
| 284 |
+
world_size = get_world_size()
|
| 285 |
+
if world_size < 2:
|
| 286 |
+
return input_dict
|
| 287 |
+
with torch.no_grad():
|
| 288 |
+
names = []
|
| 289 |
+
values = []
|
| 290 |
+
# sort the keys so that they are consistent across processes
|
| 291 |
+
for k in sorted(input_dict.keys()):
|
| 292 |
+
names.append(k)
|
| 293 |
+
values.append(input_dict[k])
|
| 294 |
+
values = torch.stack(values, dim=0)
|
| 295 |
+
dist.reduce(values, dst=0)
|
| 296 |
+
if dist.get_rank() == 0 and average:
|
| 297 |
+
# only main process gets accumulated, so only divide by
|
| 298 |
+
# world_size in this case
|
| 299 |
+
values /= world_size
|
| 300 |
+
reduced_dict = {k: v for k, v in zip(names, values)}
|
| 301 |
+
return reduced_dict
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def broadcast_data(data):
|
| 305 |
+
if not torch.distributed.is_initialized():
|
| 306 |
+
return data
|
| 307 |
+
rank = dist.get_rank()
|
| 308 |
+
if rank == 0:
|
| 309 |
+
data_tensor = torch.tensor(data + [0], device="cuda")
|
| 310 |
+
else:
|
| 311 |
+
data_tensor = torch.tensor(data + [1], device="cuda")
|
| 312 |
+
torch.distributed.broadcast(data_tensor, 0)
|
| 313 |
+
while data_tensor.cpu().numpy()[-1] == 1:
|
| 314 |
+
time.sleep(1)
|
| 315 |
+
|
| 316 |
+
return data_tensor.cpu().numpy().tolist()[:-1]
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def reduce_sum(tensor):
|
| 320 |
+
if get_world_size() <= 1:
|
| 321 |
+
return tensor
|
| 322 |
+
|
| 323 |
+
tensor = tensor.clone()
|
| 324 |
+
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
|
| 325 |
+
return tensor
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def save_result(result, filename):
|
| 329 |
+
output_folder = os.path.dirname(filename)
|
| 330 |
+
basename = os.path.splitext(os.path.basename(filename))[0]
|
| 331 |
+
os.makedirs(output_folder, exist_ok=True)
|
| 332 |
+
|
| 333 |
+
if isinstance(result, torch.Tensor) and result.ndim in [3,4]:
|
| 334 |
+
if result.ndim==3 and result.size(0) not in [1,3]:
|
| 335 |
+
result = make_grid(result.unsqueeze(1))
|
| 336 |
+
elif result.ndim==4:
|
| 337 |
+
result = make_grid(result)
|
| 338 |
+
else:
|
| 339 |
+
result = make_grid([result])
|
| 340 |
+
|
| 341 |
+
im = Image.fromarray(result.clamp_(0, 255).permute(1, 2, 0).to(torch.uint8).numpy())
|
| 342 |
+
im.save(os.path.join(output_folder, '{}.png'.format(basename)))
|
| 343 |
+
else:
|
| 344 |
+
torch.save(result, os.path.join(output_folder, '{}.pth'.format(basename)))
|
MedImageInsight/Distributed/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .Utils import analysis_model
|
| 2 |
+
from .Utils import is_main_process
|
| 3 |
+
from .Utils import gather_tensors
|
| 4 |
+
from .Utils import register_norm_module
|
| 5 |
+
from .Utils import NORM_MODULES
|
| 6 |
+
from .Utils import DistributionGridFactory
|
MedImageInsight/Distributed/__pycache__/Utils.cpython-38.pyc
ADDED
|
Binary file (10.7 kB). View file
|
|
|
MedImageInsight/Distributed/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (437 Bytes). View file
|
|
|
MedImageInsight/ImageDataLoader/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .build import build_dataloader
|
| 2 |
+
#from .build import build_multitask_dataloader
|
| 3 |
+
from .transforms import build_transforms
|
| 4 |
+
#from .imagenet.real_labels import RealLabelsImagenet
|
| 5 |
+
from .constants import IMAGENET_CLASSES
|
| 6 |
+
from .constants import IMAGENET_DEFAULT_TEMPLATES
|
| 7 |
+
from .zipdata import ZipData
|
| 8 |
+
#from .vision_dataset import VDImageTextDataset, MultiClassTorchDatasetWrapper
|
MedImageInsight/ImageDataLoader/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (433 Bytes). View file
|
|
|
MedImageInsight/ImageDataLoader/__pycache__/blob_storage.cpython-38.pyc
ADDED
|
Binary file (7.95 kB). View file
|
|
|
MedImageInsight/ImageDataLoader/__pycache__/build.cpython-38.pyc
ADDED
|
Binary file (6.38 kB). View file
|
|
|
MedImageInsight/ImageDataLoader/__pycache__/constants.cpython-38.pyc
ADDED
|
Binary file (18.9 kB). View file
|
|
|
MedImageInsight/ImageDataLoader/__pycache__/tsv.cpython-38.pyc
ADDED
|
Binary file (9.31 kB). View file
|
|
|
MedImageInsight/ImageDataLoader/__pycache__/tsv_file.cpython-38.pyc
ADDED
|
Binary file (9.49 kB). View file
|
|
|
MedImageInsight/ImageDataLoader/__pycache__/zipdata.cpython-38.pyc
ADDED
|
Binary file (3.45 kB). View file
|
|
|
MedImageInsight/ImageDataLoader/blob_storage.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import shutil
|
| 4 |
+
import logging
|
| 5 |
+
import subprocess
|
| 6 |
+
import os.path as op
|
| 7 |
+
from typing import List
|
| 8 |
+
from collections import OrderedDict
|
| 9 |
+
|
| 10 |
+
import torch.distributed as distributed
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
DEFAULT_AZCOPY_PATH = 'azcopy/azcopy'
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def disk_usage(path: str) -> float:
|
| 18 |
+
stat = shutil.disk_usage(path)
|
| 19 |
+
return stat.used / stat.total
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def is_download_successful(stdout: str) -> bool:
|
| 23 |
+
for line in stdout.split('\n'):
|
| 24 |
+
if line == "Number of Transfers Failed: 0":
|
| 25 |
+
return True
|
| 26 |
+
logger.info("Azcopy message:\n %s" % stdout)
|
| 27 |
+
return False
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def ensure_directory(path):
|
| 31 |
+
"""Check existence of the given directory path. If not, create a new directory.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
path (str): path of a given directory.
|
| 35 |
+
"""
|
| 36 |
+
if path == '' or path == '.':
|
| 37 |
+
return
|
| 38 |
+
if path is not None and len(path) > 0:
|
| 39 |
+
assert not op.isfile(path), '{} is a file'.format(path)
|
| 40 |
+
if not op.exists(path) and not op.islink(path):
|
| 41 |
+
os.makedirs(path, exist_ok=True)
|
| 42 |
+
# we should always check if it succeeds.
|
| 43 |
+
assert op.isdir(op.abspath(path)), path
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class LRU(OrderedDict):
|
| 47 |
+
def __init__(self, maxsize=3):
|
| 48 |
+
self.maxsize = maxsize
|
| 49 |
+
|
| 50 |
+
def __getitem__(self, key):
|
| 51 |
+
value = super().__getitem__(key)
|
| 52 |
+
self.move_to_end(key)
|
| 53 |
+
return value
|
| 54 |
+
|
| 55 |
+
def __setitem__(self, key, value):
|
| 56 |
+
if key in self:
|
| 57 |
+
if self[key] is not None:
|
| 58 |
+
self[key].close()
|
| 59 |
+
self.move_to_end(key)
|
| 60 |
+
|
| 61 |
+
logger.debug('=> Cache {}'.format(key))
|
| 62 |
+
super().__setitem__(key, value)
|
| 63 |
+
|
| 64 |
+
if len(self) > self.maxsize:
|
| 65 |
+
oldest = next(iter(self))
|
| 66 |
+
if self[oldest] is not None:
|
| 67 |
+
self[oldest].close()
|
| 68 |
+
logger.debug('=> Purged {}'.format(oldest))
|
| 69 |
+
del self[oldest]
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class BlobStorage(OrderedDict):
|
| 73 |
+
""" Pseudo Blob Storage manager
|
| 74 |
+
|
| 75 |
+
The registered blobs are maintained in a LRU cache.
|
| 76 |
+
Limit size, evicting the least recently looked-up key when full.
|
| 77 |
+
https://docs.python.org/3/library/collections.html#collections.OrderedDict
|
| 78 |
+
|
| 79 |
+
Input argument:
|
| 80 |
+
sas_token (str): path to SAS token.
|
| 81 |
+
"""
|
| 82 |
+
def __init__(self,
|
| 83 |
+
is_train: bool,
|
| 84 |
+
sas_token_path: str = None,
|
| 85 |
+
azcopy_path: str = None,
|
| 86 |
+
*args, **kwds):
|
| 87 |
+
super().__init__(*args, **kwds)
|
| 88 |
+
self.maxsize = 2 if is_train else 10 # Set maxsize to large number such val data never get purged.
|
| 89 |
+
self.is_train = is_train
|
| 90 |
+
|
| 91 |
+
if sas_token_path:
|
| 92 |
+
self.sas_token = BlobStorage.read_sas_token(sas_token_path)
|
| 93 |
+
self.base_url = self.sas_token[:self.sas_token.index("?")]
|
| 94 |
+
self.query_string = self.sas_token[self.sas_token.index("?"):]
|
| 95 |
+
self.container = BlobStorage.extract_container(self.sas_token)
|
| 96 |
+
else:
|
| 97 |
+
self.sas_token = None
|
| 98 |
+
self.base_url = None
|
| 99 |
+
self.query_string = None
|
| 100 |
+
self.container = None
|
| 101 |
+
|
| 102 |
+
logger.debug(
|
| 103 |
+
f"=> [BlobStorage] Base url: {self.base_url}"
|
| 104 |
+
f"=> [BlobStorage] Query string: {self.query_string}"
|
| 105 |
+
f"=> [BlobStorage] Container name: {self.container}"
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
self.azcopy_path = azcopy_path if azcopy_path else DEFAULT_AZCOPY_PATH
|
| 109 |
+
self._cached_files = LRU(3)
|
| 110 |
+
|
| 111 |
+
def __getitem__(self, key):
|
| 112 |
+
value = super().__getitem__(key)
|
| 113 |
+
self.move_to_end(key)
|
| 114 |
+
return value
|
| 115 |
+
|
| 116 |
+
def __setitem__(self, key, value):
|
| 117 |
+
if key in self:
|
| 118 |
+
self.move_to_end(key)
|
| 119 |
+
super().__setitem__(key, value)
|
| 120 |
+
# NOTE: purge the least recently used data if the disk usage is high.
|
| 121 |
+
# ITP restarts GPU clusters when disk usage reaches 80%.
|
| 122 |
+
if len(self) > self.maxsize:
|
| 123 |
+
oldest = next(iter(self))
|
| 124 |
+
del self[oldest]
|
| 125 |
+
|
| 126 |
+
@staticmethod
|
| 127 |
+
def read_sas_token(path: str) -> str:
|
| 128 |
+
with open(path, 'r') as f:
|
| 129 |
+
token = f.readline().strip()
|
| 130 |
+
return token
|
| 131 |
+
|
| 132 |
+
@staticmethod
|
| 133 |
+
def extract_container(token: str) -> str:
|
| 134 |
+
"""
|
| 135 |
+
Input argument:
|
| 136 |
+
token (str): the full URI of Shared Access Signature (SAS) in the following format.
|
| 137 |
+
https://[storage_account].blob.core.windows.net/[container_name][SAS_token]
|
| 138 |
+
"""
|
| 139 |
+
return os.path.basename(token.split('?')[0])
|
| 140 |
+
|
| 141 |
+
def _convert_to_blob_url(self, local_path: str):
|
| 142 |
+
return self.base_url + local_path.split("azcopy")[1] + self.query_string
|
| 143 |
+
|
| 144 |
+
def _convert_to_blob_folder_url(self, local_path: str):
|
| 145 |
+
return self.base_url + local_path.split("azcopy")[1] + "/*" + self.query_string
|
| 146 |
+
|
| 147 |
+
def fetch_blob(self, local_path: str) -> None:
|
| 148 |
+
if op.exists(local_path):
|
| 149 |
+
logger.info('=> Try to open {}'.format(local_path))
|
| 150 |
+
fp = open(local_path, 'r')
|
| 151 |
+
self._cached_files[local_path] = fp
|
| 152 |
+
logger.debug("=> %s downloaded. Skip." % local_path)
|
| 153 |
+
return
|
| 154 |
+
blob_url = self._convert_to_blob_url(local_path)
|
| 155 |
+
rank = '0' if 'RANK' not in os.environ else os.environ['RANK']
|
| 156 |
+
cmd = [self.azcopy_path, "copy", blob_url, local_path + rank]
|
| 157 |
+
curr_usage = disk_usage('/')
|
| 158 |
+
logger.info(
|
| 159 |
+
"=> Downloading %s with azcopy ... (disk usage: %.2f%%)"
|
| 160 |
+
% (local_path, curr_usage * 100)
|
| 161 |
+
)
|
| 162 |
+
proc = subprocess.run(cmd, stdout=subprocess.PIPE)
|
| 163 |
+
while not is_download_successful(proc.stdout.decode()):
|
| 164 |
+
logger.info("=> Azcopy failed to download {}. Retrying ...".format(blob_url))
|
| 165 |
+
proc = subprocess.run(cmd, stdout=subprocess.PIPE)
|
| 166 |
+
if not op.exists(local_path):
|
| 167 |
+
os.rename(local_path + rank, local_path)
|
| 168 |
+
else:
|
| 169 |
+
os.remove(local_path + rank)
|
| 170 |
+
logger.info(
|
| 171 |
+
"=> Downloaded %s with azcopy ... (disk usage: %.2f%% => %.2f%%)" %
|
| 172 |
+
(local_path, curr_usage * 100, disk_usage('/') * 100)
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
def fetch_blob_folder(self, local_path: str, azcopy_args: list=[]) -> None:
|
| 176 |
+
blob_url = self._convert_to_blob_folder_url(local_path)
|
| 177 |
+
cmd = [self.azcopy_path, "copy", blob_url, local_path] + azcopy_args
|
| 178 |
+
curr_usage = disk_usage('/')
|
| 179 |
+
logger.info(
|
| 180 |
+
"=> Downloading %s with azcopy args %s ... (disk usage: %.2f%%)"
|
| 181 |
+
% (local_path, ' '.join(azcopy_args), curr_usage * 100)
|
| 182 |
+
)
|
| 183 |
+
proc = subprocess.run(cmd, stdout=subprocess.PIPE)
|
| 184 |
+
while not is_download_successful(proc.stdout.decode()):
|
| 185 |
+
logger.info("=> Azcopy failed to download {} with args {}. Retrying ...".format(blob_url, ' '.join(azcopy_args)))
|
| 186 |
+
proc = subprocess.run(cmd, stdout=subprocess.PIPE)
|
| 187 |
+
logger.info(
|
| 188 |
+
"=> Downloaded %s with azcopy args %s ... (disk usage: %.2f%% => %.2f%%)" %
|
| 189 |
+
(local_path, ' '.join(azcopy_args), curr_usage * 100, disk_usage('/') * 100)
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
def register_local_tsv_paths(self, local_paths: List[str]) -> List[str]:
|
| 193 |
+
if self.sas_token:
|
| 194 |
+
tsv_paths_new = []
|
| 195 |
+
lineidx_paths = set()
|
| 196 |
+
linelist_paths = set()
|
| 197 |
+
for path in local_paths:
|
| 198 |
+
tsv_path_az = path.replace(self.container, 'azcopy')
|
| 199 |
+
tsv_paths_new.append(tsv_path_az)
|
| 200 |
+
logger.debug("=> Registering {}".format(tsv_path_az))
|
| 201 |
+
|
| 202 |
+
if not self.is_train:
|
| 203 |
+
logger.info('=> Downloading {}...'.format(tsv_path_az))
|
| 204 |
+
self.fetch_blob(tsv_path_az)
|
| 205 |
+
logger.info('=> Downloaded {}'.format(tsv_path_az))
|
| 206 |
+
|
| 207 |
+
lineidx = op.splitext(path)[0] + '.lineidx'
|
| 208 |
+
lineidx_ = lineidx.replace(self.container, 'azcopy')
|
| 209 |
+
if self.is_train:
|
| 210 |
+
if not op.isfile(lineidx_) and op.dirname(lineidx_) not in lineidx_paths:
|
| 211 |
+
lineidx_paths.add(op.dirname(lineidx_))
|
| 212 |
+
else:
|
| 213 |
+
if not op.isfile(lineidx_):
|
| 214 |
+
ensure_directory(op.dirname(lineidx_))
|
| 215 |
+
self.fetch_blob(lineidx_)
|
| 216 |
+
|
| 217 |
+
linelist = op.splitext(path)[0] + '.linelist'
|
| 218 |
+
linelist_ = linelist.replace(self.container, 'azcopy')
|
| 219 |
+
# .linelist does not always exist. Check existence before fetch
|
| 220 |
+
if self.is_train:
|
| 221 |
+
if op.isfile(linelist) and not op.isfile(linelist_) and op.dirname(linelist_) not in linelist_paths:
|
| 222 |
+
linelist_paths.add(op.dirname(linelist_))
|
| 223 |
+
else:
|
| 224 |
+
if op.isfile(linelist) and not op.isfile(linelist_):
|
| 225 |
+
ensure_directory(op.dirname(linelist_))
|
| 226 |
+
self.fetch_blob(linelist_)
|
| 227 |
+
|
| 228 |
+
if self.is_train:
|
| 229 |
+
for path in lineidx_paths:
|
| 230 |
+
self.fetch_blob_folder(path, azcopy_args=['--include-pattern', '*.lineidx'])
|
| 231 |
+
|
| 232 |
+
for path in linelist_paths:
|
| 233 |
+
self.fetch_blob_folder(path, azcopy_args=['--include-pattern', '*.linelist'])
|
| 234 |
+
|
| 235 |
+
return tsv_paths_new
|
| 236 |
+
else:
|
| 237 |
+
return local_paths
|
| 238 |
+
|
| 239 |
+
def open(self, local_path: str):
|
| 240 |
+
if self.sas_token and 'azcopy' in local_path:
|
| 241 |
+
while not op.exists(local_path):
|
| 242 |
+
time.sleep(1)
|
| 243 |
+
fid = open(local_path, 'r')
|
| 244 |
+
return fid
|
MedImageInsight/ImageDataLoader/build.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import absolute_import
|
| 2 |
+
from __future__ import division
|
| 3 |
+
from __future__ import print_function
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
import pathlib
|
| 9 |
+
from os.path import basename
|
| 10 |
+
|
| 11 |
+
from timm.data import create_loader
|
| 12 |
+
import torch
|
| 13 |
+
import torch.utils.data
|
| 14 |
+
import torch.distributed as dist
|
| 15 |
+
import torchvision.datasets as datasets
|
| 16 |
+
from torchvision.io import read_image
|
| 17 |
+
import torch.distributed as dist
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from yacs.config import CfgNode as CN
|
| 20 |
+
|
| 21 |
+
from ..LangEncoder import build_tokenizer
|
| 22 |
+
|
| 23 |
+
from .tsv import TSVImageTextDatasetV2
|
| 24 |
+
from .tsv import TSVMeta
|
| 25 |
+
from .transforms import build_transforms
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def build_dataset(cfg, is_train):
|
| 31 |
+
if cfg['DATASET']['DATASET'] == 'image_text_pairs_v2':
|
| 32 |
+
dataset = _build_pairs_dataset_v2(cfg, is_train)
|
| 33 |
+
else:
|
| 34 |
+
raise ValueError(f'Unknown dataset: {cfg["DATASET"]["DATASET"]}')
|
| 35 |
+
return dataset
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _get_tsv_list(cfg, is_train):
|
| 39 |
+
tmp_list = []
|
| 40 |
+
if is_train and 'TRAIN_TSV_LIST' in cfg['DATASET']:
|
| 41 |
+
tmp_list = cfg['DATASET']['TRAIN_TSV_LIST']
|
| 42 |
+
elif 'TEST_TSV_LIST' in cfg['DATASET']:
|
| 43 |
+
tmp_list = cfg['DATASET']['TEST_TSV_LIST']
|
| 44 |
+
|
| 45 |
+
tsv_list = []
|
| 46 |
+
for l in tmp_list:
|
| 47 |
+
if l.endswith('.list'):
|
| 48 |
+
with open(l, 'r') as f:
|
| 49 |
+
tsv_list.extend([i.strip() for i in f])
|
| 50 |
+
else:
|
| 51 |
+
tsv_list.append(l)
|
| 52 |
+
|
| 53 |
+
logger.info(f'tsv list: {tsv_list}')
|
| 54 |
+
|
| 55 |
+
return tsv_list
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _get_token_file(cfg):
|
| 59 |
+
num_nodes = dist.get_world_size() // torch.cuda.device_count()
|
| 60 |
+
if isinstance(cfg['DATASET']['TOKEN_FILE'], list):
|
| 61 |
+
if num_nodes == 1:
|
| 62 |
+
logger.warning('=> Multi token files are provided, but only one node is used for training')
|
| 63 |
+
sas_token_file = cfg['DATASET']['TOKEN_FILE'][0]
|
| 64 |
+
else:
|
| 65 |
+
rank = dist.get_rank()
|
| 66 |
+
node_idx = rank // torch.cuda.device_count()
|
| 67 |
+
num_token_files = len(cfg['DATASET']['TOKEN_FILE'])
|
| 68 |
+
sas_token_file = cfg['DATASET']['TOKEN_FILE'][node_idx % num_token_files]
|
| 69 |
+
else:
|
| 70 |
+
sas_token_file = cfg['DATASET']['TOKEN_FILE']
|
| 71 |
+
|
| 72 |
+
sas_token_file = os.path.join(cfg['DATASET']['ROOT'], sas_token_file)
|
| 73 |
+
|
| 74 |
+
if (
|
| 75 |
+
cfg['DATASET']['LOADER'] == 'blobfuse'
|
| 76 |
+
or not os.path.isfile(sas_token_file)
|
| 77 |
+
):
|
| 78 |
+
sas_token_file = None
|
| 79 |
+
|
| 80 |
+
return sas_token_file
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _build_pairs_dataset_v2(cfg, is_train):
|
| 84 |
+
transforms = build_transforms(cfg, is_train)
|
| 85 |
+
logger.info('transforms: {}'.format(transforms))
|
| 86 |
+
|
| 87 |
+
dataset_name = cfg['DATASET']['TRAIN_SET'] \
|
| 88 |
+
if is_train else cfg['DATASET']['TEST_SET']
|
| 89 |
+
|
| 90 |
+
tokenobj = build_tokenizer(cfg['LANG_ENCODER'])
|
| 91 |
+
|
| 92 |
+
if cfg['DATASET']['DATA_FORMAT'] != 'tsv':
|
| 93 |
+
raise ValueError('Only support tsv format for pairs dataset v2')
|
| 94 |
+
|
| 95 |
+
tsv_list = _get_tsv_list(cfg, is_train)
|
| 96 |
+
|
| 97 |
+
if len(tsv_list) > 0:
|
| 98 |
+
tsv_filenames = sorted(
|
| 99 |
+
[
|
| 100 |
+
os.path.join(cfg['DATASET']['ROOT'], dataset_name, f)
|
| 101 |
+
for f in tsv_list
|
| 102 |
+
]
|
| 103 |
+
)
|
| 104 |
+
else:
|
| 105 |
+
dataset_path = os.path.join(cfg['DATASET']['ROOT'], dataset_name)
|
| 106 |
+
tsv_files = Path(dataset_path).glob('**/*.tsv')
|
| 107 |
+
|
| 108 |
+
tsv_filenames = sorted(
|
| 109 |
+
[
|
| 110 |
+
str(path)
|
| 111 |
+
for path in tsv_files
|
| 112 |
+
]
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
image_tsv_files = [
|
| 116 |
+
filename
|
| 117 |
+
for filename in tsv_filenames
|
| 118 |
+
if (
|
| 119 |
+
'image-' in basename(filename)
|
| 120 |
+
or 'image_' in basename(filename)
|
| 121 |
+
or '_image' in basename(filename)
|
| 122 |
+
or '-image' in basename(filename)
|
| 123 |
+
or 'images-' in basename(filename)
|
| 124 |
+
)
|
| 125 |
+
]
|
| 126 |
+
text_tsv_files = [
|
| 127 |
+
filename
|
| 128 |
+
for filename in tsv_filenames
|
| 129 |
+
if (
|
| 130 |
+
'text-' in basename(filename)
|
| 131 |
+
or 'text_' in basename(filename)
|
| 132 |
+
or '_text' in basename(filename)
|
| 133 |
+
or '-text' in basename(filename)
|
| 134 |
+
or 'texts-' in basename(filename)
|
| 135 |
+
)
|
| 136 |
+
]
|
| 137 |
+
|
| 138 |
+
logger.info(
|
| 139 |
+
"=> found %d/%d tsv file(s) to load.",
|
| 140 |
+
len(image_tsv_files), len(text_tsv_files)
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
num_captions = 1 \
|
| 144 |
+
if is_train else cfg['DATASET'].get('NUM_CAPTIONS', 1)
|
| 145 |
+
text_format = cfg['DATASET'].get('TEXT_FORMAT', 'json')
|
| 146 |
+
|
| 147 |
+
sas_token_file = _get_token_file(cfg)
|
| 148 |
+
logger.info("=> SAS token path: %s", sas_token_file)
|
| 149 |
+
|
| 150 |
+
metas = []
|
| 151 |
+
cfg_data = cfg['DATASET']
|
| 152 |
+
if 'CLASSIFICATION_SETS' in cfg_data and 'NUM_CLASSES' in cfg_data:
|
| 153 |
+
for source, num_classes in zip(cfg_data['CLASSIFICATION_SETS'], cfg_data['NUM_CLASSES']):
|
| 154 |
+
metas.append(
|
| 155 |
+
TSVMeta(
|
| 156 |
+
source=source,
|
| 157 |
+
num_classes=num_classes,
|
| 158 |
+
task='classification'
|
| 159 |
+
)
|
| 160 |
+
)
|
| 161 |
+
logger.info('=> add meta: {}'.format(metas[-1]))
|
| 162 |
+
|
| 163 |
+
if 'coco-caption' in dataset_name:
|
| 164 |
+
logger.info('=> coco caption data is used')
|
| 165 |
+
logger.info('=> update num_captions: 5, text_format: json')
|
| 166 |
+
logger.warning('=> set sas token to None for coco evaluation')
|
| 167 |
+
sas_token_file = None
|
| 168 |
+
num_captions = 5
|
| 169 |
+
text_format = 'json'
|
| 170 |
+
|
| 171 |
+
dataset = TSVImageTextDatasetV2(
|
| 172 |
+
image_tsv_files, text_tsv_files,
|
| 173 |
+
transform=transforms,
|
| 174 |
+
tokenize=tokenobj,
|
| 175 |
+
context_length=cfg['LANG_ENCODER']['CONTEXT_LENGTH'],
|
| 176 |
+
num_captions=num_captions,
|
| 177 |
+
text_format=text_format,
|
| 178 |
+
is_train=is_train,
|
| 179 |
+
sas_token_path=sas_token_file,
|
| 180 |
+
metas=metas,
|
| 181 |
+
prompt_engineering=cfg['DATASET'].get('PROMPT_ENGINEERING', True),
|
| 182 |
+
concat_queries=cfg['DATASET'].get('CONCAT_QUERIES', False)
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
logger.info(
|
| 186 |
+
"=> %s set size: %d", 'train'
|
| 187 |
+
if is_train else 'val', len(dataset)
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
return dataset
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def build_dataloader(cfg, is_train=True, distributed=False):
|
| 194 |
+
dataset = build_dataset(cfg, is_train)
|
| 195 |
+
|
| 196 |
+
if (
|
| 197 |
+
is_train
|
| 198 |
+
and 'TIMM_AUG' in cfg['AUG']
|
| 199 |
+
and cfg['AUG']['TIMM_AUG']['USE_LOADER']
|
| 200 |
+
):
|
| 201 |
+
logger.info('=> use timm loader for training')
|
| 202 |
+
timm_cfg = CN(init_dict=cfg['AUG']['TIMM_AUG'])
|
| 203 |
+
data_loader = create_loader(
|
| 204 |
+
dataset,
|
| 205 |
+
input_size=cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0],
|
| 206 |
+
batch_size=cfg['TRAIN']['BATCH_SIZE_PER_GPU'],
|
| 207 |
+
is_training=True,
|
| 208 |
+
use_prefetcher=True,
|
| 209 |
+
no_aug=False,
|
| 210 |
+
re_prob=timm_cfg.RE_PROB,
|
| 211 |
+
re_mode=timm_cfg.RE_MODE,
|
| 212 |
+
re_count=timm_cfg.RE_COUNT,
|
| 213 |
+
re_split=timm_cfg.RE_SPLIT,
|
| 214 |
+
scale=cfg['AUG']['SCALE'],
|
| 215 |
+
ratio=cfg['AUG']['RATIO'],
|
| 216 |
+
hflip=timm_cfg.HFLIP,
|
| 217 |
+
vflip=timm_cfg.VFLIP,
|
| 218 |
+
color_jitter=timm_cfg.COLOR_JITTER,
|
| 219 |
+
auto_augment=timm_cfg.AUTO_AUGMENT,
|
| 220 |
+
num_aug_splits=0,
|
| 221 |
+
interpolation=cfg['AUG']['INTERPOLATION'],
|
| 222 |
+
mean=cfg['IMAGE_ENCODER']['IMAGE_MEAN'],
|
| 223 |
+
std=cfg['IMAGE_ENCODER']['IMAGE_STD'],
|
| 224 |
+
num_workers=cfg['WORKERS'],
|
| 225 |
+
distributed=distributed,
|
| 226 |
+
collate_fn=None,
|
| 227 |
+
pin_memory=cfg['PIN_MEMORY'],
|
| 228 |
+
use_multi_epochs_loader=True
|
| 229 |
+
)
|
| 230 |
+
else:
|
| 231 |
+
if is_train:
|
| 232 |
+
batch_size_per_gpu = cfg['TRAIN']['BATCH_SIZE_PER_GPU']
|
| 233 |
+
shuffle = cfg['TRAIN'].get('SHUFFLE', True)
|
| 234 |
+
else:
|
| 235 |
+
batch_size_per_gpu = cfg['TEST']['BATCH_SIZE_PER_GPU']
|
| 236 |
+
shuffle = cfg['TEST'].get('SHUFFLE', False)
|
| 237 |
+
|
| 238 |
+
if distributed or cfg.get('ALWAYS_ENABLE_SAMPLER', False):
|
| 239 |
+
# sampler = build_sampler(cfg, dataset, is_train, shuffle)
|
| 240 |
+
sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle)
|
| 241 |
+
shuffle = False
|
| 242 |
+
else:
|
| 243 |
+
sampler = None
|
| 244 |
+
|
| 245 |
+
data_loader = torch.utils.data.DataLoader(
|
| 246 |
+
dataset,
|
| 247 |
+
batch_size=batch_size_per_gpu,
|
| 248 |
+
shuffle=shuffle,
|
| 249 |
+
num_workers=cfg['WORKERS'],
|
| 250 |
+
pin_memory=cfg['PIN_MEMORY'],
|
| 251 |
+
sampler=sampler,
|
| 252 |
+
drop_last=True if is_train else False,
|
| 253 |
+
prefetch_factor=cfg.get('PREFETCH_FACTOR', 2)
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
return data_loader
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
|
MedImageInsight/ImageDataLoader/constants.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
IMAGENET_CLASSES = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", "box turtle", "banded gecko", "green iguana", "Carolina anole", "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", "American alligator", "triceratops", "worm snake", "ring-necked snake", "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", "freight car", "French horn", "frying pan", "fur coat", "garbage truck", "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", "printer", "prison", "projectile", "projector", "hockey puck", "punching bag", "purse", "quill", "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", "submarine", "suit", "sundial", "sunglasses", "dark glasses", "sunscreen", "suspension bridge", "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"]
|
| 2 |
+
|
| 3 |
+
IMAGENET_DEFAULT_TEMPLATES = [
|
| 4 |
+
'{}.',
|
| 5 |
+
'a bad photo of a {}.',
|
| 6 |
+
'a photo of many {}.',
|
| 7 |
+
'a sculpture of a {}.',
|
| 8 |
+
'a photo of the hard to see {}.',
|
| 9 |
+
'a low resolution photo of the {}.',
|
| 10 |
+
'a rendering of a {}.',
|
| 11 |
+
'graffiti of a {}.',
|
| 12 |
+
'a bad photo of the {}.',
|
| 13 |
+
'a cropped photo of the {}.',
|
| 14 |
+
'a tattoo of a {}.',
|
| 15 |
+
'the embroidered {}.',
|
| 16 |
+
'a photo of a hard to see {}.',
|
| 17 |
+
'a bright photo of a {}.',
|
| 18 |
+
'a photo of a clean {}.',
|
| 19 |
+
'a photo of a dirty {}.',
|
| 20 |
+
'a dark photo of the {}.',
|
| 21 |
+
'a drawing of a {}.',
|
| 22 |
+
'a photo of my {}.',
|
| 23 |
+
'the plastic {}.',
|
| 24 |
+
'a photo of the cool {}.',
|
| 25 |
+
'a close-up photo of a {}.',
|
| 26 |
+
'a black and white photo of the {}.',
|
| 27 |
+
'a painting of the {}.',
|
| 28 |
+
'a painting of a {}.',
|
| 29 |
+
'a pixelated photo of the {}.',
|
| 30 |
+
'a sculpture of the {}.',
|
| 31 |
+
'a bright photo of the {}.',
|
| 32 |
+
'a cropped photo of a {}.',
|
| 33 |
+
'a plastic {}.',
|
| 34 |
+
'a photo of the dirty {}.',
|
| 35 |
+
'a jpeg corrupted photo of a {}.',
|
| 36 |
+
'a blurry photo of the {}.',
|
| 37 |
+
'a photo of the {}.',
|
| 38 |
+
'a good photo of the {}.',
|
| 39 |
+
'a rendering of the {}.',
|
| 40 |
+
'a {} in a video game.',
|
| 41 |
+
'a photo of one {}.',
|
| 42 |
+
'a doodle of a {}.',
|
| 43 |
+
'a close-up photo of the {}.',
|
| 44 |
+
'a photo of a {}.',
|
| 45 |
+
'the origami {}.',
|
| 46 |
+
'the {} in a video game.',
|
| 47 |
+
'a sketch of a {}.',
|
| 48 |
+
'a doodle of the {}.',
|
| 49 |
+
'a origami {}.',
|
| 50 |
+
'a low resolution photo of a {}.',
|
| 51 |
+
'the toy {}.',
|
| 52 |
+
'a rendition of the {}.',
|
| 53 |
+
'a photo of the clean {}.',
|
| 54 |
+
'a photo of a large {}.',
|
| 55 |
+
'a rendition of a {}.',
|
| 56 |
+
'a photo of a nice {}.',
|
| 57 |
+
'a photo of a weird {}.',
|
| 58 |
+
'a blurry photo of a {}.',
|
| 59 |
+
'a cartoon {}.',
|
| 60 |
+
'art of a {}.',
|
| 61 |
+
'a sketch of the {}.',
|
| 62 |
+
'a embroidered {}.',
|
| 63 |
+
'a pixelated photo of a {}.',
|
| 64 |
+
'itap of the {}.',
|
| 65 |
+
'a jpeg corrupted photo of the {}.',
|
| 66 |
+
'a good photo of a {}.',
|
| 67 |
+
'a plushie {}.',
|
| 68 |
+
'a photo of the nice {}.',
|
| 69 |
+
'a photo of the small {}.',
|
| 70 |
+
'a photo of the weird {}.',
|
| 71 |
+
'the cartoon {}.',
|
| 72 |
+
'art of the {}.',
|
| 73 |
+
'a drawing of the {}.',
|
| 74 |
+
'a photo of the large {}.',
|
| 75 |
+
'a black and white photo of a {}.',
|
| 76 |
+
'the plushie {}.',
|
| 77 |
+
'a dark photo of a {}.',
|
| 78 |
+
'itap of a {}.',
|
| 79 |
+
'graffiti of the {}.',
|
| 80 |
+
'a toy {}.',
|
| 81 |
+
'itap of my {}.',
|
| 82 |
+
'a photo of a cool {}.',
|
| 83 |
+
'a photo of a small {}.',
|
| 84 |
+
'a tattoo of the {}.',
|
| 85 |
+
]
|
MedImageInsight/ImageDataLoader/languages/__init__.py
ADDED
|
File without changes
|
MedImageInsight/ImageDataLoader/languages/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (208 Bytes). View file
|
|
|
MedImageInsight/ImageDataLoader/languages/__pycache__/prompt_engineering.cpython-38.pyc
ADDED
|
Binary file (2.86 kB). View file
|
|
|
MedImageInsight/ImageDataLoader/languages/prompt_engineering.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def get_prompt_templates():
|
| 6 |
+
prompt_templates = [
|
| 7 |
+
'{}.',
|
| 8 |
+
'a photo of a {}.',
|
| 9 |
+
'a bad photo of a {}.',
|
| 10 |
+
'a photo of many {}.',
|
| 11 |
+
'a sculpture of a {}.',
|
| 12 |
+
'a photo of the hard to see {}.',
|
| 13 |
+
'a low resolution photo of the {}.',
|
| 14 |
+
'a rendering of a {}.',
|
| 15 |
+
'graffiti of a {}.',
|
| 16 |
+
'a bad photo of the {}.',
|
| 17 |
+
'a cropped photo of the {}.',
|
| 18 |
+
'a tattoo of a {}.',
|
| 19 |
+
'the embroidered {}.',
|
| 20 |
+
'a photo of a hard to see {}.',
|
| 21 |
+
'a bright photo of a {}.',
|
| 22 |
+
'a photo of a clean {}.',
|
| 23 |
+
'a photo of a dirty {}.',
|
| 24 |
+
'a dark photo of the {}.',
|
| 25 |
+
'a drawing of a {}.',
|
| 26 |
+
'a photo of my {}.',
|
| 27 |
+
'the plastic {}.',
|
| 28 |
+
'a photo of the cool {}.',
|
| 29 |
+
'a close-up photo of a {}.',
|
| 30 |
+
'a black and white photo of the {}.',
|
| 31 |
+
'a painting of the {}.',
|
| 32 |
+
'a painting of a {}.',
|
| 33 |
+
'a pixelated photo of the {}.',
|
| 34 |
+
'a sculpture of the {}.',
|
| 35 |
+
'a bright photo of the {}.',
|
| 36 |
+
'a cropped photo of a {}.',
|
| 37 |
+
'a plastic {}.',
|
| 38 |
+
'a photo of the dirty {}.',
|
| 39 |
+
'a jpeg corrupted photo of a {}.',
|
| 40 |
+
'a blurry photo of the {}.',
|
| 41 |
+
'a photo of the {}.',
|
| 42 |
+
'a good photo of the {}.',
|
| 43 |
+
'a rendering of the {}.',
|
| 44 |
+
'a {} in a video game.',
|
| 45 |
+
'a photo of one {}.',
|
| 46 |
+
'a doodle of a {}.',
|
| 47 |
+
'a close-up photo of the {}.',
|
| 48 |
+
'the origami {}.',
|
| 49 |
+
'the {} in a video game.',
|
| 50 |
+
'a sketch of a {}.',
|
| 51 |
+
'a doodle of the {}.',
|
| 52 |
+
'a origami {}.',
|
| 53 |
+
'a low resolution photo of a {}.',
|
| 54 |
+
'the toy {}.',
|
| 55 |
+
'a rendition of the {}.',
|
| 56 |
+
'a photo of the clean {}.',
|
| 57 |
+
'a photo of a large {}.',
|
| 58 |
+
'a rendition of a {}.',
|
| 59 |
+
'a photo of a nice {}.',
|
| 60 |
+
'a photo of a weird {}.',
|
| 61 |
+
'a blurry photo of a {}.',
|
| 62 |
+
'a cartoon {}.',
|
| 63 |
+
'art of a {}.',
|
| 64 |
+
'a sketch of the {}.',
|
| 65 |
+
'a embroidered {}.',
|
| 66 |
+
'a pixelated photo of a {}.',
|
| 67 |
+
'itap of the {}.',
|
| 68 |
+
'a jpeg corrupted photo of the {}.',
|
| 69 |
+
'a good photo of a {}.',
|
| 70 |
+
'a plushie {}.',
|
| 71 |
+
'a photo of the nice {}.',
|
| 72 |
+
'a photo of the small {}.',
|
| 73 |
+
'a photo of the weird {}.',
|
| 74 |
+
'the cartoon {}.',
|
| 75 |
+
'art of the {}.',
|
| 76 |
+
'a drawing of the {}.',
|
| 77 |
+
'a photo of the large {}.',
|
| 78 |
+
'a black and white photo of a {}.',
|
| 79 |
+
'the plushie {}.',
|
| 80 |
+
'a dark photo of a {}.',
|
| 81 |
+
'itap of a {}.',
|
| 82 |
+
'graffiti of the {}.',
|
| 83 |
+
'a toy {}.',
|
| 84 |
+
'itap of my {}.',
|
| 85 |
+
'a photo of a cool {}.',
|
| 86 |
+
'a photo of a small {}.',
|
| 87 |
+
'a tattoo of the {}.',
|
| 88 |
+
]
|
| 89 |
+
return prompt_templates
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def prompt_engineering(classnames):
|
| 93 |
+
prompt_templates = get_prompt_templates()
|
| 94 |
+
temp_idx = np.random.randint(len(prompt_templates))
|
| 95 |
+
|
| 96 |
+
if isinstance(classnames, list):
|
| 97 |
+
classname = random.choice(classnames)
|
| 98 |
+
else:
|
| 99 |
+
classname = classnames
|
| 100 |
+
|
| 101 |
+
return prompt_templates[temp_idx].replace('{}', classname.replace(',', '').replace('+', ' '))
|
MedImageInsight/ImageDataLoader/transforms/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .build import build_transforms
|
MedImageInsight/ImageDataLoader/transforms/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (255 Bytes). View file
|
|
|
MedImageInsight/ImageDataLoader/transforms/__pycache__/autoaugment.cpython-38.pyc
ADDED
|
Binary file (15.1 kB). View file
|
|
|
MedImageInsight/ImageDataLoader/transforms/__pycache__/build.cpython-38.pyc
ADDED
|
Binary file (6.24 kB). View file
|
|
|
MedImageInsight/ImageDataLoader/transforms/__pycache__/threeaugment.cpython-38.pyc
ADDED
|
Binary file (2.09 kB). View file
|
|
|
MedImageInsight/ImageDataLoader/transforms/autoaugment.py
ADDED
|
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from enum import Enum
|
| 3 |
+
from typing import List, Tuple, Optional, Dict
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
|
| 8 |
+
from torchvision.transforms import functional as F
|
| 9 |
+
from torchvision.transforms.functional import InterpolationMode
|
| 10 |
+
|
| 11 |
+
__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide"]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _apply_op(
|
| 15 |
+
img: Tensor, op_name: str, magnitude: float, interpolation: InterpolationMode, fill: Optional[List[float]]
|
| 16 |
+
):
|
| 17 |
+
if op_name == "ShearX":
|
| 18 |
+
img = F.affine(
|
| 19 |
+
img,
|
| 20 |
+
angle=0.0,
|
| 21 |
+
translate=[0, 0],
|
| 22 |
+
scale=1.0,
|
| 23 |
+
shear=[math.degrees(magnitude), 0.0],
|
| 24 |
+
interpolation=interpolation,
|
| 25 |
+
fill=fill,
|
| 26 |
+
)
|
| 27 |
+
elif op_name == "ShearY":
|
| 28 |
+
img = F.affine(
|
| 29 |
+
img,
|
| 30 |
+
angle=0.0,
|
| 31 |
+
translate=[0, 0],
|
| 32 |
+
scale=1.0,
|
| 33 |
+
shear=[0.0, math.degrees(magnitude)],
|
| 34 |
+
interpolation=interpolation,
|
| 35 |
+
fill=fill,
|
| 36 |
+
)
|
| 37 |
+
elif op_name == "TranslateX":
|
| 38 |
+
img = F.affine(
|
| 39 |
+
img,
|
| 40 |
+
angle=0.0,
|
| 41 |
+
translate=[int(magnitude), 0],
|
| 42 |
+
scale=1.0,
|
| 43 |
+
interpolation=interpolation,
|
| 44 |
+
shear=[0.0, 0.0],
|
| 45 |
+
fill=fill,
|
| 46 |
+
)
|
| 47 |
+
elif op_name == "TranslateY":
|
| 48 |
+
img = F.affine(
|
| 49 |
+
img,
|
| 50 |
+
angle=0.0,
|
| 51 |
+
translate=[0, int(magnitude)],
|
| 52 |
+
scale=1.0,
|
| 53 |
+
interpolation=interpolation,
|
| 54 |
+
shear=[0.0, 0.0],
|
| 55 |
+
fill=fill,
|
| 56 |
+
)
|
| 57 |
+
elif op_name == "Rotate":
|
| 58 |
+
img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill)
|
| 59 |
+
elif op_name == "Brightness":
|
| 60 |
+
img = F.adjust_brightness(img, 1.0 + magnitude)
|
| 61 |
+
elif op_name == "Color":
|
| 62 |
+
img = F.adjust_saturation(img, 1.0 + magnitude)
|
| 63 |
+
elif op_name == "Contrast":
|
| 64 |
+
img = F.adjust_contrast(img, 1.0 + magnitude)
|
| 65 |
+
elif op_name == "Sharpness":
|
| 66 |
+
img = F.adjust_sharpness(img, 1.0 + magnitude)
|
| 67 |
+
elif op_name == "Posterize":
|
| 68 |
+
img = F.posterize(img, int(magnitude))
|
| 69 |
+
elif op_name == "Solarize":
|
| 70 |
+
img = F.solarize(img, magnitude)
|
| 71 |
+
elif op_name == "AutoContrast":
|
| 72 |
+
img = F.autocontrast(img)
|
| 73 |
+
elif op_name == "Equalize":
|
| 74 |
+
img = F.equalize(img)
|
| 75 |
+
elif op_name == "Invert":
|
| 76 |
+
img = F.invert(img)
|
| 77 |
+
elif op_name == "Identity":
|
| 78 |
+
pass
|
| 79 |
+
else:
|
| 80 |
+
raise ValueError(f"The provided operator {op_name} is not recognized.")
|
| 81 |
+
return img
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class AutoAugmentPolicy(Enum):
|
| 85 |
+
"""AutoAugment policies learned on different datasets.
|
| 86 |
+
Available policies are IMAGENET, CIFAR10 and SVHN.
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
IMAGENET = "imagenet"
|
| 90 |
+
CIFAR10 = "cifar10"
|
| 91 |
+
SVHN = "svhn"
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# FIXME: Eliminate copy-pasted code for fill standardization and _augmentation_space() by moving stuff on a base class
|
| 95 |
+
class AutoAugment(torch.nn.Module):
|
| 96 |
+
r"""AutoAugment data augmentation method based on
|
| 97 |
+
`"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
|
| 98 |
+
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
|
| 99 |
+
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
| 100 |
+
If img is PIL Image, it is expected to be in mode "L" or "RGB".
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
policy (AutoAugmentPolicy): Desired policy enum defined by
|
| 104 |
+
:class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``.
|
| 105 |
+
interpolation (InterpolationMode): Desired interpolation enum defined by
|
| 106 |
+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
|
| 107 |
+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
|
| 108 |
+
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
|
| 109 |
+
image. If given a number, the value is used for all bands respectively.
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
def __init__(
|
| 113 |
+
self,
|
| 114 |
+
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
|
| 115 |
+
interpolation: InterpolationMode = InterpolationMode.NEAREST,
|
| 116 |
+
fill: Optional[List[float]] = None,
|
| 117 |
+
) -> None:
|
| 118 |
+
super().__init__()
|
| 119 |
+
self.policy = policy
|
| 120 |
+
self.interpolation = interpolation
|
| 121 |
+
self.fill = fill
|
| 122 |
+
self.policies = self._get_policies(policy)
|
| 123 |
+
|
| 124 |
+
def _get_policies(
|
| 125 |
+
self, policy: AutoAugmentPolicy
|
| 126 |
+
) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
|
| 127 |
+
if policy == AutoAugmentPolicy.IMAGENET:
|
| 128 |
+
return [
|
| 129 |
+
(("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
|
| 130 |
+
(("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
|
| 131 |
+
(("Equalize", 0.8, None), ("Equalize", 0.6, None)),
|
| 132 |
+
(("Posterize", 0.6, 7), ("Posterize", 0.6, 6)),
|
| 133 |
+
(("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
|
| 134 |
+
(("Equalize", 0.4, None), ("Rotate", 0.8, 8)),
|
| 135 |
+
(("Solarize", 0.6, 3), ("Equalize", 0.6, None)),
|
| 136 |
+
(("Posterize", 0.8, 5), ("Equalize", 1.0, None)),
|
| 137 |
+
(("Rotate", 0.2, 3), ("Solarize", 0.6, 8)),
|
| 138 |
+
(("Equalize", 0.6, None), ("Posterize", 0.4, 6)),
|
| 139 |
+
(("Rotate", 0.8, 8), ("Color", 0.4, 0)),
|
| 140 |
+
(("Rotate", 0.4, 9), ("Equalize", 0.6, None)),
|
| 141 |
+
(("Equalize", 0.0, None), ("Equalize", 0.8, None)),
|
| 142 |
+
(("Invert", 0.6, None), ("Equalize", 1.0, None)),
|
| 143 |
+
(("Color", 0.6, 4), ("Contrast", 1.0, 8)),
|
| 144 |
+
(("Rotate", 0.8, 8), ("Color", 1.0, 2)),
|
| 145 |
+
(("Color", 0.8, 8), ("Solarize", 0.8, 7)),
|
| 146 |
+
(("Sharpness", 0.4, 7), ("Invert", 0.6, None)),
|
| 147 |
+
(("ShearX", 0.6, 5), ("Equalize", 1.0, None)),
|
| 148 |
+
(("Color", 0.4, 0), ("Equalize", 0.6, None)),
|
| 149 |
+
(("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
|
| 150 |
+
(("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
|
| 151 |
+
(("Invert", 0.6, None), ("Equalize", 1.0, None)),
|
| 152 |
+
(("Color", 0.6, 4), ("Contrast", 1.0, 8)),
|
| 153 |
+
(("Equalize", 0.8, None), ("Equalize", 0.6, None)),
|
| 154 |
+
]
|
| 155 |
+
elif policy == AutoAugmentPolicy.CIFAR10:
|
| 156 |
+
return [
|
| 157 |
+
(("Invert", 0.1, None), ("Contrast", 0.2, 6)),
|
| 158 |
+
(("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)),
|
| 159 |
+
(("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
|
| 160 |
+
(("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)),
|
| 161 |
+
(("AutoContrast", 0.5, None), ("Equalize", 0.9, None)),
|
| 162 |
+
(("ShearY", 0.2, 7), ("Posterize", 0.3, 7)),
|
| 163 |
+
(("Color", 0.4, 3), ("Brightness", 0.6, 7)),
|
| 164 |
+
(("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)),
|
| 165 |
+
(("Equalize", 0.6, None), ("Equalize", 0.5, None)),
|
| 166 |
+
(("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)),
|
| 167 |
+
(("Color", 0.7, 7), ("TranslateX", 0.5, 8)),
|
| 168 |
+
(("Equalize", 0.3, None), ("AutoContrast", 0.4, None)),
|
| 169 |
+
(("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)),
|
| 170 |
+
(("Brightness", 0.9, 6), ("Color", 0.2, 8)),
|
| 171 |
+
(("Solarize", 0.5, 2), ("Invert", 0.0, None)),
|
| 172 |
+
(("Equalize", 0.2, None), ("AutoContrast", 0.6, None)),
|
| 173 |
+
(("Equalize", 0.2, None), ("Equalize", 0.6, None)),
|
| 174 |
+
(("Color", 0.9, 9), ("Equalize", 0.6, None)),
|
| 175 |
+
(("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)),
|
| 176 |
+
(("Brightness", 0.1, 3), ("Color", 0.7, 0)),
|
| 177 |
+
(("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)),
|
| 178 |
+
(("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)),
|
| 179 |
+
(("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)),
|
| 180 |
+
(("Equalize", 0.8, None), ("Invert", 0.1, None)),
|
| 181 |
+
(("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)),
|
| 182 |
+
]
|
| 183 |
+
elif policy == AutoAugmentPolicy.SVHN:
|
| 184 |
+
return [
|
| 185 |
+
(("ShearX", 0.9, 4), ("Invert", 0.2, None)),
|
| 186 |
+
(("ShearY", 0.9, 8), ("Invert", 0.7, None)),
|
| 187 |
+
(("Equalize", 0.6, None), ("Solarize", 0.6, 6)),
|
| 188 |
+
(("Invert", 0.9, None), ("Equalize", 0.6, None)),
|
| 189 |
+
(("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
|
| 190 |
+
(("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)),
|
| 191 |
+
(("ShearY", 0.9, 8), ("Invert", 0.4, None)),
|
| 192 |
+
(("ShearY", 0.9, 5), ("Solarize", 0.2, 6)),
|
| 193 |
+
(("Invert", 0.9, None), ("AutoContrast", 0.8, None)),
|
| 194 |
+
(("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
|
| 195 |
+
(("ShearX", 0.9, 4), ("Solarize", 0.3, 3)),
|
| 196 |
+
(("ShearY", 0.8, 8), ("Invert", 0.7, None)),
|
| 197 |
+
(("Equalize", 0.9, None), ("TranslateY", 0.6, 6)),
|
| 198 |
+
(("Invert", 0.9, None), ("Equalize", 0.6, None)),
|
| 199 |
+
(("Contrast", 0.3, 3), ("Rotate", 0.8, 4)),
|
| 200 |
+
(("Invert", 0.8, None), ("TranslateY", 0.0, 2)),
|
| 201 |
+
(("ShearY", 0.7, 6), ("Solarize", 0.4, 8)),
|
| 202 |
+
(("Invert", 0.6, None), ("Rotate", 0.8, 4)),
|
| 203 |
+
(("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)),
|
| 204 |
+
(("ShearX", 0.1, 6), ("Invert", 0.6, None)),
|
| 205 |
+
(("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)),
|
| 206 |
+
(("ShearY", 0.8, 4), ("Invert", 0.8, None)),
|
| 207 |
+
(("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)),
|
| 208 |
+
(("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)),
|
| 209 |
+
(("ShearX", 0.7, 2), ("Invert", 0.1, None)),
|
| 210 |
+
]
|
| 211 |
+
else:
|
| 212 |
+
raise ValueError(f"The provided policy {policy} is not recognized.")
|
| 213 |
+
|
| 214 |
+
def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
|
| 215 |
+
return {
|
| 216 |
+
# op_name: (magnitudes, signed)
|
| 217 |
+
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
|
| 218 |
+
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
|
| 219 |
+
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
|
| 220 |
+
"TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
|
| 221 |
+
"Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
|
| 222 |
+
"Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
|
| 223 |
+
"Color": (torch.linspace(0.0, 0.9, num_bins), True),
|
| 224 |
+
"Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
|
| 225 |
+
"Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
|
| 226 |
+
"Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
|
| 227 |
+
"Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
|
| 228 |
+
"AutoContrast": (torch.tensor(0.0), False),
|
| 229 |
+
"Equalize": (torch.tensor(0.0), False),
|
| 230 |
+
"Invert": (torch.tensor(0.0), False),
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
@staticmethod
|
| 234 |
+
def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]:
|
| 235 |
+
"""Get parameters for autoaugment transformation
|
| 236 |
+
|
| 237 |
+
Returns:
|
| 238 |
+
params required by the autoaugment transformation
|
| 239 |
+
"""
|
| 240 |
+
policy_id = int(torch.randint(transform_num, (1,)).item())
|
| 241 |
+
probs = torch.rand((2,))
|
| 242 |
+
signs = torch.randint(2, (2,))
|
| 243 |
+
|
| 244 |
+
return policy_id, probs, signs
|
| 245 |
+
|
| 246 |
+
def forward(self, img: Tensor) -> Tensor:
|
| 247 |
+
"""
|
| 248 |
+
img (PIL Image or Tensor): Image to be transformed.
|
| 249 |
+
|
| 250 |
+
Returns:
|
| 251 |
+
PIL Image or Tensor: AutoAugmented image.
|
| 252 |
+
"""
|
| 253 |
+
fill = self.fill
|
| 254 |
+
if isinstance(img, Tensor):
|
| 255 |
+
if isinstance(fill, (int, float)):
|
| 256 |
+
fill = [float(fill)] * F.get_image_num_channels(img)
|
| 257 |
+
elif fill is not None:
|
| 258 |
+
fill = [float(f) for f in fill]
|
| 259 |
+
|
| 260 |
+
transform_id, probs, signs = self.get_params(len(self.policies))
|
| 261 |
+
|
| 262 |
+
for i, (op_name, p, magnitude_id) in enumerate(self.policies[transform_id]):
|
| 263 |
+
if probs[i] <= p:
|
| 264 |
+
op_meta = self._augmentation_space(10, F.get_image_size(img))
|
| 265 |
+
magnitudes, signed = op_meta[op_name]
|
| 266 |
+
magnitude = float(magnitudes[magnitude_id].item()) if magnitude_id is not None else 0.0
|
| 267 |
+
if signed and signs[i] == 0:
|
| 268 |
+
magnitude *= -1.0
|
| 269 |
+
img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
|
| 270 |
+
|
| 271 |
+
return img
|
| 272 |
+
|
| 273 |
+
def __repr__(self) -> str:
|
| 274 |
+
return self.__class__.__name__ + f"(policy={self.policy}, fill={self.fill})"
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class RandAugment(torch.nn.Module):
|
| 278 |
+
r"""RandAugment data augmentation method based on
|
| 279 |
+
`"RandAugment: Practical automated data augmentation with a reduced search space"
|
| 280 |
+
<https://arxiv.org/abs/1909.13719>`_.
|
| 281 |
+
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
|
| 282 |
+
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
| 283 |
+
If img is PIL Image, it is expected to be in mode "L" or "RGB".
|
| 284 |
+
|
| 285 |
+
Args:
|
| 286 |
+
num_ops (int): Number of augmentation transformations to apply sequentially.
|
| 287 |
+
magnitude (int): Magnitude for all the transformations.
|
| 288 |
+
num_magnitude_bins (int): The number of different magnitude values.
|
| 289 |
+
interpolation (InterpolationMode): Desired interpolation enum defined by
|
| 290 |
+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
|
| 291 |
+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
|
| 292 |
+
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
|
| 293 |
+
image. If given a number, the value is used for all bands respectively.
|
| 294 |
+
"""
|
| 295 |
+
|
| 296 |
+
def __init__(
|
| 297 |
+
self,
|
| 298 |
+
num_ops: int = 2,
|
| 299 |
+
magnitude: int = 9,
|
| 300 |
+
num_magnitude_bins: int = 31,
|
| 301 |
+
interpolation: InterpolationMode = InterpolationMode.NEAREST,
|
| 302 |
+
fill: Optional[List[float]] = None,
|
| 303 |
+
) -> None:
|
| 304 |
+
super().__init__()
|
| 305 |
+
self.num_ops = num_ops
|
| 306 |
+
self.magnitude = magnitude
|
| 307 |
+
self.num_magnitude_bins = num_magnitude_bins
|
| 308 |
+
self.interpolation = interpolation
|
| 309 |
+
self.fill = fill
|
| 310 |
+
|
| 311 |
+
def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
|
| 312 |
+
return {
|
| 313 |
+
# op_name: (magnitudes, signed)
|
| 314 |
+
"Identity": (torch.tensor(0.0), False),
|
| 315 |
+
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
|
| 316 |
+
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
|
| 317 |
+
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
|
| 318 |
+
"TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
|
| 319 |
+
"Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
|
| 320 |
+
"Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
|
| 321 |
+
"Color": (torch.linspace(0.0, 0.9, num_bins), True),
|
| 322 |
+
"Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
|
| 323 |
+
"Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
|
| 324 |
+
"Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
|
| 325 |
+
"Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
|
| 326 |
+
"AutoContrast": (torch.tensor(0.0), False),
|
| 327 |
+
"Equalize": (torch.tensor(0.0), False),
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
def forward(self, img: Tensor) -> Tensor:
|
| 331 |
+
"""
|
| 332 |
+
img (PIL Image or Tensor): Image to be transformed.
|
| 333 |
+
|
| 334 |
+
Returns:
|
| 335 |
+
PIL Image or Tensor: Transformed image.
|
| 336 |
+
"""
|
| 337 |
+
fill = self.fill
|
| 338 |
+
if isinstance(img, Tensor):
|
| 339 |
+
if isinstance(fill, (int, float)):
|
| 340 |
+
fill = [float(fill)] * F.get_image_num_channels(img)
|
| 341 |
+
elif fill is not None:
|
| 342 |
+
fill = [float(f) for f in fill]
|
| 343 |
+
|
| 344 |
+
for _ in range(self.num_ops):
|
| 345 |
+
op_meta = self._augmentation_space(self.num_magnitude_bins, F.get_image_size(img))
|
| 346 |
+
op_index = int(torch.randint(len(op_meta), (1,)).item())
|
| 347 |
+
op_name = list(op_meta.keys())[op_index]
|
| 348 |
+
magnitudes, signed = op_meta[op_name]
|
| 349 |
+
magnitude = float(magnitudes[self.magnitude].item()) if magnitudes.ndim > 0 else 0.0
|
| 350 |
+
if signed and torch.randint(2, (1,)):
|
| 351 |
+
magnitude *= -1.0
|
| 352 |
+
img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
|
| 353 |
+
|
| 354 |
+
return img
|
| 355 |
+
|
| 356 |
+
def __repr__(self) -> str:
|
| 357 |
+
s = self.__class__.__name__ + "("
|
| 358 |
+
s += "num_ops={num_ops}"
|
| 359 |
+
s += ", magnitude={magnitude}"
|
| 360 |
+
s += ", num_magnitude_bins={num_magnitude_bins}"
|
| 361 |
+
s += ", interpolation={interpolation}"
|
| 362 |
+
s += ", fill={fill}"
|
| 363 |
+
s += ")"
|
| 364 |
+
return s.format(**self.__dict__)
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
class TrivialAugmentWide(torch.nn.Module):
|
| 368 |
+
r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in
|
| 369 |
+
`"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`_.
|
| 370 |
+
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
|
| 371 |
+
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
| 372 |
+
If img is PIL Image, it is expected to be in mode "L" or "RGB".
|
| 373 |
+
|
| 374 |
+
Args:
|
| 375 |
+
num_magnitude_bins (int): The number of different magnitude values.
|
| 376 |
+
interpolation (InterpolationMode): Desired interpolation enum defined by
|
| 377 |
+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
|
| 378 |
+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
|
| 379 |
+
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
|
| 380 |
+
image. If given a number, the value is used for all bands respectively.
|
| 381 |
+
"""
|
| 382 |
+
|
| 383 |
+
def __init__(
|
| 384 |
+
self,
|
| 385 |
+
num_magnitude_bins: int = 31,
|
| 386 |
+
interpolation: InterpolationMode = InterpolationMode.NEAREST,
|
| 387 |
+
fill: Optional[List[float]] = None,
|
| 388 |
+
) -> None:
|
| 389 |
+
super().__init__()
|
| 390 |
+
self.num_magnitude_bins = num_magnitude_bins
|
| 391 |
+
self.interpolation = interpolation
|
| 392 |
+
self.fill = fill
|
| 393 |
+
|
| 394 |
+
def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]:
|
| 395 |
+
return {
|
| 396 |
+
# op_name: (magnitudes, signed)
|
| 397 |
+
"Identity": (torch.tensor(0.0), False),
|
| 398 |
+
"ShearX": (torch.linspace(0.0, 0.99, num_bins), True),
|
| 399 |
+
"ShearY": (torch.linspace(0.0, 0.99, num_bins), True),
|
| 400 |
+
"TranslateX": (torch.linspace(0.0, 32.0, num_bins), True),
|
| 401 |
+
"TranslateY": (torch.linspace(0.0, 32.0, num_bins), True),
|
| 402 |
+
"Rotate": (torch.linspace(0.0, 135.0, num_bins), True),
|
| 403 |
+
"Brightness": (torch.linspace(0.0, 0.99, num_bins), True),
|
| 404 |
+
"Color": (torch.linspace(0.0, 0.99, num_bins), True),
|
| 405 |
+
"Contrast": (torch.linspace(0.0, 0.99, num_bins), True),
|
| 406 |
+
"Sharpness": (torch.linspace(0.0, 0.99, num_bins), True),
|
| 407 |
+
"Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)).round().int(), False),
|
| 408 |
+
"Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
|
| 409 |
+
"AutoContrast": (torch.tensor(0.0), False),
|
| 410 |
+
"Equalize": (torch.tensor(0.0), False),
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
def forward(self, img: Tensor) -> Tensor:
|
| 414 |
+
"""
|
| 415 |
+
img (PIL Image or Tensor): Image to be transformed.
|
| 416 |
+
|
| 417 |
+
Returns:
|
| 418 |
+
PIL Image or Tensor: Transformed image.
|
| 419 |
+
"""
|
| 420 |
+
fill = self.fill
|
| 421 |
+
if isinstance(img, Tensor):
|
| 422 |
+
if isinstance(fill, (int, float)):
|
| 423 |
+
fill = [float(fill)] * F.get_image_num_channels(img)
|
| 424 |
+
elif fill is not None:
|
| 425 |
+
fill = [float(f) for f in fill]
|
| 426 |
+
|
| 427 |
+
op_meta = self._augmentation_space(self.num_magnitude_bins)
|
| 428 |
+
op_index = int(torch.randint(len(op_meta), (1,)).item())
|
| 429 |
+
op_name = list(op_meta.keys())[op_index]
|
| 430 |
+
magnitudes, signed = op_meta[op_name]
|
| 431 |
+
magnitude = (
|
| 432 |
+
float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item())
|
| 433 |
+
if magnitudes.ndim > 0
|
| 434 |
+
else 0.0
|
| 435 |
+
)
|
| 436 |
+
if signed and torch.randint(2, (1,)):
|
| 437 |
+
magnitude *= -1.0
|
| 438 |
+
|
| 439 |
+
return _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
|
| 440 |
+
|
| 441 |
+
def __repr__(self) -> str:
|
| 442 |
+
s = self.__class__.__name__ + "("
|
| 443 |
+
s += "num_magnitude_bins={num_magnitude_bins}"
|
| 444 |
+
s += ", interpolation={interpolation}"
|
| 445 |
+
s += ", fill={fill}"
|
| 446 |
+
s += ")"
|
| 447 |
+
return s.format(**self.__dict__)
|
MedImageInsight/ImageDataLoader/transforms/build.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import absolute_import
|
| 2 |
+
from __future__ import division
|
| 3 |
+
from __future__ import print_function
|
| 4 |
+
|
| 5 |
+
import timm
|
| 6 |
+
from timm.data import create_transform
|
| 7 |
+
|
| 8 |
+
from yacs.config import CfgNode as CN
|
| 9 |
+
from PIL import ImageFilter
|
| 10 |
+
import logging
|
| 11 |
+
import random
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torchvision.transforms as T
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
from .autoaugment import AutoAugmentPolicy
|
| 18 |
+
from .autoaugment import AutoAugment
|
| 19 |
+
from .autoaugment import RandAugment
|
| 20 |
+
from .autoaugment import TrivialAugmentWide
|
| 21 |
+
from .threeaugment import deitIII_Solarization
|
| 22 |
+
from .threeaugment import deitIII_gray_scale
|
| 23 |
+
from .threeaugment import deitIII_GaussianBlur
|
| 24 |
+
|
| 25 |
+
from PIL import ImageOps
|
| 26 |
+
from timm.data.transforms import RandomResizedCropAndInterpolation
|
| 27 |
+
|
| 28 |
+
logger = logging.getLogger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class GaussianBlur(object):
|
| 32 |
+
"""Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, sigma=[.1, 2.]):
|
| 35 |
+
self.sigma = sigma
|
| 36 |
+
|
| 37 |
+
def __call__(self, x):
|
| 38 |
+
sigma = random.uniform(self.sigma[0], self.sigma[1])
|
| 39 |
+
x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
|
| 40 |
+
return x
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def get_resolution(original_resolution):
|
| 44 |
+
"""Takes (H,W) and returns (precrop, crop)."""
|
| 45 |
+
area = original_resolution[0] * original_resolution[1]
|
| 46 |
+
return (160, 128) if area < 96*96 else (512, 480)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
INTERPOLATION_MODES = {
|
| 50 |
+
'bilinear': T.InterpolationMode.BILINEAR,
|
| 51 |
+
'bicubic': T.InterpolationMode.BICUBIC,
|
| 52 |
+
'nearest': T.InterpolationMode.NEAREST,
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def build_transforms(cfg, is_train=True):
|
| 57 |
+
# assert isinstance(cfg.DATASET.OUTPUT_SIZE, (list, tuple)), 'DATASET.OUTPUT_SIZE should be list or tuple'
|
| 58 |
+
normalize = T.Normalize(
|
| 59 |
+
mean=cfg['IMAGE_ENCODER']['IMAGE_MEAN'],
|
| 60 |
+
std=cfg['IMAGE_ENCODER']['IMAGE_STD']
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
transforms = None
|
| 64 |
+
if is_train:
|
| 65 |
+
if 'THREE_AUG' in cfg['AUG']:
|
| 66 |
+
img_size = cfg['IMAGE_ENCODER']['IMAGE_SIZE']
|
| 67 |
+
remove_random_resized_crop = cfg['AUG']['THREE_AUG']['SRC']
|
| 68 |
+
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
| 69 |
+
primary_tfl = []
|
| 70 |
+
scale=(0.08, 1.0)
|
| 71 |
+
interpolation='bicubic'
|
| 72 |
+
if remove_random_resized_crop:
|
| 73 |
+
primary_tfl = [
|
| 74 |
+
T.Resize(img_size, interpolation=3),
|
| 75 |
+
T.RandomCrop(img_size, padding=4,padding_mode='reflect'),
|
| 76 |
+
T.RandomHorizontalFlip()
|
| 77 |
+
]
|
| 78 |
+
else:
|
| 79 |
+
primary_tfl = [
|
| 80 |
+
RandomResizedCropAndInterpolation(
|
| 81 |
+
img_size, scale=scale, interpolation=interpolation),
|
| 82 |
+
T.RandomHorizontalFlip()
|
| 83 |
+
]
|
| 84 |
+
secondary_tfl = [T.RandomChoice([gray_scale(p=1.0),
|
| 85 |
+
Solarization(p=1.0),
|
| 86 |
+
GaussianBlurDeiTv3(p=1.0)])]
|
| 87 |
+
color_jitter = cfg['AUG']['THREE_AUG']['COLOR_JITTER']
|
| 88 |
+
if color_jitter is not None and not color_jitter==0:
|
| 89 |
+
secondary_tfl.append(T.ColorJitter(color_jitter, color_jitter, color_jitter))
|
| 90 |
+
final_tfl = [
|
| 91 |
+
T.ToTensor(),
|
| 92 |
+
T.Normalize(
|
| 93 |
+
mean=torch.tensor(mean),
|
| 94 |
+
std=torch.tensor(std))
|
| 95 |
+
]
|
| 96 |
+
return T.Compose(primary_tfl+secondary_tfl+final_tfl)
|
| 97 |
+
elif 'TIMM_AUG' in cfg['AUG'] and cfg['AUG']['TIMM_AUG']['USE_TRANSFORM']:
|
| 98 |
+
logger.info('=> use timm transform for training')
|
| 99 |
+
timm_cfg = cfg['AUG']['TIMM_AUG']
|
| 100 |
+
transforms = create_transform(
|
| 101 |
+
input_size=cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0],
|
| 102 |
+
is_training=True,
|
| 103 |
+
use_prefetcher=False,
|
| 104 |
+
no_aug=False,
|
| 105 |
+
re_prob=timm_cfg.get('RE_PROB', 0.),
|
| 106 |
+
re_mode=timm_cfg.get('RE_MODE', 'const'),
|
| 107 |
+
re_count=timm_cfg.get('RE_COUNT', 1),
|
| 108 |
+
re_num_splits= 0 if not timm_cfg.get('RE_SPLITS', False) else timm_cfg['RE_SPLITS'], # if false or 0, return 0
|
| 109 |
+
scale=cfg['AUG'].get('SCALE', None),
|
| 110 |
+
ratio=cfg['AUG'].get('RATIO', None),
|
| 111 |
+
hflip=timm_cfg.get('HFLIP', 0.5),
|
| 112 |
+
vflip=timm_cfg.get('VFLIP', 0.),
|
| 113 |
+
color_jitter=timm_cfg.get('COLOR_JITTER', 0.4),
|
| 114 |
+
auto_augment=timm_cfg.get('AUTO_AUGMENT', None),
|
| 115 |
+
interpolation=cfg['AUG']['INTERPOLATION'],
|
| 116 |
+
mean=cfg['IMAGE_ENCODER']['IMAGE_MEAN'],
|
| 117 |
+
std=cfg['IMAGE_ENCODER']['IMAGE_STD'],
|
| 118 |
+
)
|
| 119 |
+
elif 'TORCHVISION_AUG' in cfg['AUG']:
|
| 120 |
+
logger.info('=> use torchvision transform fro training')
|
| 121 |
+
crop_size = cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0]
|
| 122 |
+
interpolation = INTERPOLATION_MODES[cfg['AUG']['INTERPOLATION']]
|
| 123 |
+
trans = [
|
| 124 |
+
T.RandomResizedCrop(
|
| 125 |
+
crop_size, scale=cfg['AUG']['SCALE'], ratio=cfg['AUG']['RATIO'],
|
| 126 |
+
interpolation=interpolation
|
| 127 |
+
)
|
| 128 |
+
]
|
| 129 |
+
hflip_prob = cfg['AUG']['TORCHVISION_AUG']['HFLIP']
|
| 130 |
+
auto_augment_policy = cfg['AUG']['TORCHVISION_AUG'].get('AUTO_AUGMENT', None)
|
| 131 |
+
if hflip_prob > 0:
|
| 132 |
+
trans.append(T.RandomHorizontalFlip(hflip_prob))
|
| 133 |
+
if auto_augment_policy is not None:
|
| 134 |
+
if auto_augment_policy == "ra":
|
| 135 |
+
trans.append(RandAugment(interpolation=interpolation))
|
| 136 |
+
elif auto_augment_policy == "ta_wide":
|
| 137 |
+
trans.append(TrivialAugmentWide(interpolation=interpolation))
|
| 138 |
+
else:
|
| 139 |
+
aa_policy = AutoAugmentPolicy(auto_augment_policy)
|
| 140 |
+
trans.append(AutoAugment(policy=aa_policy, interpolation=interpolation))
|
| 141 |
+
trans.extend(
|
| 142 |
+
[
|
| 143 |
+
T.ToTensor(),
|
| 144 |
+
normalize,
|
| 145 |
+
]
|
| 146 |
+
)
|
| 147 |
+
random_erase_prob = cfg['AUG']['TORCHVISION_AUG']['RE_PROB']
|
| 148 |
+
random_erase_scale = cfg['AUG']['TORCHVISION_AUG'].get('RE_SCALE', 0.33)
|
| 149 |
+
if random_erase_prob > 0:
|
| 150 |
+
# NCFC (4/26/2023): Added scale parameter to random erasing for medical imaging
|
| 151 |
+
trans.append(T.RandomErasing(p=random_erase_prob, scale = (0.02, random_erase_scale)))
|
| 152 |
+
|
| 153 |
+
from torchvision.transforms import InterpolationMode
|
| 154 |
+
rotation = cfg['AUG']['TORCHVISION_AUG'].get('ROTATION', 0.0)
|
| 155 |
+
if (rotation > 0.0):
|
| 156 |
+
trans.append(T.RandomRotation(rotation, interpolation=InterpolationMode.BILINEAR))
|
| 157 |
+
logger.info(" TORCH AUG: Rotation: " + str(rotation))
|
| 158 |
+
|
| 159 |
+
transforms = T.Compose(trans)
|
| 160 |
+
elif cfg['AUG'].get('RANDOM_CENTER_CROP', False):
|
| 161 |
+
logger.info('=> use random center crop data augmenation')
|
| 162 |
+
# precrop, crop = get_resolution(cfg.TRAIN.IMAGE_SIZE)
|
| 163 |
+
crop = cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0]
|
| 164 |
+
padding = cfg['AUG'].get('RANDOM_CENTER_CROP_PADDING', 32)
|
| 165 |
+
precrop = crop + padding
|
| 166 |
+
mode = INTERPOLATION_MODES[cfg['AUG']['INTERPOLATION']]
|
| 167 |
+
transforms = T.Compose([
|
| 168 |
+
T.Resize(
|
| 169 |
+
(precrop, precrop),
|
| 170 |
+
interpolation=mode
|
| 171 |
+
),
|
| 172 |
+
T.RandomCrop((crop, crop)),
|
| 173 |
+
T.RandomHorizontalFlip(),
|
| 174 |
+
T.ToTensor(),
|
| 175 |
+
normalize,
|
| 176 |
+
])
|
| 177 |
+
elif cfg['AUG'].get('MAE_FINETUNE_AUG', False):
|
| 178 |
+
mean = cfg['IMAGE_ENCODER']['IMAGE_MEAN']
|
| 179 |
+
std = cfg['IMAGE_ENCODER']['IMAGE_STD']
|
| 180 |
+
transforms = create_transform(
|
| 181 |
+
input_size=cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0],
|
| 182 |
+
is_training=True,
|
| 183 |
+
color_jitter=cfg['AUG'].get('COLOR_JITTER', None),
|
| 184 |
+
auto_augment=cfg['AUG'].get('AUTO_AUGMENT', 'rand-m9-mstd0.5-inc1'),
|
| 185 |
+
interpolation='bicubic',
|
| 186 |
+
re_prob=cfg['AUG'].get('RE_PROB', 0.25),
|
| 187 |
+
re_mode=cfg['AUG'].get('RE_MODE', "pixel"),
|
| 188 |
+
re_count=cfg['AUG'].get('RE_COUNT', 1),
|
| 189 |
+
mean=mean,
|
| 190 |
+
std=std,
|
| 191 |
+
)
|
| 192 |
+
elif cfg['AUG'].get('MAE_PRETRAIN_AUG', False):
|
| 193 |
+
mean = cfg['IMAGE_ENCODER']['IMAGE_MEAN']
|
| 194 |
+
std = cfg['IMAGE_ENCODER']['IMAGE_STD']
|
| 195 |
+
transforms = T.Compose([
|
| 196 |
+
T.RandomResizedCrop(cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0], scale=tuple(cfg['AUG']['SCALE']), interpolation=INTERPOLATION_MODES["bicubic"]), # 3 is bicubic
|
| 197 |
+
T.RandomHorizontalFlip(),
|
| 198 |
+
T.ToTensor(),
|
| 199 |
+
T.Normalize(mean=mean, std=std)])
|
| 200 |
+
elif cfg['AUG'].get('ThreeAugment', False): # from DeiT III
|
| 201 |
+
mean = cfg['IMAGE_ENCODER']['IMAGE_MEAN']
|
| 202 |
+
std = cfg['IMAGE_ENCODER']['IMAGE_STD']
|
| 203 |
+
img_size = cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0]
|
| 204 |
+
remove_random_resized_crop = cfg['AUG'].get('src', False)
|
| 205 |
+
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
| 206 |
+
primary_tfl = []
|
| 207 |
+
scale=(0.08, 1.0)
|
| 208 |
+
interpolation='bicubic'
|
| 209 |
+
if remove_random_resized_crop:
|
| 210 |
+
primary_tfl = [
|
| 211 |
+
T.Resize(img_size, interpolation=3), # bicubic
|
| 212 |
+
T.RandomCrop(img_size, padding=4,padding_mode='reflect'),
|
| 213 |
+
T.RandomHorizontalFlip()
|
| 214 |
+
]
|
| 215 |
+
else:
|
| 216 |
+
primary_tfl = [
|
| 217 |
+
timm.data.transforms.RandomResizedCropAndInterpolation(
|
| 218 |
+
img_size, scale=scale, interpolation=interpolation),
|
| 219 |
+
T.RandomHorizontalFlip()
|
| 220 |
+
]
|
| 221 |
+
|
| 222 |
+
secondary_tfl = [T.RandomChoice([deitIII_gray_scale(p=1.0),
|
| 223 |
+
deitIII_Solarization(p=1.0),
|
| 224 |
+
deitIII_GaussianBlur(p=1.0)])]
|
| 225 |
+
color_jitter = cfg['AUG']['COLOR_JITTER']
|
| 226 |
+
secondary_tfl.append(T.ColorJitter(color_jitter, color_jitter, color_jitter))
|
| 227 |
+
final_tfl = [
|
| 228 |
+
T.ToTensor(),
|
| 229 |
+
T.Normalize(
|
| 230 |
+
mean=torch.tensor(mean),
|
| 231 |
+
std=torch.tensor(std))
|
| 232 |
+
]
|
| 233 |
+
transforms = T.Compose(primary_tfl+secondary_tfl+final_tfl)
|
| 234 |
+
logger.info('=> training transformers: {}'.format(transforms))
|
| 235 |
+
else:
|
| 236 |
+
mode = INTERPOLATION_MODES[cfg['AUG']['INTERPOLATION']]
|
| 237 |
+
if cfg['TEST']['CENTER_CROP']:
|
| 238 |
+
transforms = T.Compose([
|
| 239 |
+
T.Resize(
|
| 240 |
+
int(cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0] / 0.875),
|
| 241 |
+
# the same behavior as in deit: size = int((256 / 224) * args.input_size)
|
| 242 |
+
# 224 / 256 = 0.875
|
| 243 |
+
interpolation=mode
|
| 244 |
+
),
|
| 245 |
+
T.CenterCrop(cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0]),
|
| 246 |
+
T.ToTensor(),
|
| 247 |
+
normalize,
|
| 248 |
+
])
|
| 249 |
+
else:
|
| 250 |
+
transforms = T.Compose([
|
| 251 |
+
T.Resize(
|
| 252 |
+
(cfg['IMAGE_ENCODER']['IMAGE_SIZE'][1], cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0]),
|
| 253 |
+
interpolation=mode
|
| 254 |
+
),
|
| 255 |
+
T.ToTensor(),
|
| 256 |
+
normalize,
|
| 257 |
+
])
|
| 258 |
+
logger.info('=> testing transformers: {}'.format(transforms))
|
| 259 |
+
|
| 260 |
+
return transforms
|
| 261 |
+
|
MedImageInsight/ImageDataLoader/transforms/threeaugment.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from PIL import ImageFilter, ImageOps
|
| 3 |
+
from torchvision import transforms
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class deitIII_GaussianBlur(object):
|
| 7 |
+
"""
|
| 8 |
+
Apply Gaussian Blur to the PIL image.
|
| 9 |
+
"""
|
| 10 |
+
def __init__(self, p=0.1, radius_min=0.1, radius_max=2.):
|
| 11 |
+
self.prob = p
|
| 12 |
+
self.radius_min = radius_min
|
| 13 |
+
self.radius_max = radius_max
|
| 14 |
+
|
| 15 |
+
def __call__(self, img):
|
| 16 |
+
do_it = random.random() <= self.prob
|
| 17 |
+
if not do_it:
|
| 18 |
+
return img
|
| 19 |
+
|
| 20 |
+
img = img.filter(
|
| 21 |
+
ImageFilter.GaussianBlur(
|
| 22 |
+
radius=random.uniform(self.radius_min, self.radius_max)
|
| 23 |
+
)
|
| 24 |
+
)
|
| 25 |
+
return img
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class deitIII_Solarization(object):
|
| 29 |
+
"""
|
| 30 |
+
Apply Solarization to the PIL image.
|
| 31 |
+
"""
|
| 32 |
+
def __init__(self, p=0.2):
|
| 33 |
+
self.p = p
|
| 34 |
+
|
| 35 |
+
def __call__(self, img):
|
| 36 |
+
if random.random() < self.p:
|
| 37 |
+
return ImageOps.solarize(img)
|
| 38 |
+
else:
|
| 39 |
+
return img
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class deitIII_gray_scale(object):
|
| 43 |
+
"""
|
| 44 |
+
Apply Solarization to the PIL image.
|
| 45 |
+
"""
|
| 46 |
+
def __init__(self, p=0.2):
|
| 47 |
+
self.p = p
|
| 48 |
+
self.transf = transforms.Grayscale(3)
|
| 49 |
+
|
| 50 |
+
def __call__(self, img):
|
| 51 |
+
if random.random() < self.p:
|
| 52 |
+
return self.transf(img)
|
| 53 |
+
else:
|
| 54 |
+
return img
|
MedImageInsight/ImageDataLoader/tsv.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import absolute_import
|
| 2 |
+
from __future__ import division
|
| 3 |
+
from __future__ import print_function
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from io import BytesIO
|
| 7 |
+
import json
|
| 8 |
+
import logging
|
| 9 |
+
import base64
|
| 10 |
+
import random
|
| 11 |
+
from typing import Callable, List, Tuple, Union, NamedTuple
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from PIL import ImageFile
|
| 14 |
+
import torch.utils.data as data
|
| 15 |
+
from .languages.prompt_engineering import prompt_engineering
|
| 16 |
+
from .tsv_file import TSVFile, CompositeTSVFile
|
| 17 |
+
|
| 18 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class TSVDataset(data.Dataset):
|
| 24 |
+
|
| 25 |
+
def __init__(self,
|
| 26 |
+
tsv_file: Union[str, List[str]],
|
| 27 |
+
transform: Callable = None,
|
| 28 |
+
map_file: str = None,
|
| 29 |
+
token_file: str = None,
|
| 30 |
+
is_train: bool = True,
|
| 31 |
+
azcopy_path: str = None):
|
| 32 |
+
self.transform = transform
|
| 33 |
+
self._chunk_sizes = None
|
| 34 |
+
self.label2idx = self._load_map(map_file)
|
| 35 |
+
self.class_selector = list(self.label2idx.keys()) if self.label2idx else None
|
| 36 |
+
|
| 37 |
+
if isinstance(tsv_file, str):
|
| 38 |
+
if os.path.splitext(tsv_file)[1] == '.tsv':
|
| 39 |
+
self.tsv_file = TSVFile(
|
| 40 |
+
tsv_file, class_selector=self.class_selector
|
| 41 |
+
)
|
| 42 |
+
else:
|
| 43 |
+
self.tsv_file = CompositeTSVFile(
|
| 44 |
+
tsv_file,
|
| 45 |
+
class_selector=self.class_selector,
|
| 46 |
+
is_train=is_train,
|
| 47 |
+
sas_token_path=token_file,
|
| 48 |
+
azcopy_path=azcopy_path
|
| 49 |
+
)
|
| 50 |
+
self._chunk_sizes = self.tsv_file.get_chunk_size()
|
| 51 |
+
elif isinstance(tsv_file, list):
|
| 52 |
+
self.tsv_file = CompositeTSVFile(
|
| 53 |
+
tsv_file,
|
| 54 |
+
class_selector=self.class_selector,
|
| 55 |
+
is_train=is_train,
|
| 56 |
+
sas_token_path=token_file,
|
| 57 |
+
azcopy_path=azcopy_path
|
| 58 |
+
)
|
| 59 |
+
self._chunk_sizes = self.tsv_file.get_chunk_size()
|
| 60 |
+
else:
|
| 61 |
+
raise ValueError("Invalid input! Please check the tsv filenames")
|
| 62 |
+
|
| 63 |
+
logger.debug('=> {}\titems: {}'.format(tsv_file, len(self.tsv_file)))
|
| 64 |
+
|
| 65 |
+
def fetch_blob(self, idx):
|
| 66 |
+
image_tsv = self.tsv_file.file_list[idx]
|
| 67 |
+
self.tsv_file.blob_storage.fetch_blob(image_tsv)
|
| 68 |
+
|
| 69 |
+
def num_classes(self):
|
| 70 |
+
return len(self.class_selector)
|
| 71 |
+
|
| 72 |
+
def get_chunk_sizes(self):
|
| 73 |
+
return self._chunk_sizes
|
| 74 |
+
|
| 75 |
+
def get_class_boundaries(self):
|
| 76 |
+
# The samples of each class are organized class-by-class.
|
| 77 |
+
# _class_boundaries stores the lower- and upper-bound of each class.
|
| 78 |
+
return self.tsv_file.get_class_boundaries()
|
| 79 |
+
|
| 80 |
+
def get_filenames(self):
|
| 81 |
+
filenames = [
|
| 82 |
+
self.tsv_file.get_key(i)
|
| 83 |
+
for i in range(self.tsv_file.num_rows())
|
| 84 |
+
]
|
| 85 |
+
|
| 86 |
+
return filenames
|
| 87 |
+
|
| 88 |
+
def _load_map(self, map_file: str):
|
| 89 |
+
if not map_file:
|
| 90 |
+
return None
|
| 91 |
+
|
| 92 |
+
label2idx = {}
|
| 93 |
+
with open(map_file) as f:
|
| 94 |
+
for line in f:
|
| 95 |
+
items = line.strip().split('\t')
|
| 96 |
+
label2idx[items[0]] = int(items[1])
|
| 97 |
+
|
| 98 |
+
return label2idx
|
| 99 |
+
|
| 100 |
+
def __getitem__(self, index: Union[int, Tuple[int, int]]):
|
| 101 |
+
items = self.tsv_file[index]
|
| 102 |
+
_, target, img = self._decode_data(items)
|
| 103 |
+
|
| 104 |
+
if self.transform:
|
| 105 |
+
img = self.transform(img)
|
| 106 |
+
|
| 107 |
+
return img, target
|
| 108 |
+
|
| 109 |
+
def _decode_data(self, items: Tuple[str, str, str]):
|
| 110 |
+
key = items[0]
|
| 111 |
+
label = self._get_label(items[1])
|
| 112 |
+
image = Image.open(BytesIO(base64.b64decode(items[2]))).convert('RGB')
|
| 113 |
+
|
| 114 |
+
return key, label, image
|
| 115 |
+
|
| 116 |
+
def _get_label(self, item: str):
|
| 117 |
+
if not self.label2idx:
|
| 118 |
+
return int(item)
|
| 119 |
+
|
| 120 |
+
js = json.loads(item)
|
| 121 |
+
return self.label2idx[js[0]['class']]
|
| 122 |
+
|
| 123 |
+
def __len__(self):
|
| 124 |
+
return len(self.tsv_file)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class TSVMeta(NamedTuple):
|
| 128 |
+
source: str
|
| 129 |
+
num_classes: int
|
| 130 |
+
task: str
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class TSVImageTextDatasetV2(data.Dataset):
|
| 134 |
+
"""
|
| 135 |
+
This class is intended for encapsulating Image/Text pair data for contrastive learning described in
|
| 136 |
+
the following paper,
|
| 137 |
+
"Learning Transferable Visual Models From Natural Language Supervision" (a.k.a CLIP)
|
| 138 |
+
V2: support image text pairs and supervised classification data
|
| 139 |
+
"""
|
| 140 |
+
|
| 141 |
+
def __init__(self,
|
| 142 |
+
image_tsv_file: Union[str, List[str]],
|
| 143 |
+
text_tsv_file: Union[str, List[str]],
|
| 144 |
+
transform: Callable = None,
|
| 145 |
+
tokenize: Callable = None,
|
| 146 |
+
context_length: int = 77,
|
| 147 |
+
num_captions: int = 1,
|
| 148 |
+
text_format: str = 'txt',
|
| 149 |
+
is_train: bool = True,
|
| 150 |
+
sas_token_path: str = None,
|
| 151 |
+
azcopy_path: str = None,
|
| 152 |
+
metas: List[NamedTuple] = None,
|
| 153 |
+
prompt_engineering=True,
|
| 154 |
+
concat_queries=False):
|
| 155 |
+
self.transform = transform
|
| 156 |
+
self.tokenize = tokenize
|
| 157 |
+
self._chunk_sizes = None
|
| 158 |
+
self.context_length = context_length
|
| 159 |
+
self.num_captions = num_captions
|
| 160 |
+
self.text_format = text_format
|
| 161 |
+
self.tsv_file_list = []
|
| 162 |
+
self.metas = metas
|
| 163 |
+
self.label_offsets = self.build_label_offsets()
|
| 164 |
+
self.prompt_engineering = prompt_engineering
|
| 165 |
+
self.concat_queries = concat_queries
|
| 166 |
+
|
| 167 |
+
if isinstance(image_tsv_file, str) and isinstance(text_tsv_file, str):
|
| 168 |
+
# single tsv file
|
| 169 |
+
if (
|
| 170 |
+
os.path.splitext(image_tsv_file)[1].lower() == '.tsv'
|
| 171 |
+
and os.path.splitext(text_tsv_file)[1].lower() == '.tsv'
|
| 172 |
+
):
|
| 173 |
+
self.tsv_file_list.append((image_tsv_file, text_tsv_file))
|
| 174 |
+
self.image_tsv_file = TSVFile(
|
| 175 |
+
image_tsv_file, if_generate_lineidx=True
|
| 176 |
+
)
|
| 177 |
+
self.text_tsv_file = TSVFile(
|
| 178 |
+
text_tsv_file, if_generate_lineidx=True
|
| 179 |
+
)
|
| 180 |
+
else:
|
| 181 |
+
raise ValueError("Invalid input! Please check the tsv filenames.")
|
| 182 |
+
# multiple tsv files specified in a list
|
| 183 |
+
elif (
|
| 184 |
+
isinstance(image_tsv_file, list)
|
| 185 |
+
and isinstance(text_tsv_file, list)
|
| 186 |
+
):
|
| 187 |
+
assert len(image_tsv_file) == len(text_tsv_file), \
|
| 188 |
+
"Inconsistent number of Image/Text tsv files!"
|
| 189 |
+
self.tsv_file_list = [
|
| 190 |
+
(txt, img)
|
| 191 |
+
for img, txt in zip(image_tsv_file, text_tsv_file)
|
| 192 |
+
]
|
| 193 |
+
self.image_tsv_file = CompositeTSVFile(
|
| 194 |
+
image_tsv_file,
|
| 195 |
+
is_train=is_train,
|
| 196 |
+
sas_token_path=sas_token_path,
|
| 197 |
+
azcopy_path=azcopy_path
|
| 198 |
+
)
|
| 199 |
+
self.text_tsv_file = CompositeTSVFile(
|
| 200 |
+
text_tsv_file,
|
| 201 |
+
is_train=is_train,
|
| 202 |
+
sas_token_path=sas_token_path,
|
| 203 |
+
azcopy_path=azcopy_path
|
| 204 |
+
)
|
| 205 |
+
self._chunk_sizes = self.image_tsv_file.get_chunk_size()
|
| 206 |
+
else:
|
| 207 |
+
raise ValueError("Invalid input! Please check the tsv filenames.")
|
| 208 |
+
|
| 209 |
+
assert len(self.image_tsv_file) == len(self.text_tsv_file), \
|
| 210 |
+
"Inconsistent size of Image/Text ({}/{}) data!".format(
|
| 211 |
+
len(self.image_tsv_file), len(self.text_tsv_file)
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
def build_label_offsets(self):
|
| 215 |
+
if self.metas is None:
|
| 216 |
+
return None
|
| 217 |
+
|
| 218 |
+
label_offsets = {}
|
| 219 |
+
offset = 1
|
| 220 |
+
for meta in self.metas:
|
| 221 |
+
print(meta)
|
| 222 |
+
print(label_offsets)
|
| 223 |
+
label_offsets[meta.source] = offset
|
| 224 |
+
offset += meta.num_classes
|
| 225 |
+
|
| 226 |
+
return label_offsets
|
| 227 |
+
|
| 228 |
+
def fetch_blob(self, idx):
|
| 229 |
+
# image_tsv, text_tsv = self.tsv_file_list[idx]
|
| 230 |
+
image_tsv = self.image_tsv_file.file_list[idx]
|
| 231 |
+
text_tsv = self.text_tsv_file.file_list[idx]
|
| 232 |
+
self.image_tsv_file.blob_storage.fetch_blob(image_tsv)
|
| 233 |
+
self.text_tsv_file.blob_storage.fetch_blob(text_tsv)
|
| 234 |
+
|
| 235 |
+
def get_chunk_sizes(self):
|
| 236 |
+
return self._chunk_sizes
|
| 237 |
+
|
| 238 |
+
def __getitem__(self, index: Union[int, Tuple[int, int]]):
|
| 239 |
+
if index is None:
|
| 240 |
+
import torch
|
| 241 |
+
return torch.tensor([], dtype=torch.float32), \
|
| 242 |
+
torch.tensor([], dtype=torch.int64), \
|
| 243 |
+
torch.tensor([], dtype=torch.int64)
|
| 244 |
+
|
| 245 |
+
items_image = self.image_tsv_file[index]
|
| 246 |
+
items_text = self.text_tsv_file[index]
|
| 247 |
+
|
| 248 |
+
assert items_text[0] == items_image[0], \
|
| 249 |
+
'keys do not match for image and text {} vs {}'.format(
|
| 250 |
+
items_text[0], items_image[0]
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
_, img = self._decode_image(items_image)
|
| 254 |
+
_, txt, label = self._decode_text(items_text)
|
| 255 |
+
|
| 256 |
+
if self.transform:
|
| 257 |
+
img = self.transform(img)
|
| 258 |
+
|
| 259 |
+
tokens = self.tokenize(
|
| 260 |
+
txt, padding='max_length', truncation=True, max_length=self.context_length,
|
| 261 |
+
return_tensors='pt'
|
| 262 |
+
) if self.tokenize else txt
|
| 263 |
+
|
| 264 |
+
tokens['input_ids'].squeeze_()
|
| 265 |
+
tokens['attention_mask'].squeeze_()
|
| 266 |
+
|
| 267 |
+
return img, tokens, label
|
| 268 |
+
|
| 269 |
+
def _decode_image(self, items: Tuple[str, str]):
|
| 270 |
+
key = items[0]
|
| 271 |
+
image = Image.open(BytesIO(base64.b64decode(items[1]))).convert('RGB')
|
| 272 |
+
|
| 273 |
+
return key, image
|
| 274 |
+
|
| 275 |
+
def _decode_text(self, items: Tuple[str, Union[str, dict]]):
|
| 276 |
+
key = items[0]
|
| 277 |
+
text = ''
|
| 278 |
+
|
| 279 |
+
if self.text_format != 'json':
|
| 280 |
+
raise ValueError('Only support json format')
|
| 281 |
+
|
| 282 |
+
# Do some reasonable handing of occasionally bad data.
|
| 283 |
+
try:
|
| 284 |
+
js = json.loads(items[1])
|
| 285 |
+
except Exception as e:
|
| 286 |
+
|
| 287 |
+
# empty dictionary
|
| 288 |
+
js = {}
|
| 289 |
+
|
| 290 |
+
# Record the data error in the log.
|
| 291 |
+
logger.info("JSON parsing error on: " + items[1])
|
| 292 |
+
logger.info(str(e))
|
| 293 |
+
|
| 294 |
+
# do not raise the exception
|
| 295 |
+
# raise e
|
| 296 |
+
|
| 297 |
+
# put some text in and continue processing data (do not kill job)
|
| 298 |
+
sstr = items[1].find("\"")
|
| 299 |
+
if (sstr < 0):
|
| 300 |
+
sstr = 0
|
| 301 |
+
|
| 302 |
+
estr = items[1][sstr:].find("\"")
|
| 303 |
+
if (estr < 0):
|
| 304 |
+
estr = len(items[1])
|
| 305 |
+
|
| 306 |
+
text = items[1][sstr:estr]
|
| 307 |
+
if (len(text) < 2):
|
| 308 |
+
text = "A picture showing some content."
|
| 309 |
+
|
| 310 |
+
label = 0
|
| 311 |
+
|
| 312 |
+
if 'captions' in js:
|
| 313 |
+
captions = js['captions']
|
| 314 |
+
if isinstance(captions, list):
|
| 315 |
+
if self.num_captions == 1:
|
| 316 |
+
text = random.choice(captions)
|
| 317 |
+
else:
|
| 318 |
+
text = captions
|
| 319 |
+
if len(captions) > self.num_captions:
|
| 320 |
+
text = captions[:self.num_captions]
|
| 321 |
+
elif isinstance(captions, str):
|
| 322 |
+
text = captions
|
| 323 |
+
else:
|
| 324 |
+
raise ValueError('captions should be str or list')
|
| 325 |
+
label = 0
|
| 326 |
+
elif 'tags' in js:
|
| 327 |
+
text = prompt_engineering(js['tags'])
|
| 328 |
+
label = 0
|
| 329 |
+
elif 'task' in js and js['task'] == 'classification':
|
| 330 |
+
if (self.prompt_engineering):
|
| 331 |
+
text = prompt_engineering(js['class_name'])
|
| 332 |
+
else:
|
| 333 |
+
text = js['class_name']
|
| 334 |
+
label = js['class_id']
|
| 335 |
+
|
| 336 |
+
if (self.label_offsets is not None):
|
| 337 |
+
if (js['source'] in self.label_offsets):
|
| 338 |
+
label += self.label_offsets[js['source']]
|
| 339 |
+
|
| 340 |
+
if (self.concat_queries):
|
| 341 |
+
if ('queries' in js) and (len(js['queries']) > 0):
|
| 342 |
+
q = ''
|
| 343 |
+
for item in js['queries']:
|
| 344 |
+
q = q + item + ' '
|
| 345 |
+
|
| 346 |
+
text = q + ', ' + text
|
| 347 |
+
|
| 348 |
+
return key, text, label
|
| 349 |
+
|
| 350 |
+
def __len__(self):
|
| 351 |
+
return len(self.image_tsv_file)
|
MedImageInsight/ImageDataLoader/tsv_file.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import gc
|
| 3 |
+
import os
|
| 4 |
+
import os.path as op
|
| 5 |
+
import json
|
| 6 |
+
from typing import List
|
| 7 |
+
from .blob_storage import BlobStorage, disk_usage
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def generate_lineidx(filein: str, idxout: str) -> None:
|
| 13 |
+
idxout_tmp = idxout + '.tmp'
|
| 14 |
+
with open(filein, 'r') as tsvin, open(idxout_tmp, 'w') as tsvout:
|
| 15 |
+
fsize = os.fstat(tsvin.fileno()).st_size
|
| 16 |
+
fpos = 0
|
| 17 |
+
while fpos != fsize:
|
| 18 |
+
tsvout.write(str(fpos) + "\n")
|
| 19 |
+
tsvin.readline()
|
| 20 |
+
fpos = tsvin.tell()
|
| 21 |
+
os.rename(idxout_tmp, idxout)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def read_to_character(fp, c):
|
| 25 |
+
result = []
|
| 26 |
+
while True:
|
| 27 |
+
s = fp.read(32)
|
| 28 |
+
assert s != ''
|
| 29 |
+
if c in s:
|
| 30 |
+
result.append(s[: s.index(c)])
|
| 31 |
+
break
|
| 32 |
+
else:
|
| 33 |
+
result.append(s)
|
| 34 |
+
return ''.join(result)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class TSVFile(object):
|
| 38 |
+
def __init__(self,
|
| 39 |
+
tsv_file: str,
|
| 40 |
+
if_generate_lineidx: bool = True,
|
| 41 |
+
lineidx: str = None,
|
| 42 |
+
class_selector: List[str] = None,
|
| 43 |
+
blob_storage: BlobStorage = None):
|
| 44 |
+
self.tsv_file = tsv_file
|
| 45 |
+
self.lineidx = op.splitext(tsv_file)[0] + '.lineidx' \
|
| 46 |
+
if not lineidx else lineidx
|
| 47 |
+
self.linelist = op.splitext(tsv_file)[0] + '.linelist'
|
| 48 |
+
self.chunks = op.splitext(tsv_file)[0] + '.chunks'
|
| 49 |
+
self._fp = None
|
| 50 |
+
self._lineidx = None
|
| 51 |
+
self._sample_indices = None
|
| 52 |
+
self._class_boundaries = None
|
| 53 |
+
self._class_selector = class_selector
|
| 54 |
+
self._blob_storage = blob_storage
|
| 55 |
+
self._len = None
|
| 56 |
+
# the process always keeps the process which opens the file.
|
| 57 |
+
# If the pid is not equal to the currrent pid, we will re-open the file.
|
| 58 |
+
self.pid = None
|
| 59 |
+
# generate lineidx if not exist
|
| 60 |
+
if not op.isfile(self.lineidx) and if_generate_lineidx:
|
| 61 |
+
generate_lineidx(self.tsv_file, self.lineidx)
|
| 62 |
+
|
| 63 |
+
def __del__(self):
|
| 64 |
+
self.gcidx()
|
| 65 |
+
if self._fp:
|
| 66 |
+
self._fp.close()
|
| 67 |
+
# physically remove the tsv file if it is retrieved by BlobStorage
|
| 68 |
+
if self._blob_storage and 'azcopy' in self.tsv_file and os.path.exists(self.tsv_file):
|
| 69 |
+
try:
|
| 70 |
+
original_usage = disk_usage('/')
|
| 71 |
+
os.remove(self.tsv_file)
|
| 72 |
+
logger.info("Purged %s (disk usage: %.2f%% => %.2f%%)" %
|
| 73 |
+
(self.tsv_file, original_usage, disk_usage('/') * 100))
|
| 74 |
+
except:
|
| 75 |
+
# Known issue: multiple threads attempting to delete the file will raise a FileNotFound error.
|
| 76 |
+
# TODO: try Threadling.Lock to better handle the race condition
|
| 77 |
+
pass
|
| 78 |
+
|
| 79 |
+
def __str__(self):
|
| 80 |
+
return "TSVFile(tsv_file='{}')".format(self.tsv_file)
|
| 81 |
+
|
| 82 |
+
def __repr__(self):
|
| 83 |
+
return str(self)
|
| 84 |
+
|
| 85 |
+
def gcidx(self):
|
| 86 |
+
logger.debug('Run gc collect')
|
| 87 |
+
self._lineidx = None
|
| 88 |
+
self._sample_indices = None
|
| 89 |
+
#self._class_boundaries = None
|
| 90 |
+
return gc.collect()
|
| 91 |
+
|
| 92 |
+
def get_class_boundaries(self):
|
| 93 |
+
return self._class_boundaries
|
| 94 |
+
|
| 95 |
+
def num_rows(self, gcf=False):
|
| 96 |
+
if (self._len is None):
|
| 97 |
+
self._ensure_lineidx_loaded()
|
| 98 |
+
retval = len(self._sample_indices)
|
| 99 |
+
|
| 100 |
+
if (gcf):
|
| 101 |
+
self.gcidx()
|
| 102 |
+
|
| 103 |
+
self._len = retval
|
| 104 |
+
|
| 105 |
+
return self._len
|
| 106 |
+
|
| 107 |
+
def seek(self, idx: int):
|
| 108 |
+
self._ensure_tsv_opened()
|
| 109 |
+
self._ensure_lineidx_loaded()
|
| 110 |
+
try:
|
| 111 |
+
pos = self._lineidx[self._sample_indices[idx]]
|
| 112 |
+
except:
|
| 113 |
+
logger.info('=> {}-{}'.format(self.tsv_file, idx))
|
| 114 |
+
raise
|
| 115 |
+
self._fp.seek(pos)
|
| 116 |
+
return [s.strip() for s in self._fp.readline().split('\t')]
|
| 117 |
+
|
| 118 |
+
def seek_first_column(self, idx: int):
|
| 119 |
+
self._ensure_tsv_opened()
|
| 120 |
+
self._ensure_lineidx_loaded()
|
| 121 |
+
pos = self._lineidx[idx]
|
| 122 |
+
self._fp.seek(pos)
|
| 123 |
+
return read_to_character(self._fp, '\t')
|
| 124 |
+
|
| 125 |
+
def get_key(self, idx: int):
|
| 126 |
+
return self.seek_first_column(idx)
|
| 127 |
+
|
| 128 |
+
def __getitem__(self, index: int):
|
| 129 |
+
return self.seek(index)
|
| 130 |
+
|
| 131 |
+
def __len__(self):
|
| 132 |
+
return self.num_rows()
|
| 133 |
+
|
| 134 |
+
def _ensure_lineidx_loaded(self):
|
| 135 |
+
if self._lineidx is None:
|
| 136 |
+
logger.debug('=> loading lineidx: {}'.format(self.lineidx))
|
| 137 |
+
with open(self.lineidx, 'r') as fp:
|
| 138 |
+
lines = fp.readlines()
|
| 139 |
+
lines = [line.strip() for line in lines]
|
| 140 |
+
self._lineidx = [int(line) for line in lines]
|
| 141 |
+
|
| 142 |
+
# read the line list if exists
|
| 143 |
+
linelist = None
|
| 144 |
+
if op.isfile(self.linelist):
|
| 145 |
+
with open(self.linelist, 'r') as fp:
|
| 146 |
+
linelist = sorted(
|
| 147 |
+
[
|
| 148 |
+
int(line.strip())
|
| 149 |
+
for line in fp.readlines()
|
| 150 |
+
]
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
if op.isfile(self.chunks):
|
| 154 |
+
self._sample_indices = []
|
| 155 |
+
self._class_boundaries = []
|
| 156 |
+
class_boundaries = json.load(open(self.chunks, 'r'))
|
| 157 |
+
for class_name, boundary in class_boundaries.items():
|
| 158 |
+
start = len(self._sample_indices)
|
| 159 |
+
if class_name in self._class_selector:
|
| 160 |
+
for idx in range(boundary[0], boundary[1] + 1):
|
| 161 |
+
# NOTE: potentially slow when linelist is long, try to speed it up
|
| 162 |
+
if linelist and idx not in linelist:
|
| 163 |
+
continue
|
| 164 |
+
self._sample_indices.append(idx)
|
| 165 |
+
end = len(self._sample_indices)
|
| 166 |
+
self._class_boundaries.append((start, end))
|
| 167 |
+
else:
|
| 168 |
+
if linelist:
|
| 169 |
+
self._sample_indices = linelist
|
| 170 |
+
else:
|
| 171 |
+
self._sample_indices = list(range(len(self._lineidx)))
|
| 172 |
+
|
| 173 |
+
def _ensure_tsv_opened(self):
|
| 174 |
+
if self._fp is None:
|
| 175 |
+
if self._blob_storage:
|
| 176 |
+
self._fp = self._blob_storage.open(self.tsv_file)
|
| 177 |
+
else:
|
| 178 |
+
self._fp = open(self.tsv_file, 'r')
|
| 179 |
+
self.pid = os.getpid()
|
| 180 |
+
|
| 181 |
+
if self.pid != os.getpid():
|
| 182 |
+
logger.debug('=> re-open {} because the process id changed'.format(self.tsv_file))
|
| 183 |
+
self._fp = open(self.tsv_file, 'r')
|
| 184 |
+
self.pid = os.getpid()
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class CompositeTSVFile:
|
| 188 |
+
def __init__(self,
|
| 189 |
+
file_list: List[str],
|
| 190 |
+
root: str = '.',
|
| 191 |
+
class_selector: List[str] = None,
|
| 192 |
+
is_train: bool = True,
|
| 193 |
+
sas_token_path: str = None,
|
| 194 |
+
azcopy_path: str = None):
|
| 195 |
+
self.root = root
|
| 196 |
+
self.tsvs = None
|
| 197 |
+
self.chunk_sizes = None
|
| 198 |
+
self.accum_chunk_sizes = None
|
| 199 |
+
self._class_selector = class_selector
|
| 200 |
+
self._class_boundaries = None
|
| 201 |
+
self.initialized = False
|
| 202 |
+
assert isinstance(file_list, list)
|
| 203 |
+
self.blob_storage = BlobStorage(is_train, sas_token_path, azcopy_path)
|
| 204 |
+
self.file_list = self.blob_storage.register_local_tsv_paths(file_list)
|
| 205 |
+
logger.info('=> Init CompositeTSVFile...')
|
| 206 |
+
self.initialize()
|
| 207 |
+
logger.info('=> Init CompositeTSVFile Done...')
|
| 208 |
+
|
| 209 |
+
def get_key(self, index: int):
|
| 210 |
+
idx_source, idx_row = self._calc_chunk_idx_row(index)
|
| 211 |
+
k = self.tsvs[idx_source].get_key(idx_row)
|
| 212 |
+
return '_'.join([self.file_list[idx_source], k])
|
| 213 |
+
|
| 214 |
+
def get_class_boundaries(self):
|
| 215 |
+
return self._class_boundaries
|
| 216 |
+
|
| 217 |
+
def get_chunk_size(self):
|
| 218 |
+
return self.chunk_sizes
|
| 219 |
+
|
| 220 |
+
def num_rows(self):
|
| 221 |
+
return sum(self.chunk_sizes)
|
| 222 |
+
|
| 223 |
+
def _calc_chunk_idx_row(self, index: int):
|
| 224 |
+
idx_chunk = 0
|
| 225 |
+
idx_row = index
|
| 226 |
+
while index >= self.accum_chunk_sizes[idx_chunk]:
|
| 227 |
+
idx_chunk += 1
|
| 228 |
+
idx_row = index - self.accum_chunk_sizes[idx_chunk-1]
|
| 229 |
+
return idx_chunk, idx_row
|
| 230 |
+
|
| 231 |
+
def __getitem__(self, index: int):
|
| 232 |
+
idx_source, idx_row = self._calc_chunk_idx_row(index)
|
| 233 |
+
if idx_source not in self.blob_storage:
|
| 234 |
+
self.blob_storage[idx_source] = TSVFile(
|
| 235 |
+
op.join(self.root, self.file_list[idx_source]),
|
| 236 |
+
class_selector=self._class_selector,
|
| 237 |
+
blob_storage=self.blob_storage,
|
| 238 |
+
if_generate_lineidx=True
|
| 239 |
+
)
|
| 240 |
+
return self.blob_storage[idx_source].seek(idx_row)
|
| 241 |
+
|
| 242 |
+
def __len__(self):
|
| 243 |
+
return sum(self.chunk_sizes)
|
| 244 |
+
|
| 245 |
+
def initialize(self):
|
| 246 |
+
"""
|
| 247 |
+
this function has to be called in init function if cache_policy is
|
| 248 |
+
enabled. Thus, let's always call it in init funciton to make it simple.
|
| 249 |
+
"""
|
| 250 |
+
if self.initialized:
|
| 251 |
+
return
|
| 252 |
+
self.tsvs = [
|
| 253 |
+
TSVFile(
|
| 254 |
+
op.join(self.root, f),
|
| 255 |
+
class_selector=self._class_selector
|
| 256 |
+
) for f in self.file_list
|
| 257 |
+
]
|
| 258 |
+
logger.debug("=> Calculating chunk sizes ...")
|
| 259 |
+
self.chunk_sizes = [tsv.num_rows(gcf=True) for tsv in self.tsvs]
|
| 260 |
+
|
| 261 |
+
self.accum_chunk_sizes = [0]
|
| 262 |
+
for size in self.chunk_sizes:
|
| 263 |
+
self.accum_chunk_sizes += [self.accum_chunk_sizes[-1] + size]
|
| 264 |
+
self.accum_chunk_sizes = self.accum_chunk_sizes[1:]
|
| 265 |
+
|
| 266 |
+
if (
|
| 267 |
+
self._class_selector
|
| 268 |
+
and all([tsv.get_class_boundaries() for tsv in self.tsvs])
|
| 269 |
+
):
|
| 270 |
+
"""
|
| 271 |
+
Note: When using CompositeTSVFile, make sure that the classes contained in each
|
| 272 |
+
tsv file do not overlap. Otherwise, the class boundaries won't be correct.
|
| 273 |
+
"""
|
| 274 |
+
self._class_boundaries = []
|
| 275 |
+
offset = 0
|
| 276 |
+
for tsv in self.tsvs:
|
| 277 |
+
boundaries = tsv.get_class_boundaries()
|
| 278 |
+
for bound in boundaries:
|
| 279 |
+
self._class_boundaries.append((bound[0] + offset, bound[1] + offset))
|
| 280 |
+
offset += len(tsv)
|
| 281 |
+
self.initialized = True
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def load_list_file(fname: str) -> List[str]:
|
| 285 |
+
with open(fname, 'r') as fp:
|
| 286 |
+
lines = fp.readlines()
|
| 287 |
+
result = [line.strip() for line in lines]
|
| 288 |
+
if len(result) > 0 and result[-1] == '':
|
| 289 |
+
result = result[:-1]
|
| 290 |
+
return result
|
MedImageInsight/ImageDataLoader/zipdata.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path as op
|
| 2 |
+
from zipfile import ZipFile, BadZipFile
|
| 3 |
+
import torch.utils.data as data
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from io import BytesIO
|
| 6 |
+
import multiprocessing
|
| 7 |
+
|
| 8 |
+
_VALID_IMAGE_TYPES = ['.jpg', '.jpeg', '.tiff', '.bmp', '.png']
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ZipData(data.Dataset):
|
| 12 |
+
_IGNORE_ATTRS = {'_zip_file'}
|
| 13 |
+
|
| 14 |
+
def __init__(self, path, map_file,
|
| 15 |
+
transform=None, target_transform=None,
|
| 16 |
+
extensions=None):
|
| 17 |
+
self._path = path
|
| 18 |
+
if not extensions:
|
| 19 |
+
extensions = _VALID_IMAGE_TYPES
|
| 20 |
+
self._zip_file = ZipFile(path)
|
| 21 |
+
self.zip_dict = {}
|
| 22 |
+
self.samples = []
|
| 23 |
+
self.transform = transform
|
| 24 |
+
self.target_transform = target_transform
|
| 25 |
+
self.class_to_idx = {}
|
| 26 |
+
with open(map_file, 'r') as f:
|
| 27 |
+
for line in iter(f.readline, ""):
|
| 28 |
+
line = line.strip()
|
| 29 |
+
if not line:
|
| 30 |
+
continue
|
| 31 |
+
cls_idx = [l for l in line.split('\t') if l]
|
| 32 |
+
if not cls_idx:
|
| 33 |
+
continue
|
| 34 |
+
if (len(cls_idx) < 2):
|
| 35 |
+
cls_idx = [l for l in line.split(' ') if l]
|
| 36 |
+
if not cls_idx:
|
| 37 |
+
continue
|
| 38 |
+
assert len(cls_idx) >= 2, "invalid line: {}".format(line)
|
| 39 |
+
idx = int(cls_idx[1])
|
| 40 |
+
cls = cls_idx[0]
|
| 41 |
+
del cls_idx
|
| 42 |
+
at_idx = cls.find('@')
|
| 43 |
+
assert at_idx >= 0, "invalid class: {}".format(cls)
|
| 44 |
+
cls = cls[at_idx + 1:]
|
| 45 |
+
if cls.startswith('/'):
|
| 46 |
+
# Python ZipFile expects no root
|
| 47 |
+
cls = cls[1:]
|
| 48 |
+
assert cls, "invalid class in line {}".format(line)
|
| 49 |
+
prev_idx = self.class_to_idx.get(cls)
|
| 50 |
+
assert prev_idx is None or prev_idx == idx, "class: {} idx: {} previously had idx: {}".format(
|
| 51 |
+
cls, idx, prev_idx
|
| 52 |
+
)
|
| 53 |
+
self.class_to_idx[cls] = idx
|
| 54 |
+
|
| 55 |
+
for fst in self._zip_file.infolist():
|
| 56 |
+
fname = fst.filename
|
| 57 |
+
target = self.class_to_idx.get(fname)
|
| 58 |
+
if target is None:
|
| 59 |
+
continue
|
| 60 |
+
if fname.endswith('/') or fname.startswith('.') or fst.file_size == 0:
|
| 61 |
+
continue
|
| 62 |
+
ext = op.splitext(fname)[1].lower()
|
| 63 |
+
if ext in extensions:
|
| 64 |
+
self.samples.append((fname, target))
|
| 65 |
+
assert len(self), "No images found in: {} with map: {}".format(self._path, map_file)
|
| 66 |
+
|
| 67 |
+
def __repr__(self):
|
| 68 |
+
return 'ZipData({}, size={})'.format(self._path, len(self))
|
| 69 |
+
|
| 70 |
+
def __getstate__(self):
|
| 71 |
+
return {
|
| 72 |
+
key: val if key not in self._IGNORE_ATTRS else None
|
| 73 |
+
for key, val in self.__dict__.iteritems()
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
def __getitem__(self, index):
|
| 77 |
+
proc = multiprocessing.current_process()
|
| 78 |
+
pid = proc.pid # get pid of this process.
|
| 79 |
+
if pid not in self.zip_dict:
|
| 80 |
+
self.zip_dict[pid] = ZipFile(self._path)
|
| 81 |
+
zip_file = self.zip_dict[pid]
|
| 82 |
+
|
| 83 |
+
if index >= len(self) or index < 0:
|
| 84 |
+
raise KeyError("{} is invalid".format(index))
|
| 85 |
+
path, target = self.samples[index]
|
| 86 |
+
try:
|
| 87 |
+
sample = Image.open(BytesIO(zip_file.read(path))).convert('RGB')
|
| 88 |
+
except BadZipFile:
|
| 89 |
+
print("bad zip file")
|
| 90 |
+
return None, None
|
| 91 |
+
if self.transform is not None:
|
| 92 |
+
sample = self.transform(sample)
|
| 93 |
+
if self.target_transform is not None:
|
| 94 |
+
target = self.target_transform(target)
|
| 95 |
+
return sample, target
|
| 96 |
+
|
| 97 |
+
def __len__(self):
|
| 98 |
+
return len(self.samples)
|
MedImageInsight/ImageEncoder/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import absolute_import
|
| 2 |
+
from __future__ import division
|
| 3 |
+
from __future__ import print_function
|
| 4 |
+
|
| 5 |
+
from .build import build_image_encoder
|
| 6 |
+
|
| 7 |
+
from .coswin import *
|
| 8 |
+
from .davit_v1 import *
|
MedImageInsight/ImageEncoder/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (410 Bytes). View file
|
|
|
MedImageInsight/ImageEncoder/__pycache__/build.cpython-38.pyc
ADDED
|
Binary file (574 Bytes). View file
|
|
|
MedImageInsight/ImageEncoder/__pycache__/coswin.cpython-38.pyc
ADDED
|
Binary file (24.4 kB). View file
|
|
|
MedImageInsight/ImageEncoder/__pycache__/davit_v1.cpython-38.pyc
ADDED
|
Binary file (20.5 kB). View file
|
|
|
MedImageInsight/ImageEncoder/__pycache__/registry.cpython-38.pyc
ADDED
|
Binary file (659 Bytes). View file
|
|
|
MedImageInsight/ImageEncoder/build.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .registry import image_encoders
|
| 2 |
+
from .registry import is_image_encoder
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def build_image_encoder(config_encoder, verbose, **kwargs):
|
| 6 |
+
model_name = config_encoder['NAME']
|
| 7 |
+
if model_name.startswith('cls_'):
|
| 8 |
+
model_name = model_name[4:]
|
| 9 |
+
|
| 10 |
+
if not is_image_encoder(model_name):
|
| 11 |
+
raise ValueError(f'Unkown model: {model_name}')
|
| 12 |
+
|
| 13 |
+
return image_encoders(model_name)(config_encoder, verbose, **kwargs)
|
MedImageInsight/ImageEncoder/coswin.py
ADDED
|
@@ -0,0 +1,779 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# CoSwin: Convolutional Swin Transformer
|
| 3 |
+
# Copyright (c) 2021 Microsoft
|
| 4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
+
# Written by Ze Liu
|
| 6 |
+
# Modified by Bin Xiao
|
| 7 |
+
# --------------------------------------------------------
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
import os
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.utils.checkpoint as checkpoint
|
| 14 |
+
import numpy as np
|
| 15 |
+
from einops import rearrange, repeat
|
| 16 |
+
from einops.layers.torch import Rearrange
|
| 17 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
| 18 |
+
|
| 19 |
+
from .registry import register_image_encoder
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class Mlp(nn.Module):
|
| 26 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 27 |
+
super().__init__()
|
| 28 |
+
out_features = out_features or in_features
|
| 29 |
+
hidden_features = hidden_features or in_features
|
| 30 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 31 |
+
self.act = act_layer()
|
| 32 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 33 |
+
self.drop = nn.Dropout(drop)
|
| 34 |
+
|
| 35 |
+
def forward(self, x):
|
| 36 |
+
x = self.fc1(x)
|
| 37 |
+
x = self.act(x)
|
| 38 |
+
x = self.drop(x)
|
| 39 |
+
x = self.fc2(x)
|
| 40 |
+
x = self.drop(x)
|
| 41 |
+
return x
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def window_partition(x, window_size):
|
| 45 |
+
"""
|
| 46 |
+
Args:
|
| 47 |
+
x: (B, H, W, C)
|
| 48 |
+
window_size (int): window size
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
windows: (num_windows*B, window_size, window_size, C)
|
| 52 |
+
"""
|
| 53 |
+
B, H, W, C = x.shape
|
| 54 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
| 55 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
| 56 |
+
return windows
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def window_reverse(windows, window_size, H, W):
|
| 60 |
+
"""
|
| 61 |
+
Args:
|
| 62 |
+
windows: (num_windows*B, window_size, window_size, C)
|
| 63 |
+
window_size (int): Window size
|
| 64 |
+
H (int): Height of image
|
| 65 |
+
W (int): Width of image
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
x: (B, H, W, C)
|
| 69 |
+
"""
|
| 70 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
| 71 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
| 72 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
| 73 |
+
return x
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class WindowAttention(nn.Module):
|
| 77 |
+
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
| 78 |
+
It supports both of shifted and non-shifted window.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
dim (int): Number of input channels.
|
| 82 |
+
window_size (tuple[int]): The height and width of the window.
|
| 83 |
+
num_heads (int): Number of attention heads.
|
| 84 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 85 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
| 86 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
| 87 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
| 91 |
+
|
| 92 |
+
super().__init__()
|
| 93 |
+
self.dim = dim
|
| 94 |
+
self.window_size = window_size # Wh, Ww
|
| 95 |
+
self.num_heads = num_heads
|
| 96 |
+
head_dim = dim // num_heads
|
| 97 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 98 |
+
|
| 99 |
+
# define a parameter table of relative position bias
|
| 100 |
+
self.relative_position_bias_table = nn.Parameter(
|
| 101 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
| 102 |
+
|
| 103 |
+
# get pair-wise relative position index for each token inside the window
|
| 104 |
+
coords_h = torch.arange(self.window_size[0])
|
| 105 |
+
coords_w = torch.arange(self.window_size[1])
|
| 106 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
| 107 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
| 108 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
| 109 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
| 110 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
| 111 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
| 112 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
| 113 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
| 114 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
| 115 |
+
|
| 116 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 117 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 118 |
+
self.proj = nn.Linear(dim, dim)
|
| 119 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 120 |
+
|
| 121 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
| 122 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 123 |
+
|
| 124 |
+
def forward(self, x, mask=None):
|
| 125 |
+
"""
|
| 126 |
+
Args:
|
| 127 |
+
x: input features with shape of (num_windows*B, N, C)
|
| 128 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
| 129 |
+
"""
|
| 130 |
+
B_, N, C = x.shape
|
| 131 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 132 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
| 133 |
+
|
| 134 |
+
q = q * self.scale
|
| 135 |
+
attn = (q @ k.transpose(-2, -1))
|
| 136 |
+
|
| 137 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
| 138 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
| 139 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
| 140 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
| 141 |
+
|
| 142 |
+
if mask is not None:
|
| 143 |
+
nW = mask.shape[0]
|
| 144 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
| 145 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
| 146 |
+
attn = self.softmax(attn)
|
| 147 |
+
else:
|
| 148 |
+
attn = self.softmax(attn)
|
| 149 |
+
|
| 150 |
+
attn = self.attn_drop(attn)
|
| 151 |
+
|
| 152 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
| 153 |
+
x = self.proj(x)
|
| 154 |
+
x = self.proj_drop(x)
|
| 155 |
+
return x
|
| 156 |
+
|
| 157 |
+
def extra_repr(self) -> str:
|
| 158 |
+
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
|
| 159 |
+
|
| 160 |
+
def flops(self, N):
|
| 161 |
+
# calculate flops for 1 window with token length of N
|
| 162 |
+
flops = 0
|
| 163 |
+
# qkv = self.qkv(x)
|
| 164 |
+
flops += N * self.dim * 3 * self.dim
|
| 165 |
+
# attn = (q @ k.transpose(-2, -1))
|
| 166 |
+
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
| 167 |
+
# x = (attn @ v)
|
| 168 |
+
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
| 169 |
+
# x = self.proj(x)
|
| 170 |
+
flops += N * self.dim * self.dim
|
| 171 |
+
return flops
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class SwinTransformerBlock(nn.Module):
|
| 175 |
+
r""" Swin Transformer Block.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
dim (int): Number of input channels.
|
| 179 |
+
input_resolution (tuple[int]): Input resulotion.
|
| 180 |
+
num_heads (int): Number of attention heads.
|
| 181 |
+
window_size (int): Window size.
|
| 182 |
+
shift_size (int): Shift size for SW-MSA.
|
| 183 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
| 184 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 185 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
| 186 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 187 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
| 188 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
| 189 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
| 190 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 191 |
+
"""
|
| 192 |
+
|
| 193 |
+
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
|
| 194 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
| 195 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm, layer_scale=False):
|
| 196 |
+
super().__init__()
|
| 197 |
+
self.dim = dim
|
| 198 |
+
self.input_resolution = input_resolution
|
| 199 |
+
self.num_heads = num_heads
|
| 200 |
+
self.window_size = window_size
|
| 201 |
+
self.shift_size = shift_size
|
| 202 |
+
self.mlp_ratio = mlp_ratio
|
| 203 |
+
if min(self.input_resolution) <= self.window_size:
|
| 204 |
+
# if window size is larger than input resolution, we don't partition windows
|
| 205 |
+
self.shift_size = 0
|
| 206 |
+
self.window_size = min(self.input_resolution)
|
| 207 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
| 208 |
+
|
| 209 |
+
self.norm1 = norm_layer(dim)
|
| 210 |
+
self.attn = WindowAttention(
|
| 211 |
+
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
| 212 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
| 213 |
+
|
| 214 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 215 |
+
self.norm2 = norm_layer(dim)
|
| 216 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 217 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 218 |
+
|
| 219 |
+
if self.shift_size > 0:
|
| 220 |
+
# calculate attention mask for SW-MSA
|
| 221 |
+
H, W = self.input_resolution
|
| 222 |
+
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
| 223 |
+
h_slices = (slice(0, -self.window_size),
|
| 224 |
+
slice(-self.window_size, -self.shift_size),
|
| 225 |
+
slice(-self.shift_size, None))
|
| 226 |
+
w_slices = (slice(0, -self.window_size),
|
| 227 |
+
slice(-self.window_size, -self.shift_size),
|
| 228 |
+
slice(-self.shift_size, None))
|
| 229 |
+
cnt = 0
|
| 230 |
+
for h in h_slices:
|
| 231 |
+
for w in w_slices:
|
| 232 |
+
img_mask[:, h, w, :] = cnt
|
| 233 |
+
cnt += 1
|
| 234 |
+
|
| 235 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
| 236 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
| 237 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
| 238 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
| 239 |
+
else:
|
| 240 |
+
attn_mask = None
|
| 241 |
+
|
| 242 |
+
self.gamma = 1.0
|
| 243 |
+
if layer_scale:
|
| 244 |
+
logger.info('=> enable layer scale')
|
| 245 |
+
self.gamma = nn.Parameter(
|
| 246 |
+
1e-4*torch.ones(dim), requires_grad=True
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
self.register_buffer("attn_mask", attn_mask)
|
| 250 |
+
|
| 251 |
+
def forward(self, x):
|
| 252 |
+
H, W = self.input_resolution
|
| 253 |
+
B, L, C = x.shape
|
| 254 |
+
assert L == H * W, "input feature has wrong size"
|
| 255 |
+
|
| 256 |
+
shortcut = x
|
| 257 |
+
x = self.norm1(x)
|
| 258 |
+
x = x.view(B, H, W, C)
|
| 259 |
+
|
| 260 |
+
# cyclic shift
|
| 261 |
+
if self.shift_size > 0:
|
| 262 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
| 263 |
+
else:
|
| 264 |
+
shifted_x = x
|
| 265 |
+
|
| 266 |
+
# partition windows
|
| 267 |
+
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
| 268 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
| 269 |
+
|
| 270 |
+
# W-MSA/SW-MSA
|
| 271 |
+
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
| 272 |
+
|
| 273 |
+
# merge windows
|
| 274 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
| 275 |
+
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
| 276 |
+
|
| 277 |
+
# reverse cyclic shift
|
| 278 |
+
if self.shift_size > 0:
|
| 279 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
| 280 |
+
else:
|
| 281 |
+
x = shifted_x
|
| 282 |
+
x = x.view(B, H * W, C)
|
| 283 |
+
|
| 284 |
+
# FFN
|
| 285 |
+
x = shortcut + self.drop_path(self.gamma*x)
|
| 286 |
+
x = x + self.drop_path(self.gamma*self.mlp(self.norm2(x)))
|
| 287 |
+
|
| 288 |
+
return x
|
| 289 |
+
|
| 290 |
+
def extra_repr(self) -> str:
|
| 291 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
|
| 292 |
+
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
|
| 293 |
+
|
| 294 |
+
def flops(self):
|
| 295 |
+
flops = 0
|
| 296 |
+
H, W = self.input_resolution
|
| 297 |
+
# norm1
|
| 298 |
+
flops += self.dim * H * W
|
| 299 |
+
# W-MSA/SW-MSA
|
| 300 |
+
nW = H * W / self.window_size / self.window_size
|
| 301 |
+
flops += nW * self.attn.flops(self.window_size * self.window_size)
|
| 302 |
+
# mlp
|
| 303 |
+
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
|
| 304 |
+
# norm2
|
| 305 |
+
flops += self.dim * H * W
|
| 306 |
+
return flops
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
class PatchMerging(nn.Module):
|
| 310 |
+
r""" Patch Merging Layer.
|
| 311 |
+
|
| 312 |
+
Args:
|
| 313 |
+
input_resolution (tuple[int]): Resolution of input feature.
|
| 314 |
+
dim (int): Number of input channels.
|
| 315 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 316 |
+
"""
|
| 317 |
+
|
| 318 |
+
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
| 319 |
+
super().__init__()
|
| 320 |
+
self.input_resolution = input_resolution
|
| 321 |
+
self.dim = dim
|
| 322 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
| 323 |
+
self.norm = norm_layer(4 * dim)
|
| 324 |
+
|
| 325 |
+
def forward(self, x):
|
| 326 |
+
"""
|
| 327 |
+
x: B, H*W, C
|
| 328 |
+
"""
|
| 329 |
+
H, W = self.input_resolution
|
| 330 |
+
B, L, C = x.shape
|
| 331 |
+
assert L == H * W, "input feature has wrong size"
|
| 332 |
+
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
| 333 |
+
|
| 334 |
+
x = x.view(B, H, W, C)
|
| 335 |
+
|
| 336 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
| 337 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
| 338 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
| 339 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
| 340 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
| 341 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
| 342 |
+
|
| 343 |
+
x = self.norm(x)
|
| 344 |
+
x = self.reduction(x)
|
| 345 |
+
|
| 346 |
+
return x
|
| 347 |
+
|
| 348 |
+
def extra_repr(self) -> str:
|
| 349 |
+
return f"input_resolution={self.input_resolution}, dim={self.dim}"
|
| 350 |
+
|
| 351 |
+
def flops(self):
|
| 352 |
+
H, W = self.input_resolution
|
| 353 |
+
flops = H * W * self.dim
|
| 354 |
+
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
| 355 |
+
return flops
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
class BasicLayer(nn.Module):
|
| 359 |
+
""" A basic Swin Transformer layer for one stage.
|
| 360 |
+
|
| 361 |
+
Args:
|
| 362 |
+
dim (int): Number of input channels.
|
| 363 |
+
input_resolution (tuple[int]): Input resolution.
|
| 364 |
+
depth (int): Number of blocks.
|
| 365 |
+
num_heads (int): Number of attention heads.
|
| 366 |
+
window_size (int): Local window size.
|
| 367 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
| 368 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 369 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
| 370 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 371 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
| 372 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
| 373 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 374 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
| 375 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
| 376 |
+
"""
|
| 377 |
+
|
| 378 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
| 379 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
| 380 |
+
drop_path=0., norm_layer=nn.LayerNorm, downsample=None,
|
| 381 |
+
use_checkpoint=False, layer_scale=False):
|
| 382 |
+
|
| 383 |
+
super().__init__()
|
| 384 |
+
self.dim = dim
|
| 385 |
+
self.input_resolution = input_resolution
|
| 386 |
+
self.depth = depth
|
| 387 |
+
self.use_checkpoint = use_checkpoint
|
| 388 |
+
|
| 389 |
+
# build blocks
|
| 390 |
+
self.blocks = nn.ModuleList([
|
| 391 |
+
SwinTransformerBlock(
|
| 392 |
+
dim=dim, input_resolution=input_resolution,
|
| 393 |
+
num_heads=num_heads, window_size=window_size,
|
| 394 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
| 395 |
+
mlp_ratio=mlp_ratio,
|
| 396 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 397 |
+
drop=drop, attn_drop=attn_drop,
|
| 398 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
| 399 |
+
norm_layer=norm_layer,
|
| 400 |
+
layer_scale=layer_scale
|
| 401 |
+
)
|
| 402 |
+
for i in range(depth)])
|
| 403 |
+
|
| 404 |
+
# patch merging layer
|
| 405 |
+
if downsample is not None:
|
| 406 |
+
# self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
| 407 |
+
self.downsample = downsample(
|
| 408 |
+
input_resolution=input_resolution, patch_size=3, in_chans=dim, embed_dim=dim*2,
|
| 409 |
+
stride=2, padding=1, norm_layer=norm_layer
|
| 410 |
+
)
|
| 411 |
+
else:
|
| 412 |
+
self.downsample = None
|
| 413 |
+
|
| 414 |
+
def forward(self, x):
|
| 415 |
+
for blk in self.blocks:
|
| 416 |
+
if self.use_checkpoint:
|
| 417 |
+
x = checkpoint.checkpoint(blk, x)
|
| 418 |
+
else:
|
| 419 |
+
x = blk(x)
|
| 420 |
+
if self.downsample is not None:
|
| 421 |
+
x = self.downsample(x)
|
| 422 |
+
return x
|
| 423 |
+
|
| 424 |
+
def extra_repr(self) -> str:
|
| 425 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
| 426 |
+
|
| 427 |
+
def flops(self):
|
| 428 |
+
flops = 0
|
| 429 |
+
for blk in self.blocks:
|
| 430 |
+
flops += blk.flops()
|
| 431 |
+
if self.downsample is not None:
|
| 432 |
+
flops += self.downsample.flops()
|
| 433 |
+
return flops
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
class PatchEmbed(nn.Module):
|
| 437 |
+
r""" Image to Patch Embedding
|
| 438 |
+
|
| 439 |
+
Args:
|
| 440 |
+
img_size (int): Image size. Default: 224.
|
| 441 |
+
patch_size (int): Patch token size. Default: 4.
|
| 442 |
+
in_chans (int): Number of input image channels. Default: 3.
|
| 443 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
| 444 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
| 445 |
+
"""
|
| 446 |
+
|
| 447 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
| 448 |
+
super().__init__()
|
| 449 |
+
img_size = to_2tuple(img_size)
|
| 450 |
+
patch_size = to_2tuple(patch_size)
|
| 451 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
| 452 |
+
self.img_size = img_size
|
| 453 |
+
self.patch_size = patch_size
|
| 454 |
+
self.patches_resolution = patches_resolution
|
| 455 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
| 456 |
+
|
| 457 |
+
self.in_chans = in_chans
|
| 458 |
+
self.embed_dim = embed_dim
|
| 459 |
+
|
| 460 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 461 |
+
if norm_layer is not None:
|
| 462 |
+
self.norm = norm_layer(embed_dim)
|
| 463 |
+
else:
|
| 464 |
+
self.norm = None
|
| 465 |
+
|
| 466 |
+
def forward(self, x):
|
| 467 |
+
B, C, H, W = x.shape
|
| 468 |
+
# FIXME look at relaxing size constraints
|
| 469 |
+
assert H == self.img_size[0] and W == self.img_size[1], \
|
| 470 |
+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
| 471 |
+
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
|
| 472 |
+
if self.norm is not None:
|
| 473 |
+
x = self.norm(x)
|
| 474 |
+
return x
|
| 475 |
+
|
| 476 |
+
def flops(self):
|
| 477 |
+
Ho, Wo = self.patches_resolution
|
| 478 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
| 479 |
+
if self.norm is not None:
|
| 480 |
+
flops += Ho * Wo * self.embed_dim
|
| 481 |
+
return flops
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
class ConvEmbed(nn.Module):
|
| 485 |
+
""" Image to Patch Embedding
|
| 486 |
+
"""
|
| 487 |
+
|
| 488 |
+
def __init__(
|
| 489 |
+
self,
|
| 490 |
+
input_resolution=(224,224),
|
| 491 |
+
patch_size=7,
|
| 492 |
+
in_chans=3,
|
| 493 |
+
embed_dim=64,
|
| 494 |
+
stride=4,
|
| 495 |
+
padding=2,
|
| 496 |
+
norm_layer=None
|
| 497 |
+
):
|
| 498 |
+
super().__init__()
|
| 499 |
+
self.patch_size = patch_size
|
| 500 |
+
self.input_resolution = input_resolution
|
| 501 |
+
|
| 502 |
+
self.proj = nn.Conv2d(
|
| 503 |
+
in_chans, embed_dim,
|
| 504 |
+
kernel_size=patch_size,
|
| 505 |
+
stride=stride,
|
| 506 |
+
padding=padding
|
| 507 |
+
)
|
| 508 |
+
self.norm = norm_layer(embed_dim) if norm_layer else None
|
| 509 |
+
|
| 510 |
+
def forward(self, x):
|
| 511 |
+
if len(x.size()) == 3:
|
| 512 |
+
x = rearrange(
|
| 513 |
+
x, 'b (h w) c -> b c h w',
|
| 514 |
+
h=self.input_resolution[0],
|
| 515 |
+
w=self.input_resolution[1]
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
x = self.proj(x)
|
| 519 |
+
|
| 520 |
+
B, C, H, W = x.shape
|
| 521 |
+
x = rearrange(x, 'b c h w -> b (h w) c')
|
| 522 |
+
if self.norm:
|
| 523 |
+
x = self.norm(x)
|
| 524 |
+
|
| 525 |
+
return x
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
class SwinTransformer(nn.Module):
|
| 529 |
+
r""" Swin Transformer
|
| 530 |
+
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
| 531 |
+
https://arxiv.org/pdf/2103.14030
|
| 532 |
+
|
| 533 |
+
Args:
|
| 534 |
+
img_size (int | tuple(int)): Input image size. Default 224
|
| 535 |
+
patch_size (int | tuple(int)): Patch size. Default: 4
|
| 536 |
+
in_chans (int): Number of input image channels. Default: 3
|
| 537 |
+
num_classes (int): Number of classes for classification head. Default: 1000
|
| 538 |
+
embed_dim (int): Patch embedding dimension. Default: 96
|
| 539 |
+
depths (tuple(int)): Depth of each Swin Transformer layer.
|
| 540 |
+
num_heads (tuple(int)): Number of attention heads in different layers.
|
| 541 |
+
window_size (int): Window size. Default: 7
|
| 542 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
| 543 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
| 544 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
|
| 545 |
+
drop_rate (float): Dropout rate. Default: 0
|
| 546 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0
|
| 547 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
| 548 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
| 549 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
| 550 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
| 551 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
|
| 552 |
+
"""
|
| 553 |
+
|
| 554 |
+
def __init__(self, img_size=224, patch_size=7, patch_padding=2, patch_stride=4, in_chans=3,
|
| 555 |
+
num_classes=1000, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
|
| 556 |
+
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
| 557 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
| 558 |
+
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
| 559 |
+
use_checkpoint=False, layer_scale=False, **kwargs):
|
| 560 |
+
super().__init__()
|
| 561 |
+
|
| 562 |
+
self.num_classes = num_classes
|
| 563 |
+
self.num_layers = len(depths)
|
| 564 |
+
self.embed_dim = embed_dim
|
| 565 |
+
self.ape = ape
|
| 566 |
+
self.patch_norm = patch_norm
|
| 567 |
+
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
|
| 568 |
+
self.mlp_ratio = mlp_ratio
|
| 569 |
+
|
| 570 |
+
# split image into non-overlapping patches
|
| 571 |
+
# self.patch_embed = PatchEmbed(
|
| 572 |
+
# img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
|
| 573 |
+
# norm_layer=norm_layer if self.patch_norm else None)
|
| 574 |
+
|
| 575 |
+
self.patch_embed = ConvEmbed(
|
| 576 |
+
patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, padding=patch_padding,
|
| 577 |
+
norm_layer=norm_layer if self.patch_norm else None
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
img_size = to_2tuple(img_size)
|
| 581 |
+
patches_resolution = (
|
| 582 |
+
int(np.floor(float(img_size[0]+2*patch_padding-patch_size)/patch_stride+1)),
|
| 583 |
+
int(np.floor(float(img_size[0]+2*patch_padding-patch_size)/patch_stride+1))
|
| 584 |
+
)
|
| 585 |
+
num_patches = patches_resolution[0] * patches_resolution[1]
|
| 586 |
+
# num_patches = self.patch_embed.num_patches
|
| 587 |
+
# patches_resolution = self.patch_embed.patches_resolution
|
| 588 |
+
self.patches_resolution = patches_resolution
|
| 589 |
+
|
| 590 |
+
# absolute position embedding
|
| 591 |
+
if self.ape:
|
| 592 |
+
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
| 593 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
| 594 |
+
|
| 595 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 596 |
+
|
| 597 |
+
# stochastic depth
|
| 598 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
| 599 |
+
|
| 600 |
+
# build layers
|
| 601 |
+
self.layers = nn.ModuleList()
|
| 602 |
+
for i_layer in range(self.num_layers):
|
| 603 |
+
layer = BasicLayer(
|
| 604 |
+
dim=int(embed_dim * 2 ** i_layer),
|
| 605 |
+
input_resolution=(
|
| 606 |
+
patches_resolution[0] // (2 ** i_layer),
|
| 607 |
+
patches_resolution[1] // (2 ** i_layer)
|
| 608 |
+
),
|
| 609 |
+
depth=depths[i_layer],
|
| 610 |
+
num_heads=num_heads[i_layer],
|
| 611 |
+
window_size=window_size,
|
| 612 |
+
mlp_ratio=self.mlp_ratio,
|
| 613 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 614 |
+
drop=drop_rate, attn_drop=attn_drop_rate,
|
| 615 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
| 616 |
+
norm_layer=norm_layer,
|
| 617 |
+
# downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
| 618 |
+
downsample=ConvEmbed if (i_layer < self.num_layers - 1) else None,
|
| 619 |
+
use_checkpoint=use_checkpoint,
|
| 620 |
+
layer_scale=layer_scale
|
| 621 |
+
)
|
| 622 |
+
self.layers.append(layer)
|
| 623 |
+
|
| 624 |
+
self.norm = norm_layer(self.num_features)
|
| 625 |
+
self.avgpool = nn.AdaptiveAvgPool1d(1)
|
| 626 |
+
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
| 627 |
+
|
| 628 |
+
self.apply(self._init_weights)
|
| 629 |
+
|
| 630 |
+
@property
|
| 631 |
+
def dim_out(self):
|
| 632 |
+
return self.num_features
|
| 633 |
+
|
| 634 |
+
def _init_weights(self, m):
|
| 635 |
+
if isinstance(m, nn.Linear):
|
| 636 |
+
trunc_normal_(m.weight, std=.02)
|
| 637 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 638 |
+
nn.init.constant_(m.bias, 0)
|
| 639 |
+
elif isinstance(m, nn.LayerNorm):
|
| 640 |
+
nn.init.constant_(m.bias, 0)
|
| 641 |
+
nn.init.constant_(m.weight, 1.0)
|
| 642 |
+
|
| 643 |
+
def from_pretrained(self, pretrained='', pretrained_layers=[], verbose=True):
|
| 644 |
+
if os.path.isfile(pretrained):
|
| 645 |
+
logging.info(f'=> loading pretrained model {pretrained}')
|
| 646 |
+
pretrained_dict = torch.load(pretrained, map_location='cpu')
|
| 647 |
+
|
| 648 |
+
self.from_state_dict(pretrained_dict, pretrained_layers, verbose)
|
| 649 |
+
|
| 650 |
+
def from_state_dict(self, pretrained_dict, pretrained_layers=[], verbose=True):
|
| 651 |
+
model_dict = self.state_dict()
|
| 652 |
+
stripped_key = lambda x: x[14:] if x.startswith('image_encoder.') else x
|
| 653 |
+
|
| 654 |
+
pretrained_dict = {
|
| 655 |
+
stripped_key(k): v for k, v in pretrained_dict.items()
|
| 656 |
+
if stripped_key(k) in model_dict.keys()
|
| 657 |
+
}
|
| 658 |
+
need_init_state_dict = {}
|
| 659 |
+
for k, v in pretrained_dict.items():
|
| 660 |
+
need_init = (
|
| 661 |
+
(
|
| 662 |
+
k.split('.')[0] in pretrained_layers
|
| 663 |
+
or pretrained_layers[0] == '*'
|
| 664 |
+
)
|
| 665 |
+
and 'relative_position_index' not in k
|
| 666 |
+
and 'attn_mask' not in k
|
| 667 |
+
)
|
| 668 |
+
|
| 669 |
+
if need_init:
|
| 670 |
+
if verbose:
|
| 671 |
+
logger.info(f'=> init {k} from pretrained state dict')
|
| 672 |
+
|
| 673 |
+
if 'relative_position_bias_table' in k and v.size() != model_dict[k].size():
|
| 674 |
+
relative_position_bias_table_pretrained = v
|
| 675 |
+
relative_position_bias_table_current = model_dict[k]
|
| 676 |
+
L1, nH1 = relative_position_bias_table_pretrained.size()
|
| 677 |
+
L2, nH2 = relative_position_bias_table_current.size()
|
| 678 |
+
if nH1 != nH2:
|
| 679 |
+
logger.info(f"Error in loading {k}, passing")
|
| 680 |
+
else:
|
| 681 |
+
if L1 != L2:
|
| 682 |
+
logger.info(
|
| 683 |
+
'=> load_pretrained: resized variant: {} to {}'
|
| 684 |
+
.format((L1, nH1), (L2, nH2))
|
| 685 |
+
)
|
| 686 |
+
S1 = int(L1 ** 0.5)
|
| 687 |
+
S2 = int(L2 ** 0.5)
|
| 688 |
+
relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(
|
| 689 |
+
relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
|
| 690 |
+
size=(S2, S2),
|
| 691 |
+
mode='bicubic')
|
| 692 |
+
v = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0)
|
| 693 |
+
|
| 694 |
+
if 'absolute_pos_embed' in k and v.size() != model_dict[k].size():
|
| 695 |
+
absolute_pos_embed_pretrained = v
|
| 696 |
+
absolute_pos_embed_current = model_dict[k]
|
| 697 |
+
_, L1, C1 = absolute_pos_embed_pretrained.size()
|
| 698 |
+
_, L2, C2 = absolute_pos_embed_current.size()
|
| 699 |
+
if C1 != C1:
|
| 700 |
+
logger.info(f"Error in loading {k}, passing")
|
| 701 |
+
else:
|
| 702 |
+
if L1 != L2:
|
| 703 |
+
logger.info(
|
| 704 |
+
'=> load_pretrained: resized variant: {} to {}'
|
| 705 |
+
.format((1, L1, C1), (1, L2, C2))
|
| 706 |
+
)
|
| 707 |
+
S1 = int(L1 ** 0.5)
|
| 708 |
+
S2 = int(L2 ** 0.5)
|
| 709 |
+
absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1)
|
| 710 |
+
absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2)
|
| 711 |
+
absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate(
|
| 712 |
+
absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic')
|
| 713 |
+
v = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1).flatten(1, 2)
|
| 714 |
+
|
| 715 |
+
need_init_state_dict[k] = v
|
| 716 |
+
self.load_state_dict(need_init_state_dict, strict=False)
|
| 717 |
+
|
| 718 |
+
@torch.jit.ignore
|
| 719 |
+
def no_weight_decay(self):
|
| 720 |
+
return {'absolute_pos_embed'}
|
| 721 |
+
|
| 722 |
+
@torch.jit.ignore
|
| 723 |
+
def no_weight_decay_keywords(self):
|
| 724 |
+
return {'relative_position_bias_table'}
|
| 725 |
+
|
| 726 |
+
def forward_features(self, x):
|
| 727 |
+
x = self.patch_embed(x)
|
| 728 |
+
if self.ape:
|
| 729 |
+
x = x + self.absolute_pos_embed
|
| 730 |
+
x = self.pos_drop(x)
|
| 731 |
+
|
| 732 |
+
for layer in self.layers:
|
| 733 |
+
x = layer(x)
|
| 734 |
+
|
| 735 |
+
x = self.norm(x) # B L C
|
| 736 |
+
x = self.avgpool(x.transpose(1, 2)) # B C 1
|
| 737 |
+
x = torch.flatten(x, 1)
|
| 738 |
+
return x
|
| 739 |
+
|
| 740 |
+
def forward(self, x):
|
| 741 |
+
x = self.forward_features(x)
|
| 742 |
+
x = self.head(x)
|
| 743 |
+
return x
|
| 744 |
+
|
| 745 |
+
|
| 746 |
+
@register_image_encoder
|
| 747 |
+
def image_encoder(config_encoder, verbose, **kwargs):
|
| 748 |
+
spec = config_encoder['SPEC']
|
| 749 |
+
|
| 750 |
+
coswin = SwinTransformer(
|
| 751 |
+
img_size=config_encoder['IMAGE_SIZE'],
|
| 752 |
+
patch_size=spec['PATCH_SIZE'],
|
| 753 |
+
patch_padding=spec['PATCH_PADDING'],
|
| 754 |
+
patch_stride=spec['PATCH_STRIDE'],
|
| 755 |
+
in_chans=spec['IN_CHANS'],
|
| 756 |
+
num_classes=0,
|
| 757 |
+
embed_dim=spec['EMBED_DIM'],
|
| 758 |
+
depths=spec['DEPTHS'],
|
| 759 |
+
num_heads=spec['NUM_HEADS'],
|
| 760 |
+
window_size=spec['WINDOW_SIZE'],
|
| 761 |
+
mlp_ratio=spec['MLP_RATIO'],
|
| 762 |
+
qkv_bias=spec['QKV_BIAS'],
|
| 763 |
+
qk_scale=spec.get('QK_SCALE', None),
|
| 764 |
+
drop_rate=spec['DROP_RATE'],
|
| 765 |
+
drop_path_rate=spec['DROP_PATH_RATE'],
|
| 766 |
+
ape=spec['APE'],
|
| 767 |
+
patch_norm=spec['PATCH_NORM'],
|
| 768 |
+
layer_scale=spec.get('LAYER_SCALE', False),
|
| 769 |
+
use_checkpoint=spec.get('ENABLE_CHECKPOINT', False)
|
| 770 |
+
)
|
| 771 |
+
|
| 772 |
+
if config_encoder['LOAD_PRETRAINED']:
|
| 773 |
+
coswin.from_pretrained(
|
| 774 |
+
config_encoder['PRETRAINED'],
|
| 775 |
+
config_encoder['PRETRAINED_LAYERS'],
|
| 776 |
+
verbose
|
| 777 |
+
)
|
| 778 |
+
|
| 779 |
+
return coswin
|
MedImageInsight/ImageEncoder/davit_v1.py
ADDED
|
@@ -0,0 +1,727 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import copy
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import torch.utils.checkpoint as checkpoint
|
| 9 |
+
from collections import OrderedDict
|
| 10 |
+
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
from timm.models.layers import DropPath, trunc_normal_
|
| 13 |
+
|
| 14 |
+
# helper methods
|
| 15 |
+
from .registry import register_image_encoder
|
| 16 |
+
|
| 17 |
+
import mup.init
|
| 18 |
+
from mup import MuReadout, set_base_shapes
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class MySequential(nn.Sequential):
|
| 24 |
+
def forward(self, *inputs):
|
| 25 |
+
for module in self._modules.values():
|
| 26 |
+
if type(inputs) == tuple:
|
| 27 |
+
inputs = module(*inputs)
|
| 28 |
+
else:
|
| 29 |
+
inputs = module(inputs)
|
| 30 |
+
return inputs
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class PreNorm(nn.Module):
|
| 34 |
+
def __init__(self, norm, fn, drop_path=None):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.norm = norm
|
| 37 |
+
self.fn = fn
|
| 38 |
+
self.drop_path = drop_path
|
| 39 |
+
|
| 40 |
+
def forward(self, x, *args, **kwargs):
|
| 41 |
+
shortcut = x
|
| 42 |
+
if self.norm != None:
|
| 43 |
+
x, size = self.fn(self.norm(x), *args, **kwargs)
|
| 44 |
+
else:
|
| 45 |
+
x, size = self.fn(x, *args, **kwargs)
|
| 46 |
+
|
| 47 |
+
if self.drop_path:
|
| 48 |
+
x = self.drop_path(x)
|
| 49 |
+
|
| 50 |
+
x = shortcut + x
|
| 51 |
+
|
| 52 |
+
return x, size
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class Mlp(nn.Module):
|
| 56 |
+
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def __init__(
|
| 60 |
+
self,
|
| 61 |
+
in_features,
|
| 62 |
+
hidden_features=None,
|
| 63 |
+
out_features=None,
|
| 64 |
+
act_layer=nn.GELU,
|
| 65 |
+
):
|
| 66 |
+
super().__init__()
|
| 67 |
+
out_features = out_features or in_features
|
| 68 |
+
hidden_features = hidden_features or in_features
|
| 69 |
+
self.net = nn.Sequential(OrderedDict([
|
| 70 |
+
("fc1", nn.Linear(in_features, hidden_features)),
|
| 71 |
+
("act", act_layer()),
|
| 72 |
+
("fc2", nn.Linear(hidden_features, out_features))
|
| 73 |
+
]))
|
| 74 |
+
|
| 75 |
+
def forward(self, x, size):
|
| 76 |
+
return self.net(x), size
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class DepthWiseConv2d(nn.Module):
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
dim_in,
|
| 83 |
+
kernel_size,
|
| 84 |
+
padding,
|
| 85 |
+
stride,
|
| 86 |
+
bias=True,
|
| 87 |
+
):
|
| 88 |
+
super().__init__()
|
| 89 |
+
self.dw = nn.Conv2d(
|
| 90 |
+
dim_in, dim_in,
|
| 91 |
+
kernel_size=kernel_size,
|
| 92 |
+
padding=padding,
|
| 93 |
+
groups=dim_in,
|
| 94 |
+
stride=stride,
|
| 95 |
+
bias=bias
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
def forward(self, x, size):
|
| 99 |
+
B, N, C = x.shape
|
| 100 |
+
H, W = size
|
| 101 |
+
assert N == H * W
|
| 102 |
+
|
| 103 |
+
x = self.dw(x.transpose(1, 2).view(B, C, H, W))
|
| 104 |
+
size = (x.size(-2), x.size(-1))
|
| 105 |
+
x = x.flatten(2).transpose(1, 2)
|
| 106 |
+
return x, size
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class ConvEmbed(nn.Module):
|
| 110 |
+
""" Image to Patch Embedding
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
def __init__(
|
| 114 |
+
self,
|
| 115 |
+
patch_size=7,
|
| 116 |
+
in_chans=3,
|
| 117 |
+
embed_dim=64,
|
| 118 |
+
stride=4,
|
| 119 |
+
padding=2,
|
| 120 |
+
norm_layer=None,
|
| 121 |
+
pre_norm=True
|
| 122 |
+
):
|
| 123 |
+
super().__init__()
|
| 124 |
+
self.patch_size = patch_size
|
| 125 |
+
|
| 126 |
+
self.proj = nn.Conv2d(
|
| 127 |
+
in_chans, embed_dim,
|
| 128 |
+
kernel_size=patch_size,
|
| 129 |
+
stride=stride,
|
| 130 |
+
padding=padding
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
dim_norm = in_chans if pre_norm else embed_dim
|
| 134 |
+
self.norm = norm_layer(dim_norm) if norm_layer else None
|
| 135 |
+
|
| 136 |
+
self.pre_norm = pre_norm
|
| 137 |
+
|
| 138 |
+
def forward(self, x, size):
|
| 139 |
+
H, W = size
|
| 140 |
+
if len(x.size()) == 3:
|
| 141 |
+
if self.norm and self.pre_norm:
|
| 142 |
+
x = self.norm(x)
|
| 143 |
+
x = rearrange(
|
| 144 |
+
x, 'b (h w) c -> b c h w',
|
| 145 |
+
h=H, w=W
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
x = self.proj(x)
|
| 149 |
+
|
| 150 |
+
_, _, H, W = x.shape
|
| 151 |
+
x = rearrange(x, 'b c h w -> b (h w) c')
|
| 152 |
+
if self.norm and not self.pre_norm:
|
| 153 |
+
x = self.norm(x)
|
| 154 |
+
|
| 155 |
+
return x, (H, W)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class ChannelAttention(nn.Module):
|
| 159 |
+
|
| 160 |
+
def __init__(self, dim, base_dim, groups=8, base_groups=8, qkv_bias=True, dynamic_scale=True, standparam=True):
|
| 161 |
+
super().__init__()
|
| 162 |
+
|
| 163 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 164 |
+
self.proj = nn.Linear(dim, dim)
|
| 165 |
+
self.dynamic_scale = dynamic_scale
|
| 166 |
+
|
| 167 |
+
self.dim = dim
|
| 168 |
+
self.groups = groups
|
| 169 |
+
self.group_dim = dim // groups
|
| 170 |
+
|
| 171 |
+
self.base_dim = base_dim
|
| 172 |
+
self.base_groups = base_groups
|
| 173 |
+
self.base_group_dim = base_dim // base_groups
|
| 174 |
+
|
| 175 |
+
self.group_wm = self.group_dim / self.base_group_dim # Width multiplier for each group.
|
| 176 |
+
self.standparam = standparam
|
| 177 |
+
|
| 178 |
+
def forward(self, x, size):
|
| 179 |
+
B, N, C = x.shape
|
| 180 |
+
assert C == self.dim
|
| 181 |
+
|
| 182 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.groups, C // self.groups).permute(2, 0, 3, 1, 4)
|
| 183 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # Shape: [B, groups, N, group_dim].
|
| 184 |
+
|
| 185 |
+
scale = N ** -0.5 if self.dynamic_scale else self.dim ** -0.5
|
| 186 |
+
|
| 187 |
+
# Change the scaling factor.
|
| 188 |
+
# Ref: examples/Transformer/model.py in muP.
|
| 189 |
+
# Note: We consider backward compatiblity and follow https://github.com/microsoft/mup/issues/18.
|
| 190 |
+
if self.standparam:
|
| 191 |
+
scale = N ** -0.5 if self.dynamic_scale else self.dim ** -0.5
|
| 192 |
+
else:
|
| 193 |
+
assert self.dynamic_scale # Currently only support dynamic scale.
|
| 194 |
+
scale = N ** -0.5
|
| 195 |
+
|
| 196 |
+
q = q * scale
|
| 197 |
+
attention = q.transpose(-1, -2) @ k
|
| 198 |
+
attention = attention.softmax(dim=-1)
|
| 199 |
+
|
| 200 |
+
if not self.standparam:
|
| 201 |
+
# Follow https://github.com/microsoft/mup/issues/18.
|
| 202 |
+
attention = attention / self.group_wm
|
| 203 |
+
|
| 204 |
+
x = (attention @ v.transpose(-1, -2)).transpose(-1, -2)
|
| 205 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
| 206 |
+
x = self.proj(x)
|
| 207 |
+
return x, size
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class ChannelBlock(nn.Module):
|
| 211 |
+
|
| 212 |
+
def __init__(self, dim, base_dim, groups, base_groups, mlp_ratio=4., qkv_bias=True,
|
| 213 |
+
drop_path_rate=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
|
| 214 |
+
conv_at_attn=True, conv_at_ffn=True, dynamic_scale=True, standparam=True):
|
| 215 |
+
super().__init__()
|
| 216 |
+
|
| 217 |
+
drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
| 218 |
+
|
| 219 |
+
self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None
|
| 220 |
+
self.channel_attn = PreNorm(
|
| 221 |
+
norm_layer(dim),
|
| 222 |
+
ChannelAttention(dim, base_dim, groups=groups, base_groups=base_groups, qkv_bias=qkv_bias,
|
| 223 |
+
dynamic_scale=dynamic_scale, standparam=standparam),
|
| 224 |
+
drop_path
|
| 225 |
+
)
|
| 226 |
+
self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None
|
| 227 |
+
self.ffn = PreNorm(
|
| 228 |
+
norm_layer(dim),
|
| 229 |
+
Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer),
|
| 230 |
+
drop_path
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
def forward(self, x, size):
|
| 234 |
+
if self.conv1:
|
| 235 |
+
x, size = self.conv1(x, size)
|
| 236 |
+
x, size = self.channel_attn(x, size)
|
| 237 |
+
|
| 238 |
+
if self.conv2:
|
| 239 |
+
x, size = self.conv2(x, size)
|
| 240 |
+
x, size = self.ffn(x, size)
|
| 241 |
+
|
| 242 |
+
return x, size
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def window_partition(x, window_size: int):
|
| 246 |
+
B, H, W, C = x.shape
|
| 247 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
| 248 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
| 249 |
+
return windows
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def window_reverse(windows, window_size: int, H: int, W: int):
|
| 253 |
+
B = windows.shape[0] // (H * W // window_size // window_size)
|
| 254 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
| 255 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
| 256 |
+
return x
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
class WindowAttention(nn.Module):
|
| 260 |
+
|
| 261 |
+
def __init__(self, dim, base_dim, num_heads, base_num_heads, window_size, qkv_bias=True, standparam=True):
|
| 262 |
+
|
| 263 |
+
super().__init__()
|
| 264 |
+
|
| 265 |
+
self.window_size = window_size
|
| 266 |
+
|
| 267 |
+
self.dim = dim
|
| 268 |
+
self.num_heads = num_heads
|
| 269 |
+
head_dim = dim // num_heads
|
| 270 |
+
|
| 271 |
+
self.base_dim = base_dim
|
| 272 |
+
self.base_num_heads = base_num_heads
|
| 273 |
+
base_head_dim = base_dim // base_num_heads
|
| 274 |
+
|
| 275 |
+
# Change the scaling factor.
|
| 276 |
+
# Ref: examples/Transformer/model.py in muP.
|
| 277 |
+
# Note: We consider backward compatiblity and follow https://github.com/microsoft/mup/issues/17.
|
| 278 |
+
if standparam:
|
| 279 |
+
scale = float(head_dim) ** -0.5
|
| 280 |
+
else:
|
| 281 |
+
# TODO: Here we ensure backward compatibility, which may not be optimal.
|
| 282 |
+
# We may add an argument called backward_comp. If it is set as False, we use
|
| 283 |
+
# float(head_dim) ** -1 * math.sqrt(attn_mult)
|
| 284 |
+
# as in the Transformer example in muP.
|
| 285 |
+
base_scale = float(base_head_dim) ** -0.5 # The same as scaling in standard parametrization.
|
| 286 |
+
head_wm = head_dim / base_head_dim # Width multiplier for each head.
|
| 287 |
+
scale = base_scale / head_wm
|
| 288 |
+
# scale_1 = (float(base_head_dim) ** 0.5) * (float(head_dim) ** -1) # Equivalent implementation as shown in the muP paper.
|
| 289 |
+
# assert np.isclose(scale, scale_1)
|
| 290 |
+
self.scale = scale
|
| 291 |
+
|
| 292 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 293 |
+
self.proj = nn.Linear(dim, dim)
|
| 294 |
+
|
| 295 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 296 |
+
|
| 297 |
+
def forward(self, x, size):
|
| 298 |
+
|
| 299 |
+
H, W = size
|
| 300 |
+
B, L, C = x.shape
|
| 301 |
+
assert L == H * W, "input feature has wrong size"
|
| 302 |
+
|
| 303 |
+
x = x.view(B, H, W, C)
|
| 304 |
+
|
| 305 |
+
pad_l = pad_t = 0
|
| 306 |
+
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
| 307 |
+
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
| 308 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
| 309 |
+
_, Hp, Wp, _ = x.shape
|
| 310 |
+
|
| 311 |
+
x = window_partition(x, self.window_size)
|
| 312 |
+
x = x.view(-1, self.window_size * self.window_size, C)
|
| 313 |
+
|
| 314 |
+
B_, N, C = x.shape
|
| 315 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 316 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 317 |
+
|
| 318 |
+
q = q * self.scale
|
| 319 |
+
attn = (q @ k.transpose(-2, -1))
|
| 320 |
+
attn = self.softmax(attn)
|
| 321 |
+
|
| 322 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
| 323 |
+
x = self.proj(x)
|
| 324 |
+
|
| 325 |
+
# merge windows
|
| 326 |
+
x = x.view(
|
| 327 |
+
-1, self.window_size, self.window_size, C
|
| 328 |
+
)
|
| 329 |
+
x = window_reverse(x, self.window_size, Hp, Wp)
|
| 330 |
+
|
| 331 |
+
if pad_r > 0 or pad_b > 0:
|
| 332 |
+
x = x[:, :H, :W, :].contiguous()
|
| 333 |
+
|
| 334 |
+
x = x.view(B, H * W, C)
|
| 335 |
+
|
| 336 |
+
return x, size
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
class SpatialBlock(nn.Module):
|
| 340 |
+
|
| 341 |
+
def __init__(self, dim, base_dim, num_heads, base_num_heads, window_size,
|
| 342 |
+
mlp_ratio=4., qkv_bias=True, drop_path_rate=0., act_layer=nn.GELU,
|
| 343 |
+
norm_layer=nn.LayerNorm, conv_at_attn=True, conv_at_ffn=True, standparam=True):
|
| 344 |
+
super().__init__()
|
| 345 |
+
|
| 346 |
+
drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
| 347 |
+
|
| 348 |
+
self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None
|
| 349 |
+
self.window_attn = PreNorm(
|
| 350 |
+
norm_layer(dim),
|
| 351 |
+
WindowAttention(dim, base_dim, num_heads, base_num_heads, window_size, qkv_bias=qkv_bias,
|
| 352 |
+
standparam=standparam),
|
| 353 |
+
drop_path
|
| 354 |
+
)
|
| 355 |
+
self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None
|
| 356 |
+
self.ffn = PreNorm(
|
| 357 |
+
norm_layer(dim),
|
| 358 |
+
Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer),
|
| 359 |
+
drop_path
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
def forward(self, x, size):
|
| 363 |
+
if self.conv1:
|
| 364 |
+
x, size = self.conv1(x, size)
|
| 365 |
+
x, size = self.window_attn(x, size)
|
| 366 |
+
|
| 367 |
+
if self.conv2:
|
| 368 |
+
x, size = self.conv2(x, size)
|
| 369 |
+
x, size = self.ffn(x, size)
|
| 370 |
+
return x, size
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
class DaViT(nn.Module):
|
| 374 |
+
""" DaViT: Dual-Attention Transformer
|
| 375 |
+
|
| 376 |
+
Args:
|
| 377 |
+
img_size (int | tuple(int)): Input image size. Default: 224
|
| 378 |
+
patch_size (int | tuple(int)): Patch size. Default: 4
|
| 379 |
+
in_chans (int): Number of input image channels. Default: 3
|
| 380 |
+
num_classes (int): Number of classes for classification head. Default: 1000
|
| 381 |
+
depths (tuple(int)): Number of spatial and channel blocks in different stages. Default: (1, 1, 3, 1)
|
| 382 |
+
patch_size (tuple(int)): Patch sizes in different stages. Default: (7, 2, 2, 2)
|
| 383 |
+
patch_stride (tuple(int)): Patch strides in different stages. Default: (4, 2, 2, 2)
|
| 384 |
+
patch_padding (tuple(int)): Patch padding sizes in different stages. Default: (3, 0, 0, 0)
|
| 385 |
+
patch_prenorm (tuple(bool)): Use pre-normalization or not in different stages. Default: (False, False, False, False)
|
| 386 |
+
embed_dims (tuple(int)): Patch embedding dimension. Default: (64, 128, 192, 256)
|
| 387 |
+
base_embed_dims (tuple(int)): Patch embedding dimension (base case for muP). Default: (64, 128, 192, 256)
|
| 388 |
+
num_heads (tuple(int)): Number of attention heads in different layers. Default: (4, 8, 12, 16)
|
| 389 |
+
base_num_heads (tuple(int)): Number of attention heads in different layers (base case for muP). Default: (4, 8, 12, 16)
|
| 390 |
+
num_groups (tuple(int)): Number of groups in channel attention in different layers. Default: (3, 6, 12, 24)
|
| 391 |
+
base_num_groups (tuple(int)): Number of groups in channel attention in different layers (base case for muP). Default: (3, 6, 12, 24)
|
| 392 |
+
window_size (int): Window size. Default: 7
|
| 393 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
| 394 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
| 395 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
| 396 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
| 397 |
+
enable_checkpoint (bool): If True, enabling checkpoint. Default: False
|
| 398 |
+
conv_at_attn (bool): If True, add convolution layer before attention. Default: True
|
| 399 |
+
conv_at_ffn (bool): If True, add convolution layer before ffn. Default: True
|
| 400 |
+
dynamic_scale (bool): If True, scale of channel attention is respect to the number of tokens. Default: True
|
| 401 |
+
standparam (bool): Use standard parametrization or mu-parametrization. Default: True (i.e., use standard paramerization)
|
| 402 |
+
"""
|
| 403 |
+
|
| 404 |
+
def __init__(
|
| 405 |
+
self,
|
| 406 |
+
img_size=224,
|
| 407 |
+
in_chans=3,
|
| 408 |
+
num_classes=1000,
|
| 409 |
+
depths=(1, 1, 3, 1),
|
| 410 |
+
patch_size=(7, 2, 2, 2),
|
| 411 |
+
patch_stride=(4, 2, 2, 2),
|
| 412 |
+
patch_padding=(3, 0, 0, 0),
|
| 413 |
+
patch_prenorm=(False, False, False, False),
|
| 414 |
+
embed_dims=(64, 128, 192, 256),
|
| 415 |
+
base_embed_dims=(64, 128, 192, 256),
|
| 416 |
+
num_heads=(3, 6, 12, 24),
|
| 417 |
+
base_num_heads=(3, 6, 12, 24),
|
| 418 |
+
num_groups=(3, 6, 12, 24),
|
| 419 |
+
base_num_groups=(3, 6, 12, 24),
|
| 420 |
+
window_size=7,
|
| 421 |
+
mlp_ratio=4.,
|
| 422 |
+
qkv_bias=True,
|
| 423 |
+
drop_path_rate=0.1,
|
| 424 |
+
norm_layer=nn.LayerNorm,
|
| 425 |
+
enable_checkpoint=False,
|
| 426 |
+
conv_at_attn=True,
|
| 427 |
+
conv_at_ffn=True,
|
| 428 |
+
dynamic_scale=True,
|
| 429 |
+
standparam=True
|
| 430 |
+
):
|
| 431 |
+
super().__init__()
|
| 432 |
+
|
| 433 |
+
self.num_classes = num_classes
|
| 434 |
+
self.embed_dims = embed_dims
|
| 435 |
+
self.num_heads = num_heads
|
| 436 |
+
self.num_groups = num_groups
|
| 437 |
+
self.num_stages = len(self.embed_dims)
|
| 438 |
+
self.enable_checkpoint = enable_checkpoint
|
| 439 |
+
assert self.num_stages == len(self.num_heads) == len(self.num_groups)
|
| 440 |
+
|
| 441 |
+
num_stages = len(embed_dims)
|
| 442 |
+
self.img_size = img_size
|
| 443 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths) * 2)]
|
| 444 |
+
|
| 445 |
+
depth_offset = 0
|
| 446 |
+
convs = []
|
| 447 |
+
blocks = []
|
| 448 |
+
for i in range(num_stages):
|
| 449 |
+
conv_embed = ConvEmbed(
|
| 450 |
+
patch_size=patch_size[i],
|
| 451 |
+
stride=patch_stride[i],
|
| 452 |
+
padding=patch_padding[i],
|
| 453 |
+
in_chans=in_chans if i == 0 else self.embed_dims[i - 1],
|
| 454 |
+
embed_dim=self.embed_dims[i],
|
| 455 |
+
norm_layer=norm_layer,
|
| 456 |
+
pre_norm=patch_prenorm[i]
|
| 457 |
+
)
|
| 458 |
+
convs.append(conv_embed)
|
| 459 |
+
|
| 460 |
+
logger.info(f'=> Depth offset in stage {i}: {depth_offset}')
|
| 461 |
+
block = MySequential(
|
| 462 |
+
*[
|
| 463 |
+
MySequential(OrderedDict([
|
| 464 |
+
(
|
| 465 |
+
'spatial_block', SpatialBlock(
|
| 466 |
+
embed_dims[i],
|
| 467 |
+
base_embed_dims[i],
|
| 468 |
+
num_heads[i],
|
| 469 |
+
base_num_heads[i],
|
| 470 |
+
window_size,
|
| 471 |
+
drop_path_rate=dpr[depth_offset + j * 2],
|
| 472 |
+
qkv_bias=qkv_bias,
|
| 473 |
+
mlp_ratio=mlp_ratio,
|
| 474 |
+
conv_at_attn=conv_at_attn,
|
| 475 |
+
conv_at_ffn=conv_at_ffn,
|
| 476 |
+
standparam=standparam
|
| 477 |
+
)
|
| 478 |
+
),
|
| 479 |
+
(
|
| 480 |
+
'channel_block', ChannelBlock(
|
| 481 |
+
embed_dims[i],
|
| 482 |
+
base_embed_dims[i],
|
| 483 |
+
num_groups[i],
|
| 484 |
+
base_num_groups[i],
|
| 485 |
+
drop_path_rate=dpr[depth_offset + j * 2 + 1],
|
| 486 |
+
qkv_bias=qkv_bias,
|
| 487 |
+
mlp_ratio=mlp_ratio,
|
| 488 |
+
conv_at_attn=conv_at_attn,
|
| 489 |
+
conv_at_ffn=conv_at_ffn,
|
| 490 |
+
dynamic_scale=dynamic_scale,
|
| 491 |
+
standparam=standparam
|
| 492 |
+
)
|
| 493 |
+
)
|
| 494 |
+
])) for j in range(depths[i])
|
| 495 |
+
]
|
| 496 |
+
)
|
| 497 |
+
blocks.append(block)
|
| 498 |
+
depth_offset += depths[i] * 2
|
| 499 |
+
|
| 500 |
+
self.convs = nn.ModuleList(convs)
|
| 501 |
+
self.blocks = nn.ModuleList(blocks)
|
| 502 |
+
|
| 503 |
+
self.norms = norm_layer(self.embed_dims[-1])
|
| 504 |
+
self.avgpool = nn.AdaptiveAvgPool1d(1)
|
| 505 |
+
|
| 506 |
+
if standparam:
|
| 507 |
+
self.head = nn.Linear(self.embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
|
| 508 |
+
else:
|
| 509 |
+
self.head = MuReadout(self.embed_dims[-1], num_classes,
|
| 510 |
+
readout_zero_init=True) # Follow examples/ResNet/resnet.py in muP.
|
| 511 |
+
|
| 512 |
+
if torch.cuda.is_available():
|
| 513 |
+
self.device = torch.device(type="cuda", index=0)
|
| 514 |
+
else:
|
| 515 |
+
self.device = torch.device(type="cpu")
|
| 516 |
+
|
| 517 |
+
def custom_init_weights(self, use_original_init=True):
|
| 518 |
+
self.use_original_init = use_original_init
|
| 519 |
+
logger.info('Custom init: {}'.format('original init' if self.use_original_init else 'muP init'))
|
| 520 |
+
self.apply(self._custom_init_weights)
|
| 521 |
+
|
| 522 |
+
@property
|
| 523 |
+
def dim_out(self):
|
| 524 |
+
return self.embed_dims[-1]
|
| 525 |
+
|
| 526 |
+
def _custom_init_weights(self, m):
|
| 527 |
+
# Customized initialization for weights.
|
| 528 |
+
if self.use_original_init:
|
| 529 |
+
# Original initialization.
|
| 530 |
+
# Note: This is not SP init. We do not implement SP init here.
|
| 531 |
+
custom_trunc_normal_ = trunc_normal_
|
| 532 |
+
custom_normal_ = nn.init.normal_
|
| 533 |
+
else:
|
| 534 |
+
# muP.
|
| 535 |
+
custom_trunc_normal_ = mup.init.trunc_normal_
|
| 536 |
+
custom_normal_ = mup.init.normal_
|
| 537 |
+
|
| 538 |
+
# These initializations will overwrite the existing inializations from the modules and adjusted by set_base_shapes().
|
| 539 |
+
if isinstance(m, MuReadout):
|
| 540 |
+
pass # Note: MuReadout is already zero initialized due to readout_zero_init=True.
|
| 541 |
+
elif isinstance(m, nn.Linear):
|
| 542 |
+
custom_trunc_normal_(m.weight, std=0.02)
|
| 543 |
+
if m.bias is not None:
|
| 544 |
+
nn.init.constant_(m.bias, 0)
|
| 545 |
+
elif isinstance(m, nn.Conv2d):
|
| 546 |
+
custom_normal_(m.weight, std=0.02)
|
| 547 |
+
for name, _ in m.named_parameters():
|
| 548 |
+
if name in ['bias']:
|
| 549 |
+
nn.init.constant_(m.bias, 0)
|
| 550 |
+
elif isinstance(m, nn.LayerNorm): # Follow P24 Layernorm Weights and Biases.
|
| 551 |
+
nn.init.constant_(m.weight, 1.0)
|
| 552 |
+
nn.init.constant_(m.bias, 0)
|
| 553 |
+
elif isinstance(m, nn.BatchNorm2d): # Follow P24 Layernorm Weights and Biases.
|
| 554 |
+
nn.init.constant_(m.weight, 1.0)
|
| 555 |
+
nn.init.constant_(m.bias, 0)
|
| 556 |
+
|
| 557 |
+
def _try_remap_keys(self, pretrained_dict):
|
| 558 |
+
remap_keys = {
|
| 559 |
+
"conv_embeds": "convs",
|
| 560 |
+
"main_blocks": "blocks",
|
| 561 |
+
"0.cpe.0.proj": "spatial_block.conv1.fn.dw",
|
| 562 |
+
"0.attn": "spatial_block.window_attn.fn",
|
| 563 |
+
"0.cpe.1.proj": "spatial_block.conv2.fn.dw",
|
| 564 |
+
"0.mlp": "spatial_block.ffn.fn.net",
|
| 565 |
+
"1.cpe.0.proj": "channel_block.conv1.fn.dw",
|
| 566 |
+
"1.attn": "channel_block.channel_attn.fn",
|
| 567 |
+
"1.cpe.1.proj": "channel_block.conv2.fn.dw",
|
| 568 |
+
"1.mlp": "channel_block.ffn.fn.net",
|
| 569 |
+
"0.norm1": "spatial_block.window_attn.norm",
|
| 570 |
+
"0.norm2": "spatial_block.ffn.norm",
|
| 571 |
+
"1.norm1": "channel_block.channel_attn.norm",
|
| 572 |
+
"1.norm2": "channel_block.ffn.norm"
|
| 573 |
+
}
|
| 574 |
+
|
| 575 |
+
full_key_mappings = {}
|
| 576 |
+
for k in pretrained_dict.keys():
|
| 577 |
+
old_k = k
|
| 578 |
+
for remap_key in remap_keys.keys():
|
| 579 |
+
if remap_key in k:
|
| 580 |
+
logger.info(f'=> Repace {remap_key} with {remap_keys[remap_key]}')
|
| 581 |
+
k = k.replace(remap_key, remap_keys[remap_key])
|
| 582 |
+
|
| 583 |
+
full_key_mappings[old_k] = k
|
| 584 |
+
|
| 585 |
+
return full_key_mappings
|
| 586 |
+
|
| 587 |
+
def from_state_dict(self, pretrained_dict, pretrained_layers=[], verbose=True):
|
| 588 |
+
model_dict = self.state_dict()
|
| 589 |
+
stripped_key = lambda x: x[14:] if x.startswith('image_encoder.') else x
|
| 590 |
+
full_key_mappings = self._try_remap_keys(pretrained_dict)
|
| 591 |
+
|
| 592 |
+
pretrained_dict = {
|
| 593 |
+
stripped_key(full_key_mappings[k]): v.to(self.device) for k, v in pretrained_dict.items()
|
| 594 |
+
if stripped_key(full_key_mappings[k]) in model_dict.keys()
|
| 595 |
+
}
|
| 596 |
+
need_init_state_dict = {}
|
| 597 |
+
for k, v in pretrained_dict.items():
|
| 598 |
+
need_init = (
|
| 599 |
+
k.split('.')[0] in pretrained_layers
|
| 600 |
+
or pretrained_layers[0] == '*'
|
| 601 |
+
)
|
| 602 |
+
if need_init:
|
| 603 |
+
if verbose:
|
| 604 |
+
logger.info(f'=> init {k} from pretrained state dict')
|
| 605 |
+
|
| 606 |
+
need_init_state_dict[k] = v.to(self.device)
|
| 607 |
+
self.load_state_dict(need_init_state_dict, strict=False)
|
| 608 |
+
|
| 609 |
+
def from_pretrained(self, pretrained='', pretrained_layers=[], verbose=True):
|
| 610 |
+
if os.path.isfile(pretrained):
|
| 611 |
+
logger.info(f'=> loading pretrained model {pretrained}')
|
| 612 |
+
pretrained_dict = torch.load(pretrained, map_location='cpu')
|
| 613 |
+
|
| 614 |
+
self.from_state_dict(pretrained_dict, pretrained_layers, verbose)
|
| 615 |
+
|
| 616 |
+
def forward_features(self, x):
|
| 617 |
+
input_size = (x.size(2), x.size(3))
|
| 618 |
+
for conv, block in zip(self.convs, self.blocks):
|
| 619 |
+
x, input_size = conv(x, input_size)
|
| 620 |
+
if self.enable_checkpoint:
|
| 621 |
+
x, input_size = checkpoint.checkpoint(block, x, input_size)
|
| 622 |
+
else:
|
| 623 |
+
x, input_size = block(x, input_size)
|
| 624 |
+
|
| 625 |
+
x = self.avgpool(x.transpose(1, 2))
|
| 626 |
+
x = torch.flatten(x, 1)
|
| 627 |
+
x = self.norms(x)
|
| 628 |
+
|
| 629 |
+
return x
|
| 630 |
+
|
| 631 |
+
def forward(self, x):
|
| 632 |
+
x = self.forward_features(x)
|
| 633 |
+
x = self.head(x)
|
| 634 |
+
return x
|
| 635 |
+
|
| 636 |
+
|
| 637 |
+
def create_encoder(config_encoder):
|
| 638 |
+
spec = config_encoder['SPEC']
|
| 639 |
+
standparam = spec.get('STANDPARAM', True)
|
| 640 |
+
|
| 641 |
+
if standparam:
|
| 642 |
+
# Dummy values for muP parameters.
|
| 643 |
+
base_embed_dims = spec['DIM_EMBED']
|
| 644 |
+
base_num_heads = spec['NUM_HEADS']
|
| 645 |
+
base_num_groups = spec['NUM_GROUPS']
|
| 646 |
+
else:
|
| 647 |
+
base_embed_dims = spec['BASE_DIM_EMBED']
|
| 648 |
+
base_num_heads = spec['BASE_NUM_HEADS']
|
| 649 |
+
base_num_groups = spec['BASE_NUM_GROUPS']
|
| 650 |
+
|
| 651 |
+
davit = DaViT(
|
| 652 |
+
num_classes=config_encoder['NUM_CLASSES'],
|
| 653 |
+
depths=spec['DEPTHS'],
|
| 654 |
+
embed_dims=spec['DIM_EMBED'],
|
| 655 |
+
base_embed_dims=base_embed_dims,
|
| 656 |
+
num_heads=spec['NUM_HEADS'],
|
| 657 |
+
base_num_heads=base_num_heads,
|
| 658 |
+
num_groups=spec['NUM_GROUPS'],
|
| 659 |
+
base_num_groups=base_num_groups,
|
| 660 |
+
patch_size=spec['PATCH_SIZE'],
|
| 661 |
+
patch_stride=spec['PATCH_STRIDE'],
|
| 662 |
+
patch_padding=spec['PATCH_PADDING'],
|
| 663 |
+
patch_prenorm=spec['PATCH_PRENORM'],
|
| 664 |
+
drop_path_rate=spec['DROP_PATH_RATE'],
|
| 665 |
+
img_size=config_encoder['IMAGE_SIZE'],
|
| 666 |
+
window_size=spec.get('WINDOW_SIZE', 7),
|
| 667 |
+
enable_checkpoint=spec.get('ENABLE_CHECKPOINT', False),
|
| 668 |
+
conv_at_attn=spec.get('CONV_AT_ATTN', True),
|
| 669 |
+
conv_at_ffn=spec.get('CONV_AT_FFN', True),
|
| 670 |
+
dynamic_scale=spec.get('DYNAMIC_SCALE', True),
|
| 671 |
+
standparam=standparam,
|
| 672 |
+
)
|
| 673 |
+
return davit
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
def create_mup_encoder(config_encoder):
|
| 677 |
+
def gen_config(config, wm):
|
| 678 |
+
new_config = copy.deepcopy(config)
|
| 679 |
+
for name in ['DIM_EMBED', 'NUM_HEADS', 'NUM_GROUPS']:
|
| 680 |
+
base_name = 'BASE_' + name
|
| 681 |
+
new_values = [round(base_value * wm) for base_value in
|
| 682 |
+
config['SPEC'][base_name]] # New value = base value * width multiplier.
|
| 683 |
+
logger.info(f'config["SPEC"]["{name}"]: {new_config["SPEC"][name]} -> {new_values}')
|
| 684 |
+
new_config['SPEC'][name] = new_values
|
| 685 |
+
return new_config
|
| 686 |
+
|
| 687 |
+
logger.info('muP: Create models and set base shapes')
|
| 688 |
+
logger.info('=> Create model')
|
| 689 |
+
model = create_encoder(config_encoder)
|
| 690 |
+
|
| 691 |
+
logger.info('=> Create base model')
|
| 692 |
+
base_config = gen_config(config_encoder, wm=1.0)
|
| 693 |
+
base_model = create_encoder(base_config)
|
| 694 |
+
|
| 695 |
+
logger.info('=> Create delta model')
|
| 696 |
+
delta_config = gen_config(config_encoder, wm=2.0)
|
| 697 |
+
delta_model = create_encoder(delta_config)
|
| 698 |
+
|
| 699 |
+
logger.info('=> Set base shapes in model for training')
|
| 700 |
+
set_base_shapes(model, base=base_model, delta=delta_model)
|
| 701 |
+
|
| 702 |
+
return model
|
| 703 |
+
|
| 704 |
+
|
| 705 |
+
@register_image_encoder
|
| 706 |
+
def image_encoder(config_encoder, verbose, **kwargs):
|
| 707 |
+
spec = config_encoder['SPEC']
|
| 708 |
+
standparam = spec.get('STANDPARAM', True)
|
| 709 |
+
|
| 710 |
+
if standparam:
|
| 711 |
+
logger.info('Create model with standard parameterization')
|
| 712 |
+
model = create_encoder(config_encoder)
|
| 713 |
+
model.custom_init_weights(use_original_init=True)
|
| 714 |
+
else:
|
| 715 |
+
logger.info('Create model with mu parameterization')
|
| 716 |
+
model = create_mup_encoder(config_encoder)
|
| 717 |
+
model.custom_init_weights(use_original_init=False)
|
| 718 |
+
|
| 719 |
+
logger.info('Load model from pretrained checkpoint')
|
| 720 |
+
if config_encoder['LOAD_PRETRAINED']:
|
| 721 |
+
model.from_pretrained(
|
| 722 |
+
config_encoder['PRETRAINED'],
|
| 723 |
+
config_encoder['PRETRAINED_LAYERS'],
|
| 724 |
+
verbose
|
| 725 |
+
)
|
| 726 |
+
|
| 727 |
+
return model
|
MedImageInsight/ImageEncoder/registry.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_image_encoders = {}
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def register_image_encoder(fn):
|
| 5 |
+
module_name_split = fn.__module__.split('.')
|
| 6 |
+
model_name = module_name_split[-1]
|
| 7 |
+
|
| 8 |
+
_image_encoders[model_name] = fn
|
| 9 |
+
|
| 10 |
+
return fn
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def image_encoders(model_name):
|
| 14 |
+
return _image_encoders[model_name]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def is_image_encoder(model_name):
|
| 18 |
+
return model_name in _image_encoders
|
MedImageInsight/LangEncoder/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import absolute_import
|
| 2 |
+
from __future__ import division
|
| 3 |
+
from __future__ import print_function
|
| 4 |
+
|
| 5 |
+
from .build import build_lang_encoder
|
| 6 |
+
from .build import build_tokenizer
|
| 7 |
+
|
| 8 |
+
from .transformer import *
|
| 9 |
+
# from .hf_model import *
|
| 10 |
+
# from .zcode import *
|
| 11 |
+
# from .pretrain import *
|
| 12 |
+
# from .tulrv6 import *
|
| 13 |
+
# from .t5 import *
|