AI-Cyber commited on
Commit
8d7921b
·
1 Parent(s): f0d1cb5

Upload 123 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. configs/multi_mo_multi_task.yaml +156 -0
  2. configs/multi_mo_multi_task_sar_prompt.yaml +174 -0
  3. datasets/__init__.py +3 -0
  4. datasets/__pycache__/__init__.cpython-310.pyc +0 -0
  5. datasets/__pycache__/__init__.cpython-37.pyc +0 -0
  6. datasets/__pycache__/datasets.cpython-310.pyc +0 -0
  7. datasets/__pycache__/datasets.cpython-37.pyc +0 -0
  8. datasets/__pycache__/image_folder.cpython-310.pyc +0 -0
  9. datasets/__pycache__/image_folder.cpython-37.pyc +0 -0
  10. datasets/__pycache__/wrappers.cpython-310.pyc +0 -0
  11. datasets/__pycache__/wrappers.cpython-37.pyc +0 -0
  12. datasets/data_loader_multi_tasks.py +26 -0
  13. datasets/data_simmim_pt.py +271 -0
  14. datasets/datasets.py +21 -0
  15. datasets/image_folder.py +370 -0
  16. datasets/wrappers.py +231 -0
  17. models/__init__.py +4 -0
  18. models/__pycache__/__init__.cpython-310.pyc +0 -0
  19. models/__pycache__/__init__.cpython-37.pyc +0 -0
  20. models/__pycache__/iou_loss.cpython-37.pyc +0 -0
  21. models/__pycache__/models.cpython-310.pyc +0 -0
  22. models/__pycache__/models.cpython-37.pyc +0 -0
  23. models/__pycache__/sam.cpython-310.pyc +0 -0
  24. models/__pycache__/sam.cpython-37.pyc +0 -0
  25. models/__pycache__/sam_single.cpython-37.pyc +0 -0
  26. models/__pycache__/utils_prompt.cpython-37.pyc +0 -0
  27. models/block.py +128 -0
  28. models/bn_helper.py +16 -0
  29. models/iou_loss.py +21 -0
  30. models/mmseg/__init__.py +33 -0
  31. models/mmseg/__pycache__/__init__.cpython-310.pyc +0 -0
  32. models/mmseg/__pycache__/__init__.cpython-37.pyc +0 -0
  33. models/mmseg/__pycache__/version.cpython-310.pyc +0 -0
  34. models/mmseg/__pycache__/version.cpython-37.pyc +0 -0
  35. models/mmseg/apis/__init__.py +9 -0
  36. models/mmseg/apis/inference.py +118 -0
  37. models/mmseg/apis/test.py +235 -0
  38. models/mmseg/apis/train.py +115 -0
  39. models/mmseg/core/__init__.py +3 -0
  40. models/mmseg/core/evaluation/__init__.py +8 -0
  41. models/mmseg/core/evaluation/class_names.py +152 -0
  42. models/mmseg/core/evaluation/eval_hooks.py +107 -0
  43. models/mmseg/core/evaluation/metrics.py +229 -0
  44. models/mmseg/core/seg/__init__.py +4 -0
  45. models/mmseg/core/seg/builder.py +8 -0
  46. models/mmseg/core/seg/sampler/__init__.py +4 -0
  47. models/mmseg/core/seg/sampler/base_pixel_sampler.py +13 -0
  48. models/mmseg/core/seg/sampler/ohem_pixel_sampler.py +76 -0
  49. models/mmseg/core/utils/__init__.py +3 -0
  50. models/mmseg/core/utils/misc.py +17 -0
configs/multi_mo_multi_task.yaml ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train_dataset:
2
+ dataset:
3
+ name: paired-image-folders-multi-task
4
+ args:
5
+ # root_path_1: ./SAM_DATA_UNIFY/Overall_Update/split_image
6
+ # root_path_1: ./SAM_DATA_UNIFY2/OVERALL/split_image
7
+ # root_path_1: ./SAM_DATA_UNIFY2/ISAID/split_image
8
+ # root_path_1: [{'ISAID': './SAM_DATA_UNIFY2/ISAID/split_image', 'WHU': './SAM_DATA_UNIFY2/WHU-OPT/split_images'}]
9
+ # root_path_1: [{'Decoder1': "/workspace/SAM_DATA_UNIFY3/Decoder1/split_image/", 'Decoder2': "/workspace/SAM_DATA_UNIFY3/Decoder2/split_image/"}]
10
+ root_path_1: [{'Decoder1': "/workspace/SAM_DATA_UNIFY4/Decoder1/image/", 'Decoder2': "/workspace/SAM_DATA_UNIFY4/Decoder2/image/"}]
11
+ # root_path_2: ./SAM_DATA_UNIFY/Overall_Update/split_gt
12
+ # root_path_2: ./SAM_DATA_UNIFY2/OVERALL/split_gt
13
+ # root_path_2: ./SAM_DATA_UNIFY2/ISAID/split_gt
14
+ # root_path_2: [{'ISAID': './SAM_DATA_UNIFY2/ISAID/split_gt', 'WHU': './SAM_DATA_UNIFY2/WHU-OPT/split_gt'}]
15
+ # root_path_2: [{'Decoder1': "/workspace/SAM_DATA_UNIFY3/Decoder1/split_gt/", 'Decoder2': "/workspace/SAM_DATA_UNIFY3/Decoder2/split_gt/"}]
16
+ root_path_2: [{'Decoder1': "/workspace/SAM_DATA_UNIFY4/Decoder1/gt/", 'Decoder2': "/workspace/SAM_DATA_UNIFY4/Decoder2/gt/"}]
17
+ cache: nones
18
+ split_key: train
19
+ wrapper:
20
+ name: train_multi_task
21
+ args:
22
+ inp_size: 1024
23
+ augment: false
24
+ # batch_size: 2
25
+ batch_size: 2
26
+
27
+ val_dataset:
28
+ dataset:
29
+ name: paired-image-folders-multi-task
30
+ args:
31
+ # root_path_1: ./SAM_DATA_UNIFY2/OVERALL/split_image
32
+ # root_path_1: [{'ISAID': './SAM_DATA_UNIFY2/ISAID/split_image', 'WHU': './SAM_DATA_UNIFY2/WHU-OPT/split_images'}]
33
+ # root_path_1: [{'Decoder1': "/workspace/SAM_DATA_UNIFY3/Decoder1/split_image/", 'Decoder2': "/workspace/SAM_DATA_UNIFY3/Decoder2/split_image/"}]
34
+ root_path_1: [{'Decoder1': "/workspace/SAM_DATA_UNIFY4/Decoder1/image/", 'Decoder2': "/workspace/SAM_DATA_UNIFY4/Decoder2/image/"}]
35
+ # root_path_2: ./SAM_DATA_UNIFY2/OVERALL/split_gt
36
+ # root_path_2: [{'ISAID': './SAM_DATA_UNIFY2/ISAID/split_gt', 'WHU': './SAM_DATA_UNIFY2/WHU-OPT/split_gt'}]
37
+ # root_path_2: [{'Decoder1': "/workspace/SAM_DATA_UNIFY3/Decoder1/split_gt/", 'Decoder2': "/workspace/SAM_DATA_UNIFY3/Decoder2/split_gt/"}]
38
+ root_path_2: [{'Decoder1': "/workspace/SAM_DATA_UNIFY4/Decoder1/gt/", 'Decoder2': "/workspace/SAM_DATA_UNIFY4/Decoder2/gt/"}]
39
+ cache: none
40
+ split_key: test
41
+ wrapper:
42
+ name: val_multi_task
43
+ args:
44
+ inp_size: 1024
45
+ # batch_size: 2
46
+ batch_size: 1
47
+
48
+ test_dataset:
49
+ dataset:
50
+ name: paired-image-folders
51
+ args:
52
+
53
+ # root_path_1: ./SAM_DATA_UNIFY3/ISAID/split_image
54
+ # root_path_1: ./SAM_DATA_UNIFY3/GANFEN/split_image
55
+ # root_path_1: ./SAM_DATA_UNIFY3/SAR2020/split_image_ov500
56
+ # root_path_1: ./SAM_DATA_UNIFY3/ISAID/split_image
57
+ # root_path_1: ./SAM_DATA_UNIFY4/SAR2020/split_image_ov500
58
+ # root_path_1: ./SAM_DATA_UNIFY4/GAOFEN/split_image
59
+ # root_path_1: ./SAM_DATA_UNIFY4/Vaihingen/image1
60
+ # root_path_1: ./SAM_DATA_UNIFY4/SAR2020/split_image_ov500
61
+ # root_path_1: ./SAM_DATA_UNIFY4/Potsdam/image1
62
+ # root_path_1: ./SAM_DATA_UNIFY4/whu-opt-sar/image_sar
63
+ root_path_1: /workspace/AIService/FoundationModel/sam_adapter_01/TwoDecoder_data/Prompt_GUOLV_Data/prompt_all1/image
64
+
65
+ # root_path_2: ./SAM_DATA_UNIFY3/ISAID/split_gt
66
+ # root_path_2: ./SAM_DATA_UNIFY3/GANFEN/gt_decoder1
67
+ # root_path_2: ./SAM_DATA_UNIFY3/GANFEN/gt_decoder2
68
+ # root_path_2: ./SAM_DATA_UNIFY3/SAR2020/gt_decoder2
69
+ # root_path_2: ./SAM_DATA_UNIFY3/ISAID/split_gt
70
+ # root_path_2: ./SAM_DATA_UNIFY4/SAR2020/gt_decoder2
71
+ # root_path_2: ./SAM_DATA_UNIFY4/GAOFEN/gt_decoder1_update
72
+ # root_path_2: ./SAM_DATA_UNIFY4/Vaihingen/gt2
73
+ # root_path_2: ./SAM_DATA_UNIFY4/Potsdam/gt1
74
+ # root_path_2: ./SAM_DATA_UNIFY4/SAR2020/gt_decoder2
75
+ root_path_2: /workspace/AIService/FoundationModel/sam_adapter_01/TwoDecoder_data/Prompt_GUOLV_Data/prompt_all1/gt
76
+ # root_path_2: ./SAM_DATA_UNIFY4/whu-opt-sar/gt_sar
77
+ cache: none
78
+ split_key: test
79
+ wrapper:
80
+ name: val
81
+ args:
82
+ # inp_size: 1024
83
+ inp_size: 1024
84
+ batch_size: 1
85
+
86
+ #eval_type: cod
87
+ eval_type: f1
88
+ #sam_checkpoint: ./pretrained/sam_vit_l_0b3195.pth
89
+ sam_checkpoint: sam_vit_h_4b8939.pth
90
+ data_norm:
91
+ inp:
92
+ sub:
93
+ - 0.5
94
+ div:
95
+ - 0.5
96
+ gt:
97
+ sub:
98
+ - 0.5
99
+ div:
100
+ - 0.5
101
+ gt_rgb:
102
+ sub:
103
+ - 0.5
104
+ div:
105
+ - 0.5
106
+ model:
107
+ name: sam
108
+ args:
109
+ inp_size: 1024
110
+ # loss: iou
111
+ loss: cr
112
+ encoder_mode:
113
+ name: sam
114
+ img_size: 1024
115
+ mlp_ratio: 4
116
+ patch_size: 16
117
+ qkv_bias: true
118
+ use_rel_pos: true
119
+ window_size: 14
120
+ out_chans: 256
121
+ scale_factor: 32
122
+ input_type: fft
123
+ freq_nums: 0.25
124
+ prompt_type: highpass
125
+ prompt_embed_dim: 256
126
+ tuning_stage: 1234
127
+ handcrafted_tune: true
128
+ embedding_tune: true
129
+ adaptor: adaptor
130
+ embed_dim: 1280
131
+ depth: 32
132
+ num_heads: 16
133
+ global_attn_indexes:
134
+ - 7
135
+ - 15
136
+ - 23
137
+ - 31
138
+ optimizer:
139
+ name: adamw
140
+ args:
141
+ # lr: 0.0002
142
+ # lr: 0.00002
143
+ lr: 0.00008
144
+ lr_min: 1.0e-8
145
+ #epoch_max: 20
146
+ epoch_max: 100
147
+
148
+ multi_step_lr:
149
+ milestones:
150
+ - 1
151
+ gamma: 0.1
152
+ epoch_val: 100
153
+ epoch_save: 1
154
+
155
+ #resume: 60
156
+ #start_epoch: 60
configs/multi_mo_multi_task_sar_prompt.yaml ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train_dataset:
2
+ dataset:
3
+ name: paired-image-folders
4
+ args:
5
+ # root_path_1: ./ISAID/train/trainprompt/sub_images
6
+ # root_path_1: ./ISAID/train/trainprompt/images
7
+ root_path_1: ./SAR_prompt/image
8
+ # root_path_1: ./SAM_DATA_UNIFY2/OVERALL/split_image
9
+ # root_path_1: ./SAM_DATA_UNIFY2/ISAID/split_image
10
+ # root_path_1: [{'ISAID': './SAM_DATA_UNIFY2/ISAID/split_image', 'WHU': './SAM_DATA_UNIFY2/WHU-OPT/split_images'}]
11
+ # root_path_1: [{'Decoder1': "/workspace/SAM_DATA_UNIFY3/Decoder1/split_image/", 'Decoder2': "/workspace/SAM_DATA_UNIFY3/Decoder2/split_image/"}]
12
+ # root_path_1: [{'Decoder1': "/workspace/SAM_DATA_UNIFY4/Potsdam/image1/", 'Decoder2': "/workspace/SAM_DATA_UNIFY4/Decoder2/image/"}]
13
+ # root_path_2: ./ISAID/train/trainprompt/sub_gt
14
+ root_path_2: ./SAR_prompt/gt
15
+ # root_path_2: ./SAM_DATA_UNIFY2/OVERALL/split_gt
16
+ # root_path_2: ./SAM_DATA_UNIFY2/ISAID/split_gt
17
+ # root_path_2: [{'ISAID': './SAM_DATA_UNIFY2/ISAID/split_gt', 'WHU': './SAM_DATA_UNIFY2/WHU-OPT/split_gt'}]
18
+ # root_path_2: [{'Decoder1': "/workspace/SAM_DATA_UNIFY3/Decoder1/split_gt/", 'Decoder2': "/workspace/SAM_DATA_UNIFY3/Decoder2/split_gt/"}]
19
+ # root_path_2: [{'Decoder1': "/workspace/SAM_DATA_UNIFY4/Potsdam/gt1/", 'Decoder2': "/workspace/SAM_DATA_UNIFY4/Decoder2/gt/"}]
20
+ cache: none
21
+ split_key: train
22
+ wrapper:
23
+ name: train
24
+ args:
25
+ inp_size: 1024
26
+ augment: false
27
+ # batch_size: 2
28
+ batch_size: 1
29
+
30
+ val_dataset:
31
+ dataset:
32
+ name: paired-image-folders
33
+ args:
34
+ # root_path_1: ./ISAID/train/trainprompt/images
35
+ root_path_1: ./SAR_prompt/image
36
+ # root_path_1: [{'ISAID': './SAM_DATA_UNIFY2/ISAID/split_image', 'WHU': './SAM_DATA_UNIFY2/WHU-OPT/split_images'}]
37
+ # root_path_1: [{'Decoder1': "/workspace/SAM_DATA_UNIFY3/Decoder1/split_image/", 'Decoder2': "/workspace/SAM_DATA_UNIFY3/Decoder2/split_image/"}]
38
+ # root_path_1: [{'Decoder1': "/workspace/SAM_DATA_UNIFY4/Potsdam/image1/", 'Decoder2': "/workspace/SAM_DATA_UNIFY4/Decoder2/image/"}]
39
+ # root_path_2: ./ISAID/train/trainprompt/gt
40
+ root_path_2: ./SAR_prompt/gt
41
+ # root_path_2: [{'ISAID': './SAM_DATA_UNIFY2/ISAID/split_gt', 'WHU': './SAM_DATA_UNIFY2/WHU-OPT/split_gt'}]
42
+ # root_path_2: [{'Decoder1': "/workspace/SAM_DATA_UNIFY3/Decoder1/split_gt/", 'Decoder2': "/workspace/SAM_DATA_UNIFY3/Decoder2/split_gt/"}]
43
+ # root_path_2: [{'Decoder1': "/workspace/SAM_DATA_UNIFY4/Potsdam/gt1/", 'Decoder2': "/workspace/SAM_DATA_UNIFY4/Decoder2/gt/"}]
44
+ cache: none
45
+ split_key: test
46
+ wrapper:
47
+ name: val
48
+ args:
49
+ inp_size: 1024
50
+ # batch_size: 2
51
+ batch_size: 1
52
+
53
+ test_dataset:
54
+ dataset:
55
+ name: paired-image-folders
56
+ args:
57
+ # root_path_1: ./ISAID/train/trainprompt/images
58
+ # root_path_1: ./ISAID/train/trainprompt/sub_images
59
+ root_path_1: ./save/SAR_prompt/image
60
+ # root_path_1: ./SAM_DATA_UNIFY/Vaihingen/split_image
61
+ # root_path_1: ./SAM_DATA_UNIFY/SAR2020/split_image_ov500
62
+ # root_path_1: ./SAM_DATA_UNIFY/POLARIS_SAR/split_image
63
+ # root_path_1: ./SAM_DATA_UNIFY/Overall_Update/split_image
64
+ # root_path_1: ./SAM_DATA_UNIFY2/ISAID/split_image
65
+ # root_path_1: ./SAM_DATA_UNIFY2/whu-sar-test/split_image
66
+ # root_path_1: ./SAM_DATA_UNIFY2/WHU-SAR/split_image
67
+ # root_path_1: ./SAM_DATA_UNIFY2/WHU_ALL/split_image
68
+ # root_path_1: ./SAM_DATA_UNIFY3/WHU_SAR/split_image
69
+ # root_path_1: ./SAM_DATA_UNIFY3/WHU_OPT/split_image
70
+ # root_path_1: ./SAM_DATA_UNIFY3/ISAID/split_image
71
+ # root_path_1: ./SAM_DATA_UNIFY3/GANFEN/split_image
72
+ # root_path_1: ./SAM_DATA_UNIFY4/SAR2020/split_image_ov500
73
+
74
+ # root_path_2: ./ISAID/train/trainprompt/gt
75
+ # root_path_2: ./ISAID/train/trainprompt/sub_gt
76
+ root_path_2: ./save/SAR_prompt/gt
77
+ # root_path_2: ./SAM_DATA_UNIFY/Vaihingen/split_gt
78
+ # root_path_2: ./SAM_DATA_UNIFY2/ISAID/split_gt
79
+ # root_path_2: ./SAM_DATA_UNIFY/POLARIS_SAR/split_gt
80
+ # root_path_2: ./SAM_DATA_UNIFY/Overall_Update/split_gt
81
+ # root_path_2: ./SAM_DATA_UNIFY2/ISAID/split_gt
82
+ # root_path_2: ./SAM_DATA_UNIFY2/whu-sar-test/split_gt
83
+ # root_path_2: ./SAM_DATA_UNIFY2/WHU-SAR/split_gt
84
+ # root_path_2: ./SAM_DATA_UNIFY2/WHU_ALL/split_gt
85
+ # root_path_2: ./SAM_DATA_UNIFY3/WHU_SAR/split_gt
86
+ # root_path_2: ./SAM_DATA_UNIFY3/WHU_OPT/split_gt
87
+ # root_path_2: ./SAM_DATA_UNIFY3/ISAID/split_gt
88
+ # root_path_2: ./SAM_DATA_UNIFY3/GANFEN/gt_decoder1
89
+ # root_path_2: ./SAM_DATA_UNIFY3/GANFEN/gt_decoder2
90
+ # root_path_2: ./SAM_DATA_UNIFY4/SAR2020/gt_decoder2
91
+ cache: none
92
+ split_key: test
93
+ wrapper:
94
+ name: val
95
+ args:
96
+ # inp_size: 1024
97
+ inp_size: 1024
98
+ batch_size: 1
99
+
100
+ #eval_type: cod
101
+ eval_type: f1
102
+ #sam_checkpoint: ./pretrained/sam_vit_l_0b3195.pth
103
+ #sam_checkpoint: sam_vit_h_4b8939.pth
104
+ sam_checkpoint: ./save/_multi_mo_multi_task_0626/model_epoch_last.pth
105
+ #sam_checkpoint: ./save/_multi_mo_multi_task_0626/model_epoch_last.pth
106
+ data_norm:
107
+ inp:
108
+ sub:
109
+ - 0.5
110
+ div:
111
+ - 0.5
112
+ gt:
113
+ sub:
114
+ - 0.5
115
+ div:
116
+ - 0.5
117
+ gt_rgb:
118
+ sub:
119
+ - 0.5
120
+ div:
121
+ - 0.5
122
+ model:
123
+ name: sam
124
+ args:
125
+ inp_size: 1024
126
+ # loss: iou
127
+ loss: cr
128
+ encoder_mode:
129
+ name: sam
130
+ img_size: 1024
131
+ mlp_ratio: 4
132
+ patch_size: 16
133
+ qkv_bias: true
134
+ use_rel_pos: true
135
+ window_size: 14
136
+ out_chans: 256
137
+ scale_factor: 32
138
+ input_type: fft
139
+ freq_nums: 0.25
140
+ prompt_type: highpass
141
+ prompt_embed_dim: 256
142
+ tuning_stage: 1234
143
+ handcrafted_tune: true
144
+ embedding_tune: true
145
+ adaptor: adaptor
146
+ embed_dim: 1280
147
+ depth: 32
148
+ num_heads: 16
149
+ global_attn_indexes:
150
+ - 7
151
+ - 15
152
+ - 23
153
+ - 31
154
+ optimizer:
155
+ name: adamw
156
+ args:
157
+ # lr: 0.0002
158
+ # lr: 0.00002
159
+ # lr: 0.00004
160
+ # lr: 0.00008
161
+ lr: 0.0002
162
+ lr_min: 1.0e-8
163
+ #epoch_max: 20
164
+ epoch_max: 200
165
+
166
+ multi_step_lr:
167
+ milestones:
168
+ - 1
169
+ gamma: 0.1
170
+ epoch_val: 200
171
+ epoch_save: 1
172
+
173
+ #resume: 60
174
+ #start_epoch: 60
datasets/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .datasets import register, make
2
+ from . import image_folder
3
+ from . import wrappers
datasets/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (261 Bytes). View file
 
datasets/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (255 Bytes). View file
 
datasets/__pycache__/datasets.cpython-310.pyc ADDED
Binary file (683 Bytes). View file
 
datasets/__pycache__/datasets.cpython-37.pyc ADDED
Binary file (656 Bytes). View file
 
datasets/__pycache__/image_folder.cpython-310.pyc ADDED
Binary file (10.3 kB). View file
 
datasets/__pycache__/image_folder.cpython-37.pyc ADDED
Binary file (11.2 kB). View file
 
datasets/__pycache__/wrappers.cpython-310.pyc ADDED
Binary file (4.36 kB). View file
 
datasets/__pycache__/wrappers.cpython-37.pyc ADDED
Binary file (5.45 kB). View file
 
datasets/data_loader_multi_tasks.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ def build_loader_simmim(config):
3
+ ############ single model #####################
4
+ # transform = SimMIMTransform(config)
5
+ # dataset = ImageFolder(config.DATA.DATA_PATH, transform)
6
+ # sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True)
7
+ # dataloader = DataLoader(dataset, config.DATA.BATCH_SIZE, sampler=sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, drop_last=True, collate_fn=collate_fn)
8
+
9
+ ############## multi model ####################
10
+ datasets = []
11
+ ### 数据增强 ######
12
+ model_paths = config.DATA.TYPE_PATH[0]
13
+ for i in model_paths.keys():
14
+ a = config.DATA.SCALE[0][i].split(',')
15
+ scale_model = (float(a[0].split('(')[1]) ,float(a[1].split(')')[0]))
16
+ transform = SimMIMTransform(config, config.DATA.NORM[0][i], scale_model)
17
+ dataset = CachedImageFolder(model_paths[i], transform = transform, model = i)
18
+ datasets.append(dataset)
19
+ multi_task_train_dataset = MultiTaskDataset(datasets)
20
+ print(len(datasets))
21
+ multi_task_batch_sampler = DistrubutedMultiTaskBatchSampler(datasets, batch_size=config.DATA.BATCH_SIZE, num_replicas=dist.get_world_size(), rank=dist.get_rank(), mix_opt=0, extra_task_ratio=0, drop_last=True ,shuffle =True)
22
+ dataloader = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, collate_fn=collate_fn)
23
+ # dataloader = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler, pin_memory=True, collate_fn=collate_fn)
24
+ print(len(dataloader))
25
+
26
+ return dataloader
datasets/data_simmim_pt.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # SimMIM
3
+ # Copyright (c) 2021 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Zhenda Xie
6
+ # --------------------------------------------------------
7
+
8
+ import math
9
+ import random
10
+ import numpy as np
11
+
12
+ import torch
13
+ import torch.distributed as dist
14
+ import torchvision.transforms as T
15
+ from torch.utils.data import DataLoader, DistributedSampler
16
+ from torch.utils.data._utils.collate import default_collate
17
+ from torchvision.datasets import ImageFolder
18
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
19
+ from torch.utils.data import Dataset, BatchSampler
20
+ from torchvision.io import read_image
21
+ from .cached_image_folder import CachedImageFolder
22
+
23
+ class MultiTaskDataset(Dataset):
24
+ """
25
+ useage example:
26
+ train_datasets = [SemData_Single(), SemData_Single()]
27
+ multi_task_train_dataset = MultiTaskDataset(train_datasets)
28
+ multi_task_batch_sampler = MultiTaskBatchSampler(train_datasets, batch_size=4, mix_opt=0, extra_task_ratio=0, drop_last=True)
29
+ multi_task_train_data = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler)
30
+ for i, (task_id, input, target) in enumerate(multi_task_train_data):
31
+ pre = model(input)
32
+ """
33
+ def __init__(self, datasets):
34
+ self._datasets = datasets
35
+ task_id_2_data_set_dic = {}
36
+ for i, dataset in enumerate(datasets):
37
+ task_id = i
38
+ assert task_id not in task_id_2_data_set_dic, "Duplicate task_id %s" % task_id
39
+ task_id_2_data_set_dic[task_id] = dataset
40
+
41
+ self._task_id_2_data_set_dic = task_id_2_data_set_dic
42
+
43
+ def __len__(self):
44
+ return sum(len(dataset) for dataset in self._datasets)
45
+
46
+ def __getitem__(self, idx):
47
+ task_id, sample_id = idx
48
+ return self._task_id_2_data_set_dic[task_id][sample_id]
49
+
50
+ class DistrubutedMultiTaskBatchSampler(BatchSampler):
51
+ """
52
+ datasets: class the class of the Dataset
53
+ batch_size: int
54
+ mix_opt: int mix_opt ==0 shuffle all_task; mix_opt ==1 shuffle extra_task
55
+ extra_task_ratio(float, optional): the rate between task one and extra task
56
+ drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
57
+ if the dataset size is not divisible by the batch size. If ``False`` and
58
+ the size of dataset is not divisible by the batch size, then the last batch
59
+ will be smaller. (default: ``True``)
60
+ """
61
+ def __init__(self, datasets, batch_size, num_replicas, rank, mix_opt=0, extra_task_ratio=0, drop_last=True,shuffle = True):
62
+ if num_replicas is None:
63
+ if not dist.is_available():
64
+ raise RuntimeError("Requires distributed package to be available")
65
+ num_replicas = dist.get_world_size()
66
+ if rank is None:
67
+ if not dist.is_available():
68
+ raise RuntimeError("Requires distributed package to be available")
69
+ rank = dist.get_rank()
70
+ if rank >= num_replicas or rank < 0:
71
+ raise ValueError(
72
+ "Invalid rank {}, rank should be in the interval"
73
+ " [0, {}]".format(rank, num_replicas - 1))
74
+ self.num_replicas = num_replicas
75
+ self.rank = rank
76
+ self.epoch = 0
77
+ assert mix_opt in [0, 1], 'mix_opt must equal 0 or 1'
78
+ assert extra_task_ratio >= 0, 'extra_task_ratio must greater than 0'
79
+ self._datasets = datasets
80
+ self._batch_size = batch_size
81
+ self._mix_opt = mix_opt
82
+ self._extra_task_ratio = extra_task_ratio
83
+ self._drop_last = drop_last
84
+ train_data_list = []
85
+ self.shuffle = shuffle
86
+ for dataset in datasets:
87
+ print(len(dataset))
88
+ train_data_list.append(self._get_index_batches(len(dataset), batch_size, self._drop_last))
89
+
90
+ ######### 一个列表里存n个dataset的数据,数据也以列表形式存在,一个dataset的列表里面把数据划分成了不同的batch的index
91
+ self._train_data_list = train_data_list
92
+ self.total_len = sum(len(train_data) for train_data in self._train_data_list)
93
+
94
+ ######### DDP ######################
95
+ if self._drop_last and self.total_len % self.num_replicas != 0: # type: ignore[arg-type]
96
+ self.num_samples = math.ceil(
97
+ (self.total_len - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
98
+ )
99
+ else:
100
+ self.num_samples = math.ceil(self.total_len / self.num_replicas) # type: ignore[arg-type]
101
+
102
+ self.total_size = self.num_samples * self.num_replicas
103
+ self.epoch = 0
104
+ self.seed = 0
105
+
106
+ def set_epoch(self, epoch):
107
+ self.epoch = epoch
108
+
109
+ @staticmethod
110
+ def _get_index_batches(dataset_len, batch_size, drop_last):
111
+ # index_batches = [list(range(i, min(i+batch_size, dataset_len))) for i in range(0, dataset_len, batch_size)]
112
+ index = list(range(dataset_len))
113
+ if drop_last and dataset_len % batch_size:
114
+ del index[-(dataset_len % batch_size):]
115
+ index_batches = [index[i:i+batch_size] for i in range(0, len(index), batch_size)]
116
+ return index_batches
117
+
118
+ def __len__(self):
119
+ # return sum(len(train_data) for train_data in self._train_data_list)
120
+ return self.num_samples
121
+
122
+ def __iter__(self):
123
+ all_iters = [iter(item) for item in self._train_data_list]
124
+ all_indices = self._gen_task_indices(self._train_data_list, self._mix_opt, self._extra_task_ratio)
125
+
126
+ ######### DDP ######################
127
+ random.shuffle(all_indices)
128
+ all_indices = all_indices[self.rank:self.total_size:self.num_replicas]
129
+ assert len(all_indices) == self.num_samples
130
+
131
+ for local_task_idx in all_indices:
132
+ # task_id = self._datasets[local_task_idx].get_task_id()
133
+ batch = next(all_iters[local_task_idx])
134
+ # batch = batch[self.rank:len(batch):self.num_replicas]
135
+ # print(local_task_idx)
136
+ yield [(local_task_idx, sample_id) for sample_id in batch]
137
+ # yield iter(batch)
138
+
139
+ @staticmethod
140
+ def _gen_task_indices(train_data_list, mix_opt, extra_task_ratio):
141
+
142
+ ########## accoding to the number of models ###########
143
+ all_indices = []
144
+ for i in range(len(train_data_list)):
145
+ all_indices += [i] * len(train_data_list[i])
146
+ # print(all_indices)
147
+ return all_indices
148
+ # def set_epoch(self, epoch)
149
+ # r"""
150
+ # Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
151
+ # use a different random ordering for each epoch. Otherwise, the next iteration of this
152
+ # sampler will yield the same ordering.
153
+
154
+ # Args:
155
+ # epoch (int): Epoch number.
156
+ # """
157
+ # self.epoch = epoch
158
+
159
+
160
+ class MaskGenerator:
161
+ def __init__(self, input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6):
162
+ self.input_size = input_size
163
+ self.mask_patch_size = mask_patch_size
164
+ self.model_patch_size = model_patch_size
165
+ self.mask_ratio = mask_ratio
166
+
167
+ assert self.input_size % self.mask_patch_size == 0
168
+ assert self.mask_patch_size % self.model_patch_size == 0
169
+
170
+ self.rand_size = self.input_size // self.mask_patch_size
171
+ self.scale = self.mask_patch_size // self.model_patch_size
172
+
173
+ self.token_count = self.rand_size ** 2
174
+ self.mask_count = int(np.ceil(self.token_count * self.mask_ratio))
175
+
176
+ def __call__(self):
177
+ mask_idx = np.random.permutation(self.token_count)[:self.mask_count]
178
+ mask = np.zeros(self.token_count, dtype=int)
179
+ mask[mask_idx] = 1
180
+
181
+ mask = mask.reshape((self.rand_size, self.rand_size))
182
+ mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1)
183
+
184
+ return mask
185
+
186
+
187
+ class ZeroOneNormalize(object):
188
+ def __call__(self, img):
189
+ return img.float().div(255)
190
+
191
+ class SimMIMTransform:
192
+ def __init__(self, config, NORM, SCALE):
193
+ self.transform_img = T.Compose([
194
+ # T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
195
+ # T.RandomResizedCrop(config.DATA.IMG_SIZE, scale=(0.67, 1.), ratio=(3. / 4., 4. / 3.)),
196
+ # T.RandomHorizontalFlip(),
197
+ # T.ToTensor(),
198
+ # T.Normalize(mean=torch.tensor(IMAGENET_DEFAULT_MEAN),std=torch.tensor(IMAGENET_DEFAULT_STD)),
199
+
200
+ T.RandomResizedCrop(config.DATA.IMG_SIZE, scale=SCALE, ratio=(3. / 4., 4. / 3.)),
201
+ T.RandomHorizontalFlip(),
202
+ ZeroOneNormalize(),
203
+ T.Normalize(mean=torch.tensor(NORM[0]),std=torch.tensor(NORM[1])),
204
+ ])
205
+
206
+ if config.MODEL.TYPE in ['swin', 'swinv2']:
207
+ model_patch_size=config.MODEL.SWIN.PATCH_SIZE
208
+ else:
209
+ raise NotImplementedError
210
+
211
+ self.mask_generator = MaskGenerator(
212
+ input_size=config.DATA.IMG_SIZE,
213
+ mask_patch_size=config.DATA.MASK_PATCH_SIZE,
214
+ model_patch_size=model_patch_size,
215
+ mask_ratio=config.DATA.MASK_RATIO,
216
+ )
217
+
218
+ def __call__(self, img):
219
+ img = self.transform_img(img)
220
+ mask = self.mask_generator()
221
+
222
+ return img, mask
223
+
224
+ def collate_fn(batch):
225
+ # print(len(batch))
226
+ # print('*'*10)
227
+ # print(batch[0][0])
228
+ # print('#'*10)
229
+ # print(batch[0][1])
230
+ # batch = list(filter(lambda x: x[0][0] is not None, batch))
231
+ # if len(batch) == 0: return torch.Tensor()
232
+
233
+ if not isinstance(batch[0][0], tuple):
234
+ return default_collate(batch)
235
+ else:
236
+ batch_num = len(batch)
237
+ ret = []
238
+ for item_idx in range(len(batch[0][0])):
239
+ if batch[0][0][item_idx] is None:
240
+ ret.append(None)
241
+ else:
242
+ ret.append(default_collate([batch[i][0][item_idx] for i in range(batch_num)]))
243
+ ret.append(default_collate([batch[i][1] for i in range(batch_num)]))
244
+ return ret
245
+
246
+
247
+ def build_loader_simmim(config):
248
+ ############ single model #####################
249
+ # transform = SimMIMTransform(config)
250
+ # dataset = ImageFolder(config.DATA.DATA_PATH, transform)
251
+ # sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True)
252
+ # dataloader = DataLoader(dataset, config.DATA.BATCH_SIZE, sampler=sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, drop_last=True, collate_fn=collate_fn)
253
+
254
+ ############## multi model ####################
255
+ datasets = []
256
+ ### 数据增强 ######
257
+ model_paths = config.DATA.TYPE_PATH[0]
258
+ for i in model_paths.keys():
259
+ a = config.DATA.SCALE[0][i].split(',')
260
+ scale_model = (float(a[0].split('(')[1]),float(a[1].split(')')[0]))
261
+ transform = SimMIMTransform(config, config.DATA.NORM[0][i], scale_model)
262
+ dataset = CachedImageFolder(model_paths[i], transform = transform, model = i)
263
+ datasets.append(dataset)
264
+ multi_task_train_dataset = MultiTaskDataset(datasets)
265
+ print(len(datasets))
266
+ multi_task_batch_sampler = DistrubutedMultiTaskBatchSampler(datasets, batch_size=config.DATA.BATCH_SIZE, num_replicas=dist.get_world_size(), rank=dist.get_rank(), mix_opt=0, extra_task_ratio=0, drop_last=True,shuffle =True)
267
+ dataloader = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, collate_fn=collate_fn)
268
+ # dataloader = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler, pin_memory=True, collate_fn=collate_fn)
269
+ print(len(dataloader))
270
+
271
+ return dataloader
datasets/datasets.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+
4
+ datasets = {}
5
+
6
+
7
+ def register(name):
8
+ def decorator(cls):
9
+ datasets[name] = cls
10
+ return cls
11
+ return decorator
12
+
13
+
14
+ def make(dataset_spec, args=None):
15
+ if args is not None:
16
+ dataset_args = copy.deepcopy(dataset_spec['args'])
17
+ dataset_args.update(args)
18
+ else:
19
+ dataset_args = dataset_spec['args']
20
+ dataset = datasets[dataset_spec['name']](**dataset_args)
21
+ return dataset
datasets/image_folder.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from PIL import Image
4
+
5
+ import pickle
6
+ import imageio
7
+ import numpy as np
8
+ import torch
9
+ from torch.utils.data import Dataset
10
+ from torchvision import transforms
11
+ import random
12
+ from datasets import register
13
+
14
+ import math
15
+ import torch.distributed as dist
16
+ from torch.utils.data import BatchSampler
17
+
18
+ from torch.utils.data._utils.collate import default_collate
19
+
20
+ @register('image-folder')
21
+ class ImageFolder(Dataset):
22
+ def __init__(self, path, split_file=None, split_key=None, first_k=None, size=None,
23
+ repeat=1, cache='none', mask=False):
24
+ self.repeat = repeat
25
+ self.cache = cache
26
+ self.path = path
27
+ self.Train = False
28
+ self.split_key = split_key
29
+
30
+ self.size = size
31
+ self.mask = mask
32
+ if self.mask:
33
+ self.img_transform = transforms.Compose([
34
+ transforms.Resize((self.size, self.size), interpolation=Image.NEAREST),
35
+ transforms.ToTensor(),
36
+ ])
37
+ else:
38
+ self.img_transform = transforms.Compose([
39
+ transforms.Resize((self.size, self.size)),
40
+ transforms.ToTensor(),
41
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
42
+ std=[0.229, 0.224, 0.225])
43
+ ])
44
+
45
+ if split_file is None:
46
+ filenames = sorted(os.listdir(path))
47
+ else:
48
+ with open(split_file, 'r') as f:
49
+ filenames = json.load(f)[split_key]
50
+ if first_k is not None:
51
+ filenames = filenames[:first_k]
52
+
53
+ self.files = []
54
+
55
+ for filename in filenames:
56
+ file = os.path.join(path, filename)
57
+ self.append_file(file)
58
+
59
+ def append_file(self, file):
60
+ if self.cache == 'none':
61
+ self.files.append(file)
62
+ elif self.cache == 'in_memory':
63
+ self.files.append(self.img_process(file))
64
+
65
+ def __len__(self):
66
+ return len(self.files) * self.repeat
67
+
68
+ def __getitem__(self, idx):
69
+ x = self.files[idx % len(self.files)]
70
+
71
+ if self.cache == 'none':
72
+ return self.img_process(x)
73
+ elif self.cache == 'in_memory':
74
+ return x
75
+
76
+ def img_process(self, file):
77
+ if self.mask:
78
+ # return Image.open(file).convert('L')
79
+ return file
80
+ else:
81
+ return Image.open(file).convert('RGB')
82
+
83
+ @register('paired-image-folders')
84
+ class PairedImageFolders(Dataset):
85
+
86
+ def __init__(self, root_path_1, root_path_2, **kwargs):
87
+ self.dataset_1 = ImageFolder(root_path_1, **kwargs)
88
+ self.dataset_2 = ImageFolder(root_path_2, **kwargs, mask=True)
89
+
90
+ def __len__(self):
91
+ return len(self.dataset_1)
92
+
93
+ def __getitem__(self, idx):
94
+ return self.dataset_1[idx], self.dataset_2[idx]
95
+
96
+ class ImageFolder_multi_task(Dataset):
97
+ def __init__(self, path, split_file=None, split_key=None, first_k=None, size=None,
98
+ repeat=1, cache='none', mask=False):
99
+ self.repeat = repeat
100
+ self.cache = cache
101
+ self.path = path
102
+ self.Train = False
103
+ self.split_key = split_key
104
+
105
+ self.size = size
106
+ self.mask = mask
107
+ if self.mask:
108
+ self.img_transform = transforms.Compose([
109
+ transforms.Resize((self.size, self.size), interpolation=Image.NEAREST),
110
+ transforms.ToTensor(),
111
+ ])
112
+ else:
113
+ self.img_transform = transforms.Compose([
114
+ transforms.Resize((self.size, self.size)),
115
+ transforms.ToTensor(),
116
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
117
+ std=[0.229, 0.224, 0.225])
118
+ ])
119
+
120
+ if split_file is None:
121
+ filenames = sorted(os.listdir(path))
122
+ else:
123
+ with open(split_file, 'r') as f:
124
+ filenames = json.load(f)[split_key]
125
+ if first_k is not None:
126
+ filenames = filenames[:first_k]
127
+
128
+ self.files = []
129
+
130
+ for filename in filenames:
131
+ file = os.path.join(path, filename)
132
+ self.append_file(file)
133
+
134
+ def append_file(self, file):
135
+ if self.cache == 'none':
136
+ self.files.append(file)
137
+ elif self.cache == 'in_memory':
138
+ self.files.append(self.img_process(file))
139
+
140
+ def __len__(self):
141
+ return len(self.files) * self.repeat
142
+
143
+ def __getitem__(self, idx):
144
+ x = self.files[idx % len(self.files)]
145
+
146
+ if self.cache == 'none':
147
+ return self.img_process(x)
148
+ elif self.cache == 'in_memory':
149
+ return x
150
+
151
+ def img_process(self, file):
152
+ if self.mask:
153
+ # return Image.open(file).convert('L')
154
+ return file
155
+ else:
156
+ return Image.open(file).convert('RGB')
157
+
158
+ @register('paired-image-folders-multi-task')
159
+ class PairedImageFolders_multi_task(Dataset):
160
+
161
+ def __init__(self, root_path_1, root_path_2, model=None, **kwargs):
162
+
163
+ self.dataset_1 = ImageFolder_multi_task(root_path_1, **kwargs)
164
+ self.dataset_2 = ImageFolder_multi_task(root_path_2, **kwargs, mask=True)
165
+
166
+ def __len__(self):
167
+ return len(self.dataset_1)
168
+
169
+ def __getitem__(self, idx):
170
+ return self.dataset_1[idx], self.dataset_2[idx]
171
+
172
+
173
+
174
+
175
+ # class MultiTaskDataset(Dataset):
176
+ # """
177
+ # useage example:
178
+ # train_datasets = [SemData_Single(), SemData_Single()]
179
+ # multi_task_train_dataset = MultiTaskDataset(train_datasets)
180
+ # multi_task_batch_sampler = MultiTaskBatchSampler(train_datasets, batch_size=4, mix_opt=0, extra_task_ratio=0, drop_last=True)
181
+ # multi_task_train_data = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler)
182
+ # for i, (task_id, input, target) in enumerate(multi_task_train_data):
183
+ # pre = model(input)
184
+ # """
185
+ # def __init__(self, datasets_image, datasets_gt):
186
+ # self._datasets = datasets_image
187
+ # task_id_2_image_set_dic = {}
188
+ # for i, dataset in enumerate(datasets_image):
189
+ # task_id = i
190
+ # assert task_id not in task_id_2_image_set_dic, "Duplicate task_id %s" % task_id
191
+ # task_id_2_image_set_dic[task_id] = dataset
192
+ # self.datasets_1 = task_id_2_image_set_dic
193
+ #
194
+ # task_id_2_gt_set_dic = {}
195
+ # for i, dataset in enumerate(datasets_gt):
196
+ # task_id = i
197
+ # assert task_id not in task_id_2_gt_set_dic, "Duplicate task_id %s" % task_id
198
+ # task_id_2_gt_set_dic[task_id] = dataset
199
+ # self.dataset_2 = task_id_2_gt_set_dic
200
+ #
201
+ #
202
+ # def __len__(self):
203
+ # return sum(len(dataset) for dataset in self._datasets)
204
+ #
205
+ # def __getitem__(self, idx):
206
+ # task_id, sample_id = idx
207
+ # # return self._task_id_2_data_set_dic[task_id][sample_id]
208
+ # return self.dataset_1[task_id][sample_id], self.dataset_2[task_id][sample_id]
209
+
210
+ class MultiTaskDataset(Dataset):
211
+ """
212
+ useage example:
213
+ train_datasets = [SemData_Single(), SemData_Single()]
214
+ multi_task_train_dataset = MultiTaskDataset(train_datasets)
215
+ multi_task_batch_sampler = MultiTaskBatchSampler(train_datasets, batch_size=4, mix_opt=0, extra_task_ratio=0, drop_last=True)
216
+ multi_task_train_data = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler)
217
+ for i, (task_id, input, target) in enumerate(multi_task_train_data):
218
+ pre = model(input)
219
+ """
220
+ def __init__(self, datasets):
221
+ self._datasets = datasets
222
+ task_id_2_data_set_dic = {}
223
+ for i, dataset in enumerate(datasets):
224
+ task_id = i
225
+ assert task_id not in task_id_2_data_set_dic, "Duplicate task_id %s" % task_id
226
+ task_id_2_data_set_dic[task_id] = dataset
227
+
228
+ self._task_id_2_data_set_dic = task_id_2_data_set_dic
229
+
230
+ def __len__(self):
231
+ return sum(len(dataset) for dataset in self._datasets)
232
+
233
+ def __getitem__(self, idx):
234
+ task_id, sample_id = idx
235
+ # print('----', idx, task_id, sample_id)
236
+ return self._task_id_2_data_set_dic[task_id][sample_id]
237
+
238
+ def collate_fn(batch):
239
+ # print(len(batch))
240
+ # print('*'*10)
241
+ # print(batch[0][0])
242
+ # print('#'*10)
243
+ # print(batch[0][1])
244
+ # batch = list(filter(lambda x: x[0][0] is not None, batch))
245
+ # if len(batch) == 0: return torch.Tensor()
246
+ print('******------',batch)
247
+ if not isinstance(batch[0][0], tuple):
248
+ return default_collate(batch)
249
+ else:
250
+ batch_num = len(batch)
251
+ ret = []
252
+ for item_idx in range(len(batch[0][0])):
253
+ if batch[0][0][item_idx] is None:
254
+ ret.append(None)
255
+ else:
256
+ ret.append(default_collate([batch[i][0][item_idx] for i in range(batch_num)]))
257
+ ret.append(default_collate([batch[i][1] for i in range(batch_num)]))
258
+ return ret
259
+
260
+ class DistrubutedMultiTaskBatchSampler(BatchSampler):
261
+ """
262
+ datasets: class the class of the Dataset
263
+ batch_size: int
264
+ mix_opt: int mix_opt ==0 shuffle all_task; mix_opt ==1 shuffle extra_task
265
+ extra_task_ratio(float, optional): the rate between task one and extra task
266
+ drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
267
+ if the dataset size is not divisible by the batch size. If ``False`` and
268
+ the size of dataset is not divisible by the batch size, then the last batch
269
+ will be smaller. (default: ``True``)
270
+ """
271
+
272
+ def __init__(self, datasets, batch_size, num_replicas, rank, mix_opt=0, extra_task_ratio=0, drop_last=True,
273
+ shuffle=True):
274
+ if num_replicas is None:
275
+ if not dist.is_available():
276
+ raise RuntimeError("Requires distributed package to be available")
277
+ num_replicas = dist.get_world_size()
278
+ if rank is None:
279
+ if not dist.is_available():
280
+ raise RuntimeError("Requires distributed package to be available")
281
+ rank = dist.get_rank()
282
+ if rank >= num_replicas or rank < 0:
283
+ raise ValueError(
284
+ "Invalid rank {}, rank should be in the interval"
285
+ " [0, {}]".format(rank, num_replicas - 1))
286
+ self.num_replicas = num_replicas
287
+ self.rank = rank
288
+ self.epoch = 0
289
+ assert mix_opt in [0, 1], 'mix_opt must equal 0 or 1'
290
+ assert extra_task_ratio >= 0, 'extra_task_ratio must greater than 0'
291
+ # self._datasets = datasets
292
+ self._batch_size = batch_size
293
+ self._mix_opt = mix_opt
294
+ self._extra_task_ratio = extra_task_ratio
295
+ self._drop_last = drop_last
296
+ train_data_list = []
297
+ self.shuffle = shuffle
298
+ for dataset in datasets:
299
+ print(len(dataset))
300
+ train_data_list.append(self._get_index_batches(len(dataset), batch_size, self._drop_last))
301
+
302
+ ######### 一个列表里存n个dataset的数据,数据也以列表形式存在,一个dataset的列表里面把数据划分成了不同的batch的index
303
+ self._train_data_list = train_data_list
304
+ self.total_len = sum(len(train_data) for train_data in self._train_data_list)
305
+
306
+ ######### DDP ######################
307
+ if self._drop_last and self.total_len % self.num_replicas != 0: # type: ignore[arg-type]
308
+ self.num_samples = math.ceil(
309
+ (self.total_len - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
310
+ )
311
+ else:
312
+ self.num_samples = math.ceil(self.total_len / self.num_replicas) # type: ignore[arg-type]
313
+
314
+ self.total_size = self.num_samples * self.num_replicas
315
+ self.epoch = 0
316
+ self.seed = 0
317
+
318
+ def set_epoch(self, epoch):
319
+ # print('&&&&****')
320
+ self.epoch = epoch
321
+
322
+ @staticmethod
323
+ def _get_index_batches(dataset_len, batch_size, drop_last):
324
+ # index_batches = [list(range(i, min(i+batch_size, dataset_len))) for i in range(0, dataset_len, batch_size)]
325
+ index = list(range(dataset_len))
326
+ if drop_last and dataset_len % batch_size:
327
+ del index[-(dataset_len % batch_size):]
328
+ index_batches = [index[i:i + batch_size] for i in range(0, len(index), batch_size)]
329
+ return index_batches
330
+
331
+ def __len__(self):
332
+ # return sum(len(train_data) for train_data in self._train_data_list)
333
+ return self.num_samples
334
+
335
+ def __iter__(self):
336
+ all_iters = [iter(item) for item in self._train_data_list]
337
+ all_indices = self._gen_task_indices(self._train_data_list, self._mix_opt, self._extra_task_ratio)
338
+
339
+ ######### DDP ######################
340
+ random.shuffle(all_indices)
341
+ all_indices = all_indices[self.rank:self.total_size:self.num_replicas]
342
+ assert len(all_indices) == self.num_samples
343
+
344
+ for local_task_idx in all_indices:
345
+ # task_id = self._datasets[local_task_idx].get_task_id()
346
+ batch = next(all_iters[local_task_idx])
347
+ # batch = batch[self.rank:len(batch):self.num_replicas]
348
+ # print(local_task_idx)
349
+ yield [(local_task_idx, sample_id) for sample_id in batch]
350
+ # yield iter(batch)
351
+
352
+ @staticmethod
353
+ def _gen_task_indices(train_data_list, mix_opt, extra_task_ratio):
354
+
355
+ ########## accoding to the number of models ###########
356
+ all_indices = []
357
+ for i in range(len(train_data_list)):
358
+ all_indices += [i] * len(train_data_list[i])
359
+ # print(all_indices)
360
+ return all_indices
361
+ # def set_epoch(self, epoch)
362
+ # r"""
363
+ # Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
364
+ # use a different random ordering for each epoch. Otherwise, the next iteration of this
365
+ # sampler will yield the same ordering.
366
+
367
+ # Args:
368
+ # epoch (int): Epoch number.
369
+ # """
370
+ # self.epoch = epoch
datasets/wrappers.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import functools
3
+ import random
4
+ import math
5
+ from PIL import Image
6
+ import cv2
7
+
8
+ import numpy as np
9
+ import torch
10
+ from torch.utils.data import Dataset
11
+ from torchvision import transforms
12
+ import torchvision
13
+
14
+ from datasets import register
15
+ import cv2
16
+ from math import pi
17
+ from torchvision.transforms import InterpolationMode
18
+
19
+ import torch.nn.functional as F
20
+ def to_mask(mask):
21
+ return transforms.ToTensor()(
22
+ transforms.Grayscale(num_output_channels=1)(
23
+ transforms.ToPILImage()(mask)))
24
+
25
+
26
+ def resize_fn(img, size):
27
+ return transforms.ToTensor()(
28
+ transforms.Resize(size)(
29
+ transforms.ToPILImage()(img)))
30
+
31
+
32
+ @register('val')
33
+ class ValDataset(Dataset):
34
+ def __init__(self, dataset, inp_size=None, augment=False):
35
+ self.dataset = dataset
36
+ self.inp_size = inp_size
37
+ self.augment = augment
38
+
39
+ self.img_transform = transforms.Compose([
40
+ # transforms.Resize((inp_size, inp_size)),
41
+ transforms.ToTensor(),
42
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
43
+ std=[0.229, 0.224, 0.225])
44
+ ])
45
+ self.mask_transform = transforms.Compose([
46
+ transforms.Resize((inp_size, inp_size), interpolation=Image.NEAREST),
47
+ transforms.ToTensor(),
48
+ ])
49
+
50
+ def __len__(self):
51
+ return len(self.dataset)
52
+
53
+ def __getitem__(self, idx):
54
+ img, mask = self.dataset[idx]
55
+ mask_name = mask
56
+ a = self.img_transform(img)
57
+ # b = self.mask_transform(mask)
58
+
59
+ # print(idx, mask.filename)
60
+ # b = cv2.imread(mask.filename,cv2.IMREAD_UNCHANGED)
61
+ b = cv2.imread(mask,cv2.IMREAD_UNCHANGED)
62
+ return {
63
+ 'inp': self.img_transform(img),
64
+ 'gt': torch.tensor(b),
65
+ 'name': mask_name,
66
+ 'filp': False
67
+ # 'idx': idx
68
+ }
69
+
70
+
71
+ @register('train')
72
+ class TrainDataset(Dataset):
73
+ def __init__(self, dataset, size_min=None, size_max=None, inp_size=None,
74
+ augment=False, gt_resize=None):
75
+ self.dataset = dataset
76
+ self.size_min = size_min
77
+ if size_max is None:
78
+ size_max = size_min
79
+ self.size_max = size_max
80
+ self.augment = augment
81
+ self.gt_resize = gt_resize
82
+
83
+ self.inp_size = inp_size
84
+ self.img_transform = transforms.Compose([
85
+ transforms.Resize((self.inp_size, self.inp_size)),
86
+ transforms.ToTensor(),
87
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
88
+ std=[0.229, 0.224, 0.225])
89
+ ])
90
+ self.inverse_transform = transforms.Compose([
91
+ transforms.Normalize(mean=[0., 0., 0.],
92
+ std=[1/0.229, 1/0.224, 1/0.225]),
93
+ transforms.Normalize(mean=[-0.485, -0.456, -0.406],
94
+ std=[1, 1, 1])
95
+ ])
96
+ self.mask_transform = transforms.Compose([
97
+ transforms.Resize((self.inp_size, self.inp_size)),
98
+ transforms.ToTensor(),
99
+ ])
100
+
101
+ def __len__(self):
102
+ return len(self.dataset)
103
+
104
+ def __getitem__(self, idx):
105
+ # print('lodd****',idx,self.dataset[idx])
106
+ img, mask = self.dataset[idx]
107
+ mask_name = mask
108
+ # print('befor mask', mask)
109
+ #new add
110
+ # print(idx, mask.filename, img.size)
111
+
112
+ # mask = cv2.imread(mask.filename, cv2.IMREAD_UNCHANGED)
113
+ mask = cv2.imread(mask, cv2.IMREAD_UNCHANGED)
114
+ # print('befor mask', mask)
115
+ # random filp
116
+ if random.random() < 0.5:
117
+ img = img.transpose(Image.FLIP_LEFT_RIGHT)
118
+ # mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
119
+ mask = cv2.flip(mask, 1)
120
+
121
+ img = transforms.Resize((self.inp_size, self.inp_size))(img)
122
+ # mask = transforms.Resize((self.inp_size, self.inp_size), interpolation=InterpolationMode.NEAREST)(mask)
123
+ mask = torch.from_numpy(mask)
124
+ # print('behind mask', mask)
125
+ return {
126
+ 'inp': self.img_transform(img),
127
+ # 'gt': self.mask_transform(mask)
128
+ 'gt': mask,
129
+ 'name': mask_name,
130
+ # 'idx': idx
131
+ }
132
+
133
+ @register('train_multi_task')
134
+ class TrainDataset(Dataset):
135
+ def __init__(self, dataset, size_min=None, size_max=None, inp_size=None,
136
+ augment=False, gt_resize=None):
137
+ self.dataset = dataset
138
+ self.size_min = size_min
139
+ if size_max is None:
140
+ size_max = size_min
141
+ self.size_max = size_max
142
+ self.augment = augment
143
+ self.gt_resize = gt_resize
144
+
145
+ self.inp_size = inp_size
146
+ self.img_transform = transforms.Compose([
147
+ transforms.Resize((self.inp_size, self.inp_size)),
148
+ transforms.ToTensor(),
149
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
150
+ std=[0.229, 0.224, 0.225])
151
+ ])
152
+ self.inverse_transform = transforms.Compose([
153
+ transforms.Normalize(mean=[0., 0., 0.],
154
+ std=[1/0.229, 1/0.224, 1/0.225]),
155
+ transforms.Normalize(mean=[-0.485, -0.456, -0.406],
156
+ std=[1, 1, 1])
157
+ ])
158
+ self.mask_transform = transforms.Compose([
159
+ transforms.Resize((self.inp_size, self.inp_size)),
160
+ transforms.ToTensor(),
161
+ ])
162
+
163
+ def __len__(self):
164
+ return len(self.dataset)
165
+ # return sum(len(dataset) for dataset in self.datasets)
166
+
167
+ def __getitem__(self, idx):
168
+ # print('lodd****',idx,self.dataset[idx])
169
+ # print('+++++',idx)
170
+ img, mask = self.dataset[idx]
171
+ # print('befor mask', mask)
172
+ #new add
173
+ # print('****',idx, mask)
174
+ mask_name = mask
175
+ mask = cv2.imread(mask, cv2.IMREAD_UNCHANGED)
176
+
177
+ # print('****',mask)
178
+ # print('befor mask', mask)
179
+ # random filp
180
+ if random.random() < 0.5:
181
+ img = img.transpose(Image.FLIP_LEFT_RIGHT)
182
+ # mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
183
+ mask = cv2.flip(mask, 1)
184
+
185
+ img = transforms.Resize((self.inp_size, self.inp_size))(img)
186
+ # mask = transforms.Resize((self.inp_size, self.inp_size), interpolation=InterpolationMode.NEAREST)(mask)
187
+ mask = torch.from_numpy(mask)
188
+ # print('behind mask', mask)
189
+ return {
190
+ 'inp': self.img_transform(img),
191
+ # 'gt': self.mask_transform(mask)
192
+ 'gt': mask,
193
+ 'name': mask_name
194
+ }
195
+
196
+
197
+ @register('val_multi_task')
198
+ class ValDataset(Dataset):
199
+ def __init__(self, dataset, inp_size=None, augment=False):
200
+ self.dataset = dataset
201
+ self.inp_size = inp_size
202
+ self.augment = augment
203
+
204
+ self.img_transform = transforms.Compose([
205
+ transforms.Resize((inp_size, inp_size)),
206
+ transforms.ToTensor(),
207
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
208
+ std=[0.229, 0.224, 0.225])
209
+ ])
210
+ self.mask_transform = transforms.Compose([
211
+ transforms.Resize((inp_size, inp_size), interpolation=Image.NEAREST),
212
+ transforms.ToTensor(),
213
+ ])
214
+
215
+ def __len__(self):
216
+ return len(self.dataset)
217
+
218
+ def __getitem__(self, idx):
219
+ img, mask = self.dataset[idx]
220
+ a = self.img_transform(img)
221
+ # b = self.mask_transform(mask)
222
+ mask_name = mask
223
+ # print(idx, mask.filename)
224
+ # b = cv2.imread(mask.filename,cv2.IMREAD_UNCHANGED)
225
+ b = cv2.imread(mask, cv2.IMREAD_UNCHANGED)
226
+ return {
227
+ 'inp': self.img_transform(img),
228
+ 'gt': torch.tensor(b),
229
+ 'name': mask_name
230
+ # 'idx': idx
231
+ }
models/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .models import register, make
2
+ from . import sam
3
+ from . import sam_single
4
+
models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (250 Bytes). View file
 
models/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (244 Bytes). View file
 
models/__pycache__/iou_loss.cpython-37.pyc ADDED
Binary file (938 Bytes). View file
 
models/__pycache__/models.cpython-310.pyc ADDED
Binary file (723 Bytes). View file
 
models/__pycache__/models.cpython-37.pyc ADDED
Binary file (698 Bytes). View file
 
models/__pycache__/sam.cpython-310.pyc ADDED
Binary file (9.78 kB). View file
 
models/__pycache__/sam.cpython-37.pyc ADDED
Binary file (9.75 kB). View file
 
models/__pycache__/sam_single.cpython-37.pyc ADDED
Binary file (9.5 kB). View file
 
models/__pycache__/utils_prompt.cpython-37.pyc ADDED
Binary file (2.2 kB). View file
 
models/block.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class MergeAndConv(nn.Module):
10
+
11
+ def __init__(self, ic, oc, inner=32):
12
+ super().__init__()
13
+
14
+ self.conv1 = nn.Conv2d(ic, inner, kernel_size=3, stride=1, padding=1)
15
+ self.bn = nn.BatchNorm2d(inner)
16
+ self.relu = nn.ReLU(inplace=True)
17
+ self.conv2 = nn.Conv2d(inner, oc, kernel_size=3, stride=1, padding=1)
18
+
19
+ def forward(self, x):
20
+ x = self.conv2(self.bn(self.relu(self.conv1(x))))
21
+ x = torch.sigmoid(x)
22
+ return x
23
+
24
+
25
+ class SideClassifer(nn.Module):
26
+ def __init__(self, ic, n_class=1, M=2, kernel_size=1):
27
+ super().__init__()
28
+
29
+ sides = []
30
+ for i in range(M):
31
+ sides.append(nn.Conv2d(ic, n_class, kernel_size=kernel_size))
32
+
33
+ self.sides = nn.ModuleList(sides)
34
+
35
+ def forward(self, x):
36
+ return [fn(x) for fn in self.sides]
37
+
38
+
39
+ class UpsampleSKConv(nn.Module):
40
+ """docstring for UpsampleSKConvPlus"""
41
+
42
+ def __init__(self, ic, oc, reduce=4):
43
+ super(UpsampleSKConv, self).__init__()
44
+
45
+ self.relu = nn.ReLU(inplace=True)
46
+ self.prev = nn.Conv2d(ic, ic // reduce, kernel_size=3, stride=1, padding=1)
47
+ self.bn = nn.BatchNorm2d(ic // reduce)
48
+
49
+ self.next = nn.Conv2d(ic // reduce, oc, kernel_size=1, stride=1)
50
+ self.bn2 = nn.BatchNorm2d(oc)
51
+
52
+ self.sk = SKSPP(ic // reduce, ic // reduce, M=4)
53
+
54
+ def forward(self, x):
55
+ x = F.interpolate(x, scale_factor=2)
56
+
57
+ x = self.bn(self.relu(self.prev(x)))
58
+
59
+ x = self.sk(x)
60
+
61
+ x = self.bn2(self.relu(self.next(x)))
62
+
63
+ return x
64
+
65
+
66
+ class SKSPP(nn.Module):
67
+ def __init__(self, features, WH, M=2, G=1, r=16, stride=1, L=32):
68
+ """ Constructor
69
+ Args:
70
+ features: input channel dimensionality.
71
+ WH: input spatial dimensionality, used for GAP kernel size.
72
+ M: the number of branchs.
73
+ G: num of convolution groups.
74
+ r: the radio for compute d, the length of z.
75
+ stride: stride, default 1.
76
+ L: the minimum dim of the vector z in paper, default 32.
77
+ """
78
+ super(SKSPP, self).__init__()
79
+ d = max(int(features / r), L)
80
+ self.M = M # original
81
+ self.features = features
82
+ self.convs = nn.ModuleList([])
83
+
84
+ # 1,3,5,7 padding:[0,1,2,3]
85
+ for i in range(1, M):
86
+ self.convs.append(nn.Sequential(
87
+ nn.Conv2d(features, features, kernel_size=1 + i * 2, dilation=1 + i * 2, stride=stride,
88
+ padding=((1 + i * 2) * (i * 2) + 1) // 2, groups=G),
89
+ nn.BatchNorm2d(features),
90
+ nn.ReLU(inplace=False)
91
+ ))
92
+ # self.gap = nn.AvgPool2d(int(WH/stride))
93
+ self.fc = nn.Linear(features, d)
94
+ self.fcs = nn.ModuleList([])
95
+ for i in range(M):
96
+ self.fcs.append(
97
+ nn.Linear(d, features)
98
+ )
99
+ self.softmax = nn.Softmax(dim=1)
100
+
101
+ def forward(self, x):
102
+
103
+ feas = torch.unsqueeze(x, dim=1)
104
+
105
+ # F->conv1x1->conv3x3->conv5x5->conv7x7
106
+
107
+ for i, conv in enumerate(self.convs):
108
+ x = conv(x)
109
+ # if i == 0:
110
+ # feas = fea
111
+ # else:
112
+ feas = torch.cat([feas, torch.unsqueeze(x, dim=1)], dim=1)
113
+
114
+ fea_U = torch.sum(feas, dim=1)
115
+ fea_s = fea_U.mean(-1).mean(-1)
116
+ fea_z = self.fc(fea_s)
117
+
118
+ for i, fc in enumerate(self.fcs):
119
+ vector = fc(fea_z).unsqueeze_(dim=1)
120
+ if i == 0:
121
+ attention_vectors = vector
122
+ else:
123
+ attention_vectors = torch.cat([attention_vectors, vector], dim=1)
124
+
125
+ attention_vectors = self.softmax(attention_vectors)
126
+ attention_vectors = attention_vectors.unsqueeze(-1).unsqueeze(-1)
127
+ fea_v = (feas * attention_vectors).sum(dim=1)
128
+ return fea_v
models/bn_helper.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import functools
3
+
4
+ if torch.__version__.startswith('0'):
5
+ from .sync_bn.inplace_abn.bn import InPlaceABNSync
6
+ BatchNorm2d = functools.partial(InPlaceABNSync, activation='none')
7
+ BatchNorm2d_class = InPlaceABNSync
8
+ relu_inplace = False
9
+ else:
10
+ BatchNorm2d_class = BatchNorm2d = torch.nn.SyncBatchNorm
11
+ relu_inplace = True
12
+
13
+ import torch
14
+ BatchNorm2d = torch.nn.BatchNorm2d
15
+ BatchNorm2d_class = BatchNorm2d
16
+ relu_inplace = False
models/iou_loss.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ ###################################################################
6
+ # ########################## iou loss #############################
7
+ ###################################################################
8
+ class IOU(torch.nn.Module):
9
+ def __init__(self):
10
+ super(IOU, self).__init__()
11
+
12
+ def _iou(self, pred, target):
13
+ pred = torch.sigmoid(pred)
14
+ inter = (pred * target).sum(dim=(2, 3))
15
+ union = (pred + target).sum(dim=(2, 3)) - inter
16
+ iou = 1 - (inter / union)
17
+
18
+ return iou.mean()
19
+
20
+ def forward(self, pred, target):
21
+ return self._iou(pred, target)
models/mmseg/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mmcv
2
+
3
+ from .version import __version__, version_info
4
+
5
+ # MMCV_MIN = '1.1.4'
6
+ # MMCV_MAX = '1.3.0'
7
+
8
+ MMCV_MIN = '1.1.4'
9
+ MMCV_MAX = '1.7.0'
10
+
11
+
12
+ def digit_version(version_str):
13
+ digit_version = []
14
+ for x in version_str.split('.'):
15
+ if x.isdigit():
16
+ digit_version.append(int(x))
17
+ elif x.find('rc') != -1:
18
+ patch_version = x.split('rc')
19
+ digit_version.append(int(patch_version[0]) - 1)
20
+ digit_version.append(int(patch_version[1]))
21
+ return digit_version
22
+
23
+
24
+ mmcv_min_version = digit_version(MMCV_MIN)
25
+ mmcv_max_version = digit_version(MMCV_MAX)
26
+ mmcv_version = digit_version(mmcv.__version__)
27
+
28
+
29
+ assert (mmcv_min_version <= mmcv_version <= mmcv_max_version), \
30
+ f'MMCV=={mmcv.__version__} is used but incompatible. ' \
31
+ f'Please install mmcv>={mmcv_min_version}, <={mmcv_max_version}.'
32
+
33
+ __all__ = ['__version__', 'version_info']
models/mmseg/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (841 Bytes). View file
 
models/mmseg/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (839 Bytes). View file
 
models/mmseg/__pycache__/version.cpython-310.pyc ADDED
Binary file (521 Bytes). View file
 
models/mmseg/__pycache__/version.cpython-37.pyc ADDED
Binary file (513 Bytes). View file
 
models/mmseg/apis/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .inference import inference_segmentor, init_segmentor, show_result_pyplot
2
+ from .test import multi_gpu_test, single_gpu_test
3
+ from .train import get_root_logger, set_random_seed, train_segmentor
4
+
5
+ __all__ = [
6
+ 'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor',
7
+ 'inference_segmentor', 'multi_gpu_test', 'single_gpu_test',
8
+ 'show_result_pyplot'
9
+ ]
models/mmseg/apis/inference.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import mmcv
3
+ import torch
4
+ from mmcv.parallel import collate, scatter
5
+ from mmcv.runner import load_checkpoint
6
+
7
+ from mmseg.datasets.pipelines import Compose
8
+ from mmseg.models import build_segmentor
9
+
10
+
11
+ def init_segmentor(config, checkpoint=None, device='cuda:0'):
12
+ """Initialize a segmentor from config file.
13
+
14
+ Args:
15
+ config (str or :obj:`mmcv.Config`): Config file path or the config
16
+ object.
17
+ checkpoint (str, optional): Checkpoint path. If left as None, the model
18
+ will not load any weights.
19
+ device (str, optional) CPU/CUDA device option. Default 'cuda:0'.
20
+ Use 'cpu' for loading model on CPU.
21
+ Returns:
22
+ nn.Module: The constructed segmentor.
23
+ """
24
+ if isinstance(config, str):
25
+ config = mmcv.Config.fromfile(config)
26
+ elif not isinstance(config, mmcv.Config):
27
+ raise TypeError('config must be a filename or Config object, '
28
+ 'but got {}'.format(type(config)))
29
+ config.model.pretrained = None
30
+ config.model.train_cfg = None
31
+ model = build_segmentor(config.model, test_cfg=config.get('test_cfg'))
32
+ if checkpoint is not None:
33
+ checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
34
+ model.CLASSES = checkpoint['meta']['CLASSES']
35
+ model.PALETTE = checkpoint['meta']['PALETTE']
36
+ model.cfg = config # save the config in the model for convenience
37
+ model.to(device)
38
+ model.eval()
39
+ return model
40
+
41
+
42
+ class LoadImage:
43
+ """A simple pipeline to load image."""
44
+
45
+ def __call__(self, results):
46
+ """Call function to load images into results.
47
+
48
+ Args:
49
+ results (dict): A result dict contains the file name
50
+ of the image to be read.
51
+
52
+ Returns:
53
+ dict: ``results`` will be returned containing loaded image.
54
+ """
55
+
56
+ if isinstance(results['img'], str):
57
+ results['filename'] = results['img']
58
+ results['ori_filename'] = results['img']
59
+ else:
60
+ results['filename'] = None
61
+ results['ori_filename'] = None
62
+ img = mmcv.imread(results['img'])
63
+ results['img'] = img
64
+ results['img_shape'] = img.shape
65
+ results['ori_shape'] = img.shape
66
+ return results
67
+
68
+
69
+ def inference_segmentor(model, img):
70
+ """Inference image(s) with the segmentor.
71
+
72
+ Args:
73
+ model (nn.Module): The loaded segmentor.
74
+ imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
75
+ images.
76
+
77
+ Returns:
78
+ (list[Tensor]): The segmentation result.
79
+ """
80
+ cfg = model.cfg
81
+ device = next(model.parameters()).device # model device
82
+ # build the data pipeline
83
+ test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
84
+ test_pipeline = Compose(test_pipeline)
85
+ # prepare data
86
+ data = dict(img=img)
87
+ data = test_pipeline(data)
88
+ data = collate([data], samples_per_gpu=1)
89
+ if next(model.parameters()).is_cuda:
90
+ # scatter to specified GPU
91
+ data = scatter(data, [device])[0]
92
+ else:
93
+ data['img_metas'] = [i.data[0] for i in data['img_metas']]
94
+
95
+ # forward the model
96
+ with torch.no_grad():
97
+ result = model(return_loss=False, rescale=True, **data)
98
+ return result
99
+
100
+
101
+ def show_result_pyplot(model, img, result, palette=None, fig_size=(15, 10)):
102
+ """Visualize the segmentation results on the image.
103
+
104
+ Args:
105
+ model (nn.Module): The loaded segmentor.
106
+ img (str or np.ndarray): Image filename or loaded image.
107
+ result (list): The segmentation result.
108
+ palette (list[list[int]]] | None): The palette of segmentation
109
+ map. If None is given, random palette will be generated.
110
+ Default: None
111
+ fig_size (tuple): Figure size of the pyplot figure.
112
+ """
113
+ if hasattr(model, 'module'):
114
+ model = model.module
115
+ img = model.show_result(img, result, palette=palette, show=False)
116
+ plt.figure(figsize=fig_size)
117
+ plt.imshow(mmcv.bgr2rgb(img))
118
+ plt.show()
models/mmseg/apis/test.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import pickle
3
+ import shutil
4
+ import tempfile
5
+
6
+ import mmcv
7
+ import numpy as np
8
+ import torch
9
+ import torch.distributed as dist
10
+ from mmcv.image import tensor2imgs
11
+ from mmcv.runner import get_dist_info
12
+ from IPython import embed
13
+ from mmseg.ops import resize
14
+
15
+ def np2tmp(array, temp_file_name=None):
16
+ """Save ndarray to local numpy file.
17
+
18
+ Args:
19
+ array (ndarray): Ndarray to save.
20
+ temp_file_name (str): Numpy file name. If 'temp_file_name=None', this
21
+ function will generate a file name with tempfile.NamedTemporaryFile
22
+ to save ndarray. Default: None.
23
+
24
+ Returns:
25
+ str: The numpy file name.
26
+ """
27
+
28
+ if temp_file_name is None:
29
+ temp_file_name = tempfile.NamedTemporaryFile(
30
+ suffix='.npy', delete=False).name
31
+ np.save(temp_file_name, array)
32
+ return temp_file_name
33
+
34
+
35
+ def single_gpu_test(model,
36
+ data_loader,
37
+ show=False,
38
+ out_dir=None,
39
+ efficient_test=False):
40
+ """Test with single GPU.
41
+
42
+ Args:
43
+ model (nn.Module): Model to be tested.
44
+ data_loader (utils.data.Dataloader): Pytorch data loader.
45
+ show (bool): Whether show results during infernece. Default: False.
46
+ out_dir (str, optional): If specified, the results will be dumped into
47
+ the directory to save output results.
48
+ efficient_test (bool): Whether save the results as local numpy files to
49
+ save CPU memory during evaluation. Default: False.
50
+
51
+ Returns:
52
+ list: The prediction results.
53
+ """
54
+
55
+ model.eval()
56
+ results = []
57
+ dataset = data_loader.dataset
58
+ prog_bar = mmcv.ProgressBar(len(dataset))
59
+ for i, data in enumerate(data_loader):
60
+ with torch.no_grad():
61
+ result = model(return_loss=False, **data)
62
+
63
+ if show or out_dir:
64
+ img_tensor = data['img'][0]
65
+ img_metas = data['img_metas'][0].data[0]
66
+ imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
67
+ assert len(imgs) == len(img_metas)
68
+
69
+ for img, img_meta in zip(imgs, img_metas):
70
+ h, w, _ = img_meta['img_shape']
71
+ img_show = img[:h, :w, :]
72
+
73
+ ori_h, ori_w = img_meta['ori_shape'][:-1]
74
+ img_show = mmcv.imresize(img_show, (ori_w, ori_h))
75
+
76
+ if out_dir:
77
+ out_file = osp.join(out_dir, img_meta['ori_filename'])
78
+ else:
79
+ out_file = None
80
+
81
+ model.module.show_result(
82
+ img_show,
83
+ result,
84
+ palette=dataset.PALETTE,
85
+ show=show,
86
+ out_file=out_file)
87
+
88
+ if isinstance(result, list):
89
+ if efficient_test:
90
+ result = [np2tmp(_) for _ in result]
91
+ results.extend(result)
92
+ else:
93
+ if efficient_test:
94
+ result = np2tmp(result)
95
+ results.append(result)
96
+
97
+ batch_size = data['img'][0].size(0)
98
+ for _ in range(batch_size):
99
+ prog_bar.update()
100
+ return results
101
+
102
+
103
+ def multi_gpu_test(model,
104
+ data_loader,
105
+ tmpdir=None,
106
+ gpu_collect=False,
107
+ efficient_test=False):
108
+ """Test model with multiple gpus.
109
+
110
+ This method tests model with multiple gpus and collects the results
111
+ under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
112
+ it encodes results to gpu tensors and use gpu communication for results
113
+ collection. On cpu mode it saves the results on different gpus to 'tmpdir'
114
+ and collects them by the rank 0 worker.
115
+
116
+ Args:
117
+ model (nn.Module): Model to be tested.
118
+ data_loader (utils.data.Dataloader): Pytorch data loader.
119
+ tmpdir (str): Path of directory to save the temporary results from
120
+ different gpus under cpu mode.
121
+ gpu_collect (bool): Option to use either gpu or cpu to collect results.
122
+ efficient_test (bool): Whether save the results as local numpy files to
123
+ save CPU memory during evaluation. Default: False.
124
+
125
+ Returns:
126
+ list: The prediction results.
127
+ """
128
+
129
+ model.eval()
130
+ results = []
131
+ dataset = data_loader.dataset
132
+ rank, world_size = get_dist_info()
133
+ if rank == 0:
134
+ prog_bar = mmcv.ProgressBar(len(dataset))
135
+ for i, data in enumerate(data_loader):
136
+ with torch.no_grad():
137
+ result = model(return_loss=False, rescale=True, **data)
138
+
139
+ if isinstance(result, list):
140
+ if efficient_test:
141
+ result = [np2tmp(_) for _ in result]
142
+ results.extend(result)
143
+ else:
144
+ if efficient_test:
145
+ result = np2tmp(result)
146
+ results.append(result)
147
+
148
+ if rank == 0:
149
+ batch_size = data['img'][0].size(0)
150
+ for _ in range(batch_size * world_size):
151
+ prog_bar.update()
152
+
153
+ # collect results from all ranks
154
+ if gpu_collect:
155
+ results = collect_results_gpu(results, len(dataset))
156
+ else:
157
+ results = collect_results_cpu(results, len(dataset), tmpdir)
158
+ return results
159
+
160
+
161
+ def collect_results_cpu(result_part, size, tmpdir=None):
162
+ """Collect results with CPU."""
163
+ rank, world_size = get_dist_info()
164
+ # create a tmp dir if it is not specified
165
+ if tmpdir is None:
166
+ MAX_LEN = 512
167
+ # 32 is whitespace
168
+ dir_tensor = torch.full((MAX_LEN, ),
169
+ 32,
170
+ dtype=torch.uint8,
171
+ device='cuda')
172
+ if rank == 0:
173
+ tmpdir = tempfile.mkdtemp()
174
+ tmpdir = torch.tensor(
175
+ bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
176
+ dir_tensor[:len(tmpdir)] = tmpdir
177
+ dist.broadcast(dir_tensor, 0)
178
+ tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
179
+ else:
180
+ mmcv.mkdir_or_exist(tmpdir)
181
+ # dump the part result to the dir
182
+ mmcv.dump(result_part, osp.join(tmpdir, 'part_{}.pkl'.format(rank)))
183
+ dist.barrier()
184
+ # collect all parts
185
+ if rank != 0:
186
+ return None
187
+ else:
188
+ # load results of all parts from tmp dir
189
+ part_list = []
190
+ for i in range(world_size):
191
+ part_file = osp.join(tmpdir, 'part_{}.pkl'.format(i))
192
+ part_list.append(mmcv.load(part_file))
193
+ # sort the results
194
+ ordered_results = []
195
+ for res in zip(*part_list):
196
+ ordered_results.extend(list(res))
197
+ # the dataloader may pad some samples
198
+ ordered_results = ordered_results[:size]
199
+ # remove tmp dir
200
+ shutil.rmtree(tmpdir)
201
+ return ordered_results
202
+
203
+
204
+ def collect_results_gpu(result_part, size):
205
+ """Collect results with GPU."""
206
+ rank, world_size = get_dist_info()
207
+ # dump result part to tensor with pickle
208
+ part_tensor = torch.tensor(
209
+ bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
210
+ # gather all result part tensor shape
211
+ shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
212
+ shape_list = [shape_tensor.clone() for _ in range(world_size)]
213
+ dist.all_gather(shape_list, shape_tensor)
214
+ # padding result part tensor to max length
215
+ shape_max = torch.tensor(shape_list).max()
216
+ part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
217
+ part_send[:shape_tensor[0]] = part_tensor
218
+ part_recv_list = [
219
+ part_tensor.new_zeros(shape_max) for _ in range(world_size)
220
+ ]
221
+ # gather all result part
222
+ dist.all_gather(part_recv_list, part_send)
223
+
224
+ if rank == 0:
225
+ part_list = []
226
+ for recv, shape in zip(part_recv_list, shape_list):
227
+ part_list.append(
228
+ pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))
229
+ # sort the results
230
+ ordered_results = []
231
+ for res in zip(*part_list):
232
+ ordered_results.extend(list(res))
233
+ # the dataloader may pad some samples
234
+ ordered_results = ordered_results[:size]
235
+ return ordered_results
models/mmseg/apis/train.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import warnings
3
+
4
+ import numpy as np
5
+ import torch
6
+ from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
7
+ from mmcv.runner import build_optimizer, build_runner
8
+
9
+ from mmseg.core import DistEvalHook, EvalHook
10
+ from mmseg.datasets import build_dataloader, build_dataset
11
+ from mmseg.utils import get_root_logger
12
+
13
+
14
+ def set_random_seed(seed, deterministic=False):
15
+ """Set random seed.
16
+ Args:
17
+ seed (int): Seed to be used.
18
+ deterministic (bool): Whether to set the deterministic option for
19
+ CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
20
+ to True and `torch.backends.cudnn.benchmark` to False.
21
+ Default: False.
22
+ """
23
+ random.seed(seed)
24
+ np.random.seed(seed)
25
+ torch.manual_seed(seed)
26
+ torch.cuda.manual_seed_all(seed)
27
+ if deterministic:
28
+ torch.backends.cudnn.deterministic = True
29
+ torch.backends.cudnn.benchmark = False
30
+
31
+
32
+ def train_segmentor(model,
33
+ dataset,
34
+ cfg,
35
+ distributed=False,
36
+ validate=False,
37
+ timestamp=None,
38
+ meta=None):
39
+ """Launch segmentor training."""
40
+ logger = get_root_logger(cfg.log_level)
41
+
42
+ # prepare data loaders
43
+ dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
44
+ data_loaders = [
45
+ build_dataloader(
46
+ ds,
47
+ cfg.data.samples_per_gpu,
48
+ cfg.data.workers_per_gpu,
49
+ # cfg.gpus will be ignored if distributed
50
+ len(cfg.gpu_ids),
51
+ dist=distributed,
52
+ seed=cfg.seed,
53
+ drop_last=True) for ds in dataset
54
+ ]
55
+
56
+ # put model on gpus
57
+ if distributed:
58
+ find_unused_parameters = cfg.get('find_unused_parameters', False)
59
+ # Sets the `find_unused_parameters` parameter in
60
+ # torch.nn.parallel.DistributedDataParallel
61
+ model = MMDistributedDataParallel(
62
+ model.cuda(),
63
+ device_ids=[torch.cuda.current_device()],
64
+ broadcast_buffers=False,
65
+ find_unused_parameters=find_unused_parameters)
66
+ else:
67
+ model = MMDataParallel(
68
+ model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
69
+
70
+ # build runner
71
+ optimizer = build_optimizer(model, cfg.optimizer)
72
+
73
+ if cfg.get('runner') is None:
74
+ cfg.runner = {'type': 'IterBasedRunner', 'max_iters': cfg.total_iters}
75
+ warnings.warn(
76
+ 'config is now expected to have a `runner` section, '
77
+ 'please set `runner` in your config.', UserWarning)
78
+
79
+ runner = build_runner(
80
+ cfg.runner,
81
+ default_args=dict(
82
+ model=model,
83
+ batch_processor=None,
84
+ optimizer=optimizer,
85
+ work_dir=cfg.work_dir,
86
+ logger=logger,
87
+ meta=meta))
88
+
89
+ # register hooks
90
+ runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,
91
+ cfg.checkpoint_config, cfg.log_config,
92
+ cfg.get('momentum_config', None))
93
+
94
+ # an ugly walkaround to make the .log and .log.json filenames the same
95
+ runner.timestamp = timestamp
96
+
97
+ # register eval hooks
98
+ if validate:
99
+ val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
100
+ val_dataloader = build_dataloader(
101
+ val_dataset,
102
+ samples_per_gpu=1,
103
+ workers_per_gpu=cfg.data.workers_per_gpu,
104
+ dist=distributed,
105
+ shuffle=False)
106
+ eval_cfg = cfg.get('evaluation', {})
107
+ eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
108
+ eval_hook = DistEvalHook if distributed else EvalHook
109
+ runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
110
+
111
+ if cfg.resume_from:
112
+ runner.resume(cfg.resume_from)
113
+ elif cfg.load_from:
114
+ runner.load_checkpoint(cfg.load_from)
115
+ runner.run(data_loaders, cfg.workflow)
models/mmseg/core/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .evaluation import * # noqa: F401, F403
2
+ from .seg import * # noqa: F401, F403
3
+ from .utils import * # noqa: F401, F403
models/mmseg/core/evaluation/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .class_names import get_classes, get_palette
2
+ from .eval_hooks import DistEvalHook, EvalHook
3
+ from .metrics import eval_metrics, mean_dice, mean_iou
4
+
5
+ __all__ = [
6
+ 'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'eval_metrics',
7
+ 'get_classes', 'get_palette'
8
+ ]
models/mmseg/core/evaluation/class_names.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mmcv
2
+
3
+
4
+ def cityscapes_classes():
5
+ """Cityscapes class names for external use."""
6
+ return [
7
+ 'road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
8
+ 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
9
+ 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
10
+ 'bicycle'
11
+ ]
12
+
13
+
14
+ def ade_classes():
15
+ """ADE20K class names for external use."""
16
+ return [
17
+ 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ',
18
+ 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',
19
+ 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car',
20
+ 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',
21
+ 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe',
22
+ 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
23
+ 'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
24
+ 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',
25
+ 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door',
26
+ 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',
27
+ 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove',
28
+ 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',
29
+ 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
30
+ 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
31
+ 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister',
32
+ 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',
33
+ 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',
34
+ 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',
35
+ 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',
36
+ 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',
37
+ 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',
38
+ 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',
39
+ 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
40
+ 'clock', 'flag'
41
+ ]
42
+
43
+
44
+ def voc_classes():
45
+ """Pascal VOC class names for external use."""
46
+ return [
47
+ 'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
48
+ 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
49
+ 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
50
+ 'tvmonitor'
51
+ ]
52
+
53
+
54
+ def cityscapes_palette():
55
+ """Cityscapes palette for external use."""
56
+ return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
57
+ [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
58
+ [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
59
+ [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
60
+ [0, 0, 230], [119, 11, 32]]
61
+
62
+
63
+ def ade_palette():
64
+ """ADE20K palette for external use."""
65
+ return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
66
+ [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
67
+ [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
68
+ [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
69
+ [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
70
+ [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
71
+ [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
72
+ [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
73
+ [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
74
+ [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
75
+ [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
76
+ [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
77
+ [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
78
+ [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
79
+ [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
80
+ [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
81
+ [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
82
+ [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
83
+ [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
84
+ [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
85
+ [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
86
+ [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
87
+ [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
88
+ [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
89
+ [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
90
+ [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
91
+ [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
92
+ [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
93
+ [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
94
+ [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
95
+ [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
96
+ [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
97
+ [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
98
+ [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
99
+ [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
100
+ [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
101
+ [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
102
+ [102, 255, 0], [92, 0, 255]]
103
+
104
+
105
+ def voc_palette():
106
+ """Pascal VOC palette for external use."""
107
+ return [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
108
+ [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
109
+ [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
110
+ [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
111
+ [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
112
+
113
+
114
+ dataset_aliases = {
115
+ 'cityscapes': ['cityscapes'],
116
+ 'ade': ['ade', 'ade20k'],
117
+ 'voc': ['voc', 'pascal_voc', 'voc12', 'voc12aug']
118
+ }
119
+
120
+
121
+ def get_classes(dataset):
122
+ """Get class names of a dataset."""
123
+ alias2name = {}
124
+ for name, aliases in dataset_aliases.items():
125
+ for alias in aliases:
126
+ alias2name[alias] = name
127
+
128
+ if mmcv.is_str(dataset):
129
+ if dataset in alias2name:
130
+ labels = eval(alias2name[dataset] + '_classes()')
131
+ else:
132
+ raise ValueError(f'Unrecognized dataset: {dataset}')
133
+ else:
134
+ raise TypeError(f'dataset must a str, but got {type(dataset)}')
135
+ return labels
136
+
137
+
138
+ def get_palette(dataset):
139
+ """Get class palette (RGB) of a dataset."""
140
+ alias2name = {}
141
+ for name, aliases in dataset_aliases.items():
142
+ for alias in aliases:
143
+ alias2name[alias] = name
144
+
145
+ if mmcv.is_str(dataset):
146
+ if dataset in alias2name:
147
+ labels = eval(alias2name[dataset] + '_palette()')
148
+ else:
149
+ raise ValueError(f'Unrecognized dataset: {dataset}')
150
+ else:
151
+ raise TypeError(f'dataset must a str, but got {type(dataset)}')
152
+ return labels
models/mmseg/core/evaluation/eval_hooks.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+
3
+ from mmcv.runner import Hook
4
+ from torch.utils.data import DataLoader
5
+
6
+
7
+ class EvalHook(Hook):
8
+ """Evaluation hook.
9
+
10
+ Attributes:
11
+ dataloader (DataLoader): A PyTorch dataloader.
12
+ interval (int): Evaluation interval (by epochs). Default: 1.
13
+ """
14
+
15
+ def __init__(self, dataloader, interval=1, by_epoch=False, **eval_kwargs):
16
+ if not isinstance(dataloader, DataLoader):
17
+ raise TypeError('dataloader must be a pytorch DataLoader, but got '
18
+ f'{type(dataloader)}')
19
+ self.dataloader = dataloader
20
+ self.interval = interval
21
+ self.by_epoch = by_epoch
22
+ self.eval_kwargs = eval_kwargs
23
+
24
+ def after_train_iter(self, runner):
25
+ """After train epoch hook."""
26
+ if self.by_epoch or not self.every_n_iters(runner, self.interval):
27
+ return
28
+ from mmseg.apis import single_gpu_test
29
+ runner.log_buffer.clear()
30
+ results = single_gpu_test(runner.model, self.dataloader, show=False)
31
+ self.evaluate(runner, results)
32
+
33
+ def after_train_epoch(self, runner):
34
+ """After train epoch hook."""
35
+ if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
36
+ return
37
+ from mmseg.apis import single_gpu_test
38
+ runner.log_buffer.clear()
39
+ results = single_gpu_test(runner.model, self.dataloader, show=False)
40
+ self.evaluate(runner, results)
41
+
42
+ def evaluate(self, runner, results):
43
+ """Call evaluate function of dataset."""
44
+ eval_res = self.dataloader.dataset.evaluate(
45
+ results, logger=runner.logger, **self.eval_kwargs)
46
+ for name, val in eval_res.items():
47
+ runner.log_buffer.output[name] = val
48
+ runner.log_buffer.ready = True
49
+
50
+
51
+ class DistEvalHook(EvalHook):
52
+ """Distributed evaluation hook.
53
+
54
+ Attributes:
55
+ dataloader (DataLoader): A PyTorch dataloader.
56
+ interval (int): Evaluation interval (by epochs). Default: 1.
57
+ tmpdir (str | None): Temporary directory to save the results of all
58
+ processes. Default: None.
59
+ gpu_collect (bool): Whether to use gpu or cpu to collect results.
60
+ Default: False.
61
+ """
62
+
63
+ def __init__(self,
64
+ dataloader,
65
+ interval=1,
66
+ gpu_collect=False,
67
+ by_epoch=False,
68
+ **eval_kwargs):
69
+ if not isinstance(dataloader, DataLoader):
70
+ raise TypeError(
71
+ 'dataloader must be a pytorch DataLoader, but got {}'.format(
72
+ type(dataloader)))
73
+ self.dataloader = dataloader
74
+ self.interval = interval
75
+ self.gpu_collect = gpu_collect
76
+ self.by_epoch = by_epoch
77
+ self.eval_kwargs = eval_kwargs
78
+
79
+ def after_train_iter(self, runner):
80
+ """After train epoch hook."""
81
+ if self.by_epoch or not self.every_n_iters(runner, self.interval):
82
+ return
83
+ from mmseg.apis import multi_gpu_test
84
+ runner.log_buffer.clear()
85
+ results = multi_gpu_test(
86
+ runner.model,
87
+ self.dataloader,
88
+ tmpdir=osp.join(runner.work_dir, '.eval_hook'),
89
+ gpu_collect=self.gpu_collect)
90
+ if runner.rank == 0:
91
+ print('\n')
92
+ self.evaluate(runner, results)
93
+
94
+ def after_train_epoch(self, runner):
95
+ """After train epoch hook."""
96
+ if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
97
+ return
98
+ from mmseg.apis import multi_gpu_test
99
+ runner.log_buffer.clear()
100
+ results = multi_gpu_test(
101
+ runner.model,
102
+ self.dataloader,
103
+ tmpdir=osp.join(runner.work_dir, '.eval_hook'),
104
+ gpu_collect=self.gpu_collect)
105
+ if runner.rank == 0:
106
+ print('\n')
107
+ self.evaluate(runner, results)
models/mmseg/core/evaluation/metrics.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mmcv
2
+ import numpy as np
3
+
4
+
5
+ def intersect_and_union(pred_label,
6
+ label,
7
+ num_classes,
8
+ ignore_index,
9
+ label_map=dict(),
10
+ reduce_zero_label=False):
11
+ """Calculate intersection and Union.
12
+
13
+ Args:
14
+ pred_label (ndarray): Prediction segmentation map.
15
+ label (ndarray): Ground truth segmentation map.
16
+ num_classes (int): Number of categories.
17
+ ignore_index (int): Index that will be ignored in evaluation.
18
+ label_map (dict): Mapping old labels to new labels. The parameter will
19
+ work only when label is str. Default: dict().
20
+ reduce_zero_label (bool): Wether ignore zero label. The parameter will
21
+ work only when label is str. Default: False.
22
+
23
+ Returns:
24
+ ndarray: The intersection of prediction and ground truth histogram
25
+ on all classes.
26
+ ndarray: The union of prediction and ground truth histogram on all
27
+ classes.
28
+ ndarray: The prediction histogram on all classes.
29
+ ndarray: The ground truth histogram on all classes.
30
+ """
31
+
32
+ if isinstance(pred_label, str):
33
+ pred_label = np.load(pred_label)
34
+
35
+ if isinstance(label, str):
36
+ label = mmcv.imread(label, flag='unchanged', backend='pillow')
37
+ # modify if custom classes
38
+ if label_map is not None:
39
+ for old_id, new_id in label_map.items():
40
+ label[label == old_id] = new_id
41
+ if reduce_zero_label:
42
+ # avoid using underflow conversion
43
+ label[label == 0] = 255
44
+ label = label - 1
45
+ label[label == 254] = 255
46
+
47
+ mask = (label != ignore_index)
48
+ pred_label = pred_label[mask]
49
+ label = label[mask]
50
+
51
+ intersect = pred_label[pred_label == label]
52
+ area_intersect, _ = np.histogram(
53
+ intersect, bins=np.arange(num_classes + 1))
54
+ area_pred_label, _ = np.histogram(
55
+ pred_label, bins=np.arange(num_classes + 1))
56
+ area_label, _ = np.histogram(label, bins=np.arange(num_classes + 1))
57
+ area_union = area_pred_label + area_label - area_intersect
58
+
59
+ return area_intersect, area_union, area_pred_label, area_label
60
+
61
+
62
+ def total_intersect_and_union(results,
63
+ gt_seg_maps,
64
+ num_classes,
65
+ ignore_index,
66
+ label_map=dict(),
67
+ reduce_zero_label=False):
68
+ """Calculate Total Intersection and Union.
69
+
70
+ Args:
71
+ results (list[ndarray]): List of prediction segmentation maps.
72
+ gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
73
+ num_classes (int): Number of categories.
74
+ ignore_index (int): Index that will be ignored in evaluation.
75
+ label_map (dict): Mapping old labels to new labels. Default: dict().
76
+ reduce_zero_label (bool): Wether ignore zero label. Default: False.
77
+
78
+ Returns:
79
+ ndarray: The intersection of prediction and ground truth histogram
80
+ on all classes.
81
+ ndarray: The union of prediction and ground truth histogram on all
82
+ classes.
83
+ ndarray: The prediction histogram on all classes.
84
+ ndarray: The ground truth histogram on all classes.
85
+ """
86
+
87
+ num_imgs = len(results)
88
+ assert len(gt_seg_maps) == num_imgs
89
+ total_area_intersect = np.zeros((num_classes, ), dtype=np.float)
90
+ total_area_union = np.zeros((num_classes, ), dtype=np.float)
91
+ total_area_pred_label = np.zeros((num_classes, ), dtype=np.float)
92
+ total_area_label = np.zeros((num_classes, ), dtype=np.float)
93
+ for i in range(num_imgs):
94
+ area_intersect, area_union, area_pred_label, area_label = \
95
+ intersect_and_union(results[i], gt_seg_maps[i], num_classes,
96
+ ignore_index, label_map, reduce_zero_label)
97
+ total_area_intersect += area_intersect
98
+ total_area_union += area_union
99
+ total_area_pred_label += area_pred_label
100
+ total_area_label += area_label
101
+ return total_area_intersect, total_area_union, \
102
+ total_area_pred_label, total_area_label
103
+
104
+
105
+ def mean_iou(results,
106
+ gt_seg_maps,
107
+ num_classes,
108
+ ignore_index,
109
+ nan_to_num=None,
110
+ label_map=dict(),
111
+ reduce_zero_label=False):
112
+ """Calculate Mean Intersection and Union (mIoU)
113
+
114
+ Args:
115
+ results (list[ndarray]): List of prediction segmentation maps.
116
+ gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
117
+ num_classes (int): Number of categories.
118
+ ignore_index (int): Index that will be ignored in evaluation.
119
+ nan_to_num (int, optional): If specified, NaN values will be replaced
120
+ by the numbers defined by the user. Default: None.
121
+ label_map (dict): Mapping old labels to new labels. Default: dict().
122
+ reduce_zero_label (bool): Wether ignore zero label. Default: False.
123
+
124
+ Returns:
125
+ float: Overall accuracy on all images.
126
+ ndarray: Per category accuracy, shape (num_classes, ).
127
+ ndarray: Per category IoU, shape (num_classes, ).
128
+ """
129
+
130
+ all_acc, acc, iou = eval_metrics(
131
+ results=results,
132
+ gt_seg_maps=gt_seg_maps,
133
+ num_classes=num_classes,
134
+ ignore_index=ignore_index,
135
+ metrics=['mIoU'],
136
+ nan_to_num=nan_to_num,
137
+ label_map=label_map,
138
+ reduce_zero_label=reduce_zero_label)
139
+ return all_acc, acc, iou
140
+
141
+
142
+ def mean_dice(results,
143
+ gt_seg_maps,
144
+ num_classes,
145
+ ignore_index,
146
+ nan_to_num=None,
147
+ label_map=dict(),
148
+ reduce_zero_label=False):
149
+ """Calculate Mean Dice (mDice)
150
+
151
+ Args:
152
+ results (list[ndarray]): List of prediction segmentation maps.
153
+ gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
154
+ num_classes (int): Number of categories.
155
+ ignore_index (int): Index that will be ignored in evaluation.
156
+ nan_to_num (int, optional): If specified, NaN values will be replaced
157
+ by the numbers defined by the user. Default: None.
158
+ label_map (dict): Mapping old labels to new labels. Default: dict().
159
+ reduce_zero_label (bool): Wether ignore zero label. Default: False.
160
+
161
+ Returns:
162
+ float: Overall accuracy on all images.
163
+ ndarray: Per category accuracy, shape (num_classes, ).
164
+ ndarray: Per category dice, shape (num_classes, ).
165
+ """
166
+
167
+ all_acc, acc, dice = eval_metrics(
168
+ results=results,
169
+ gt_seg_maps=gt_seg_maps,
170
+ num_classes=num_classes,
171
+ ignore_index=ignore_index,
172
+ metrics=['mDice'],
173
+ nan_to_num=nan_to_num,
174
+ label_map=label_map,
175
+ reduce_zero_label=reduce_zero_label)
176
+ return all_acc, acc, dice
177
+
178
+
179
+ def eval_metrics(results,
180
+ gt_seg_maps,
181
+ num_classes,
182
+ ignore_index,
183
+ metrics=['mIoU'],
184
+ nan_to_num=None,
185
+ label_map=dict(),
186
+ reduce_zero_label=False):
187
+ """Calculate evaluation metrics
188
+ Args:
189
+ results (list[ndarray]): List of prediction segmentation maps.
190
+ gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
191
+ num_classes (int): Number of categories.
192
+ ignore_index (int): Index that will be ignored in evaluation.
193
+ metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
194
+ nan_to_num (int, optional): If specified, NaN values will be replaced
195
+ by the numbers defined by the user. Default: None.
196
+ label_map (dict): Mapping old labels to new labels. Default: dict().
197
+ reduce_zero_label (bool): Wether ignore zero label. Default: False.
198
+ Returns:
199
+ float: Overall accuracy on all images.
200
+ ndarray: Per category accuracy, shape (num_classes, ).
201
+ ndarray: Per category evalution metrics, shape (num_classes, ).
202
+ """
203
+
204
+ if isinstance(metrics, str):
205
+ metrics = [metrics]
206
+ allowed_metrics = ['mIoU', 'mDice']
207
+ if not set(metrics).issubset(set(allowed_metrics)):
208
+ raise KeyError('metrics {} is not supported'.format(metrics))
209
+ total_area_intersect, total_area_union, total_area_pred_label, \
210
+ total_area_label = total_intersect_and_union(results, gt_seg_maps,
211
+ num_classes, ignore_index,
212
+ label_map,
213
+ reduce_zero_label)
214
+ all_acc = total_area_intersect.sum() / total_area_label.sum()
215
+ acc = total_area_intersect / total_area_label
216
+ ret_metrics = [all_acc, acc]
217
+ for metric in metrics:
218
+ if metric == 'mIoU':
219
+ iou = total_area_intersect / total_area_union
220
+ ret_metrics.append(iou)
221
+ elif metric == 'mDice':
222
+ dice = 2 * total_area_intersect / (
223
+ total_area_pred_label + total_area_label)
224
+ ret_metrics.append(dice)
225
+ if nan_to_num is not None:
226
+ ret_metrics = [
227
+ np.nan_to_num(metric, nan=nan_to_num) for metric in ret_metrics
228
+ ]
229
+ return ret_metrics
models/mmseg/core/seg/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .builder import build_pixel_sampler
2
+ from .sampler import BasePixelSampler, OHEMPixelSampler
3
+
4
+ __all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler']
models/mmseg/core/seg/builder.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from mmcv.utils import Registry, build_from_cfg
2
+
3
+ PIXEL_SAMPLERS = Registry('pixel sampler')
4
+
5
+
6
+ def build_pixel_sampler(cfg, **default_args):
7
+ """Build pixel sampler for segmentation map."""
8
+ return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args)
models/mmseg/core/seg/sampler/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .base_pixel_sampler import BasePixelSampler
2
+ from .ohem_pixel_sampler import OHEMPixelSampler
3
+
4
+ __all__ = ['BasePixelSampler', 'OHEMPixelSampler']
models/mmseg/core/seg/sampler/base_pixel_sampler.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABCMeta, abstractmethod
2
+
3
+
4
+ class BasePixelSampler(metaclass=ABCMeta):
5
+ """Base class of pixel sampler."""
6
+
7
+ def __init__(self, **kwargs):
8
+ pass
9
+
10
+ @abstractmethod
11
+ def sample(self, seg_logit, seg_label):
12
+ """Placeholder for sample function."""
13
+ pass
models/mmseg/core/seg/sampler/ohem_pixel_sampler.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from ..builder import PIXEL_SAMPLERS
5
+ from .base_pixel_sampler import BasePixelSampler
6
+
7
+
8
+ @PIXEL_SAMPLERS.register_module()
9
+ class OHEMPixelSampler(BasePixelSampler):
10
+ """Online Hard Example Mining Sampler for segmentation.
11
+
12
+ Args:
13
+ context (nn.Module): The context of sampler, subclass of
14
+ :obj:`BaseDecodeHead`.
15
+ thresh (float, optional): The threshold for hard example selection.
16
+ Below which, are prediction with low confidence. If not
17
+ specified, the hard examples will be pixels of top ``min_kept``
18
+ loss. Default: None.
19
+ min_kept (int, optional): The minimum number of predictions to keep.
20
+ Default: 100000.
21
+ """
22
+
23
+ def __init__(self, context, thresh=None, min_kept=100000):
24
+ super(OHEMPixelSampler, self).__init__()
25
+ self.context = context
26
+ assert min_kept > 1
27
+ self.thresh = thresh
28
+ self.min_kept = min_kept
29
+
30
+ def sample(self, seg_logit, seg_label):
31
+ """Sample pixels that have high loss or with low prediction confidence.
32
+
33
+ Args:
34
+ seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W)
35
+ seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W)
36
+
37
+ Returns:
38
+ torch.Tensor: segmentation weight, shape (N, H, W)
39
+ """
40
+ with torch.no_grad():
41
+ assert seg_logit.shape[2:] == seg_label.shape[2:]
42
+ assert seg_label.shape[1] == 1
43
+ seg_label = seg_label.squeeze(1).long()
44
+ batch_kept = self.min_kept * seg_label.size(0)
45
+ valid_mask = seg_label != self.context.ignore_index
46
+ seg_weight = seg_logit.new_zeros(size=seg_label.size())
47
+ valid_seg_weight = seg_weight[valid_mask]
48
+ if self.thresh is not None:
49
+ seg_prob = F.softmax(seg_logit, dim=1)
50
+
51
+ tmp_seg_label = seg_label.clone().unsqueeze(1)
52
+ tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0
53
+ seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1)
54
+ sort_prob, sort_indices = seg_prob[valid_mask].sort()
55
+
56
+ if sort_prob.numel() > 0:
57
+ min_threshold = sort_prob[min(batch_kept,
58
+ sort_prob.numel() - 1)]
59
+ else:
60
+ min_threshold = 0.0
61
+ threshold = max(min_threshold, self.thresh)
62
+ valid_seg_weight[seg_prob[valid_mask] < threshold] = 1.
63
+ else:
64
+ losses = self.context.loss_decode(
65
+ seg_logit,
66
+ seg_label,
67
+ weight=None,
68
+ ignore_index=self.context.ignore_index,
69
+ reduction_override='none')
70
+ # faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa
71
+ _, sort_indices = losses[valid_mask].sort(descending=True)
72
+ valid_seg_weight[sort_indices[:batch_kept]] = 1.
73
+
74
+ seg_weight[valid_mask] = valid_seg_weight
75
+
76
+ return seg_weight
models/mmseg/core/utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .misc import add_prefix
2
+
3
+ __all__ = ['add_prefix']
models/mmseg/core/utils/misc.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def add_prefix(inputs, prefix):
2
+ """Add prefix for dict.
3
+
4
+ Args:
5
+ inputs (dict): The input dict with str keys.
6
+ prefix (str): The prefix to add.
7
+
8
+ Returns:
9
+
10
+ dict: The dict with keys updated with ``prefix``.
11
+ """
12
+
13
+ outputs = dict()
14
+ for name, value in inputs.items():
15
+ outputs[f'{prefix}.{name}'] = value
16
+
17
+ return outputs