Spaces:
Runtime error
Runtime error
Upload 123 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- configs/multi_mo_multi_task.yaml +156 -0
- configs/multi_mo_multi_task_sar_prompt.yaml +174 -0
- datasets/__init__.py +3 -0
- datasets/__pycache__/__init__.cpython-310.pyc +0 -0
- datasets/__pycache__/__init__.cpython-37.pyc +0 -0
- datasets/__pycache__/datasets.cpython-310.pyc +0 -0
- datasets/__pycache__/datasets.cpython-37.pyc +0 -0
- datasets/__pycache__/image_folder.cpython-310.pyc +0 -0
- datasets/__pycache__/image_folder.cpython-37.pyc +0 -0
- datasets/__pycache__/wrappers.cpython-310.pyc +0 -0
- datasets/__pycache__/wrappers.cpython-37.pyc +0 -0
- datasets/data_loader_multi_tasks.py +26 -0
- datasets/data_simmim_pt.py +271 -0
- datasets/datasets.py +21 -0
- datasets/image_folder.py +370 -0
- datasets/wrappers.py +231 -0
- models/__init__.py +4 -0
- models/__pycache__/__init__.cpython-310.pyc +0 -0
- models/__pycache__/__init__.cpython-37.pyc +0 -0
- models/__pycache__/iou_loss.cpython-37.pyc +0 -0
- models/__pycache__/models.cpython-310.pyc +0 -0
- models/__pycache__/models.cpython-37.pyc +0 -0
- models/__pycache__/sam.cpython-310.pyc +0 -0
- models/__pycache__/sam.cpython-37.pyc +0 -0
- models/__pycache__/sam_single.cpython-37.pyc +0 -0
- models/__pycache__/utils_prompt.cpython-37.pyc +0 -0
- models/block.py +128 -0
- models/bn_helper.py +16 -0
- models/iou_loss.py +21 -0
- models/mmseg/__init__.py +33 -0
- models/mmseg/__pycache__/__init__.cpython-310.pyc +0 -0
- models/mmseg/__pycache__/__init__.cpython-37.pyc +0 -0
- models/mmseg/__pycache__/version.cpython-310.pyc +0 -0
- models/mmseg/__pycache__/version.cpython-37.pyc +0 -0
- models/mmseg/apis/__init__.py +9 -0
- models/mmseg/apis/inference.py +118 -0
- models/mmseg/apis/test.py +235 -0
- models/mmseg/apis/train.py +115 -0
- models/mmseg/core/__init__.py +3 -0
- models/mmseg/core/evaluation/__init__.py +8 -0
- models/mmseg/core/evaluation/class_names.py +152 -0
- models/mmseg/core/evaluation/eval_hooks.py +107 -0
- models/mmseg/core/evaluation/metrics.py +229 -0
- models/mmseg/core/seg/__init__.py +4 -0
- models/mmseg/core/seg/builder.py +8 -0
- models/mmseg/core/seg/sampler/__init__.py +4 -0
- models/mmseg/core/seg/sampler/base_pixel_sampler.py +13 -0
- models/mmseg/core/seg/sampler/ohem_pixel_sampler.py +76 -0
- models/mmseg/core/utils/__init__.py +3 -0
- models/mmseg/core/utils/misc.py +17 -0
configs/multi_mo_multi_task.yaml
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
train_dataset:
|
| 2 |
+
dataset:
|
| 3 |
+
name: paired-image-folders-multi-task
|
| 4 |
+
args:
|
| 5 |
+
# root_path_1: ./SAM_DATA_UNIFY/Overall_Update/split_image
|
| 6 |
+
# root_path_1: ./SAM_DATA_UNIFY2/OVERALL/split_image
|
| 7 |
+
# root_path_1: ./SAM_DATA_UNIFY2/ISAID/split_image
|
| 8 |
+
# root_path_1: [{'ISAID': './SAM_DATA_UNIFY2/ISAID/split_image', 'WHU': './SAM_DATA_UNIFY2/WHU-OPT/split_images'}]
|
| 9 |
+
# root_path_1: [{'Decoder1': "/workspace/SAM_DATA_UNIFY3/Decoder1/split_image/", 'Decoder2': "/workspace/SAM_DATA_UNIFY3/Decoder2/split_image/"}]
|
| 10 |
+
root_path_1: [{'Decoder1': "/workspace/SAM_DATA_UNIFY4/Decoder1/image/", 'Decoder2': "/workspace/SAM_DATA_UNIFY4/Decoder2/image/"}]
|
| 11 |
+
# root_path_2: ./SAM_DATA_UNIFY/Overall_Update/split_gt
|
| 12 |
+
# root_path_2: ./SAM_DATA_UNIFY2/OVERALL/split_gt
|
| 13 |
+
# root_path_2: ./SAM_DATA_UNIFY2/ISAID/split_gt
|
| 14 |
+
# root_path_2: [{'ISAID': './SAM_DATA_UNIFY2/ISAID/split_gt', 'WHU': './SAM_DATA_UNIFY2/WHU-OPT/split_gt'}]
|
| 15 |
+
# root_path_2: [{'Decoder1': "/workspace/SAM_DATA_UNIFY3/Decoder1/split_gt/", 'Decoder2': "/workspace/SAM_DATA_UNIFY3/Decoder2/split_gt/"}]
|
| 16 |
+
root_path_2: [{'Decoder1': "/workspace/SAM_DATA_UNIFY4/Decoder1/gt/", 'Decoder2': "/workspace/SAM_DATA_UNIFY4/Decoder2/gt/"}]
|
| 17 |
+
cache: nones
|
| 18 |
+
split_key: train
|
| 19 |
+
wrapper:
|
| 20 |
+
name: train_multi_task
|
| 21 |
+
args:
|
| 22 |
+
inp_size: 1024
|
| 23 |
+
augment: false
|
| 24 |
+
# batch_size: 2
|
| 25 |
+
batch_size: 2
|
| 26 |
+
|
| 27 |
+
val_dataset:
|
| 28 |
+
dataset:
|
| 29 |
+
name: paired-image-folders-multi-task
|
| 30 |
+
args:
|
| 31 |
+
# root_path_1: ./SAM_DATA_UNIFY2/OVERALL/split_image
|
| 32 |
+
# root_path_1: [{'ISAID': './SAM_DATA_UNIFY2/ISAID/split_image', 'WHU': './SAM_DATA_UNIFY2/WHU-OPT/split_images'}]
|
| 33 |
+
# root_path_1: [{'Decoder1': "/workspace/SAM_DATA_UNIFY3/Decoder1/split_image/", 'Decoder2': "/workspace/SAM_DATA_UNIFY3/Decoder2/split_image/"}]
|
| 34 |
+
root_path_1: [{'Decoder1': "/workspace/SAM_DATA_UNIFY4/Decoder1/image/", 'Decoder2': "/workspace/SAM_DATA_UNIFY4/Decoder2/image/"}]
|
| 35 |
+
# root_path_2: ./SAM_DATA_UNIFY2/OVERALL/split_gt
|
| 36 |
+
# root_path_2: [{'ISAID': './SAM_DATA_UNIFY2/ISAID/split_gt', 'WHU': './SAM_DATA_UNIFY2/WHU-OPT/split_gt'}]
|
| 37 |
+
# root_path_2: [{'Decoder1': "/workspace/SAM_DATA_UNIFY3/Decoder1/split_gt/", 'Decoder2': "/workspace/SAM_DATA_UNIFY3/Decoder2/split_gt/"}]
|
| 38 |
+
root_path_2: [{'Decoder1': "/workspace/SAM_DATA_UNIFY4/Decoder1/gt/", 'Decoder2': "/workspace/SAM_DATA_UNIFY4/Decoder2/gt/"}]
|
| 39 |
+
cache: none
|
| 40 |
+
split_key: test
|
| 41 |
+
wrapper:
|
| 42 |
+
name: val_multi_task
|
| 43 |
+
args:
|
| 44 |
+
inp_size: 1024
|
| 45 |
+
# batch_size: 2
|
| 46 |
+
batch_size: 1
|
| 47 |
+
|
| 48 |
+
test_dataset:
|
| 49 |
+
dataset:
|
| 50 |
+
name: paired-image-folders
|
| 51 |
+
args:
|
| 52 |
+
|
| 53 |
+
# root_path_1: ./SAM_DATA_UNIFY3/ISAID/split_image
|
| 54 |
+
# root_path_1: ./SAM_DATA_UNIFY3/GANFEN/split_image
|
| 55 |
+
# root_path_1: ./SAM_DATA_UNIFY3/SAR2020/split_image_ov500
|
| 56 |
+
# root_path_1: ./SAM_DATA_UNIFY3/ISAID/split_image
|
| 57 |
+
# root_path_1: ./SAM_DATA_UNIFY4/SAR2020/split_image_ov500
|
| 58 |
+
# root_path_1: ./SAM_DATA_UNIFY4/GAOFEN/split_image
|
| 59 |
+
# root_path_1: ./SAM_DATA_UNIFY4/Vaihingen/image1
|
| 60 |
+
# root_path_1: ./SAM_DATA_UNIFY4/SAR2020/split_image_ov500
|
| 61 |
+
# root_path_1: ./SAM_DATA_UNIFY4/Potsdam/image1
|
| 62 |
+
# root_path_1: ./SAM_DATA_UNIFY4/whu-opt-sar/image_sar
|
| 63 |
+
root_path_1: /workspace/AIService/FoundationModel/sam_adapter_01/TwoDecoder_data/Prompt_GUOLV_Data/prompt_all1/image
|
| 64 |
+
|
| 65 |
+
# root_path_2: ./SAM_DATA_UNIFY3/ISAID/split_gt
|
| 66 |
+
# root_path_2: ./SAM_DATA_UNIFY3/GANFEN/gt_decoder1
|
| 67 |
+
# root_path_2: ./SAM_DATA_UNIFY3/GANFEN/gt_decoder2
|
| 68 |
+
# root_path_2: ./SAM_DATA_UNIFY3/SAR2020/gt_decoder2
|
| 69 |
+
# root_path_2: ./SAM_DATA_UNIFY3/ISAID/split_gt
|
| 70 |
+
# root_path_2: ./SAM_DATA_UNIFY4/SAR2020/gt_decoder2
|
| 71 |
+
# root_path_2: ./SAM_DATA_UNIFY4/GAOFEN/gt_decoder1_update
|
| 72 |
+
# root_path_2: ./SAM_DATA_UNIFY4/Vaihingen/gt2
|
| 73 |
+
# root_path_2: ./SAM_DATA_UNIFY4/Potsdam/gt1
|
| 74 |
+
# root_path_2: ./SAM_DATA_UNIFY4/SAR2020/gt_decoder2
|
| 75 |
+
root_path_2: /workspace/AIService/FoundationModel/sam_adapter_01/TwoDecoder_data/Prompt_GUOLV_Data/prompt_all1/gt
|
| 76 |
+
# root_path_2: ./SAM_DATA_UNIFY4/whu-opt-sar/gt_sar
|
| 77 |
+
cache: none
|
| 78 |
+
split_key: test
|
| 79 |
+
wrapper:
|
| 80 |
+
name: val
|
| 81 |
+
args:
|
| 82 |
+
# inp_size: 1024
|
| 83 |
+
inp_size: 1024
|
| 84 |
+
batch_size: 1
|
| 85 |
+
|
| 86 |
+
#eval_type: cod
|
| 87 |
+
eval_type: f1
|
| 88 |
+
#sam_checkpoint: ./pretrained/sam_vit_l_0b3195.pth
|
| 89 |
+
sam_checkpoint: sam_vit_h_4b8939.pth
|
| 90 |
+
data_norm:
|
| 91 |
+
inp:
|
| 92 |
+
sub:
|
| 93 |
+
- 0.5
|
| 94 |
+
div:
|
| 95 |
+
- 0.5
|
| 96 |
+
gt:
|
| 97 |
+
sub:
|
| 98 |
+
- 0.5
|
| 99 |
+
div:
|
| 100 |
+
- 0.5
|
| 101 |
+
gt_rgb:
|
| 102 |
+
sub:
|
| 103 |
+
- 0.5
|
| 104 |
+
div:
|
| 105 |
+
- 0.5
|
| 106 |
+
model:
|
| 107 |
+
name: sam
|
| 108 |
+
args:
|
| 109 |
+
inp_size: 1024
|
| 110 |
+
# loss: iou
|
| 111 |
+
loss: cr
|
| 112 |
+
encoder_mode:
|
| 113 |
+
name: sam
|
| 114 |
+
img_size: 1024
|
| 115 |
+
mlp_ratio: 4
|
| 116 |
+
patch_size: 16
|
| 117 |
+
qkv_bias: true
|
| 118 |
+
use_rel_pos: true
|
| 119 |
+
window_size: 14
|
| 120 |
+
out_chans: 256
|
| 121 |
+
scale_factor: 32
|
| 122 |
+
input_type: fft
|
| 123 |
+
freq_nums: 0.25
|
| 124 |
+
prompt_type: highpass
|
| 125 |
+
prompt_embed_dim: 256
|
| 126 |
+
tuning_stage: 1234
|
| 127 |
+
handcrafted_tune: true
|
| 128 |
+
embedding_tune: true
|
| 129 |
+
adaptor: adaptor
|
| 130 |
+
embed_dim: 1280
|
| 131 |
+
depth: 32
|
| 132 |
+
num_heads: 16
|
| 133 |
+
global_attn_indexes:
|
| 134 |
+
- 7
|
| 135 |
+
- 15
|
| 136 |
+
- 23
|
| 137 |
+
- 31
|
| 138 |
+
optimizer:
|
| 139 |
+
name: adamw
|
| 140 |
+
args:
|
| 141 |
+
# lr: 0.0002
|
| 142 |
+
# lr: 0.00002
|
| 143 |
+
lr: 0.00008
|
| 144 |
+
lr_min: 1.0e-8
|
| 145 |
+
#epoch_max: 20
|
| 146 |
+
epoch_max: 100
|
| 147 |
+
|
| 148 |
+
multi_step_lr:
|
| 149 |
+
milestones:
|
| 150 |
+
- 1
|
| 151 |
+
gamma: 0.1
|
| 152 |
+
epoch_val: 100
|
| 153 |
+
epoch_save: 1
|
| 154 |
+
|
| 155 |
+
#resume: 60
|
| 156 |
+
#start_epoch: 60
|
configs/multi_mo_multi_task_sar_prompt.yaml
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
train_dataset:
|
| 2 |
+
dataset:
|
| 3 |
+
name: paired-image-folders
|
| 4 |
+
args:
|
| 5 |
+
# root_path_1: ./ISAID/train/trainprompt/sub_images
|
| 6 |
+
# root_path_1: ./ISAID/train/trainprompt/images
|
| 7 |
+
root_path_1: ./SAR_prompt/image
|
| 8 |
+
# root_path_1: ./SAM_DATA_UNIFY2/OVERALL/split_image
|
| 9 |
+
# root_path_1: ./SAM_DATA_UNIFY2/ISAID/split_image
|
| 10 |
+
# root_path_1: [{'ISAID': './SAM_DATA_UNIFY2/ISAID/split_image', 'WHU': './SAM_DATA_UNIFY2/WHU-OPT/split_images'}]
|
| 11 |
+
# root_path_1: [{'Decoder1': "/workspace/SAM_DATA_UNIFY3/Decoder1/split_image/", 'Decoder2': "/workspace/SAM_DATA_UNIFY3/Decoder2/split_image/"}]
|
| 12 |
+
# root_path_1: [{'Decoder1': "/workspace/SAM_DATA_UNIFY4/Potsdam/image1/", 'Decoder2': "/workspace/SAM_DATA_UNIFY4/Decoder2/image/"}]
|
| 13 |
+
# root_path_2: ./ISAID/train/trainprompt/sub_gt
|
| 14 |
+
root_path_2: ./SAR_prompt/gt
|
| 15 |
+
# root_path_2: ./SAM_DATA_UNIFY2/OVERALL/split_gt
|
| 16 |
+
# root_path_2: ./SAM_DATA_UNIFY2/ISAID/split_gt
|
| 17 |
+
# root_path_2: [{'ISAID': './SAM_DATA_UNIFY2/ISAID/split_gt', 'WHU': './SAM_DATA_UNIFY2/WHU-OPT/split_gt'}]
|
| 18 |
+
# root_path_2: [{'Decoder1': "/workspace/SAM_DATA_UNIFY3/Decoder1/split_gt/", 'Decoder2': "/workspace/SAM_DATA_UNIFY3/Decoder2/split_gt/"}]
|
| 19 |
+
# root_path_2: [{'Decoder1': "/workspace/SAM_DATA_UNIFY4/Potsdam/gt1/", 'Decoder2': "/workspace/SAM_DATA_UNIFY4/Decoder2/gt/"}]
|
| 20 |
+
cache: none
|
| 21 |
+
split_key: train
|
| 22 |
+
wrapper:
|
| 23 |
+
name: train
|
| 24 |
+
args:
|
| 25 |
+
inp_size: 1024
|
| 26 |
+
augment: false
|
| 27 |
+
# batch_size: 2
|
| 28 |
+
batch_size: 1
|
| 29 |
+
|
| 30 |
+
val_dataset:
|
| 31 |
+
dataset:
|
| 32 |
+
name: paired-image-folders
|
| 33 |
+
args:
|
| 34 |
+
# root_path_1: ./ISAID/train/trainprompt/images
|
| 35 |
+
root_path_1: ./SAR_prompt/image
|
| 36 |
+
# root_path_1: [{'ISAID': './SAM_DATA_UNIFY2/ISAID/split_image', 'WHU': './SAM_DATA_UNIFY2/WHU-OPT/split_images'}]
|
| 37 |
+
# root_path_1: [{'Decoder1': "/workspace/SAM_DATA_UNIFY3/Decoder1/split_image/", 'Decoder2': "/workspace/SAM_DATA_UNIFY3/Decoder2/split_image/"}]
|
| 38 |
+
# root_path_1: [{'Decoder1': "/workspace/SAM_DATA_UNIFY4/Potsdam/image1/", 'Decoder2': "/workspace/SAM_DATA_UNIFY4/Decoder2/image/"}]
|
| 39 |
+
# root_path_2: ./ISAID/train/trainprompt/gt
|
| 40 |
+
root_path_2: ./SAR_prompt/gt
|
| 41 |
+
# root_path_2: [{'ISAID': './SAM_DATA_UNIFY2/ISAID/split_gt', 'WHU': './SAM_DATA_UNIFY2/WHU-OPT/split_gt'}]
|
| 42 |
+
# root_path_2: [{'Decoder1': "/workspace/SAM_DATA_UNIFY3/Decoder1/split_gt/", 'Decoder2': "/workspace/SAM_DATA_UNIFY3/Decoder2/split_gt/"}]
|
| 43 |
+
# root_path_2: [{'Decoder1': "/workspace/SAM_DATA_UNIFY4/Potsdam/gt1/", 'Decoder2': "/workspace/SAM_DATA_UNIFY4/Decoder2/gt/"}]
|
| 44 |
+
cache: none
|
| 45 |
+
split_key: test
|
| 46 |
+
wrapper:
|
| 47 |
+
name: val
|
| 48 |
+
args:
|
| 49 |
+
inp_size: 1024
|
| 50 |
+
# batch_size: 2
|
| 51 |
+
batch_size: 1
|
| 52 |
+
|
| 53 |
+
test_dataset:
|
| 54 |
+
dataset:
|
| 55 |
+
name: paired-image-folders
|
| 56 |
+
args:
|
| 57 |
+
# root_path_1: ./ISAID/train/trainprompt/images
|
| 58 |
+
# root_path_1: ./ISAID/train/trainprompt/sub_images
|
| 59 |
+
root_path_1: ./save/SAR_prompt/image
|
| 60 |
+
# root_path_1: ./SAM_DATA_UNIFY/Vaihingen/split_image
|
| 61 |
+
# root_path_1: ./SAM_DATA_UNIFY/SAR2020/split_image_ov500
|
| 62 |
+
# root_path_1: ./SAM_DATA_UNIFY/POLARIS_SAR/split_image
|
| 63 |
+
# root_path_1: ./SAM_DATA_UNIFY/Overall_Update/split_image
|
| 64 |
+
# root_path_1: ./SAM_DATA_UNIFY2/ISAID/split_image
|
| 65 |
+
# root_path_1: ./SAM_DATA_UNIFY2/whu-sar-test/split_image
|
| 66 |
+
# root_path_1: ./SAM_DATA_UNIFY2/WHU-SAR/split_image
|
| 67 |
+
# root_path_1: ./SAM_DATA_UNIFY2/WHU_ALL/split_image
|
| 68 |
+
# root_path_1: ./SAM_DATA_UNIFY3/WHU_SAR/split_image
|
| 69 |
+
# root_path_1: ./SAM_DATA_UNIFY3/WHU_OPT/split_image
|
| 70 |
+
# root_path_1: ./SAM_DATA_UNIFY3/ISAID/split_image
|
| 71 |
+
# root_path_1: ./SAM_DATA_UNIFY3/GANFEN/split_image
|
| 72 |
+
# root_path_1: ./SAM_DATA_UNIFY4/SAR2020/split_image_ov500
|
| 73 |
+
|
| 74 |
+
# root_path_2: ./ISAID/train/trainprompt/gt
|
| 75 |
+
# root_path_2: ./ISAID/train/trainprompt/sub_gt
|
| 76 |
+
root_path_2: ./save/SAR_prompt/gt
|
| 77 |
+
# root_path_2: ./SAM_DATA_UNIFY/Vaihingen/split_gt
|
| 78 |
+
# root_path_2: ./SAM_DATA_UNIFY2/ISAID/split_gt
|
| 79 |
+
# root_path_2: ./SAM_DATA_UNIFY/POLARIS_SAR/split_gt
|
| 80 |
+
# root_path_2: ./SAM_DATA_UNIFY/Overall_Update/split_gt
|
| 81 |
+
# root_path_2: ./SAM_DATA_UNIFY2/ISAID/split_gt
|
| 82 |
+
# root_path_2: ./SAM_DATA_UNIFY2/whu-sar-test/split_gt
|
| 83 |
+
# root_path_2: ./SAM_DATA_UNIFY2/WHU-SAR/split_gt
|
| 84 |
+
# root_path_2: ./SAM_DATA_UNIFY2/WHU_ALL/split_gt
|
| 85 |
+
# root_path_2: ./SAM_DATA_UNIFY3/WHU_SAR/split_gt
|
| 86 |
+
# root_path_2: ./SAM_DATA_UNIFY3/WHU_OPT/split_gt
|
| 87 |
+
# root_path_2: ./SAM_DATA_UNIFY3/ISAID/split_gt
|
| 88 |
+
# root_path_2: ./SAM_DATA_UNIFY3/GANFEN/gt_decoder1
|
| 89 |
+
# root_path_2: ./SAM_DATA_UNIFY3/GANFEN/gt_decoder2
|
| 90 |
+
# root_path_2: ./SAM_DATA_UNIFY4/SAR2020/gt_decoder2
|
| 91 |
+
cache: none
|
| 92 |
+
split_key: test
|
| 93 |
+
wrapper:
|
| 94 |
+
name: val
|
| 95 |
+
args:
|
| 96 |
+
# inp_size: 1024
|
| 97 |
+
inp_size: 1024
|
| 98 |
+
batch_size: 1
|
| 99 |
+
|
| 100 |
+
#eval_type: cod
|
| 101 |
+
eval_type: f1
|
| 102 |
+
#sam_checkpoint: ./pretrained/sam_vit_l_0b3195.pth
|
| 103 |
+
#sam_checkpoint: sam_vit_h_4b8939.pth
|
| 104 |
+
sam_checkpoint: ./save/_multi_mo_multi_task_0626/model_epoch_last.pth
|
| 105 |
+
#sam_checkpoint: ./save/_multi_mo_multi_task_0626/model_epoch_last.pth
|
| 106 |
+
data_norm:
|
| 107 |
+
inp:
|
| 108 |
+
sub:
|
| 109 |
+
- 0.5
|
| 110 |
+
div:
|
| 111 |
+
- 0.5
|
| 112 |
+
gt:
|
| 113 |
+
sub:
|
| 114 |
+
- 0.5
|
| 115 |
+
div:
|
| 116 |
+
- 0.5
|
| 117 |
+
gt_rgb:
|
| 118 |
+
sub:
|
| 119 |
+
- 0.5
|
| 120 |
+
div:
|
| 121 |
+
- 0.5
|
| 122 |
+
model:
|
| 123 |
+
name: sam
|
| 124 |
+
args:
|
| 125 |
+
inp_size: 1024
|
| 126 |
+
# loss: iou
|
| 127 |
+
loss: cr
|
| 128 |
+
encoder_mode:
|
| 129 |
+
name: sam
|
| 130 |
+
img_size: 1024
|
| 131 |
+
mlp_ratio: 4
|
| 132 |
+
patch_size: 16
|
| 133 |
+
qkv_bias: true
|
| 134 |
+
use_rel_pos: true
|
| 135 |
+
window_size: 14
|
| 136 |
+
out_chans: 256
|
| 137 |
+
scale_factor: 32
|
| 138 |
+
input_type: fft
|
| 139 |
+
freq_nums: 0.25
|
| 140 |
+
prompt_type: highpass
|
| 141 |
+
prompt_embed_dim: 256
|
| 142 |
+
tuning_stage: 1234
|
| 143 |
+
handcrafted_tune: true
|
| 144 |
+
embedding_tune: true
|
| 145 |
+
adaptor: adaptor
|
| 146 |
+
embed_dim: 1280
|
| 147 |
+
depth: 32
|
| 148 |
+
num_heads: 16
|
| 149 |
+
global_attn_indexes:
|
| 150 |
+
- 7
|
| 151 |
+
- 15
|
| 152 |
+
- 23
|
| 153 |
+
- 31
|
| 154 |
+
optimizer:
|
| 155 |
+
name: adamw
|
| 156 |
+
args:
|
| 157 |
+
# lr: 0.0002
|
| 158 |
+
# lr: 0.00002
|
| 159 |
+
# lr: 0.00004
|
| 160 |
+
# lr: 0.00008
|
| 161 |
+
lr: 0.0002
|
| 162 |
+
lr_min: 1.0e-8
|
| 163 |
+
#epoch_max: 20
|
| 164 |
+
epoch_max: 200
|
| 165 |
+
|
| 166 |
+
multi_step_lr:
|
| 167 |
+
milestones:
|
| 168 |
+
- 1
|
| 169 |
+
gamma: 0.1
|
| 170 |
+
epoch_val: 200
|
| 171 |
+
epoch_save: 1
|
| 172 |
+
|
| 173 |
+
#resume: 60
|
| 174 |
+
#start_epoch: 60
|
datasets/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .datasets import register, make
|
| 2 |
+
from . import image_folder
|
| 3 |
+
from . import wrappers
|
datasets/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (261 Bytes). View file
|
|
|
datasets/__pycache__/__init__.cpython-37.pyc
ADDED
|
Binary file (255 Bytes). View file
|
|
|
datasets/__pycache__/datasets.cpython-310.pyc
ADDED
|
Binary file (683 Bytes). View file
|
|
|
datasets/__pycache__/datasets.cpython-37.pyc
ADDED
|
Binary file (656 Bytes). View file
|
|
|
datasets/__pycache__/image_folder.cpython-310.pyc
ADDED
|
Binary file (10.3 kB). View file
|
|
|
datasets/__pycache__/image_folder.cpython-37.pyc
ADDED
|
Binary file (11.2 kB). View file
|
|
|
datasets/__pycache__/wrappers.cpython-310.pyc
ADDED
|
Binary file (4.36 kB). View file
|
|
|
datasets/__pycache__/wrappers.cpython-37.pyc
ADDED
|
Binary file (5.45 kB). View file
|
|
|
datasets/data_loader_multi_tasks.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
def build_loader_simmim(config):
|
| 3 |
+
############ single model #####################
|
| 4 |
+
# transform = SimMIMTransform(config)
|
| 5 |
+
# dataset = ImageFolder(config.DATA.DATA_PATH, transform)
|
| 6 |
+
# sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True)
|
| 7 |
+
# dataloader = DataLoader(dataset, config.DATA.BATCH_SIZE, sampler=sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, drop_last=True, collate_fn=collate_fn)
|
| 8 |
+
|
| 9 |
+
############## multi model ####################
|
| 10 |
+
datasets = []
|
| 11 |
+
### 数据增强 ######
|
| 12 |
+
model_paths = config.DATA.TYPE_PATH[0]
|
| 13 |
+
for i in model_paths.keys():
|
| 14 |
+
a = config.DATA.SCALE[0][i].split(',')
|
| 15 |
+
scale_model = (float(a[0].split('(')[1]) ,float(a[1].split(')')[0]))
|
| 16 |
+
transform = SimMIMTransform(config, config.DATA.NORM[0][i], scale_model)
|
| 17 |
+
dataset = CachedImageFolder(model_paths[i], transform = transform, model = i)
|
| 18 |
+
datasets.append(dataset)
|
| 19 |
+
multi_task_train_dataset = MultiTaskDataset(datasets)
|
| 20 |
+
print(len(datasets))
|
| 21 |
+
multi_task_batch_sampler = DistrubutedMultiTaskBatchSampler(datasets, batch_size=config.DATA.BATCH_SIZE, num_replicas=dist.get_world_size(), rank=dist.get_rank(), mix_opt=0, extra_task_ratio=0, drop_last=True ,shuffle =True)
|
| 22 |
+
dataloader = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, collate_fn=collate_fn)
|
| 23 |
+
# dataloader = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler, pin_memory=True, collate_fn=collate_fn)
|
| 24 |
+
print(len(dataloader))
|
| 25 |
+
|
| 26 |
+
return dataloader
|
datasets/data_simmim_pt.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# SimMIM
|
| 3 |
+
# Copyright (c) 2021 Microsoft
|
| 4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
+
# Written by Zhenda Xie
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
|
| 8 |
+
import math
|
| 9 |
+
import random
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.distributed as dist
|
| 14 |
+
import torchvision.transforms as T
|
| 15 |
+
from torch.utils.data import DataLoader, DistributedSampler
|
| 16 |
+
from torch.utils.data._utils.collate import default_collate
|
| 17 |
+
from torchvision.datasets import ImageFolder
|
| 18 |
+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 19 |
+
from torch.utils.data import Dataset, BatchSampler
|
| 20 |
+
from torchvision.io import read_image
|
| 21 |
+
from .cached_image_folder import CachedImageFolder
|
| 22 |
+
|
| 23 |
+
class MultiTaskDataset(Dataset):
|
| 24 |
+
"""
|
| 25 |
+
useage example:
|
| 26 |
+
train_datasets = [SemData_Single(), SemData_Single()]
|
| 27 |
+
multi_task_train_dataset = MultiTaskDataset(train_datasets)
|
| 28 |
+
multi_task_batch_sampler = MultiTaskBatchSampler(train_datasets, batch_size=4, mix_opt=0, extra_task_ratio=0, drop_last=True)
|
| 29 |
+
multi_task_train_data = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler)
|
| 30 |
+
for i, (task_id, input, target) in enumerate(multi_task_train_data):
|
| 31 |
+
pre = model(input)
|
| 32 |
+
"""
|
| 33 |
+
def __init__(self, datasets):
|
| 34 |
+
self._datasets = datasets
|
| 35 |
+
task_id_2_data_set_dic = {}
|
| 36 |
+
for i, dataset in enumerate(datasets):
|
| 37 |
+
task_id = i
|
| 38 |
+
assert task_id not in task_id_2_data_set_dic, "Duplicate task_id %s" % task_id
|
| 39 |
+
task_id_2_data_set_dic[task_id] = dataset
|
| 40 |
+
|
| 41 |
+
self._task_id_2_data_set_dic = task_id_2_data_set_dic
|
| 42 |
+
|
| 43 |
+
def __len__(self):
|
| 44 |
+
return sum(len(dataset) for dataset in self._datasets)
|
| 45 |
+
|
| 46 |
+
def __getitem__(self, idx):
|
| 47 |
+
task_id, sample_id = idx
|
| 48 |
+
return self._task_id_2_data_set_dic[task_id][sample_id]
|
| 49 |
+
|
| 50 |
+
class DistrubutedMultiTaskBatchSampler(BatchSampler):
|
| 51 |
+
"""
|
| 52 |
+
datasets: class the class of the Dataset
|
| 53 |
+
batch_size: int
|
| 54 |
+
mix_opt: int mix_opt ==0 shuffle all_task; mix_opt ==1 shuffle extra_task
|
| 55 |
+
extra_task_ratio(float, optional): the rate between task one and extra task
|
| 56 |
+
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
|
| 57 |
+
if the dataset size is not divisible by the batch size. If ``False`` and
|
| 58 |
+
the size of dataset is not divisible by the batch size, then the last batch
|
| 59 |
+
will be smaller. (default: ``True``)
|
| 60 |
+
"""
|
| 61 |
+
def __init__(self, datasets, batch_size, num_replicas, rank, mix_opt=0, extra_task_ratio=0, drop_last=True,shuffle = True):
|
| 62 |
+
if num_replicas is None:
|
| 63 |
+
if not dist.is_available():
|
| 64 |
+
raise RuntimeError("Requires distributed package to be available")
|
| 65 |
+
num_replicas = dist.get_world_size()
|
| 66 |
+
if rank is None:
|
| 67 |
+
if not dist.is_available():
|
| 68 |
+
raise RuntimeError("Requires distributed package to be available")
|
| 69 |
+
rank = dist.get_rank()
|
| 70 |
+
if rank >= num_replicas or rank < 0:
|
| 71 |
+
raise ValueError(
|
| 72 |
+
"Invalid rank {}, rank should be in the interval"
|
| 73 |
+
" [0, {}]".format(rank, num_replicas - 1))
|
| 74 |
+
self.num_replicas = num_replicas
|
| 75 |
+
self.rank = rank
|
| 76 |
+
self.epoch = 0
|
| 77 |
+
assert mix_opt in [0, 1], 'mix_opt must equal 0 or 1'
|
| 78 |
+
assert extra_task_ratio >= 0, 'extra_task_ratio must greater than 0'
|
| 79 |
+
self._datasets = datasets
|
| 80 |
+
self._batch_size = batch_size
|
| 81 |
+
self._mix_opt = mix_opt
|
| 82 |
+
self._extra_task_ratio = extra_task_ratio
|
| 83 |
+
self._drop_last = drop_last
|
| 84 |
+
train_data_list = []
|
| 85 |
+
self.shuffle = shuffle
|
| 86 |
+
for dataset in datasets:
|
| 87 |
+
print(len(dataset))
|
| 88 |
+
train_data_list.append(self._get_index_batches(len(dataset), batch_size, self._drop_last))
|
| 89 |
+
|
| 90 |
+
######### 一个列表里存n个dataset的数据,数据也以列表形式存在,一个dataset的列表里面把数据划分成了不同的batch的index
|
| 91 |
+
self._train_data_list = train_data_list
|
| 92 |
+
self.total_len = sum(len(train_data) for train_data in self._train_data_list)
|
| 93 |
+
|
| 94 |
+
######### DDP ######################
|
| 95 |
+
if self._drop_last and self.total_len % self.num_replicas != 0: # type: ignore[arg-type]
|
| 96 |
+
self.num_samples = math.ceil(
|
| 97 |
+
(self.total_len - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
|
| 98 |
+
)
|
| 99 |
+
else:
|
| 100 |
+
self.num_samples = math.ceil(self.total_len / self.num_replicas) # type: ignore[arg-type]
|
| 101 |
+
|
| 102 |
+
self.total_size = self.num_samples * self.num_replicas
|
| 103 |
+
self.epoch = 0
|
| 104 |
+
self.seed = 0
|
| 105 |
+
|
| 106 |
+
def set_epoch(self, epoch):
|
| 107 |
+
self.epoch = epoch
|
| 108 |
+
|
| 109 |
+
@staticmethod
|
| 110 |
+
def _get_index_batches(dataset_len, batch_size, drop_last):
|
| 111 |
+
# index_batches = [list(range(i, min(i+batch_size, dataset_len))) for i in range(0, dataset_len, batch_size)]
|
| 112 |
+
index = list(range(dataset_len))
|
| 113 |
+
if drop_last and dataset_len % batch_size:
|
| 114 |
+
del index[-(dataset_len % batch_size):]
|
| 115 |
+
index_batches = [index[i:i+batch_size] for i in range(0, len(index), batch_size)]
|
| 116 |
+
return index_batches
|
| 117 |
+
|
| 118 |
+
def __len__(self):
|
| 119 |
+
# return sum(len(train_data) for train_data in self._train_data_list)
|
| 120 |
+
return self.num_samples
|
| 121 |
+
|
| 122 |
+
def __iter__(self):
|
| 123 |
+
all_iters = [iter(item) for item in self._train_data_list]
|
| 124 |
+
all_indices = self._gen_task_indices(self._train_data_list, self._mix_opt, self._extra_task_ratio)
|
| 125 |
+
|
| 126 |
+
######### DDP ######################
|
| 127 |
+
random.shuffle(all_indices)
|
| 128 |
+
all_indices = all_indices[self.rank:self.total_size:self.num_replicas]
|
| 129 |
+
assert len(all_indices) == self.num_samples
|
| 130 |
+
|
| 131 |
+
for local_task_idx in all_indices:
|
| 132 |
+
# task_id = self._datasets[local_task_idx].get_task_id()
|
| 133 |
+
batch = next(all_iters[local_task_idx])
|
| 134 |
+
# batch = batch[self.rank:len(batch):self.num_replicas]
|
| 135 |
+
# print(local_task_idx)
|
| 136 |
+
yield [(local_task_idx, sample_id) for sample_id in batch]
|
| 137 |
+
# yield iter(batch)
|
| 138 |
+
|
| 139 |
+
@staticmethod
|
| 140 |
+
def _gen_task_indices(train_data_list, mix_opt, extra_task_ratio):
|
| 141 |
+
|
| 142 |
+
########## accoding to the number of models ###########
|
| 143 |
+
all_indices = []
|
| 144 |
+
for i in range(len(train_data_list)):
|
| 145 |
+
all_indices += [i] * len(train_data_list[i])
|
| 146 |
+
# print(all_indices)
|
| 147 |
+
return all_indices
|
| 148 |
+
# def set_epoch(self, epoch)
|
| 149 |
+
# r"""
|
| 150 |
+
# Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
|
| 151 |
+
# use a different random ordering for each epoch. Otherwise, the next iteration of this
|
| 152 |
+
# sampler will yield the same ordering.
|
| 153 |
+
|
| 154 |
+
# Args:
|
| 155 |
+
# epoch (int): Epoch number.
|
| 156 |
+
# """
|
| 157 |
+
# self.epoch = epoch
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class MaskGenerator:
|
| 161 |
+
def __init__(self, input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6):
|
| 162 |
+
self.input_size = input_size
|
| 163 |
+
self.mask_patch_size = mask_patch_size
|
| 164 |
+
self.model_patch_size = model_patch_size
|
| 165 |
+
self.mask_ratio = mask_ratio
|
| 166 |
+
|
| 167 |
+
assert self.input_size % self.mask_patch_size == 0
|
| 168 |
+
assert self.mask_patch_size % self.model_patch_size == 0
|
| 169 |
+
|
| 170 |
+
self.rand_size = self.input_size // self.mask_patch_size
|
| 171 |
+
self.scale = self.mask_patch_size // self.model_patch_size
|
| 172 |
+
|
| 173 |
+
self.token_count = self.rand_size ** 2
|
| 174 |
+
self.mask_count = int(np.ceil(self.token_count * self.mask_ratio))
|
| 175 |
+
|
| 176 |
+
def __call__(self):
|
| 177 |
+
mask_idx = np.random.permutation(self.token_count)[:self.mask_count]
|
| 178 |
+
mask = np.zeros(self.token_count, dtype=int)
|
| 179 |
+
mask[mask_idx] = 1
|
| 180 |
+
|
| 181 |
+
mask = mask.reshape((self.rand_size, self.rand_size))
|
| 182 |
+
mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1)
|
| 183 |
+
|
| 184 |
+
return mask
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class ZeroOneNormalize(object):
|
| 188 |
+
def __call__(self, img):
|
| 189 |
+
return img.float().div(255)
|
| 190 |
+
|
| 191 |
+
class SimMIMTransform:
|
| 192 |
+
def __init__(self, config, NORM, SCALE):
|
| 193 |
+
self.transform_img = T.Compose([
|
| 194 |
+
# T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
| 195 |
+
# T.RandomResizedCrop(config.DATA.IMG_SIZE, scale=(0.67, 1.), ratio=(3. / 4., 4. / 3.)),
|
| 196 |
+
# T.RandomHorizontalFlip(),
|
| 197 |
+
# T.ToTensor(),
|
| 198 |
+
# T.Normalize(mean=torch.tensor(IMAGENET_DEFAULT_MEAN),std=torch.tensor(IMAGENET_DEFAULT_STD)),
|
| 199 |
+
|
| 200 |
+
T.RandomResizedCrop(config.DATA.IMG_SIZE, scale=SCALE, ratio=(3. / 4., 4. / 3.)),
|
| 201 |
+
T.RandomHorizontalFlip(),
|
| 202 |
+
ZeroOneNormalize(),
|
| 203 |
+
T.Normalize(mean=torch.tensor(NORM[0]),std=torch.tensor(NORM[1])),
|
| 204 |
+
])
|
| 205 |
+
|
| 206 |
+
if config.MODEL.TYPE in ['swin', 'swinv2']:
|
| 207 |
+
model_patch_size=config.MODEL.SWIN.PATCH_SIZE
|
| 208 |
+
else:
|
| 209 |
+
raise NotImplementedError
|
| 210 |
+
|
| 211 |
+
self.mask_generator = MaskGenerator(
|
| 212 |
+
input_size=config.DATA.IMG_SIZE,
|
| 213 |
+
mask_patch_size=config.DATA.MASK_PATCH_SIZE,
|
| 214 |
+
model_patch_size=model_patch_size,
|
| 215 |
+
mask_ratio=config.DATA.MASK_RATIO,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
def __call__(self, img):
|
| 219 |
+
img = self.transform_img(img)
|
| 220 |
+
mask = self.mask_generator()
|
| 221 |
+
|
| 222 |
+
return img, mask
|
| 223 |
+
|
| 224 |
+
def collate_fn(batch):
|
| 225 |
+
# print(len(batch))
|
| 226 |
+
# print('*'*10)
|
| 227 |
+
# print(batch[0][0])
|
| 228 |
+
# print('#'*10)
|
| 229 |
+
# print(batch[0][1])
|
| 230 |
+
# batch = list(filter(lambda x: x[0][0] is not None, batch))
|
| 231 |
+
# if len(batch) == 0: return torch.Tensor()
|
| 232 |
+
|
| 233 |
+
if not isinstance(batch[0][0], tuple):
|
| 234 |
+
return default_collate(batch)
|
| 235 |
+
else:
|
| 236 |
+
batch_num = len(batch)
|
| 237 |
+
ret = []
|
| 238 |
+
for item_idx in range(len(batch[0][0])):
|
| 239 |
+
if batch[0][0][item_idx] is None:
|
| 240 |
+
ret.append(None)
|
| 241 |
+
else:
|
| 242 |
+
ret.append(default_collate([batch[i][0][item_idx] for i in range(batch_num)]))
|
| 243 |
+
ret.append(default_collate([batch[i][1] for i in range(batch_num)]))
|
| 244 |
+
return ret
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def build_loader_simmim(config):
|
| 248 |
+
############ single model #####################
|
| 249 |
+
# transform = SimMIMTransform(config)
|
| 250 |
+
# dataset = ImageFolder(config.DATA.DATA_PATH, transform)
|
| 251 |
+
# sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True)
|
| 252 |
+
# dataloader = DataLoader(dataset, config.DATA.BATCH_SIZE, sampler=sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, drop_last=True, collate_fn=collate_fn)
|
| 253 |
+
|
| 254 |
+
############## multi model ####################
|
| 255 |
+
datasets = []
|
| 256 |
+
### 数据增强 ######
|
| 257 |
+
model_paths = config.DATA.TYPE_PATH[0]
|
| 258 |
+
for i in model_paths.keys():
|
| 259 |
+
a = config.DATA.SCALE[0][i].split(',')
|
| 260 |
+
scale_model = (float(a[0].split('(')[1]),float(a[1].split(')')[0]))
|
| 261 |
+
transform = SimMIMTransform(config, config.DATA.NORM[0][i], scale_model)
|
| 262 |
+
dataset = CachedImageFolder(model_paths[i], transform = transform, model = i)
|
| 263 |
+
datasets.append(dataset)
|
| 264 |
+
multi_task_train_dataset = MultiTaskDataset(datasets)
|
| 265 |
+
print(len(datasets))
|
| 266 |
+
multi_task_batch_sampler = DistrubutedMultiTaskBatchSampler(datasets, batch_size=config.DATA.BATCH_SIZE, num_replicas=dist.get_world_size(), rank=dist.get_rank(), mix_opt=0, extra_task_ratio=0, drop_last=True,shuffle =True)
|
| 267 |
+
dataloader = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, collate_fn=collate_fn)
|
| 268 |
+
# dataloader = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler, pin_memory=True, collate_fn=collate_fn)
|
| 269 |
+
print(len(dataloader))
|
| 270 |
+
|
| 271 |
+
return dataloader
|
datasets/datasets.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
datasets = {}
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def register(name):
|
| 8 |
+
def decorator(cls):
|
| 9 |
+
datasets[name] = cls
|
| 10 |
+
return cls
|
| 11 |
+
return decorator
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def make(dataset_spec, args=None):
|
| 15 |
+
if args is not None:
|
| 16 |
+
dataset_args = copy.deepcopy(dataset_spec['args'])
|
| 17 |
+
dataset_args.update(args)
|
| 18 |
+
else:
|
| 19 |
+
dataset_args = dataset_spec['args']
|
| 20 |
+
dataset = datasets[dataset_spec['name']](**dataset_args)
|
| 21 |
+
return dataset
|
datasets/image_folder.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from PIL import Image
|
| 4 |
+
|
| 5 |
+
import pickle
|
| 6 |
+
import imageio
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from torch.utils.data import Dataset
|
| 10 |
+
from torchvision import transforms
|
| 11 |
+
import random
|
| 12 |
+
from datasets import register
|
| 13 |
+
|
| 14 |
+
import math
|
| 15 |
+
import torch.distributed as dist
|
| 16 |
+
from torch.utils.data import BatchSampler
|
| 17 |
+
|
| 18 |
+
from torch.utils.data._utils.collate import default_collate
|
| 19 |
+
|
| 20 |
+
@register('image-folder')
|
| 21 |
+
class ImageFolder(Dataset):
|
| 22 |
+
def __init__(self, path, split_file=None, split_key=None, first_k=None, size=None,
|
| 23 |
+
repeat=1, cache='none', mask=False):
|
| 24 |
+
self.repeat = repeat
|
| 25 |
+
self.cache = cache
|
| 26 |
+
self.path = path
|
| 27 |
+
self.Train = False
|
| 28 |
+
self.split_key = split_key
|
| 29 |
+
|
| 30 |
+
self.size = size
|
| 31 |
+
self.mask = mask
|
| 32 |
+
if self.mask:
|
| 33 |
+
self.img_transform = transforms.Compose([
|
| 34 |
+
transforms.Resize((self.size, self.size), interpolation=Image.NEAREST),
|
| 35 |
+
transforms.ToTensor(),
|
| 36 |
+
])
|
| 37 |
+
else:
|
| 38 |
+
self.img_transform = transforms.Compose([
|
| 39 |
+
transforms.Resize((self.size, self.size)),
|
| 40 |
+
transforms.ToTensor(),
|
| 41 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 42 |
+
std=[0.229, 0.224, 0.225])
|
| 43 |
+
])
|
| 44 |
+
|
| 45 |
+
if split_file is None:
|
| 46 |
+
filenames = sorted(os.listdir(path))
|
| 47 |
+
else:
|
| 48 |
+
with open(split_file, 'r') as f:
|
| 49 |
+
filenames = json.load(f)[split_key]
|
| 50 |
+
if first_k is not None:
|
| 51 |
+
filenames = filenames[:first_k]
|
| 52 |
+
|
| 53 |
+
self.files = []
|
| 54 |
+
|
| 55 |
+
for filename in filenames:
|
| 56 |
+
file = os.path.join(path, filename)
|
| 57 |
+
self.append_file(file)
|
| 58 |
+
|
| 59 |
+
def append_file(self, file):
|
| 60 |
+
if self.cache == 'none':
|
| 61 |
+
self.files.append(file)
|
| 62 |
+
elif self.cache == 'in_memory':
|
| 63 |
+
self.files.append(self.img_process(file))
|
| 64 |
+
|
| 65 |
+
def __len__(self):
|
| 66 |
+
return len(self.files) * self.repeat
|
| 67 |
+
|
| 68 |
+
def __getitem__(self, idx):
|
| 69 |
+
x = self.files[idx % len(self.files)]
|
| 70 |
+
|
| 71 |
+
if self.cache == 'none':
|
| 72 |
+
return self.img_process(x)
|
| 73 |
+
elif self.cache == 'in_memory':
|
| 74 |
+
return x
|
| 75 |
+
|
| 76 |
+
def img_process(self, file):
|
| 77 |
+
if self.mask:
|
| 78 |
+
# return Image.open(file).convert('L')
|
| 79 |
+
return file
|
| 80 |
+
else:
|
| 81 |
+
return Image.open(file).convert('RGB')
|
| 82 |
+
|
| 83 |
+
@register('paired-image-folders')
|
| 84 |
+
class PairedImageFolders(Dataset):
|
| 85 |
+
|
| 86 |
+
def __init__(self, root_path_1, root_path_2, **kwargs):
|
| 87 |
+
self.dataset_1 = ImageFolder(root_path_1, **kwargs)
|
| 88 |
+
self.dataset_2 = ImageFolder(root_path_2, **kwargs, mask=True)
|
| 89 |
+
|
| 90 |
+
def __len__(self):
|
| 91 |
+
return len(self.dataset_1)
|
| 92 |
+
|
| 93 |
+
def __getitem__(self, idx):
|
| 94 |
+
return self.dataset_1[idx], self.dataset_2[idx]
|
| 95 |
+
|
| 96 |
+
class ImageFolder_multi_task(Dataset):
|
| 97 |
+
def __init__(self, path, split_file=None, split_key=None, first_k=None, size=None,
|
| 98 |
+
repeat=1, cache='none', mask=False):
|
| 99 |
+
self.repeat = repeat
|
| 100 |
+
self.cache = cache
|
| 101 |
+
self.path = path
|
| 102 |
+
self.Train = False
|
| 103 |
+
self.split_key = split_key
|
| 104 |
+
|
| 105 |
+
self.size = size
|
| 106 |
+
self.mask = mask
|
| 107 |
+
if self.mask:
|
| 108 |
+
self.img_transform = transforms.Compose([
|
| 109 |
+
transforms.Resize((self.size, self.size), interpolation=Image.NEAREST),
|
| 110 |
+
transforms.ToTensor(),
|
| 111 |
+
])
|
| 112 |
+
else:
|
| 113 |
+
self.img_transform = transforms.Compose([
|
| 114 |
+
transforms.Resize((self.size, self.size)),
|
| 115 |
+
transforms.ToTensor(),
|
| 116 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 117 |
+
std=[0.229, 0.224, 0.225])
|
| 118 |
+
])
|
| 119 |
+
|
| 120 |
+
if split_file is None:
|
| 121 |
+
filenames = sorted(os.listdir(path))
|
| 122 |
+
else:
|
| 123 |
+
with open(split_file, 'r') as f:
|
| 124 |
+
filenames = json.load(f)[split_key]
|
| 125 |
+
if first_k is not None:
|
| 126 |
+
filenames = filenames[:first_k]
|
| 127 |
+
|
| 128 |
+
self.files = []
|
| 129 |
+
|
| 130 |
+
for filename in filenames:
|
| 131 |
+
file = os.path.join(path, filename)
|
| 132 |
+
self.append_file(file)
|
| 133 |
+
|
| 134 |
+
def append_file(self, file):
|
| 135 |
+
if self.cache == 'none':
|
| 136 |
+
self.files.append(file)
|
| 137 |
+
elif self.cache == 'in_memory':
|
| 138 |
+
self.files.append(self.img_process(file))
|
| 139 |
+
|
| 140 |
+
def __len__(self):
|
| 141 |
+
return len(self.files) * self.repeat
|
| 142 |
+
|
| 143 |
+
def __getitem__(self, idx):
|
| 144 |
+
x = self.files[idx % len(self.files)]
|
| 145 |
+
|
| 146 |
+
if self.cache == 'none':
|
| 147 |
+
return self.img_process(x)
|
| 148 |
+
elif self.cache == 'in_memory':
|
| 149 |
+
return x
|
| 150 |
+
|
| 151 |
+
def img_process(self, file):
|
| 152 |
+
if self.mask:
|
| 153 |
+
# return Image.open(file).convert('L')
|
| 154 |
+
return file
|
| 155 |
+
else:
|
| 156 |
+
return Image.open(file).convert('RGB')
|
| 157 |
+
|
| 158 |
+
@register('paired-image-folders-multi-task')
|
| 159 |
+
class PairedImageFolders_multi_task(Dataset):
|
| 160 |
+
|
| 161 |
+
def __init__(self, root_path_1, root_path_2, model=None, **kwargs):
|
| 162 |
+
|
| 163 |
+
self.dataset_1 = ImageFolder_multi_task(root_path_1, **kwargs)
|
| 164 |
+
self.dataset_2 = ImageFolder_multi_task(root_path_2, **kwargs, mask=True)
|
| 165 |
+
|
| 166 |
+
def __len__(self):
|
| 167 |
+
return len(self.dataset_1)
|
| 168 |
+
|
| 169 |
+
def __getitem__(self, idx):
|
| 170 |
+
return self.dataset_1[idx], self.dataset_2[idx]
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
# class MultiTaskDataset(Dataset):
|
| 176 |
+
# """
|
| 177 |
+
# useage example:
|
| 178 |
+
# train_datasets = [SemData_Single(), SemData_Single()]
|
| 179 |
+
# multi_task_train_dataset = MultiTaskDataset(train_datasets)
|
| 180 |
+
# multi_task_batch_sampler = MultiTaskBatchSampler(train_datasets, batch_size=4, mix_opt=0, extra_task_ratio=0, drop_last=True)
|
| 181 |
+
# multi_task_train_data = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler)
|
| 182 |
+
# for i, (task_id, input, target) in enumerate(multi_task_train_data):
|
| 183 |
+
# pre = model(input)
|
| 184 |
+
# """
|
| 185 |
+
# def __init__(self, datasets_image, datasets_gt):
|
| 186 |
+
# self._datasets = datasets_image
|
| 187 |
+
# task_id_2_image_set_dic = {}
|
| 188 |
+
# for i, dataset in enumerate(datasets_image):
|
| 189 |
+
# task_id = i
|
| 190 |
+
# assert task_id not in task_id_2_image_set_dic, "Duplicate task_id %s" % task_id
|
| 191 |
+
# task_id_2_image_set_dic[task_id] = dataset
|
| 192 |
+
# self.datasets_1 = task_id_2_image_set_dic
|
| 193 |
+
#
|
| 194 |
+
# task_id_2_gt_set_dic = {}
|
| 195 |
+
# for i, dataset in enumerate(datasets_gt):
|
| 196 |
+
# task_id = i
|
| 197 |
+
# assert task_id not in task_id_2_gt_set_dic, "Duplicate task_id %s" % task_id
|
| 198 |
+
# task_id_2_gt_set_dic[task_id] = dataset
|
| 199 |
+
# self.dataset_2 = task_id_2_gt_set_dic
|
| 200 |
+
#
|
| 201 |
+
#
|
| 202 |
+
# def __len__(self):
|
| 203 |
+
# return sum(len(dataset) for dataset in self._datasets)
|
| 204 |
+
#
|
| 205 |
+
# def __getitem__(self, idx):
|
| 206 |
+
# task_id, sample_id = idx
|
| 207 |
+
# # return self._task_id_2_data_set_dic[task_id][sample_id]
|
| 208 |
+
# return self.dataset_1[task_id][sample_id], self.dataset_2[task_id][sample_id]
|
| 209 |
+
|
| 210 |
+
class MultiTaskDataset(Dataset):
|
| 211 |
+
"""
|
| 212 |
+
useage example:
|
| 213 |
+
train_datasets = [SemData_Single(), SemData_Single()]
|
| 214 |
+
multi_task_train_dataset = MultiTaskDataset(train_datasets)
|
| 215 |
+
multi_task_batch_sampler = MultiTaskBatchSampler(train_datasets, batch_size=4, mix_opt=0, extra_task_ratio=0, drop_last=True)
|
| 216 |
+
multi_task_train_data = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler)
|
| 217 |
+
for i, (task_id, input, target) in enumerate(multi_task_train_data):
|
| 218 |
+
pre = model(input)
|
| 219 |
+
"""
|
| 220 |
+
def __init__(self, datasets):
|
| 221 |
+
self._datasets = datasets
|
| 222 |
+
task_id_2_data_set_dic = {}
|
| 223 |
+
for i, dataset in enumerate(datasets):
|
| 224 |
+
task_id = i
|
| 225 |
+
assert task_id not in task_id_2_data_set_dic, "Duplicate task_id %s" % task_id
|
| 226 |
+
task_id_2_data_set_dic[task_id] = dataset
|
| 227 |
+
|
| 228 |
+
self._task_id_2_data_set_dic = task_id_2_data_set_dic
|
| 229 |
+
|
| 230 |
+
def __len__(self):
|
| 231 |
+
return sum(len(dataset) for dataset in self._datasets)
|
| 232 |
+
|
| 233 |
+
def __getitem__(self, idx):
|
| 234 |
+
task_id, sample_id = idx
|
| 235 |
+
# print('----', idx, task_id, sample_id)
|
| 236 |
+
return self._task_id_2_data_set_dic[task_id][sample_id]
|
| 237 |
+
|
| 238 |
+
def collate_fn(batch):
|
| 239 |
+
# print(len(batch))
|
| 240 |
+
# print('*'*10)
|
| 241 |
+
# print(batch[0][0])
|
| 242 |
+
# print('#'*10)
|
| 243 |
+
# print(batch[0][1])
|
| 244 |
+
# batch = list(filter(lambda x: x[0][0] is not None, batch))
|
| 245 |
+
# if len(batch) == 0: return torch.Tensor()
|
| 246 |
+
print('******------',batch)
|
| 247 |
+
if not isinstance(batch[0][0], tuple):
|
| 248 |
+
return default_collate(batch)
|
| 249 |
+
else:
|
| 250 |
+
batch_num = len(batch)
|
| 251 |
+
ret = []
|
| 252 |
+
for item_idx in range(len(batch[0][0])):
|
| 253 |
+
if batch[0][0][item_idx] is None:
|
| 254 |
+
ret.append(None)
|
| 255 |
+
else:
|
| 256 |
+
ret.append(default_collate([batch[i][0][item_idx] for i in range(batch_num)]))
|
| 257 |
+
ret.append(default_collate([batch[i][1] for i in range(batch_num)]))
|
| 258 |
+
return ret
|
| 259 |
+
|
| 260 |
+
class DistrubutedMultiTaskBatchSampler(BatchSampler):
|
| 261 |
+
"""
|
| 262 |
+
datasets: class the class of the Dataset
|
| 263 |
+
batch_size: int
|
| 264 |
+
mix_opt: int mix_opt ==0 shuffle all_task; mix_opt ==1 shuffle extra_task
|
| 265 |
+
extra_task_ratio(float, optional): the rate between task one and extra task
|
| 266 |
+
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
|
| 267 |
+
if the dataset size is not divisible by the batch size. If ``False`` and
|
| 268 |
+
the size of dataset is not divisible by the batch size, then the last batch
|
| 269 |
+
will be smaller. (default: ``True``)
|
| 270 |
+
"""
|
| 271 |
+
|
| 272 |
+
def __init__(self, datasets, batch_size, num_replicas, rank, mix_opt=0, extra_task_ratio=0, drop_last=True,
|
| 273 |
+
shuffle=True):
|
| 274 |
+
if num_replicas is None:
|
| 275 |
+
if not dist.is_available():
|
| 276 |
+
raise RuntimeError("Requires distributed package to be available")
|
| 277 |
+
num_replicas = dist.get_world_size()
|
| 278 |
+
if rank is None:
|
| 279 |
+
if not dist.is_available():
|
| 280 |
+
raise RuntimeError("Requires distributed package to be available")
|
| 281 |
+
rank = dist.get_rank()
|
| 282 |
+
if rank >= num_replicas or rank < 0:
|
| 283 |
+
raise ValueError(
|
| 284 |
+
"Invalid rank {}, rank should be in the interval"
|
| 285 |
+
" [0, {}]".format(rank, num_replicas - 1))
|
| 286 |
+
self.num_replicas = num_replicas
|
| 287 |
+
self.rank = rank
|
| 288 |
+
self.epoch = 0
|
| 289 |
+
assert mix_opt in [0, 1], 'mix_opt must equal 0 or 1'
|
| 290 |
+
assert extra_task_ratio >= 0, 'extra_task_ratio must greater than 0'
|
| 291 |
+
# self._datasets = datasets
|
| 292 |
+
self._batch_size = batch_size
|
| 293 |
+
self._mix_opt = mix_opt
|
| 294 |
+
self._extra_task_ratio = extra_task_ratio
|
| 295 |
+
self._drop_last = drop_last
|
| 296 |
+
train_data_list = []
|
| 297 |
+
self.shuffle = shuffle
|
| 298 |
+
for dataset in datasets:
|
| 299 |
+
print(len(dataset))
|
| 300 |
+
train_data_list.append(self._get_index_batches(len(dataset), batch_size, self._drop_last))
|
| 301 |
+
|
| 302 |
+
######### 一个列表里存n个dataset的数据,数据也以列表形式存在,一个dataset的列表里面把数据划分成了不同的batch的index
|
| 303 |
+
self._train_data_list = train_data_list
|
| 304 |
+
self.total_len = sum(len(train_data) for train_data in self._train_data_list)
|
| 305 |
+
|
| 306 |
+
######### DDP ######################
|
| 307 |
+
if self._drop_last and self.total_len % self.num_replicas != 0: # type: ignore[arg-type]
|
| 308 |
+
self.num_samples = math.ceil(
|
| 309 |
+
(self.total_len - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
|
| 310 |
+
)
|
| 311 |
+
else:
|
| 312 |
+
self.num_samples = math.ceil(self.total_len / self.num_replicas) # type: ignore[arg-type]
|
| 313 |
+
|
| 314 |
+
self.total_size = self.num_samples * self.num_replicas
|
| 315 |
+
self.epoch = 0
|
| 316 |
+
self.seed = 0
|
| 317 |
+
|
| 318 |
+
def set_epoch(self, epoch):
|
| 319 |
+
# print('&&&&****')
|
| 320 |
+
self.epoch = epoch
|
| 321 |
+
|
| 322 |
+
@staticmethod
|
| 323 |
+
def _get_index_batches(dataset_len, batch_size, drop_last):
|
| 324 |
+
# index_batches = [list(range(i, min(i+batch_size, dataset_len))) for i in range(0, dataset_len, batch_size)]
|
| 325 |
+
index = list(range(dataset_len))
|
| 326 |
+
if drop_last and dataset_len % batch_size:
|
| 327 |
+
del index[-(dataset_len % batch_size):]
|
| 328 |
+
index_batches = [index[i:i + batch_size] for i in range(0, len(index), batch_size)]
|
| 329 |
+
return index_batches
|
| 330 |
+
|
| 331 |
+
def __len__(self):
|
| 332 |
+
# return sum(len(train_data) for train_data in self._train_data_list)
|
| 333 |
+
return self.num_samples
|
| 334 |
+
|
| 335 |
+
def __iter__(self):
|
| 336 |
+
all_iters = [iter(item) for item in self._train_data_list]
|
| 337 |
+
all_indices = self._gen_task_indices(self._train_data_list, self._mix_opt, self._extra_task_ratio)
|
| 338 |
+
|
| 339 |
+
######### DDP ######################
|
| 340 |
+
random.shuffle(all_indices)
|
| 341 |
+
all_indices = all_indices[self.rank:self.total_size:self.num_replicas]
|
| 342 |
+
assert len(all_indices) == self.num_samples
|
| 343 |
+
|
| 344 |
+
for local_task_idx in all_indices:
|
| 345 |
+
# task_id = self._datasets[local_task_idx].get_task_id()
|
| 346 |
+
batch = next(all_iters[local_task_idx])
|
| 347 |
+
# batch = batch[self.rank:len(batch):self.num_replicas]
|
| 348 |
+
# print(local_task_idx)
|
| 349 |
+
yield [(local_task_idx, sample_id) for sample_id in batch]
|
| 350 |
+
# yield iter(batch)
|
| 351 |
+
|
| 352 |
+
@staticmethod
|
| 353 |
+
def _gen_task_indices(train_data_list, mix_opt, extra_task_ratio):
|
| 354 |
+
|
| 355 |
+
########## accoding to the number of models ###########
|
| 356 |
+
all_indices = []
|
| 357 |
+
for i in range(len(train_data_list)):
|
| 358 |
+
all_indices += [i] * len(train_data_list[i])
|
| 359 |
+
# print(all_indices)
|
| 360 |
+
return all_indices
|
| 361 |
+
# def set_epoch(self, epoch)
|
| 362 |
+
# r"""
|
| 363 |
+
# Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
|
| 364 |
+
# use a different random ordering for each epoch. Otherwise, the next iteration of this
|
| 365 |
+
# sampler will yield the same ordering.
|
| 366 |
+
|
| 367 |
+
# Args:
|
| 368 |
+
# epoch (int): Epoch number.
|
| 369 |
+
# """
|
| 370 |
+
# self.epoch = epoch
|
datasets/wrappers.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import functools
|
| 3 |
+
import random
|
| 4 |
+
import math
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import cv2
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from torch.utils.data import Dataset
|
| 11 |
+
from torchvision import transforms
|
| 12 |
+
import torchvision
|
| 13 |
+
|
| 14 |
+
from datasets import register
|
| 15 |
+
import cv2
|
| 16 |
+
from math import pi
|
| 17 |
+
from torchvision.transforms import InterpolationMode
|
| 18 |
+
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
def to_mask(mask):
|
| 21 |
+
return transforms.ToTensor()(
|
| 22 |
+
transforms.Grayscale(num_output_channels=1)(
|
| 23 |
+
transforms.ToPILImage()(mask)))
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def resize_fn(img, size):
|
| 27 |
+
return transforms.ToTensor()(
|
| 28 |
+
transforms.Resize(size)(
|
| 29 |
+
transforms.ToPILImage()(img)))
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@register('val')
|
| 33 |
+
class ValDataset(Dataset):
|
| 34 |
+
def __init__(self, dataset, inp_size=None, augment=False):
|
| 35 |
+
self.dataset = dataset
|
| 36 |
+
self.inp_size = inp_size
|
| 37 |
+
self.augment = augment
|
| 38 |
+
|
| 39 |
+
self.img_transform = transforms.Compose([
|
| 40 |
+
# transforms.Resize((inp_size, inp_size)),
|
| 41 |
+
transforms.ToTensor(),
|
| 42 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 43 |
+
std=[0.229, 0.224, 0.225])
|
| 44 |
+
])
|
| 45 |
+
self.mask_transform = transforms.Compose([
|
| 46 |
+
transforms.Resize((inp_size, inp_size), interpolation=Image.NEAREST),
|
| 47 |
+
transforms.ToTensor(),
|
| 48 |
+
])
|
| 49 |
+
|
| 50 |
+
def __len__(self):
|
| 51 |
+
return len(self.dataset)
|
| 52 |
+
|
| 53 |
+
def __getitem__(self, idx):
|
| 54 |
+
img, mask = self.dataset[idx]
|
| 55 |
+
mask_name = mask
|
| 56 |
+
a = self.img_transform(img)
|
| 57 |
+
# b = self.mask_transform(mask)
|
| 58 |
+
|
| 59 |
+
# print(idx, mask.filename)
|
| 60 |
+
# b = cv2.imread(mask.filename,cv2.IMREAD_UNCHANGED)
|
| 61 |
+
b = cv2.imread(mask,cv2.IMREAD_UNCHANGED)
|
| 62 |
+
return {
|
| 63 |
+
'inp': self.img_transform(img),
|
| 64 |
+
'gt': torch.tensor(b),
|
| 65 |
+
'name': mask_name,
|
| 66 |
+
'filp': False
|
| 67 |
+
# 'idx': idx
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@register('train')
|
| 72 |
+
class TrainDataset(Dataset):
|
| 73 |
+
def __init__(self, dataset, size_min=None, size_max=None, inp_size=None,
|
| 74 |
+
augment=False, gt_resize=None):
|
| 75 |
+
self.dataset = dataset
|
| 76 |
+
self.size_min = size_min
|
| 77 |
+
if size_max is None:
|
| 78 |
+
size_max = size_min
|
| 79 |
+
self.size_max = size_max
|
| 80 |
+
self.augment = augment
|
| 81 |
+
self.gt_resize = gt_resize
|
| 82 |
+
|
| 83 |
+
self.inp_size = inp_size
|
| 84 |
+
self.img_transform = transforms.Compose([
|
| 85 |
+
transforms.Resize((self.inp_size, self.inp_size)),
|
| 86 |
+
transforms.ToTensor(),
|
| 87 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 88 |
+
std=[0.229, 0.224, 0.225])
|
| 89 |
+
])
|
| 90 |
+
self.inverse_transform = transforms.Compose([
|
| 91 |
+
transforms.Normalize(mean=[0., 0., 0.],
|
| 92 |
+
std=[1/0.229, 1/0.224, 1/0.225]),
|
| 93 |
+
transforms.Normalize(mean=[-0.485, -0.456, -0.406],
|
| 94 |
+
std=[1, 1, 1])
|
| 95 |
+
])
|
| 96 |
+
self.mask_transform = transforms.Compose([
|
| 97 |
+
transforms.Resize((self.inp_size, self.inp_size)),
|
| 98 |
+
transforms.ToTensor(),
|
| 99 |
+
])
|
| 100 |
+
|
| 101 |
+
def __len__(self):
|
| 102 |
+
return len(self.dataset)
|
| 103 |
+
|
| 104 |
+
def __getitem__(self, idx):
|
| 105 |
+
# print('lodd****',idx,self.dataset[idx])
|
| 106 |
+
img, mask = self.dataset[idx]
|
| 107 |
+
mask_name = mask
|
| 108 |
+
# print('befor mask', mask)
|
| 109 |
+
#new add
|
| 110 |
+
# print(idx, mask.filename, img.size)
|
| 111 |
+
|
| 112 |
+
# mask = cv2.imread(mask.filename, cv2.IMREAD_UNCHANGED)
|
| 113 |
+
mask = cv2.imread(mask, cv2.IMREAD_UNCHANGED)
|
| 114 |
+
# print('befor mask', mask)
|
| 115 |
+
# random filp
|
| 116 |
+
if random.random() < 0.5:
|
| 117 |
+
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
| 118 |
+
# mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
|
| 119 |
+
mask = cv2.flip(mask, 1)
|
| 120 |
+
|
| 121 |
+
img = transforms.Resize((self.inp_size, self.inp_size))(img)
|
| 122 |
+
# mask = transforms.Resize((self.inp_size, self.inp_size), interpolation=InterpolationMode.NEAREST)(mask)
|
| 123 |
+
mask = torch.from_numpy(mask)
|
| 124 |
+
# print('behind mask', mask)
|
| 125 |
+
return {
|
| 126 |
+
'inp': self.img_transform(img),
|
| 127 |
+
# 'gt': self.mask_transform(mask)
|
| 128 |
+
'gt': mask,
|
| 129 |
+
'name': mask_name,
|
| 130 |
+
# 'idx': idx
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
@register('train_multi_task')
|
| 134 |
+
class TrainDataset(Dataset):
|
| 135 |
+
def __init__(self, dataset, size_min=None, size_max=None, inp_size=None,
|
| 136 |
+
augment=False, gt_resize=None):
|
| 137 |
+
self.dataset = dataset
|
| 138 |
+
self.size_min = size_min
|
| 139 |
+
if size_max is None:
|
| 140 |
+
size_max = size_min
|
| 141 |
+
self.size_max = size_max
|
| 142 |
+
self.augment = augment
|
| 143 |
+
self.gt_resize = gt_resize
|
| 144 |
+
|
| 145 |
+
self.inp_size = inp_size
|
| 146 |
+
self.img_transform = transforms.Compose([
|
| 147 |
+
transforms.Resize((self.inp_size, self.inp_size)),
|
| 148 |
+
transforms.ToTensor(),
|
| 149 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 150 |
+
std=[0.229, 0.224, 0.225])
|
| 151 |
+
])
|
| 152 |
+
self.inverse_transform = transforms.Compose([
|
| 153 |
+
transforms.Normalize(mean=[0., 0., 0.],
|
| 154 |
+
std=[1/0.229, 1/0.224, 1/0.225]),
|
| 155 |
+
transforms.Normalize(mean=[-0.485, -0.456, -0.406],
|
| 156 |
+
std=[1, 1, 1])
|
| 157 |
+
])
|
| 158 |
+
self.mask_transform = transforms.Compose([
|
| 159 |
+
transforms.Resize((self.inp_size, self.inp_size)),
|
| 160 |
+
transforms.ToTensor(),
|
| 161 |
+
])
|
| 162 |
+
|
| 163 |
+
def __len__(self):
|
| 164 |
+
return len(self.dataset)
|
| 165 |
+
# return sum(len(dataset) for dataset in self.datasets)
|
| 166 |
+
|
| 167 |
+
def __getitem__(self, idx):
|
| 168 |
+
# print('lodd****',idx,self.dataset[idx])
|
| 169 |
+
# print('+++++',idx)
|
| 170 |
+
img, mask = self.dataset[idx]
|
| 171 |
+
# print('befor mask', mask)
|
| 172 |
+
#new add
|
| 173 |
+
# print('****',idx, mask)
|
| 174 |
+
mask_name = mask
|
| 175 |
+
mask = cv2.imread(mask, cv2.IMREAD_UNCHANGED)
|
| 176 |
+
|
| 177 |
+
# print('****',mask)
|
| 178 |
+
# print('befor mask', mask)
|
| 179 |
+
# random filp
|
| 180 |
+
if random.random() < 0.5:
|
| 181 |
+
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
| 182 |
+
# mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
|
| 183 |
+
mask = cv2.flip(mask, 1)
|
| 184 |
+
|
| 185 |
+
img = transforms.Resize((self.inp_size, self.inp_size))(img)
|
| 186 |
+
# mask = transforms.Resize((self.inp_size, self.inp_size), interpolation=InterpolationMode.NEAREST)(mask)
|
| 187 |
+
mask = torch.from_numpy(mask)
|
| 188 |
+
# print('behind mask', mask)
|
| 189 |
+
return {
|
| 190 |
+
'inp': self.img_transform(img),
|
| 191 |
+
# 'gt': self.mask_transform(mask)
|
| 192 |
+
'gt': mask,
|
| 193 |
+
'name': mask_name
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
@register('val_multi_task')
|
| 198 |
+
class ValDataset(Dataset):
|
| 199 |
+
def __init__(self, dataset, inp_size=None, augment=False):
|
| 200 |
+
self.dataset = dataset
|
| 201 |
+
self.inp_size = inp_size
|
| 202 |
+
self.augment = augment
|
| 203 |
+
|
| 204 |
+
self.img_transform = transforms.Compose([
|
| 205 |
+
transforms.Resize((inp_size, inp_size)),
|
| 206 |
+
transforms.ToTensor(),
|
| 207 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 208 |
+
std=[0.229, 0.224, 0.225])
|
| 209 |
+
])
|
| 210 |
+
self.mask_transform = transforms.Compose([
|
| 211 |
+
transforms.Resize((inp_size, inp_size), interpolation=Image.NEAREST),
|
| 212 |
+
transforms.ToTensor(),
|
| 213 |
+
])
|
| 214 |
+
|
| 215 |
+
def __len__(self):
|
| 216 |
+
return len(self.dataset)
|
| 217 |
+
|
| 218 |
+
def __getitem__(self, idx):
|
| 219 |
+
img, mask = self.dataset[idx]
|
| 220 |
+
a = self.img_transform(img)
|
| 221 |
+
# b = self.mask_transform(mask)
|
| 222 |
+
mask_name = mask
|
| 223 |
+
# print(idx, mask.filename)
|
| 224 |
+
# b = cv2.imread(mask.filename,cv2.IMREAD_UNCHANGED)
|
| 225 |
+
b = cv2.imread(mask, cv2.IMREAD_UNCHANGED)
|
| 226 |
+
return {
|
| 227 |
+
'inp': self.img_transform(img),
|
| 228 |
+
'gt': torch.tensor(b),
|
| 229 |
+
'name': mask_name
|
| 230 |
+
# 'idx': idx
|
| 231 |
+
}
|
models/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .models import register, make
|
| 2 |
+
from . import sam
|
| 3 |
+
from . import sam_single
|
| 4 |
+
|
models/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (250 Bytes). View file
|
|
|
models/__pycache__/__init__.cpython-37.pyc
ADDED
|
Binary file (244 Bytes). View file
|
|
|
models/__pycache__/iou_loss.cpython-37.pyc
ADDED
|
Binary file (938 Bytes). View file
|
|
|
models/__pycache__/models.cpython-310.pyc
ADDED
|
Binary file (723 Bytes). View file
|
|
|
models/__pycache__/models.cpython-37.pyc
ADDED
|
Binary file (698 Bytes). View file
|
|
|
models/__pycache__/sam.cpython-310.pyc
ADDED
|
Binary file (9.78 kB). View file
|
|
|
models/__pycache__/sam.cpython-37.pyc
ADDED
|
Binary file (9.75 kB). View file
|
|
|
models/__pycache__/sam_single.cpython-37.pyc
ADDED
|
Binary file (9.5 kB). View file
|
|
|
models/__pycache__/utils_prompt.cpython-37.pyc
ADDED
|
Binary file (2.2 kB). View file
|
|
|
models/block.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.optim as optim
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class MergeAndConv(nn.Module):
|
| 10 |
+
|
| 11 |
+
def __init__(self, ic, oc, inner=32):
|
| 12 |
+
super().__init__()
|
| 13 |
+
|
| 14 |
+
self.conv1 = nn.Conv2d(ic, inner, kernel_size=3, stride=1, padding=1)
|
| 15 |
+
self.bn = nn.BatchNorm2d(inner)
|
| 16 |
+
self.relu = nn.ReLU(inplace=True)
|
| 17 |
+
self.conv2 = nn.Conv2d(inner, oc, kernel_size=3, stride=1, padding=1)
|
| 18 |
+
|
| 19 |
+
def forward(self, x):
|
| 20 |
+
x = self.conv2(self.bn(self.relu(self.conv1(x))))
|
| 21 |
+
x = torch.sigmoid(x)
|
| 22 |
+
return x
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class SideClassifer(nn.Module):
|
| 26 |
+
def __init__(self, ic, n_class=1, M=2, kernel_size=1):
|
| 27 |
+
super().__init__()
|
| 28 |
+
|
| 29 |
+
sides = []
|
| 30 |
+
for i in range(M):
|
| 31 |
+
sides.append(nn.Conv2d(ic, n_class, kernel_size=kernel_size))
|
| 32 |
+
|
| 33 |
+
self.sides = nn.ModuleList(sides)
|
| 34 |
+
|
| 35 |
+
def forward(self, x):
|
| 36 |
+
return [fn(x) for fn in self.sides]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class UpsampleSKConv(nn.Module):
|
| 40 |
+
"""docstring for UpsampleSKConvPlus"""
|
| 41 |
+
|
| 42 |
+
def __init__(self, ic, oc, reduce=4):
|
| 43 |
+
super(UpsampleSKConv, self).__init__()
|
| 44 |
+
|
| 45 |
+
self.relu = nn.ReLU(inplace=True)
|
| 46 |
+
self.prev = nn.Conv2d(ic, ic // reduce, kernel_size=3, stride=1, padding=1)
|
| 47 |
+
self.bn = nn.BatchNorm2d(ic // reduce)
|
| 48 |
+
|
| 49 |
+
self.next = nn.Conv2d(ic // reduce, oc, kernel_size=1, stride=1)
|
| 50 |
+
self.bn2 = nn.BatchNorm2d(oc)
|
| 51 |
+
|
| 52 |
+
self.sk = SKSPP(ic // reduce, ic // reduce, M=4)
|
| 53 |
+
|
| 54 |
+
def forward(self, x):
|
| 55 |
+
x = F.interpolate(x, scale_factor=2)
|
| 56 |
+
|
| 57 |
+
x = self.bn(self.relu(self.prev(x)))
|
| 58 |
+
|
| 59 |
+
x = self.sk(x)
|
| 60 |
+
|
| 61 |
+
x = self.bn2(self.relu(self.next(x)))
|
| 62 |
+
|
| 63 |
+
return x
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class SKSPP(nn.Module):
|
| 67 |
+
def __init__(self, features, WH, M=2, G=1, r=16, stride=1, L=32):
|
| 68 |
+
""" Constructor
|
| 69 |
+
Args:
|
| 70 |
+
features: input channel dimensionality.
|
| 71 |
+
WH: input spatial dimensionality, used for GAP kernel size.
|
| 72 |
+
M: the number of branchs.
|
| 73 |
+
G: num of convolution groups.
|
| 74 |
+
r: the radio for compute d, the length of z.
|
| 75 |
+
stride: stride, default 1.
|
| 76 |
+
L: the minimum dim of the vector z in paper, default 32.
|
| 77 |
+
"""
|
| 78 |
+
super(SKSPP, self).__init__()
|
| 79 |
+
d = max(int(features / r), L)
|
| 80 |
+
self.M = M # original
|
| 81 |
+
self.features = features
|
| 82 |
+
self.convs = nn.ModuleList([])
|
| 83 |
+
|
| 84 |
+
# 1,3,5,7 padding:[0,1,2,3]
|
| 85 |
+
for i in range(1, M):
|
| 86 |
+
self.convs.append(nn.Sequential(
|
| 87 |
+
nn.Conv2d(features, features, kernel_size=1 + i * 2, dilation=1 + i * 2, stride=stride,
|
| 88 |
+
padding=((1 + i * 2) * (i * 2) + 1) // 2, groups=G),
|
| 89 |
+
nn.BatchNorm2d(features),
|
| 90 |
+
nn.ReLU(inplace=False)
|
| 91 |
+
))
|
| 92 |
+
# self.gap = nn.AvgPool2d(int(WH/stride))
|
| 93 |
+
self.fc = nn.Linear(features, d)
|
| 94 |
+
self.fcs = nn.ModuleList([])
|
| 95 |
+
for i in range(M):
|
| 96 |
+
self.fcs.append(
|
| 97 |
+
nn.Linear(d, features)
|
| 98 |
+
)
|
| 99 |
+
self.softmax = nn.Softmax(dim=1)
|
| 100 |
+
|
| 101 |
+
def forward(self, x):
|
| 102 |
+
|
| 103 |
+
feas = torch.unsqueeze(x, dim=1)
|
| 104 |
+
|
| 105 |
+
# F->conv1x1->conv3x3->conv5x5->conv7x7
|
| 106 |
+
|
| 107 |
+
for i, conv in enumerate(self.convs):
|
| 108 |
+
x = conv(x)
|
| 109 |
+
# if i == 0:
|
| 110 |
+
# feas = fea
|
| 111 |
+
# else:
|
| 112 |
+
feas = torch.cat([feas, torch.unsqueeze(x, dim=1)], dim=1)
|
| 113 |
+
|
| 114 |
+
fea_U = torch.sum(feas, dim=1)
|
| 115 |
+
fea_s = fea_U.mean(-1).mean(-1)
|
| 116 |
+
fea_z = self.fc(fea_s)
|
| 117 |
+
|
| 118 |
+
for i, fc in enumerate(self.fcs):
|
| 119 |
+
vector = fc(fea_z).unsqueeze_(dim=1)
|
| 120 |
+
if i == 0:
|
| 121 |
+
attention_vectors = vector
|
| 122 |
+
else:
|
| 123 |
+
attention_vectors = torch.cat([attention_vectors, vector], dim=1)
|
| 124 |
+
|
| 125 |
+
attention_vectors = self.softmax(attention_vectors)
|
| 126 |
+
attention_vectors = attention_vectors.unsqueeze(-1).unsqueeze(-1)
|
| 127 |
+
fea_v = (feas * attention_vectors).sum(dim=1)
|
| 128 |
+
return fea_v
|
models/bn_helper.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import functools
|
| 3 |
+
|
| 4 |
+
if torch.__version__.startswith('0'):
|
| 5 |
+
from .sync_bn.inplace_abn.bn import InPlaceABNSync
|
| 6 |
+
BatchNorm2d = functools.partial(InPlaceABNSync, activation='none')
|
| 7 |
+
BatchNorm2d_class = InPlaceABNSync
|
| 8 |
+
relu_inplace = False
|
| 9 |
+
else:
|
| 10 |
+
BatchNorm2d_class = BatchNorm2d = torch.nn.SyncBatchNorm
|
| 11 |
+
relu_inplace = True
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
BatchNorm2d = torch.nn.BatchNorm2d
|
| 15 |
+
BatchNorm2d_class = BatchNorm2d
|
| 16 |
+
relu_inplace = False
|
models/iou_loss.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
###################################################################
|
| 6 |
+
# ########################## iou loss #############################
|
| 7 |
+
###################################################################
|
| 8 |
+
class IOU(torch.nn.Module):
|
| 9 |
+
def __init__(self):
|
| 10 |
+
super(IOU, self).__init__()
|
| 11 |
+
|
| 12 |
+
def _iou(self, pred, target):
|
| 13 |
+
pred = torch.sigmoid(pred)
|
| 14 |
+
inter = (pred * target).sum(dim=(2, 3))
|
| 15 |
+
union = (pred + target).sum(dim=(2, 3)) - inter
|
| 16 |
+
iou = 1 - (inter / union)
|
| 17 |
+
|
| 18 |
+
return iou.mean()
|
| 19 |
+
|
| 20 |
+
def forward(self, pred, target):
|
| 21 |
+
return self._iou(pred, target)
|
models/mmseg/__init__.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import mmcv
|
| 2 |
+
|
| 3 |
+
from .version import __version__, version_info
|
| 4 |
+
|
| 5 |
+
# MMCV_MIN = '1.1.4'
|
| 6 |
+
# MMCV_MAX = '1.3.0'
|
| 7 |
+
|
| 8 |
+
MMCV_MIN = '1.1.4'
|
| 9 |
+
MMCV_MAX = '1.7.0'
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def digit_version(version_str):
|
| 13 |
+
digit_version = []
|
| 14 |
+
for x in version_str.split('.'):
|
| 15 |
+
if x.isdigit():
|
| 16 |
+
digit_version.append(int(x))
|
| 17 |
+
elif x.find('rc') != -1:
|
| 18 |
+
patch_version = x.split('rc')
|
| 19 |
+
digit_version.append(int(patch_version[0]) - 1)
|
| 20 |
+
digit_version.append(int(patch_version[1]))
|
| 21 |
+
return digit_version
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
mmcv_min_version = digit_version(MMCV_MIN)
|
| 25 |
+
mmcv_max_version = digit_version(MMCV_MAX)
|
| 26 |
+
mmcv_version = digit_version(mmcv.__version__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
assert (mmcv_min_version <= mmcv_version <= mmcv_max_version), \
|
| 30 |
+
f'MMCV=={mmcv.__version__} is used but incompatible. ' \
|
| 31 |
+
f'Please install mmcv>={mmcv_min_version}, <={mmcv_max_version}.'
|
| 32 |
+
|
| 33 |
+
__all__ = ['__version__', 'version_info']
|
models/mmseg/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (841 Bytes). View file
|
|
|
models/mmseg/__pycache__/__init__.cpython-37.pyc
ADDED
|
Binary file (839 Bytes). View file
|
|
|
models/mmseg/__pycache__/version.cpython-310.pyc
ADDED
|
Binary file (521 Bytes). View file
|
|
|
models/mmseg/__pycache__/version.cpython-37.pyc
ADDED
|
Binary file (513 Bytes). View file
|
|
|
models/mmseg/apis/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .inference import inference_segmentor, init_segmentor, show_result_pyplot
|
| 2 |
+
from .test import multi_gpu_test, single_gpu_test
|
| 3 |
+
from .train import get_root_logger, set_random_seed, train_segmentor
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor',
|
| 7 |
+
'inference_segmentor', 'multi_gpu_test', 'single_gpu_test',
|
| 8 |
+
'show_result_pyplot'
|
| 9 |
+
]
|
models/mmseg/apis/inference.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import mmcv
|
| 3 |
+
import torch
|
| 4 |
+
from mmcv.parallel import collate, scatter
|
| 5 |
+
from mmcv.runner import load_checkpoint
|
| 6 |
+
|
| 7 |
+
from mmseg.datasets.pipelines import Compose
|
| 8 |
+
from mmseg.models import build_segmentor
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def init_segmentor(config, checkpoint=None, device='cuda:0'):
|
| 12 |
+
"""Initialize a segmentor from config file.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
config (str or :obj:`mmcv.Config`): Config file path or the config
|
| 16 |
+
object.
|
| 17 |
+
checkpoint (str, optional): Checkpoint path. If left as None, the model
|
| 18 |
+
will not load any weights.
|
| 19 |
+
device (str, optional) CPU/CUDA device option. Default 'cuda:0'.
|
| 20 |
+
Use 'cpu' for loading model on CPU.
|
| 21 |
+
Returns:
|
| 22 |
+
nn.Module: The constructed segmentor.
|
| 23 |
+
"""
|
| 24 |
+
if isinstance(config, str):
|
| 25 |
+
config = mmcv.Config.fromfile(config)
|
| 26 |
+
elif not isinstance(config, mmcv.Config):
|
| 27 |
+
raise TypeError('config must be a filename or Config object, '
|
| 28 |
+
'but got {}'.format(type(config)))
|
| 29 |
+
config.model.pretrained = None
|
| 30 |
+
config.model.train_cfg = None
|
| 31 |
+
model = build_segmentor(config.model, test_cfg=config.get('test_cfg'))
|
| 32 |
+
if checkpoint is not None:
|
| 33 |
+
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
|
| 34 |
+
model.CLASSES = checkpoint['meta']['CLASSES']
|
| 35 |
+
model.PALETTE = checkpoint['meta']['PALETTE']
|
| 36 |
+
model.cfg = config # save the config in the model for convenience
|
| 37 |
+
model.to(device)
|
| 38 |
+
model.eval()
|
| 39 |
+
return model
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class LoadImage:
|
| 43 |
+
"""A simple pipeline to load image."""
|
| 44 |
+
|
| 45 |
+
def __call__(self, results):
|
| 46 |
+
"""Call function to load images into results.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
results (dict): A result dict contains the file name
|
| 50 |
+
of the image to be read.
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
dict: ``results`` will be returned containing loaded image.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
if isinstance(results['img'], str):
|
| 57 |
+
results['filename'] = results['img']
|
| 58 |
+
results['ori_filename'] = results['img']
|
| 59 |
+
else:
|
| 60 |
+
results['filename'] = None
|
| 61 |
+
results['ori_filename'] = None
|
| 62 |
+
img = mmcv.imread(results['img'])
|
| 63 |
+
results['img'] = img
|
| 64 |
+
results['img_shape'] = img.shape
|
| 65 |
+
results['ori_shape'] = img.shape
|
| 66 |
+
return results
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def inference_segmentor(model, img):
|
| 70 |
+
"""Inference image(s) with the segmentor.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
model (nn.Module): The loaded segmentor.
|
| 74 |
+
imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
|
| 75 |
+
images.
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
(list[Tensor]): The segmentation result.
|
| 79 |
+
"""
|
| 80 |
+
cfg = model.cfg
|
| 81 |
+
device = next(model.parameters()).device # model device
|
| 82 |
+
# build the data pipeline
|
| 83 |
+
test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
|
| 84 |
+
test_pipeline = Compose(test_pipeline)
|
| 85 |
+
# prepare data
|
| 86 |
+
data = dict(img=img)
|
| 87 |
+
data = test_pipeline(data)
|
| 88 |
+
data = collate([data], samples_per_gpu=1)
|
| 89 |
+
if next(model.parameters()).is_cuda:
|
| 90 |
+
# scatter to specified GPU
|
| 91 |
+
data = scatter(data, [device])[0]
|
| 92 |
+
else:
|
| 93 |
+
data['img_metas'] = [i.data[0] for i in data['img_metas']]
|
| 94 |
+
|
| 95 |
+
# forward the model
|
| 96 |
+
with torch.no_grad():
|
| 97 |
+
result = model(return_loss=False, rescale=True, **data)
|
| 98 |
+
return result
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def show_result_pyplot(model, img, result, palette=None, fig_size=(15, 10)):
|
| 102 |
+
"""Visualize the segmentation results on the image.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
model (nn.Module): The loaded segmentor.
|
| 106 |
+
img (str or np.ndarray): Image filename or loaded image.
|
| 107 |
+
result (list): The segmentation result.
|
| 108 |
+
palette (list[list[int]]] | None): The palette of segmentation
|
| 109 |
+
map. If None is given, random palette will be generated.
|
| 110 |
+
Default: None
|
| 111 |
+
fig_size (tuple): Figure size of the pyplot figure.
|
| 112 |
+
"""
|
| 113 |
+
if hasattr(model, 'module'):
|
| 114 |
+
model = model.module
|
| 115 |
+
img = model.show_result(img, result, palette=palette, show=False)
|
| 116 |
+
plt.figure(figsize=fig_size)
|
| 117 |
+
plt.imshow(mmcv.bgr2rgb(img))
|
| 118 |
+
plt.show()
|
models/mmseg/apis/test.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path as osp
|
| 2 |
+
import pickle
|
| 3 |
+
import shutil
|
| 4 |
+
import tempfile
|
| 5 |
+
|
| 6 |
+
import mmcv
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.distributed as dist
|
| 10 |
+
from mmcv.image import tensor2imgs
|
| 11 |
+
from mmcv.runner import get_dist_info
|
| 12 |
+
from IPython import embed
|
| 13 |
+
from mmseg.ops import resize
|
| 14 |
+
|
| 15 |
+
def np2tmp(array, temp_file_name=None):
|
| 16 |
+
"""Save ndarray to local numpy file.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
array (ndarray): Ndarray to save.
|
| 20 |
+
temp_file_name (str): Numpy file name. If 'temp_file_name=None', this
|
| 21 |
+
function will generate a file name with tempfile.NamedTemporaryFile
|
| 22 |
+
to save ndarray. Default: None.
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
str: The numpy file name.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
if temp_file_name is None:
|
| 29 |
+
temp_file_name = tempfile.NamedTemporaryFile(
|
| 30 |
+
suffix='.npy', delete=False).name
|
| 31 |
+
np.save(temp_file_name, array)
|
| 32 |
+
return temp_file_name
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def single_gpu_test(model,
|
| 36 |
+
data_loader,
|
| 37 |
+
show=False,
|
| 38 |
+
out_dir=None,
|
| 39 |
+
efficient_test=False):
|
| 40 |
+
"""Test with single GPU.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
model (nn.Module): Model to be tested.
|
| 44 |
+
data_loader (utils.data.Dataloader): Pytorch data loader.
|
| 45 |
+
show (bool): Whether show results during infernece. Default: False.
|
| 46 |
+
out_dir (str, optional): If specified, the results will be dumped into
|
| 47 |
+
the directory to save output results.
|
| 48 |
+
efficient_test (bool): Whether save the results as local numpy files to
|
| 49 |
+
save CPU memory during evaluation. Default: False.
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
list: The prediction results.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
model.eval()
|
| 56 |
+
results = []
|
| 57 |
+
dataset = data_loader.dataset
|
| 58 |
+
prog_bar = mmcv.ProgressBar(len(dataset))
|
| 59 |
+
for i, data in enumerate(data_loader):
|
| 60 |
+
with torch.no_grad():
|
| 61 |
+
result = model(return_loss=False, **data)
|
| 62 |
+
|
| 63 |
+
if show or out_dir:
|
| 64 |
+
img_tensor = data['img'][0]
|
| 65 |
+
img_metas = data['img_metas'][0].data[0]
|
| 66 |
+
imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
|
| 67 |
+
assert len(imgs) == len(img_metas)
|
| 68 |
+
|
| 69 |
+
for img, img_meta in zip(imgs, img_metas):
|
| 70 |
+
h, w, _ = img_meta['img_shape']
|
| 71 |
+
img_show = img[:h, :w, :]
|
| 72 |
+
|
| 73 |
+
ori_h, ori_w = img_meta['ori_shape'][:-1]
|
| 74 |
+
img_show = mmcv.imresize(img_show, (ori_w, ori_h))
|
| 75 |
+
|
| 76 |
+
if out_dir:
|
| 77 |
+
out_file = osp.join(out_dir, img_meta['ori_filename'])
|
| 78 |
+
else:
|
| 79 |
+
out_file = None
|
| 80 |
+
|
| 81 |
+
model.module.show_result(
|
| 82 |
+
img_show,
|
| 83 |
+
result,
|
| 84 |
+
palette=dataset.PALETTE,
|
| 85 |
+
show=show,
|
| 86 |
+
out_file=out_file)
|
| 87 |
+
|
| 88 |
+
if isinstance(result, list):
|
| 89 |
+
if efficient_test:
|
| 90 |
+
result = [np2tmp(_) for _ in result]
|
| 91 |
+
results.extend(result)
|
| 92 |
+
else:
|
| 93 |
+
if efficient_test:
|
| 94 |
+
result = np2tmp(result)
|
| 95 |
+
results.append(result)
|
| 96 |
+
|
| 97 |
+
batch_size = data['img'][0].size(0)
|
| 98 |
+
for _ in range(batch_size):
|
| 99 |
+
prog_bar.update()
|
| 100 |
+
return results
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def multi_gpu_test(model,
|
| 104 |
+
data_loader,
|
| 105 |
+
tmpdir=None,
|
| 106 |
+
gpu_collect=False,
|
| 107 |
+
efficient_test=False):
|
| 108 |
+
"""Test model with multiple gpus.
|
| 109 |
+
|
| 110 |
+
This method tests model with multiple gpus and collects the results
|
| 111 |
+
under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
|
| 112 |
+
it encodes results to gpu tensors and use gpu communication for results
|
| 113 |
+
collection. On cpu mode it saves the results on different gpus to 'tmpdir'
|
| 114 |
+
and collects them by the rank 0 worker.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
model (nn.Module): Model to be tested.
|
| 118 |
+
data_loader (utils.data.Dataloader): Pytorch data loader.
|
| 119 |
+
tmpdir (str): Path of directory to save the temporary results from
|
| 120 |
+
different gpus under cpu mode.
|
| 121 |
+
gpu_collect (bool): Option to use either gpu or cpu to collect results.
|
| 122 |
+
efficient_test (bool): Whether save the results as local numpy files to
|
| 123 |
+
save CPU memory during evaluation. Default: False.
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
list: The prediction results.
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
model.eval()
|
| 130 |
+
results = []
|
| 131 |
+
dataset = data_loader.dataset
|
| 132 |
+
rank, world_size = get_dist_info()
|
| 133 |
+
if rank == 0:
|
| 134 |
+
prog_bar = mmcv.ProgressBar(len(dataset))
|
| 135 |
+
for i, data in enumerate(data_loader):
|
| 136 |
+
with torch.no_grad():
|
| 137 |
+
result = model(return_loss=False, rescale=True, **data)
|
| 138 |
+
|
| 139 |
+
if isinstance(result, list):
|
| 140 |
+
if efficient_test:
|
| 141 |
+
result = [np2tmp(_) for _ in result]
|
| 142 |
+
results.extend(result)
|
| 143 |
+
else:
|
| 144 |
+
if efficient_test:
|
| 145 |
+
result = np2tmp(result)
|
| 146 |
+
results.append(result)
|
| 147 |
+
|
| 148 |
+
if rank == 0:
|
| 149 |
+
batch_size = data['img'][0].size(0)
|
| 150 |
+
for _ in range(batch_size * world_size):
|
| 151 |
+
prog_bar.update()
|
| 152 |
+
|
| 153 |
+
# collect results from all ranks
|
| 154 |
+
if gpu_collect:
|
| 155 |
+
results = collect_results_gpu(results, len(dataset))
|
| 156 |
+
else:
|
| 157 |
+
results = collect_results_cpu(results, len(dataset), tmpdir)
|
| 158 |
+
return results
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def collect_results_cpu(result_part, size, tmpdir=None):
|
| 162 |
+
"""Collect results with CPU."""
|
| 163 |
+
rank, world_size = get_dist_info()
|
| 164 |
+
# create a tmp dir if it is not specified
|
| 165 |
+
if tmpdir is None:
|
| 166 |
+
MAX_LEN = 512
|
| 167 |
+
# 32 is whitespace
|
| 168 |
+
dir_tensor = torch.full((MAX_LEN, ),
|
| 169 |
+
32,
|
| 170 |
+
dtype=torch.uint8,
|
| 171 |
+
device='cuda')
|
| 172 |
+
if rank == 0:
|
| 173 |
+
tmpdir = tempfile.mkdtemp()
|
| 174 |
+
tmpdir = torch.tensor(
|
| 175 |
+
bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
|
| 176 |
+
dir_tensor[:len(tmpdir)] = tmpdir
|
| 177 |
+
dist.broadcast(dir_tensor, 0)
|
| 178 |
+
tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
|
| 179 |
+
else:
|
| 180 |
+
mmcv.mkdir_or_exist(tmpdir)
|
| 181 |
+
# dump the part result to the dir
|
| 182 |
+
mmcv.dump(result_part, osp.join(tmpdir, 'part_{}.pkl'.format(rank)))
|
| 183 |
+
dist.barrier()
|
| 184 |
+
# collect all parts
|
| 185 |
+
if rank != 0:
|
| 186 |
+
return None
|
| 187 |
+
else:
|
| 188 |
+
# load results of all parts from tmp dir
|
| 189 |
+
part_list = []
|
| 190 |
+
for i in range(world_size):
|
| 191 |
+
part_file = osp.join(tmpdir, 'part_{}.pkl'.format(i))
|
| 192 |
+
part_list.append(mmcv.load(part_file))
|
| 193 |
+
# sort the results
|
| 194 |
+
ordered_results = []
|
| 195 |
+
for res in zip(*part_list):
|
| 196 |
+
ordered_results.extend(list(res))
|
| 197 |
+
# the dataloader may pad some samples
|
| 198 |
+
ordered_results = ordered_results[:size]
|
| 199 |
+
# remove tmp dir
|
| 200 |
+
shutil.rmtree(tmpdir)
|
| 201 |
+
return ordered_results
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def collect_results_gpu(result_part, size):
|
| 205 |
+
"""Collect results with GPU."""
|
| 206 |
+
rank, world_size = get_dist_info()
|
| 207 |
+
# dump result part to tensor with pickle
|
| 208 |
+
part_tensor = torch.tensor(
|
| 209 |
+
bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
|
| 210 |
+
# gather all result part tensor shape
|
| 211 |
+
shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
|
| 212 |
+
shape_list = [shape_tensor.clone() for _ in range(world_size)]
|
| 213 |
+
dist.all_gather(shape_list, shape_tensor)
|
| 214 |
+
# padding result part tensor to max length
|
| 215 |
+
shape_max = torch.tensor(shape_list).max()
|
| 216 |
+
part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
|
| 217 |
+
part_send[:shape_tensor[0]] = part_tensor
|
| 218 |
+
part_recv_list = [
|
| 219 |
+
part_tensor.new_zeros(shape_max) for _ in range(world_size)
|
| 220 |
+
]
|
| 221 |
+
# gather all result part
|
| 222 |
+
dist.all_gather(part_recv_list, part_send)
|
| 223 |
+
|
| 224 |
+
if rank == 0:
|
| 225 |
+
part_list = []
|
| 226 |
+
for recv, shape in zip(part_recv_list, shape_list):
|
| 227 |
+
part_list.append(
|
| 228 |
+
pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))
|
| 229 |
+
# sort the results
|
| 230 |
+
ordered_results = []
|
| 231 |
+
for res in zip(*part_list):
|
| 232 |
+
ordered_results.extend(list(res))
|
| 233 |
+
# the dataloader may pad some samples
|
| 234 |
+
ordered_results = ordered_results[:size]
|
| 235 |
+
return ordered_results
|
models/mmseg/apis/train.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import warnings
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
| 7 |
+
from mmcv.runner import build_optimizer, build_runner
|
| 8 |
+
|
| 9 |
+
from mmseg.core import DistEvalHook, EvalHook
|
| 10 |
+
from mmseg.datasets import build_dataloader, build_dataset
|
| 11 |
+
from mmseg.utils import get_root_logger
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def set_random_seed(seed, deterministic=False):
|
| 15 |
+
"""Set random seed.
|
| 16 |
+
Args:
|
| 17 |
+
seed (int): Seed to be used.
|
| 18 |
+
deterministic (bool): Whether to set the deterministic option for
|
| 19 |
+
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
|
| 20 |
+
to True and `torch.backends.cudnn.benchmark` to False.
|
| 21 |
+
Default: False.
|
| 22 |
+
"""
|
| 23 |
+
random.seed(seed)
|
| 24 |
+
np.random.seed(seed)
|
| 25 |
+
torch.manual_seed(seed)
|
| 26 |
+
torch.cuda.manual_seed_all(seed)
|
| 27 |
+
if deterministic:
|
| 28 |
+
torch.backends.cudnn.deterministic = True
|
| 29 |
+
torch.backends.cudnn.benchmark = False
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def train_segmentor(model,
|
| 33 |
+
dataset,
|
| 34 |
+
cfg,
|
| 35 |
+
distributed=False,
|
| 36 |
+
validate=False,
|
| 37 |
+
timestamp=None,
|
| 38 |
+
meta=None):
|
| 39 |
+
"""Launch segmentor training."""
|
| 40 |
+
logger = get_root_logger(cfg.log_level)
|
| 41 |
+
|
| 42 |
+
# prepare data loaders
|
| 43 |
+
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
|
| 44 |
+
data_loaders = [
|
| 45 |
+
build_dataloader(
|
| 46 |
+
ds,
|
| 47 |
+
cfg.data.samples_per_gpu,
|
| 48 |
+
cfg.data.workers_per_gpu,
|
| 49 |
+
# cfg.gpus will be ignored if distributed
|
| 50 |
+
len(cfg.gpu_ids),
|
| 51 |
+
dist=distributed,
|
| 52 |
+
seed=cfg.seed,
|
| 53 |
+
drop_last=True) for ds in dataset
|
| 54 |
+
]
|
| 55 |
+
|
| 56 |
+
# put model on gpus
|
| 57 |
+
if distributed:
|
| 58 |
+
find_unused_parameters = cfg.get('find_unused_parameters', False)
|
| 59 |
+
# Sets the `find_unused_parameters` parameter in
|
| 60 |
+
# torch.nn.parallel.DistributedDataParallel
|
| 61 |
+
model = MMDistributedDataParallel(
|
| 62 |
+
model.cuda(),
|
| 63 |
+
device_ids=[torch.cuda.current_device()],
|
| 64 |
+
broadcast_buffers=False,
|
| 65 |
+
find_unused_parameters=find_unused_parameters)
|
| 66 |
+
else:
|
| 67 |
+
model = MMDataParallel(
|
| 68 |
+
model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
|
| 69 |
+
|
| 70 |
+
# build runner
|
| 71 |
+
optimizer = build_optimizer(model, cfg.optimizer)
|
| 72 |
+
|
| 73 |
+
if cfg.get('runner') is None:
|
| 74 |
+
cfg.runner = {'type': 'IterBasedRunner', 'max_iters': cfg.total_iters}
|
| 75 |
+
warnings.warn(
|
| 76 |
+
'config is now expected to have a `runner` section, '
|
| 77 |
+
'please set `runner` in your config.', UserWarning)
|
| 78 |
+
|
| 79 |
+
runner = build_runner(
|
| 80 |
+
cfg.runner,
|
| 81 |
+
default_args=dict(
|
| 82 |
+
model=model,
|
| 83 |
+
batch_processor=None,
|
| 84 |
+
optimizer=optimizer,
|
| 85 |
+
work_dir=cfg.work_dir,
|
| 86 |
+
logger=logger,
|
| 87 |
+
meta=meta))
|
| 88 |
+
|
| 89 |
+
# register hooks
|
| 90 |
+
runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,
|
| 91 |
+
cfg.checkpoint_config, cfg.log_config,
|
| 92 |
+
cfg.get('momentum_config', None))
|
| 93 |
+
|
| 94 |
+
# an ugly walkaround to make the .log and .log.json filenames the same
|
| 95 |
+
runner.timestamp = timestamp
|
| 96 |
+
|
| 97 |
+
# register eval hooks
|
| 98 |
+
if validate:
|
| 99 |
+
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
|
| 100 |
+
val_dataloader = build_dataloader(
|
| 101 |
+
val_dataset,
|
| 102 |
+
samples_per_gpu=1,
|
| 103 |
+
workers_per_gpu=cfg.data.workers_per_gpu,
|
| 104 |
+
dist=distributed,
|
| 105 |
+
shuffle=False)
|
| 106 |
+
eval_cfg = cfg.get('evaluation', {})
|
| 107 |
+
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
|
| 108 |
+
eval_hook = DistEvalHook if distributed else EvalHook
|
| 109 |
+
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
|
| 110 |
+
|
| 111 |
+
if cfg.resume_from:
|
| 112 |
+
runner.resume(cfg.resume_from)
|
| 113 |
+
elif cfg.load_from:
|
| 114 |
+
runner.load_checkpoint(cfg.load_from)
|
| 115 |
+
runner.run(data_loaders, cfg.workflow)
|
models/mmseg/core/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .evaluation import * # noqa: F401, F403
|
| 2 |
+
from .seg import * # noqa: F401, F403
|
| 3 |
+
from .utils import * # noqa: F401, F403
|
models/mmseg/core/evaluation/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .class_names import get_classes, get_palette
|
| 2 |
+
from .eval_hooks import DistEvalHook, EvalHook
|
| 3 |
+
from .metrics import eval_metrics, mean_dice, mean_iou
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'eval_metrics',
|
| 7 |
+
'get_classes', 'get_palette'
|
| 8 |
+
]
|
models/mmseg/core/evaluation/class_names.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import mmcv
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def cityscapes_classes():
|
| 5 |
+
"""Cityscapes class names for external use."""
|
| 6 |
+
return [
|
| 7 |
+
'road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
|
| 8 |
+
'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
|
| 9 |
+
'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
|
| 10 |
+
'bicycle'
|
| 11 |
+
]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def ade_classes():
|
| 15 |
+
"""ADE20K class names for external use."""
|
| 16 |
+
return [
|
| 17 |
+
'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ',
|
| 18 |
+
'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',
|
| 19 |
+
'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car',
|
| 20 |
+
'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',
|
| 21 |
+
'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe',
|
| 22 |
+
'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
|
| 23 |
+
'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
|
| 24 |
+
'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',
|
| 25 |
+
'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door',
|
| 26 |
+
'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',
|
| 27 |
+
'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove',
|
| 28 |
+
'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',
|
| 29 |
+
'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
|
| 30 |
+
'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
|
| 31 |
+
'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister',
|
| 32 |
+
'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',
|
| 33 |
+
'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',
|
| 34 |
+
'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',
|
| 35 |
+
'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',
|
| 36 |
+
'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',
|
| 37 |
+
'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',
|
| 38 |
+
'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',
|
| 39 |
+
'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
|
| 40 |
+
'clock', 'flag'
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def voc_classes():
|
| 45 |
+
"""Pascal VOC class names for external use."""
|
| 46 |
+
return [
|
| 47 |
+
'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
|
| 48 |
+
'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
|
| 49 |
+
'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
|
| 50 |
+
'tvmonitor'
|
| 51 |
+
]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def cityscapes_palette():
|
| 55 |
+
"""Cityscapes palette for external use."""
|
| 56 |
+
return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
|
| 57 |
+
[190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
|
| 58 |
+
[107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
|
| 59 |
+
[255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
|
| 60 |
+
[0, 0, 230], [119, 11, 32]]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def ade_palette():
|
| 64 |
+
"""ADE20K palette for external use."""
|
| 65 |
+
return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
|
| 66 |
+
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
|
| 67 |
+
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
|
| 68 |
+
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
|
| 69 |
+
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
|
| 70 |
+
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
|
| 71 |
+
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
|
| 72 |
+
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
|
| 73 |
+
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
|
| 74 |
+
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
|
| 75 |
+
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
|
| 76 |
+
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
|
| 77 |
+
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
|
| 78 |
+
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
|
| 79 |
+
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
|
| 80 |
+
[11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
|
| 81 |
+
[0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
|
| 82 |
+
[255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
|
| 83 |
+
[0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
|
| 84 |
+
[173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
|
| 85 |
+
[255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
|
| 86 |
+
[255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
|
| 87 |
+
[255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
|
| 88 |
+
[0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
|
| 89 |
+
[0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
|
| 90 |
+
[143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
|
| 91 |
+
[8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
|
| 92 |
+
[255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
|
| 93 |
+
[92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
|
| 94 |
+
[163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
|
| 95 |
+
[255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
|
| 96 |
+
[255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
|
| 97 |
+
[10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
|
| 98 |
+
[255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
|
| 99 |
+
[41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
|
| 100 |
+
[71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
|
| 101 |
+
[184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
|
| 102 |
+
[102, 255, 0], [92, 0, 255]]
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def voc_palette():
|
| 106 |
+
"""Pascal VOC palette for external use."""
|
| 107 |
+
return [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
|
| 108 |
+
[128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
|
| 109 |
+
[192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
|
| 110 |
+
[192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
|
| 111 |
+
[128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
dataset_aliases = {
|
| 115 |
+
'cityscapes': ['cityscapes'],
|
| 116 |
+
'ade': ['ade', 'ade20k'],
|
| 117 |
+
'voc': ['voc', 'pascal_voc', 'voc12', 'voc12aug']
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def get_classes(dataset):
|
| 122 |
+
"""Get class names of a dataset."""
|
| 123 |
+
alias2name = {}
|
| 124 |
+
for name, aliases in dataset_aliases.items():
|
| 125 |
+
for alias in aliases:
|
| 126 |
+
alias2name[alias] = name
|
| 127 |
+
|
| 128 |
+
if mmcv.is_str(dataset):
|
| 129 |
+
if dataset in alias2name:
|
| 130 |
+
labels = eval(alias2name[dataset] + '_classes()')
|
| 131 |
+
else:
|
| 132 |
+
raise ValueError(f'Unrecognized dataset: {dataset}')
|
| 133 |
+
else:
|
| 134 |
+
raise TypeError(f'dataset must a str, but got {type(dataset)}')
|
| 135 |
+
return labels
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def get_palette(dataset):
|
| 139 |
+
"""Get class palette (RGB) of a dataset."""
|
| 140 |
+
alias2name = {}
|
| 141 |
+
for name, aliases in dataset_aliases.items():
|
| 142 |
+
for alias in aliases:
|
| 143 |
+
alias2name[alias] = name
|
| 144 |
+
|
| 145 |
+
if mmcv.is_str(dataset):
|
| 146 |
+
if dataset in alias2name:
|
| 147 |
+
labels = eval(alias2name[dataset] + '_palette()')
|
| 148 |
+
else:
|
| 149 |
+
raise ValueError(f'Unrecognized dataset: {dataset}')
|
| 150 |
+
else:
|
| 151 |
+
raise TypeError(f'dataset must a str, but got {type(dataset)}')
|
| 152 |
+
return labels
|
models/mmseg/core/evaluation/eval_hooks.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path as osp
|
| 2 |
+
|
| 3 |
+
from mmcv.runner import Hook
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class EvalHook(Hook):
|
| 8 |
+
"""Evaluation hook.
|
| 9 |
+
|
| 10 |
+
Attributes:
|
| 11 |
+
dataloader (DataLoader): A PyTorch dataloader.
|
| 12 |
+
interval (int): Evaluation interval (by epochs). Default: 1.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, dataloader, interval=1, by_epoch=False, **eval_kwargs):
|
| 16 |
+
if not isinstance(dataloader, DataLoader):
|
| 17 |
+
raise TypeError('dataloader must be a pytorch DataLoader, but got '
|
| 18 |
+
f'{type(dataloader)}')
|
| 19 |
+
self.dataloader = dataloader
|
| 20 |
+
self.interval = interval
|
| 21 |
+
self.by_epoch = by_epoch
|
| 22 |
+
self.eval_kwargs = eval_kwargs
|
| 23 |
+
|
| 24 |
+
def after_train_iter(self, runner):
|
| 25 |
+
"""After train epoch hook."""
|
| 26 |
+
if self.by_epoch or not self.every_n_iters(runner, self.interval):
|
| 27 |
+
return
|
| 28 |
+
from mmseg.apis import single_gpu_test
|
| 29 |
+
runner.log_buffer.clear()
|
| 30 |
+
results = single_gpu_test(runner.model, self.dataloader, show=False)
|
| 31 |
+
self.evaluate(runner, results)
|
| 32 |
+
|
| 33 |
+
def after_train_epoch(self, runner):
|
| 34 |
+
"""After train epoch hook."""
|
| 35 |
+
if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
|
| 36 |
+
return
|
| 37 |
+
from mmseg.apis import single_gpu_test
|
| 38 |
+
runner.log_buffer.clear()
|
| 39 |
+
results = single_gpu_test(runner.model, self.dataloader, show=False)
|
| 40 |
+
self.evaluate(runner, results)
|
| 41 |
+
|
| 42 |
+
def evaluate(self, runner, results):
|
| 43 |
+
"""Call evaluate function of dataset."""
|
| 44 |
+
eval_res = self.dataloader.dataset.evaluate(
|
| 45 |
+
results, logger=runner.logger, **self.eval_kwargs)
|
| 46 |
+
for name, val in eval_res.items():
|
| 47 |
+
runner.log_buffer.output[name] = val
|
| 48 |
+
runner.log_buffer.ready = True
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class DistEvalHook(EvalHook):
|
| 52 |
+
"""Distributed evaluation hook.
|
| 53 |
+
|
| 54 |
+
Attributes:
|
| 55 |
+
dataloader (DataLoader): A PyTorch dataloader.
|
| 56 |
+
interval (int): Evaluation interval (by epochs). Default: 1.
|
| 57 |
+
tmpdir (str | None): Temporary directory to save the results of all
|
| 58 |
+
processes. Default: None.
|
| 59 |
+
gpu_collect (bool): Whether to use gpu or cpu to collect results.
|
| 60 |
+
Default: False.
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
def __init__(self,
|
| 64 |
+
dataloader,
|
| 65 |
+
interval=1,
|
| 66 |
+
gpu_collect=False,
|
| 67 |
+
by_epoch=False,
|
| 68 |
+
**eval_kwargs):
|
| 69 |
+
if not isinstance(dataloader, DataLoader):
|
| 70 |
+
raise TypeError(
|
| 71 |
+
'dataloader must be a pytorch DataLoader, but got {}'.format(
|
| 72 |
+
type(dataloader)))
|
| 73 |
+
self.dataloader = dataloader
|
| 74 |
+
self.interval = interval
|
| 75 |
+
self.gpu_collect = gpu_collect
|
| 76 |
+
self.by_epoch = by_epoch
|
| 77 |
+
self.eval_kwargs = eval_kwargs
|
| 78 |
+
|
| 79 |
+
def after_train_iter(self, runner):
|
| 80 |
+
"""After train epoch hook."""
|
| 81 |
+
if self.by_epoch or not self.every_n_iters(runner, self.interval):
|
| 82 |
+
return
|
| 83 |
+
from mmseg.apis import multi_gpu_test
|
| 84 |
+
runner.log_buffer.clear()
|
| 85 |
+
results = multi_gpu_test(
|
| 86 |
+
runner.model,
|
| 87 |
+
self.dataloader,
|
| 88 |
+
tmpdir=osp.join(runner.work_dir, '.eval_hook'),
|
| 89 |
+
gpu_collect=self.gpu_collect)
|
| 90 |
+
if runner.rank == 0:
|
| 91 |
+
print('\n')
|
| 92 |
+
self.evaluate(runner, results)
|
| 93 |
+
|
| 94 |
+
def after_train_epoch(self, runner):
|
| 95 |
+
"""After train epoch hook."""
|
| 96 |
+
if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
|
| 97 |
+
return
|
| 98 |
+
from mmseg.apis import multi_gpu_test
|
| 99 |
+
runner.log_buffer.clear()
|
| 100 |
+
results = multi_gpu_test(
|
| 101 |
+
runner.model,
|
| 102 |
+
self.dataloader,
|
| 103 |
+
tmpdir=osp.join(runner.work_dir, '.eval_hook'),
|
| 104 |
+
gpu_collect=self.gpu_collect)
|
| 105 |
+
if runner.rank == 0:
|
| 106 |
+
print('\n')
|
| 107 |
+
self.evaluate(runner, results)
|
models/mmseg/core/evaluation/metrics.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import mmcv
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def intersect_and_union(pred_label,
|
| 6 |
+
label,
|
| 7 |
+
num_classes,
|
| 8 |
+
ignore_index,
|
| 9 |
+
label_map=dict(),
|
| 10 |
+
reduce_zero_label=False):
|
| 11 |
+
"""Calculate intersection and Union.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
pred_label (ndarray): Prediction segmentation map.
|
| 15 |
+
label (ndarray): Ground truth segmentation map.
|
| 16 |
+
num_classes (int): Number of categories.
|
| 17 |
+
ignore_index (int): Index that will be ignored in evaluation.
|
| 18 |
+
label_map (dict): Mapping old labels to new labels. The parameter will
|
| 19 |
+
work only when label is str. Default: dict().
|
| 20 |
+
reduce_zero_label (bool): Wether ignore zero label. The parameter will
|
| 21 |
+
work only when label is str. Default: False.
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
ndarray: The intersection of prediction and ground truth histogram
|
| 25 |
+
on all classes.
|
| 26 |
+
ndarray: The union of prediction and ground truth histogram on all
|
| 27 |
+
classes.
|
| 28 |
+
ndarray: The prediction histogram on all classes.
|
| 29 |
+
ndarray: The ground truth histogram on all classes.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
if isinstance(pred_label, str):
|
| 33 |
+
pred_label = np.load(pred_label)
|
| 34 |
+
|
| 35 |
+
if isinstance(label, str):
|
| 36 |
+
label = mmcv.imread(label, flag='unchanged', backend='pillow')
|
| 37 |
+
# modify if custom classes
|
| 38 |
+
if label_map is not None:
|
| 39 |
+
for old_id, new_id in label_map.items():
|
| 40 |
+
label[label == old_id] = new_id
|
| 41 |
+
if reduce_zero_label:
|
| 42 |
+
# avoid using underflow conversion
|
| 43 |
+
label[label == 0] = 255
|
| 44 |
+
label = label - 1
|
| 45 |
+
label[label == 254] = 255
|
| 46 |
+
|
| 47 |
+
mask = (label != ignore_index)
|
| 48 |
+
pred_label = pred_label[mask]
|
| 49 |
+
label = label[mask]
|
| 50 |
+
|
| 51 |
+
intersect = pred_label[pred_label == label]
|
| 52 |
+
area_intersect, _ = np.histogram(
|
| 53 |
+
intersect, bins=np.arange(num_classes + 1))
|
| 54 |
+
area_pred_label, _ = np.histogram(
|
| 55 |
+
pred_label, bins=np.arange(num_classes + 1))
|
| 56 |
+
area_label, _ = np.histogram(label, bins=np.arange(num_classes + 1))
|
| 57 |
+
area_union = area_pred_label + area_label - area_intersect
|
| 58 |
+
|
| 59 |
+
return area_intersect, area_union, area_pred_label, area_label
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def total_intersect_and_union(results,
|
| 63 |
+
gt_seg_maps,
|
| 64 |
+
num_classes,
|
| 65 |
+
ignore_index,
|
| 66 |
+
label_map=dict(),
|
| 67 |
+
reduce_zero_label=False):
|
| 68 |
+
"""Calculate Total Intersection and Union.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
results (list[ndarray]): List of prediction segmentation maps.
|
| 72 |
+
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
|
| 73 |
+
num_classes (int): Number of categories.
|
| 74 |
+
ignore_index (int): Index that will be ignored in evaluation.
|
| 75 |
+
label_map (dict): Mapping old labels to new labels. Default: dict().
|
| 76 |
+
reduce_zero_label (bool): Wether ignore zero label. Default: False.
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
ndarray: The intersection of prediction and ground truth histogram
|
| 80 |
+
on all classes.
|
| 81 |
+
ndarray: The union of prediction and ground truth histogram on all
|
| 82 |
+
classes.
|
| 83 |
+
ndarray: The prediction histogram on all classes.
|
| 84 |
+
ndarray: The ground truth histogram on all classes.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
num_imgs = len(results)
|
| 88 |
+
assert len(gt_seg_maps) == num_imgs
|
| 89 |
+
total_area_intersect = np.zeros((num_classes, ), dtype=np.float)
|
| 90 |
+
total_area_union = np.zeros((num_classes, ), dtype=np.float)
|
| 91 |
+
total_area_pred_label = np.zeros((num_classes, ), dtype=np.float)
|
| 92 |
+
total_area_label = np.zeros((num_classes, ), dtype=np.float)
|
| 93 |
+
for i in range(num_imgs):
|
| 94 |
+
area_intersect, area_union, area_pred_label, area_label = \
|
| 95 |
+
intersect_and_union(results[i], gt_seg_maps[i], num_classes,
|
| 96 |
+
ignore_index, label_map, reduce_zero_label)
|
| 97 |
+
total_area_intersect += area_intersect
|
| 98 |
+
total_area_union += area_union
|
| 99 |
+
total_area_pred_label += area_pred_label
|
| 100 |
+
total_area_label += area_label
|
| 101 |
+
return total_area_intersect, total_area_union, \
|
| 102 |
+
total_area_pred_label, total_area_label
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def mean_iou(results,
|
| 106 |
+
gt_seg_maps,
|
| 107 |
+
num_classes,
|
| 108 |
+
ignore_index,
|
| 109 |
+
nan_to_num=None,
|
| 110 |
+
label_map=dict(),
|
| 111 |
+
reduce_zero_label=False):
|
| 112 |
+
"""Calculate Mean Intersection and Union (mIoU)
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
results (list[ndarray]): List of prediction segmentation maps.
|
| 116 |
+
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
|
| 117 |
+
num_classes (int): Number of categories.
|
| 118 |
+
ignore_index (int): Index that will be ignored in evaluation.
|
| 119 |
+
nan_to_num (int, optional): If specified, NaN values will be replaced
|
| 120 |
+
by the numbers defined by the user. Default: None.
|
| 121 |
+
label_map (dict): Mapping old labels to new labels. Default: dict().
|
| 122 |
+
reduce_zero_label (bool): Wether ignore zero label. Default: False.
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
float: Overall accuracy on all images.
|
| 126 |
+
ndarray: Per category accuracy, shape (num_classes, ).
|
| 127 |
+
ndarray: Per category IoU, shape (num_classes, ).
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
all_acc, acc, iou = eval_metrics(
|
| 131 |
+
results=results,
|
| 132 |
+
gt_seg_maps=gt_seg_maps,
|
| 133 |
+
num_classes=num_classes,
|
| 134 |
+
ignore_index=ignore_index,
|
| 135 |
+
metrics=['mIoU'],
|
| 136 |
+
nan_to_num=nan_to_num,
|
| 137 |
+
label_map=label_map,
|
| 138 |
+
reduce_zero_label=reduce_zero_label)
|
| 139 |
+
return all_acc, acc, iou
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def mean_dice(results,
|
| 143 |
+
gt_seg_maps,
|
| 144 |
+
num_classes,
|
| 145 |
+
ignore_index,
|
| 146 |
+
nan_to_num=None,
|
| 147 |
+
label_map=dict(),
|
| 148 |
+
reduce_zero_label=False):
|
| 149 |
+
"""Calculate Mean Dice (mDice)
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
results (list[ndarray]): List of prediction segmentation maps.
|
| 153 |
+
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
|
| 154 |
+
num_classes (int): Number of categories.
|
| 155 |
+
ignore_index (int): Index that will be ignored in evaluation.
|
| 156 |
+
nan_to_num (int, optional): If specified, NaN values will be replaced
|
| 157 |
+
by the numbers defined by the user. Default: None.
|
| 158 |
+
label_map (dict): Mapping old labels to new labels. Default: dict().
|
| 159 |
+
reduce_zero_label (bool): Wether ignore zero label. Default: False.
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
float: Overall accuracy on all images.
|
| 163 |
+
ndarray: Per category accuracy, shape (num_classes, ).
|
| 164 |
+
ndarray: Per category dice, shape (num_classes, ).
|
| 165 |
+
"""
|
| 166 |
+
|
| 167 |
+
all_acc, acc, dice = eval_metrics(
|
| 168 |
+
results=results,
|
| 169 |
+
gt_seg_maps=gt_seg_maps,
|
| 170 |
+
num_classes=num_classes,
|
| 171 |
+
ignore_index=ignore_index,
|
| 172 |
+
metrics=['mDice'],
|
| 173 |
+
nan_to_num=nan_to_num,
|
| 174 |
+
label_map=label_map,
|
| 175 |
+
reduce_zero_label=reduce_zero_label)
|
| 176 |
+
return all_acc, acc, dice
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def eval_metrics(results,
|
| 180 |
+
gt_seg_maps,
|
| 181 |
+
num_classes,
|
| 182 |
+
ignore_index,
|
| 183 |
+
metrics=['mIoU'],
|
| 184 |
+
nan_to_num=None,
|
| 185 |
+
label_map=dict(),
|
| 186 |
+
reduce_zero_label=False):
|
| 187 |
+
"""Calculate evaluation metrics
|
| 188 |
+
Args:
|
| 189 |
+
results (list[ndarray]): List of prediction segmentation maps.
|
| 190 |
+
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
|
| 191 |
+
num_classes (int): Number of categories.
|
| 192 |
+
ignore_index (int): Index that will be ignored in evaluation.
|
| 193 |
+
metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
|
| 194 |
+
nan_to_num (int, optional): If specified, NaN values will be replaced
|
| 195 |
+
by the numbers defined by the user. Default: None.
|
| 196 |
+
label_map (dict): Mapping old labels to new labels. Default: dict().
|
| 197 |
+
reduce_zero_label (bool): Wether ignore zero label. Default: False.
|
| 198 |
+
Returns:
|
| 199 |
+
float: Overall accuracy on all images.
|
| 200 |
+
ndarray: Per category accuracy, shape (num_classes, ).
|
| 201 |
+
ndarray: Per category evalution metrics, shape (num_classes, ).
|
| 202 |
+
"""
|
| 203 |
+
|
| 204 |
+
if isinstance(metrics, str):
|
| 205 |
+
metrics = [metrics]
|
| 206 |
+
allowed_metrics = ['mIoU', 'mDice']
|
| 207 |
+
if not set(metrics).issubset(set(allowed_metrics)):
|
| 208 |
+
raise KeyError('metrics {} is not supported'.format(metrics))
|
| 209 |
+
total_area_intersect, total_area_union, total_area_pred_label, \
|
| 210 |
+
total_area_label = total_intersect_and_union(results, gt_seg_maps,
|
| 211 |
+
num_classes, ignore_index,
|
| 212 |
+
label_map,
|
| 213 |
+
reduce_zero_label)
|
| 214 |
+
all_acc = total_area_intersect.sum() / total_area_label.sum()
|
| 215 |
+
acc = total_area_intersect / total_area_label
|
| 216 |
+
ret_metrics = [all_acc, acc]
|
| 217 |
+
for metric in metrics:
|
| 218 |
+
if metric == 'mIoU':
|
| 219 |
+
iou = total_area_intersect / total_area_union
|
| 220 |
+
ret_metrics.append(iou)
|
| 221 |
+
elif metric == 'mDice':
|
| 222 |
+
dice = 2 * total_area_intersect / (
|
| 223 |
+
total_area_pred_label + total_area_label)
|
| 224 |
+
ret_metrics.append(dice)
|
| 225 |
+
if nan_to_num is not None:
|
| 226 |
+
ret_metrics = [
|
| 227 |
+
np.nan_to_num(metric, nan=nan_to_num) for metric in ret_metrics
|
| 228 |
+
]
|
| 229 |
+
return ret_metrics
|
models/mmseg/core/seg/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .builder import build_pixel_sampler
|
| 2 |
+
from .sampler import BasePixelSampler, OHEMPixelSampler
|
| 3 |
+
|
| 4 |
+
__all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler']
|
models/mmseg/core/seg/builder.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from mmcv.utils import Registry, build_from_cfg
|
| 2 |
+
|
| 3 |
+
PIXEL_SAMPLERS = Registry('pixel sampler')
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def build_pixel_sampler(cfg, **default_args):
|
| 7 |
+
"""Build pixel sampler for segmentation map."""
|
| 8 |
+
return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args)
|
models/mmseg/core/seg/sampler/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base_pixel_sampler import BasePixelSampler
|
| 2 |
+
from .ohem_pixel_sampler import OHEMPixelSampler
|
| 3 |
+
|
| 4 |
+
__all__ = ['BasePixelSampler', 'OHEMPixelSampler']
|
models/mmseg/core/seg/sampler/base_pixel_sampler.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABCMeta, abstractmethod
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class BasePixelSampler(metaclass=ABCMeta):
|
| 5 |
+
"""Base class of pixel sampler."""
|
| 6 |
+
|
| 7 |
+
def __init__(self, **kwargs):
|
| 8 |
+
pass
|
| 9 |
+
|
| 10 |
+
@abstractmethod
|
| 11 |
+
def sample(self, seg_logit, seg_label):
|
| 12 |
+
"""Placeholder for sample function."""
|
| 13 |
+
pass
|
models/mmseg/core/seg/sampler/ohem_pixel_sampler.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
from ..builder import PIXEL_SAMPLERS
|
| 5 |
+
from .base_pixel_sampler import BasePixelSampler
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@PIXEL_SAMPLERS.register_module()
|
| 9 |
+
class OHEMPixelSampler(BasePixelSampler):
|
| 10 |
+
"""Online Hard Example Mining Sampler for segmentation.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
context (nn.Module): The context of sampler, subclass of
|
| 14 |
+
:obj:`BaseDecodeHead`.
|
| 15 |
+
thresh (float, optional): The threshold for hard example selection.
|
| 16 |
+
Below which, are prediction with low confidence. If not
|
| 17 |
+
specified, the hard examples will be pixels of top ``min_kept``
|
| 18 |
+
loss. Default: None.
|
| 19 |
+
min_kept (int, optional): The minimum number of predictions to keep.
|
| 20 |
+
Default: 100000.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, context, thresh=None, min_kept=100000):
|
| 24 |
+
super(OHEMPixelSampler, self).__init__()
|
| 25 |
+
self.context = context
|
| 26 |
+
assert min_kept > 1
|
| 27 |
+
self.thresh = thresh
|
| 28 |
+
self.min_kept = min_kept
|
| 29 |
+
|
| 30 |
+
def sample(self, seg_logit, seg_label):
|
| 31 |
+
"""Sample pixels that have high loss or with low prediction confidence.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W)
|
| 35 |
+
seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W)
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
torch.Tensor: segmentation weight, shape (N, H, W)
|
| 39 |
+
"""
|
| 40 |
+
with torch.no_grad():
|
| 41 |
+
assert seg_logit.shape[2:] == seg_label.shape[2:]
|
| 42 |
+
assert seg_label.shape[1] == 1
|
| 43 |
+
seg_label = seg_label.squeeze(1).long()
|
| 44 |
+
batch_kept = self.min_kept * seg_label.size(0)
|
| 45 |
+
valid_mask = seg_label != self.context.ignore_index
|
| 46 |
+
seg_weight = seg_logit.new_zeros(size=seg_label.size())
|
| 47 |
+
valid_seg_weight = seg_weight[valid_mask]
|
| 48 |
+
if self.thresh is not None:
|
| 49 |
+
seg_prob = F.softmax(seg_logit, dim=1)
|
| 50 |
+
|
| 51 |
+
tmp_seg_label = seg_label.clone().unsqueeze(1)
|
| 52 |
+
tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0
|
| 53 |
+
seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1)
|
| 54 |
+
sort_prob, sort_indices = seg_prob[valid_mask].sort()
|
| 55 |
+
|
| 56 |
+
if sort_prob.numel() > 0:
|
| 57 |
+
min_threshold = sort_prob[min(batch_kept,
|
| 58 |
+
sort_prob.numel() - 1)]
|
| 59 |
+
else:
|
| 60 |
+
min_threshold = 0.0
|
| 61 |
+
threshold = max(min_threshold, self.thresh)
|
| 62 |
+
valid_seg_weight[seg_prob[valid_mask] < threshold] = 1.
|
| 63 |
+
else:
|
| 64 |
+
losses = self.context.loss_decode(
|
| 65 |
+
seg_logit,
|
| 66 |
+
seg_label,
|
| 67 |
+
weight=None,
|
| 68 |
+
ignore_index=self.context.ignore_index,
|
| 69 |
+
reduction_override='none')
|
| 70 |
+
# faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa
|
| 71 |
+
_, sort_indices = losses[valid_mask].sort(descending=True)
|
| 72 |
+
valid_seg_weight[sort_indices[:batch_kept]] = 1.
|
| 73 |
+
|
| 74 |
+
seg_weight[valid_mask] = valid_seg_weight
|
| 75 |
+
|
| 76 |
+
return seg_weight
|
models/mmseg/core/utils/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .misc import add_prefix
|
| 2 |
+
|
| 3 |
+
__all__ = ['add_prefix']
|
models/mmseg/core/utils/misc.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def add_prefix(inputs, prefix):
|
| 2 |
+
"""Add prefix for dict.
|
| 3 |
+
|
| 4 |
+
Args:
|
| 5 |
+
inputs (dict): The input dict with str keys.
|
| 6 |
+
prefix (str): The prefix to add.
|
| 7 |
+
|
| 8 |
+
Returns:
|
| 9 |
+
|
| 10 |
+
dict: The dict with keys updated with ``prefix``.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
outputs = dict()
|
| 14 |
+
for name, value in inputs.items():
|
| 15 |
+
outputs[f'{prefix}.{name}'] = value
|
| 16 |
+
|
| 17 |
+
return outputs
|